Files
backend/backend/app/ai/service/image_task_service.py
2025-12-30 20:37:49 +08:00

121 lines
4.7 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import math
from abc import ABC, abstractmethod
from typing import Dict, Any, Tuple, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from backend.app.ai.model.image_task import ImageTaskStatus, ImageProcessingTask
from backend.app.ai.crud.image_task_crud import image_task_dao
from backend.app.admin.service.points_service import points_service
from backend.app.ai.service.rate_limit_service import rate_limit_service
from backend.database.db import background_db_session
from backend.common.const import LLM_CHAT_COST
from backend.common.log import log as logger
from backend.app.ai.tasks import update_task_status_with_retry
class TaskProcessor(ABC):
@abstractmethod
async def process(self, db: AsyncSession, task: ImageProcessingTask) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Execute the specific business logic for the task.
Args:
db: Database session
task: The task object
Returns:
Tuple containing:
- result: The result data to be stored in the task
- token_usage: Token usage information for points deduction
"""
pass
class ImageTaskService:
async def process_task(self, task_id: int, user_id: int, processor: TaskProcessor):
"""
Generic method to process an image task with standard lifecycle management:
1. Status update (PROCESSING)
2. Business logic execution (via processor)
3. Points deduction
4. Status update (COMPLETED/FAILED)
5. Task slot release
"""
try:
async with background_db_session() as db:
task = await image_task_dao.get(db, task_id)
if not task:
logger.warning(f"Task {task_id} not found during processing")
return
await image_task_dao.update_task_status(db, task_id, ImageTaskStatus.PROCESSING)
# Execute specific business logic
# Processor should return the final result dict and token usage info
result, token_usage = await processor.process(db, task)
# Calculate and deduct points
total_tokens = 0
if isinstance(token_usage, dict):
# Check if token_usage is nested (legacy structure) or direct
if "total_tokens" in token_usage:
total_tokens = int(token_usage.get("total_tokens") or 0)
else:
total_tokens = int((token_usage.get("token_usage") or {}).get("total_tokens") or 0)
deduct_amount = LLM_CHAT_COST
if total_tokens > 0:
units = math.ceil(max(total_tokens, 1) / 1000)
deduct_amount = units * LLM_CHAT_COST
# Use ref_id as the related_id for points record
points_deducted = await points_service.deduct_points_with_db(
user_id=task.user_id,
amount=deduct_amount,
db=db,
related_id=task.ref_id,
details={
"task_id": task_id,
"ref_type": task.ref_type,
"token_usage": total_tokens
},
action=task.ref_type
)
if not points_deducted:
raise Exception("Failed to deduct points")
# If result doesn't have token_usage, we might want to add it,
# but let's assume processor handles result structure.
# Actually, some existing logic adds token_usage to result.
if isinstance(result, dict) and 'token_usage' not in result:
result['token_usage'] = token_usage
await update_task_status_with_retry(
db, task_id, ImageTaskStatus.COMPLETED,
result=result
)
await db.commit()
except Exception as e:
logger.error(f"Error processing task {task_id}: {str(e)}")
try:
async with background_db_session() as db:
await update_task_status_with_retry(
db, task_id, ImageTaskStatus.FAILED,
error_message=str(e)
)
await db.commit()
except Exception:
pass
finally:
try:
await rate_limit_service.release_task_slot(user_id)
except Exception:
pass
image_task_service = ImageTaskService()