121 lines
4.7 KiB
Python
121 lines
4.7 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
import math
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, Any, Tuple, Optional
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from backend.app.ai.model.image_task import ImageTaskStatus, ImageProcessingTask
|
|
from backend.app.ai.crud.image_task_crud import image_task_dao
|
|
from backend.app.admin.service.points_service import points_service
|
|
from backend.app.ai.service.rate_limit_service import rate_limit_service
|
|
from backend.database.db import background_db_session
|
|
from backend.common.const import LLM_CHAT_COST
|
|
from backend.common.log import log as logger
|
|
from backend.app.ai.tasks import update_task_status_with_retry
|
|
|
|
|
|
class TaskProcessor(ABC):
|
|
@abstractmethod
|
|
async def process(self, db: AsyncSession, task: ImageProcessingTask) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
"""
|
|
Execute the specific business logic for the task.
|
|
|
|
Args:
|
|
db: Database session
|
|
task: The task object
|
|
|
|
Returns:
|
|
Tuple containing:
|
|
- result: The result data to be stored in the task
|
|
- token_usage: Token usage information for points deduction
|
|
"""
|
|
pass
|
|
|
|
|
|
class ImageTaskService:
|
|
async def process_task(self, task_id: int, user_id: int, processor: TaskProcessor):
|
|
"""
|
|
Generic method to process an image task with standard lifecycle management:
|
|
1. Status update (PROCESSING)
|
|
2. Business logic execution (via processor)
|
|
3. Points deduction
|
|
4. Status update (COMPLETED/FAILED)
|
|
5. Task slot release
|
|
"""
|
|
try:
|
|
async with background_db_session() as db:
|
|
task = await image_task_dao.get(db, task_id)
|
|
if not task:
|
|
logger.warning(f"Task {task_id} not found during processing")
|
|
return
|
|
|
|
await image_task_dao.update_task_status(db, task_id, ImageTaskStatus.PROCESSING)
|
|
|
|
# Execute specific business logic
|
|
# Processor should return the final result dict and token usage info
|
|
result, token_usage = await processor.process(db, task)
|
|
|
|
# Calculate and deduct points
|
|
total_tokens = 0
|
|
if isinstance(token_usage, dict):
|
|
# Check if token_usage is nested (legacy structure) or direct
|
|
if "total_tokens" in token_usage:
|
|
total_tokens = int(token_usage.get("total_tokens") or 0)
|
|
else:
|
|
total_tokens = int((token_usage.get("token_usage") or {}).get("total_tokens") or 0)
|
|
|
|
deduct_amount = LLM_CHAT_COST
|
|
if total_tokens > 0:
|
|
units = math.ceil(max(total_tokens, 1) / 1000)
|
|
deduct_amount = units * LLM_CHAT_COST
|
|
|
|
# Use ref_id as the related_id for points record
|
|
points_deducted = await points_service.deduct_points_with_db(
|
|
user_id=task.user_id,
|
|
amount=deduct_amount,
|
|
db=db,
|
|
related_id=task.ref_id,
|
|
details={
|
|
"task_id": task_id,
|
|
"ref_type": task.ref_type,
|
|
"token_usage": total_tokens
|
|
},
|
|
action=task.ref_type
|
|
)
|
|
|
|
if not points_deducted:
|
|
raise Exception("Failed to deduct points")
|
|
|
|
# If result doesn't have token_usage, we might want to add it,
|
|
# but let's assume processor handles result structure.
|
|
# Actually, some existing logic adds token_usage to result.
|
|
if isinstance(result, dict) and 'token_usage' not in result:
|
|
result['token_usage'] = token_usage
|
|
|
|
await update_task_status_with_retry(
|
|
db, task_id, ImageTaskStatus.COMPLETED,
|
|
result=result
|
|
)
|
|
await db.commit()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing task {task_id}: {str(e)}")
|
|
try:
|
|
async with background_db_session() as db:
|
|
await update_task_status_with_retry(
|
|
db, task_id, ImageTaskStatus.FAILED,
|
|
error_message=str(e)
|
|
)
|
|
await db.commit()
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
try:
|
|
await rate_limit_service.release_task_slot(user_id)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
image_task_service = ImageTaskService()
|