fix code
This commit is contained in:
@@ -629,7 +629,7 @@ class DictService:
|
||||
image = await image_dao.get(db, task.image_id)
|
||||
if not image:
|
||||
return
|
||||
await db.commit()
|
||||
await db.flush()
|
||||
|
||||
# 检查图片是否有识别结果
|
||||
if not image.details or "recognition_result" not in image.details:
|
||||
@@ -673,7 +673,7 @@ class DictService:
|
||||
image = await image_dao.get(db, task.image_id)
|
||||
if not image:
|
||||
return
|
||||
await db.commit()
|
||||
await db.flush()
|
||||
|
||||
# 检查图片是否有识别结果
|
||||
if not image.details or "recognition_result" not in image.details:
|
||||
@@ -792,7 +792,7 @@ class DictService:
|
||||
db, image.id,
|
||||
UpdateImageParam(details=image.details)
|
||||
)
|
||||
await db.commit()
|
||||
await db.flush()
|
||||
|
||||
@staticmethod
|
||||
async def _build_desc_ipa_for_recognition_result_with_db(image: Image, recognition_result: Dict[str, Any], word_phonetics: Dict[str, str], db: AsyncSession) -> None:
|
||||
@@ -853,6 +853,6 @@ class DictService:
|
||||
db, image.id,
|
||||
UpdateImageParam(details=image.details)
|
||||
)
|
||||
await db.commit()
|
||||
await db.flush()
|
||||
|
||||
dict_service = DictService()
|
||||
dict_service = DictService()
|
||||
|
||||
@@ -15,6 +15,11 @@ class ImageCRUD(CRUDPlus[Image]):
|
||||
async def get(self, db: AsyncSession, id: int) -> Image | None:
|
||||
return await self.select_model(db, id)
|
||||
|
||||
async def get_for_update(self, db: AsyncSession, id: int) -> Image | None:
|
||||
stmt = select(Image).where(Image.id == id).with_for_update()
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_file_id(self, db: AsyncSession, file_id: int) -> Image | None:
|
||||
return await self.select_model_by_column(db, file_id=file_id)
|
||||
|
||||
|
||||
@@ -16,6 +16,11 @@ class ImageTaskCRUD(CRUDPlus[ImageProcessingTask]):
|
||||
async def get(self, db: AsyncSession, id: int) -> ImageProcessingTask | None:
|
||||
return await self.select_model(db, id)
|
||||
|
||||
async def get_for_update(self, db: AsyncSession, id: int) -> ImageProcessingTask | None:
|
||||
stmt = select(self.model).where(self.model.id == id).with_for_update()
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_image_id(self, db: AsyncSession, image_id: int) -> ImageProcessingTask | None:
|
||||
"""
|
||||
根据图片ID获取处理任务
|
||||
@@ -72,6 +77,24 @@ class ImageTaskCRUD(CRUDPlus[ImageProcessingTask]):
|
||||
result = await db.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_processing_tasks_before_time(self, db: AsyncSession, before_time: datetime, limit: int = 100) -> List[ImageProcessingTask]:
|
||||
"""
|
||||
获取指定时间之前处于processing状态的任务(用于重启恢复)
|
||||
|
||||
:param db: 数据库会话
|
||||
:param before_time: 时间界限
|
||||
:param limit: 限制返回的任务数量
|
||||
:return: 任务列表
|
||||
"""
|
||||
stmt = select(ImageProcessingTask).where(
|
||||
and_(
|
||||
ImageProcessingTask.status == ImageTaskStatus.PROCESSING,
|
||||
ImageProcessingTask.created_time < before_time
|
||||
)
|
||||
).limit(limit)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def count_tasks_by_status(self, db: AsyncSession, user_id: int, statuses: List[ImageTaskStatus]) -> int:
|
||||
"""
|
||||
根据状态统计用户任务数量
|
||||
|
||||
@@ -398,13 +398,25 @@ class ImageService:
|
||||
task_processing_success = False
|
||||
points_deducted = False
|
||||
try:
|
||||
# Step 1: Execute image recognition (includes external API call)
|
||||
await ImageService._process_image_recognition(task_id)
|
||||
# Step 1: Execute image recognition only if not already done
|
||||
need_recognition = True
|
||||
try:
|
||||
async with background_db_session() as db_check:
|
||||
task_check = await image_task_dao.get(db_check, task_id)
|
||||
if task_check and task_check.result:
|
||||
need_recognition = False
|
||||
except Exception:
|
||||
pass
|
||||
if need_recognition:
|
||||
await ImageService._process_image_recognition(task_id)
|
||||
|
||||
# Step 2: Process all remaining steps in a single database transaction for consistency
|
||||
async with background_db_session() as db:
|
||||
await db.begin()
|
||||
try:
|
||||
task_locked = await image_task_dao.get_for_update(db, task_id)
|
||||
if task_locked:
|
||||
await image_dao.get_for_update(db, task_locked.image_id)
|
||||
# Step 2: Process lookup word
|
||||
await dict_service.process_lookup_word_with_db(task_id, db)
|
||||
|
||||
@@ -434,9 +446,8 @@ class ImageService:
|
||||
# Step 5: Update task status to completed
|
||||
await ImageService._update_task_status_with_db(task_id, ImageTaskStatus.COMPLETED, db)
|
||||
|
||||
# All steps completed successfully
|
||||
task_processing_success = True
|
||||
await db.commit()
|
||||
task_processing_success = True
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
|
||||
@@ -240,8 +240,7 @@ class ImageTextService:
|
||||
created_texts.append(new_text)
|
||||
newly_created_texts.append(new_text)
|
||||
|
||||
# 提交事务
|
||||
await db.commit()
|
||||
await db.flush()
|
||||
|
||||
# 刷新创建的文本记录
|
||||
for text in created_texts:
|
||||
|
||||
@@ -34,13 +34,14 @@ async def process_pending_tasks():
|
||||
async with background_db_session() as db:
|
||||
# 获取所有启动前创建的待处理的任务
|
||||
pending_tasks = await image_task_dao.get_pending_tasks_before_time(db, startup_time, limit=100)
|
||||
processing_tasks = await image_task_dao.get_processing_tasks(db, limit=100)
|
||||
# 仅恢复在启动前已处于processing的任务,避免正常运行时重复触发
|
||||
processing_tasks = await image_task_dao.get_processing_tasks_before_time(db, startup_time, limit=100)
|
||||
|
||||
# 处理待处理的任务
|
||||
for task in pending_tasks:
|
||||
logger.info(f"Processing pending task {task.id} (created at {task.created_time})")
|
||||
# 创建包装任务来处理完成和异常情况
|
||||
task_wrapper = create_task_wrapper(image_service._process_image_recognition, task.id)
|
||||
task_wrapper = create_task_wrapper(image_service._process_image_with_limiting, task.id, task.user_id)
|
||||
asyncio.create_task(task_wrapper)
|
||||
|
||||
# 处理之前标记为处理中但可能因服务器重启而中断的任务
|
||||
@@ -56,7 +57,7 @@ async def process_pending_tasks():
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to release slot for interrupted task {task.id}: {str(e)}")
|
||||
# 创建包装任务来处理完成和异常情况
|
||||
task_wrapper = create_task_wrapper(image_service._process_image_recognition, task.id)
|
||||
task_wrapper = create_task_wrapper(image_service._process_image_with_limiting, task.id, task.user_id)
|
||||
asyncio.create_task(task_wrapper)
|
||||
|
||||
# If we reach here, the operation was successful
|
||||
@@ -76,12 +77,12 @@ async def process_pending_tasks():
|
||||
raise
|
||||
|
||||
|
||||
async def create_task_wrapper(process_func, task_id):
|
||||
async def create_task_wrapper(process_func, task_id, user_id):
|
||||
"""
|
||||
创建一个包装任务,确保即使发生异常也能正确更新任务状态
|
||||
"""
|
||||
try:
|
||||
await process_func(task_id)
|
||||
await process_func(task_id, user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing task {task_id}: {str(e)}")
|
||||
# 尝试更新任务状态为失败
|
||||
|
||||
Reference in New Issue
Block a user