fix code
This commit is contained in:
@@ -2,6 +2,8 @@ import json
|
||||
from typing import Optional, Set, Dict, Any, Tuple, List
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.app.admin.schema.dict import (
|
||||
DictWordResponse, SimpleDictEntry, SimpleSense, SimpleDefinition, SimpleExample,
|
||||
SimpleCrossReference, SimpleFrequency, SimplePronunciation, WordMetaData
|
||||
@@ -651,6 +653,50 @@ class DictService:
|
||||
# 获取单词音标信息并构建desc_ipa
|
||||
await DictService._build_desc_ipa_for_recognition_result(image, recognition_result, word_phonetics)
|
||||
|
||||
@staticmethod
|
||||
async def process_lookup_word_with_db(task_id: int, db: AsyncSession) -> None:
|
||||
"""
|
||||
处理图片识别结果中的单词查询(从ImageService迁移过来的方法)
|
||||
使用提供的数据库连接
|
||||
|
||||
Args:
|
||||
task_id: 图片处理任务ID
|
||||
db: 数据库连接
|
||||
"""
|
||||
# 获取任务信息
|
||||
task = await image_task_dao.get(db, task_id)
|
||||
if not task:
|
||||
logger.error(f"Task {task_id} not found")
|
||||
return
|
||||
|
||||
# 获取图片信息
|
||||
image = await image_dao.get(db, task.image_id)
|
||||
if not image:
|
||||
return
|
||||
await db.commit()
|
||||
|
||||
# 检查图片是否有识别结果
|
||||
if not image.details or "recognition_result" not in image.details:
|
||||
logger.info(f"No recognition result found for image {image.id}")
|
||||
return
|
||||
|
||||
recognition_result = image.details["recognition_result"]
|
||||
|
||||
# 提取所有单词
|
||||
words = DictService._extract_words_from_recognition_result(recognition_result)
|
||||
|
||||
if not words:
|
||||
logger.info("No words extracted from recognition result")
|
||||
return
|
||||
|
||||
logger.info(f"Extracted {len(words)} unique words for lookup")
|
||||
|
||||
# 使用线程池并发处理单词查询,并获取单词音标映射
|
||||
word_phonetics = await DictService._process_words_concurrently(words)
|
||||
|
||||
# 获取单词音标信息并构建desc_ipa
|
||||
await DictService._build_desc_ipa_for_recognition_result_with_db(image, recognition_result, word_phonetics, db)
|
||||
|
||||
@staticmethod
|
||||
def _extract_words_from_recognition_result(recognition_result: Dict[str, Any]) -> Set[str]:
|
||||
"""
|
||||
@@ -748,4 +794,65 @@ class DictService:
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
@staticmethod
|
||||
async def _build_desc_ipa_for_recognition_result_with_db(image: Image, recognition_result: Dict[str, Any], word_phonetics: Dict[str, str], db: AsyncSession) -> None:
|
||||
"""
|
||||
为识别结果构建desc_ipa数组
|
||||
使用提供的数据库连接
|
||||
|
||||
Args:
|
||||
image: 图片对象
|
||||
recognition_result: 识别结果
|
||||
word_phonetics: 单词到音标的映射
|
||||
db: 数据库连接
|
||||
"""
|
||||
# 为每个level构建desc_ipa数组
|
||||
for level_key, level_data in recognition_result.items():
|
||||
if not isinstance(level_data, dict):
|
||||
continue
|
||||
|
||||
# 获取desc_en字段
|
||||
desc_en_list = level_data.get("desc_en", [])
|
||||
if not isinstance(desc_en_list, list):
|
||||
continue
|
||||
|
||||
# 为每个desc_en字符串构建对应的desc_ipa
|
||||
desc_ipa_list = []
|
||||
for desc_en in desc_en_list:
|
||||
if not isinstance(desc_en, str):
|
||||
desc_ipa_list.append("")
|
||||
continue
|
||||
|
||||
# 按空格分割单词
|
||||
words = desc_en.split()
|
||||
ipa_parts = []
|
||||
|
||||
# 为每个单词查找音标
|
||||
for word in words:
|
||||
# 清理单词(移除标点符号等)
|
||||
cleaned_word = ''.join(char for char in word if char.isalnum()).lower()
|
||||
if cleaned_word and cleaned_word in word_phonetics:
|
||||
phonetic = word_phonetics[cleaned_word]
|
||||
# 如果音标包含分号,取分号前的部分(不包含分号)
|
||||
if ';' in phonetic:
|
||||
phonetic = phonetic.split(';')[0].strip()
|
||||
ipa_parts.append(phonetic)
|
||||
else:
|
||||
ipa_parts.append("*")
|
||||
|
||||
# 用空格连接音标
|
||||
desc_ipa_list.append(" ".join(ipa_parts))
|
||||
|
||||
# 更新level_data中的desc_ipa
|
||||
level_data["desc_ipa"] = desc_ipa_list
|
||||
|
||||
# 更新数据库中的识别结果
|
||||
if image.details:
|
||||
image.details["recognition_result"] = recognition_result
|
||||
await image_dao.update(
|
||||
db, image.id,
|
||||
UpdateImageParam(details=image.details)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
dict_service = DictService()
|
||||
@@ -146,7 +146,10 @@ class ImageTaskCRUD(CRUDPlus[ImageProcessingTask]):
|
||||
# 如果任务状态从进行中变为完成或失败,更新用户任务计数缓存
|
||||
if task and task.status in [ImageTaskStatus.PENDING, ImageTaskStatus.PROCESSING] and \
|
||||
status in [ImageTaskStatus.COMPLETED, ImageTaskStatus.FAILED]:
|
||||
await rate_limit_service.decrement_user_task_count(task.user_id)
|
||||
try:
|
||||
await rate_limit_service.decrement_user_task_count(task.user_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return db_result.rowcount
|
||||
|
||||
@@ -195,4 +198,4 @@ class ImageTaskCRUD(CRUDPlus[ImageProcessingTask]):
|
||||
raise e
|
||||
|
||||
|
||||
image_task_dao: ImageTaskCRUD = ImageTaskCRUD(ImageProcessingTask)
|
||||
image_task_dao: ImageTaskCRUD = ImageTaskCRUD(ImageProcessingTask)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import select, and_
|
||||
from sqlalchemy import select, and_, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy_crud_plus import CRUDPlus
|
||||
|
||||
@@ -30,6 +30,25 @@ class ImageTextCRUD(CRUDPlus[ImageText]):
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_image_id_and_level(self, db: AsyncSession, image_id: int, dict_level: str) -> List[ImageText]:
|
||||
"""根据图片ID和词典等级获取文本记录"""
|
||||
if dict_level:
|
||||
stmt = select(self.model).where(
|
||||
and_(
|
||||
self.model.image_id == image_id,
|
||||
func.lower(self.model.dict_level) == func.lower(dict_level)
|
||||
)
|
||||
)
|
||||
else:
|
||||
stmt = select(self.model).where(
|
||||
and_(
|
||||
self.model.image_id == image_id,
|
||||
self.model.dict_level.is_(None)
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_by_standard_audio_id(self, db: AsyncSession, standard_audio_id: int) -> Optional[ImageText]:
|
||||
"""根据标准音频文件ID获取文本记录"""
|
||||
stmt = select(self.model).where(self.model.standard_audio_id == standard_audio_id)
|
||||
@@ -45,4 +64,4 @@ class ImageTextCRUD(CRUDPlus[ImageText]):
|
||||
return await self.delete_model(db, id)
|
||||
|
||||
|
||||
image_text_dao: ImageTextCRUD = ImageTextCRUD(ImageText)
|
||||
image_text_dao: ImageTextCRUD = ImageTextCRUD(ImageText)
|
||||
|
||||
@@ -21,7 +21,7 @@ class ImageText(Base):
|
||||
image_id: Mapped[Optional[int]] = mapped_column(BigInteger, ForeignKey('image.id'), nullable=True, comment="关联的图片ID")
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False, comment="文本内容")
|
||||
standard_audio_id: Mapped[Optional[int]] = mapped_column(BigInteger, ForeignKey('file.id'), nullable=True, comment="标准朗读音频文件ID")
|
||||
ipa: Mapped[Optional[str]] = mapped_column(String(100), default=None, comment="ipa")
|
||||
ipa: Mapped[Optional[str]] = mapped_column(String(1000), default=None, comment="ipa")
|
||||
zh: Mapped[Optional[str]] = mapped_column(String(100), default=None, comment="中文")
|
||||
position: Mapped[Optional[dict]] = mapped_column(MySQLJSON, default=None, comment="文本在图片中的位置信息或文章中的位置信息")
|
||||
dict_level: Mapped[Optional[str]] = mapped_column(String(20), default=None, comment="词典等级")
|
||||
|
||||
@@ -395,51 +395,139 @@ class ImageService:
|
||||
@staticmethod
|
||||
async def _process_image_with_limiting(task_id: int, user_id: int) -> None:
|
||||
"""带限流控制的后台处理图片识别任务"""
|
||||
task_processing_success = False
|
||||
points_deducted = False
|
||||
try:
|
||||
# 执行图片处理任务
|
||||
await ImageService._process_image(task_id)
|
||||
# 任务成功完成后更新状态
|
||||
await ImageService._update_task_status(task_id, ImageTaskStatus.COMPLETED)
|
||||
# Step 1: Execute image recognition (includes external API call)
|
||||
await ImageService._process_image_recognition(task_id)
|
||||
|
||||
# Step 2: Process all remaining steps in a single database transaction for consistency
|
||||
async with background_db_session() as db:
|
||||
await db.begin()
|
||||
try:
|
||||
# Step 2: Process lookup word
|
||||
await dict_service.process_lookup_word_with_db(task_id, db)
|
||||
|
||||
# Step 3: Initialize image text
|
||||
from backend.app.ai import ImageTextService
|
||||
result = await ImageTextService.init_image_text_by_task_with_db(task_id, db)
|
||||
if not result:
|
||||
raise Exception("Failed to initialize image text")
|
||||
|
||||
# Step 4: Deduct user points
|
||||
task = await image_task_dao.get(db, task_id)
|
||||
if task:
|
||||
image = await image_dao.get(db, task.image_id)
|
||||
if image:
|
||||
points_deducted = await points_service.deduct_points_with_db(
|
||||
user_id=task.user_id,
|
||||
amount=IMAGE_RECOGNITION_COST,
|
||||
db=db,
|
||||
related_id=image.id,
|
||||
details={"task_id": task_id},
|
||||
action=POINTS_ACTION_IMAGE_RECOGNITION
|
||||
)
|
||||
if not points_deducted:
|
||||
logger.error(f"Failed to deduct points for user {task.user_id} for task {task_id}")
|
||||
raise Exception("Failed to deduct points")
|
||||
|
||||
# Step 5: Update task status to completed
|
||||
await ImageService._update_task_status_with_db(task_id, ImageTaskStatus.COMPLETED, db)
|
||||
|
||||
# All steps completed successfully
|
||||
task_processing_success = True
|
||||
await db.commit()
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image task {task_id}: {str(e)}")
|
||||
# 任务失败时更新状态
|
||||
await ImageService._update_task_status(task_id, ImageTaskStatus.FAILED, str(e))
|
||||
|
||||
# Try to compensate for partial completion
|
||||
try:
|
||||
async with background_db_session() as db:
|
||||
await db.begin()
|
||||
try:
|
||||
# If points were deducted but task failed, try to refund points
|
||||
if points_deducted and task_processing_success:
|
||||
task = await image_task_dao.get(db, task_id)
|
||||
if task:
|
||||
image = await image_dao.get(db, task.image_id)
|
||||
if image:
|
||||
# Try to refund points by adding them back
|
||||
refund_success = await points_service.add_points(
|
||||
user_id=task.user_id,
|
||||
amount=IMAGE_RECOGNITION_COST,
|
||||
related_id=image.id,
|
||||
details={"task_id": task_id, "refund_for": "task_failure"},
|
||||
action=POINTS_ACTION_IMAGE_RECOGNITION + "_REFUND"
|
||||
)
|
||||
if not refund_success:
|
||||
logger.error(f"Failed to refund points for user {task.user_id} for task {task_id}")
|
||||
|
||||
# Update task status to failed
|
||||
await ImageService._update_task_status_with_db(task_id, ImageTaskStatus.FAILED, db, str(e))
|
||||
await db.commit()
|
||||
except Exception as compensation_error:
|
||||
await db.rollback()
|
||||
logger.error(f"Failed to compensate for task {task_id} failure: {str(compensation_error)}")
|
||||
# Try to update task status with separate connection as last resort
|
||||
try:
|
||||
await ImageService._update_task_status(task_id, ImageTaskStatus.FAILED, str(e))
|
||||
except Exception as final_error:
|
||||
logger.error(f"Failed to update task {task_id} status to FAILED: {str(final_error)}")
|
||||
except Exception as compensation_setup_error:
|
||||
logger.error(f"Failed to setup compensation for task {task_id}: {str(compensation_setup_error)}")
|
||||
# Try to update task status with separate connection as last resort
|
||||
try:
|
||||
await ImageService._update_task_status(task_id, ImageTaskStatus.FAILED, str(e))
|
||||
except Exception as final_error:
|
||||
logger.error(f"Failed to update task {task_id} status to FAILED: {str(final_error)}")
|
||||
finally:
|
||||
# 释放槽位
|
||||
# Release slot
|
||||
try:
|
||||
await rate_limit_service.release_task_slot(user_id)
|
||||
except Exception as slot_error:
|
||||
logger.error(f"Failed to release task slot for user {user_id}: {str(slot_error)}")
|
||||
|
||||
@staticmethod
|
||||
async def _update_task_status(task_id: int, status: ImageTaskStatus, error_message: str = None) -> None:
|
||||
"""更新任务状态"""
|
||||
async def _update_task_status_with_db(task_id: int, status: ImageTaskStatus, db: AsyncSession, error_message: str = None) -> None:
|
||||
"""使用提供的数据库连接更新任务状态"""
|
||||
try:
|
||||
async with background_db_session.begin() as db:
|
||||
task = await image_task_dao.get(db, task_id)
|
||||
if not task:
|
||||
logger.warning(f"Task {task_id} not found when updating status")
|
||||
return
|
||||
|
||||
# 只有当任务不是最终状态时才更新
|
||||
if task.status not in [ImageTaskStatus.COMPLETED, ImageTaskStatus.FAILED]:
|
||||
update_data = {"status": status}
|
||||
if error_message:
|
||||
update_data["error_message"] = error_message
|
||||
|
||||
await image_task_dao.update(db, task_id, update_data)
|
||||
await db.commit()
|
||||
from backend.app.ai.tasks import update_task_status_with_retry
|
||||
await update_task_status_with_retry(
|
||||
db, task_id, status, error_message=error_message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update task {task_id} status to {status}: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def _update_task_status(task_id: int, status: ImageTaskStatus, error_message: str = None) -> None:
|
||||
"""更新任务状态 - 为兼容性而创建的方法"""
|
||||
try:
|
||||
from backend.app.ai.tasks import update_task_status_with_retry
|
||||
async with background_db_session() as db:
|
||||
await update_task_status_with_retry(
|
||||
db, task_id, status, error_message=error_message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update task {task_id} status to {status}: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def _process_image(task_id: int) -> None:
|
||||
await ImageService._process_image_recognition(task_id)
|
||||
await dict_service.process_lookup_word(task_id)
|
||||
# This method is no longer used as we've moved the logic to _process_image_with_limiting
|
||||
# Keeping it for backward compatibility
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def _process_image_recognition(task_id: int) -> None:
|
||||
"""后台处理图片识别任务"""
|
||||
"""后台处理图片识别任务 - compatible version for task processor"""
|
||||
# This is maintained for backward compatibility with the task processor
|
||||
# It creates its own database connection like the original implementation
|
||||
from backend.app.ai.tasks import update_task_status_with_retry, increment_retry_count_with_retry
|
||||
|
||||
max_retries = 5
|
||||
@@ -459,6 +547,7 @@ class ImageService:
|
||||
db, task_id, ImageTaskStatus.FAILED,
|
||||
error_message="Image not found"
|
||||
)
|
||||
await db.commit()
|
||||
return
|
||||
|
||||
# 获取exclude_words
|
||||
@@ -470,9 +559,11 @@ class ImageService:
|
||||
for word in section.get('ref_word', [])
|
||||
if isinstance(section.get('ref_word'), list)
|
||||
])
|
||||
|
||||
# 提交当前事务,释放锁
|
||||
await db.commit()
|
||||
|
||||
# 下载文件
|
||||
# 下载文件(在数据库事务外执行)
|
||||
file_content, file_name, content_type = await file_service.download_file(task.file_id)
|
||||
image_format = image_service.detect_image_format(file_content)
|
||||
image_format_str = image_format.value
|
||||
@@ -495,8 +586,10 @@ class ImageService:
|
||||
recognize_response = await Qwen.recognize_image(recognize_params)
|
||||
recognition_result = recognize_response.get("result").strip().replace("```json", "").replace("```", "").strip()
|
||||
|
||||
async with background_db_session.begin() as db:
|
||||
|
||||
# 使用新的数据库会话处理API响应
|
||||
async with background_db_session() as db:
|
||||
await db.begin()
|
||||
|
||||
if recognize_response.get("error"):
|
||||
# 增加重试次数
|
||||
await increment_retry_count_with_retry(db, task_id)
|
||||
@@ -513,6 +606,7 @@ class ImageService:
|
||||
await update_task_status_with_retry(
|
||||
db, task_id, ImageTaskStatus.PENDING
|
||||
)
|
||||
await db.commit()
|
||||
return
|
||||
|
||||
# Improve JSON parsing with better error handling
|
||||
@@ -548,6 +642,7 @@ class ImageService:
|
||||
await update_task_status_with_retry(
|
||||
db, task_id, ImageTaskStatus.PENDING
|
||||
)
|
||||
await db.commit()
|
||||
return
|
||||
|
||||
# Transform the data structure from array of objects to grouped arrays
|
||||
@@ -576,22 +671,13 @@ class ImageService:
|
||||
)
|
||||
)
|
||||
|
||||
# 更新任务状态为完成
|
||||
# 更新任务状态为处理中(不扣减积分)
|
||||
await update_task_status_with_retry(
|
||||
db, task_id, ImageTaskStatus.PROCESSING,
|
||||
result=transformed_result
|
||||
)
|
||||
|
||||
# 扣减用户积分
|
||||
points_deducted = await points_service.deduct_points_with_db(
|
||||
user_id=task.user_id,
|
||||
amount=IMAGE_RECOGNITION_COST, # 扣减图片识别费用
|
||||
db=db,
|
||||
related_id=image.id,
|
||||
details={"task_id": task_id},
|
||||
action=POINTS_ACTION_IMAGE_RECOGNITION
|
||||
)
|
||||
|
||||
# 提交事务
|
||||
await db.commit()
|
||||
|
||||
# If we reach here, the operation was successful
|
||||
@@ -599,22 +685,69 @@ class ImageService:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image recognition task {task_id}: {str(e)}")
|
||||
async with background_db_session.begin() as db:
|
||||
# 增加重试次数
|
||||
await increment_retry_count_with_retry(db, task_id)
|
||||
task = await image_task_dao.get(db, task_id)
|
||||
# 如果重试次数超过限制,标记为失败
|
||||
if task and task.retry_count >= max_retries:
|
||||
await update_task_status_with_retry(
|
||||
db, task_id, ImageTaskStatus.FAILED,
|
||||
error_message=str(e)
|
||||
)
|
||||
else:
|
||||
# 重置为待处理状态以便重试
|
||||
await update_task_status_with_retry(
|
||||
db, task_id, ImageTaskStatus.PENDING
|
||||
)
|
||||
await db.commit()
|
||||
# Handle database operations in a separate, isolated session to avoid transaction conflicts
|
||||
try:
|
||||
async with background_db_session() as db:
|
||||
await db.begin()
|
||||
# 增加重试次数
|
||||
await increment_retry_count_with_retry(db, task_id)
|
||||
task = await image_task_dao.get(db, task_id)
|
||||
# 如果重试次数超过限制,标记为失败
|
||||
if task and task.retry_count >= max_retries:
|
||||
await update_task_status_with_retry(
|
||||
db, task_id, ImageTaskStatus.FAILED,
|
||||
error_message=str(e)
|
||||
)
|
||||
else:
|
||||
# 重置为待处理状态以便重试
|
||||
await update_task_status_with_retry(
|
||||
db, task_id, ImageTaskStatus.PENDING
|
||||
)
|
||||
await db.commit()
|
||||
except Exception as db_error:
|
||||
logger.error(f"Failed to update task {task_id} after error: {str(db_error)}")
|
||||
|
||||
@staticmethod
|
||||
async def _process_image_recognition_with_db(task_id: int, db: AsyncSession) -> None:
|
||||
"""使用提供的数据库连接处理图片识别任务(不更新状态和扣减积分)"""
|
||||
from backend.app.ai.tasks import update_task_status_with_retry, increment_retry_count_with_retry
|
||||
|
||||
# 获取任务信息
|
||||
task = await image_task_dao.get(db, task_id)
|
||||
if not task:
|
||||
logger.error(f"Task {task_id} not found")
|
||||
raise Exception(f"Task {task_id} not found")
|
||||
|
||||
# 获取图片信息
|
||||
image = await image_dao.get(db, task.image_id)
|
||||
if not image:
|
||||
await update_task_status_with_retry(
|
||||
db, task_id, ImageTaskStatus.FAILED,
|
||||
error_message="Image not found"
|
||||
)
|
||||
return
|
||||
|
||||
# 获取exclude_words
|
||||
exclude_words = []
|
||||
if image.details and "recognition_result" in image.details:
|
||||
recognition_result = image.details["recognition_result"]
|
||||
exclude_words.extend([
|
||||
word for section in recognition_result.values()
|
||||
for word in section.get('ref_word', [])
|
||||
if isinstance(section.get('ref_word'), list)
|
||||
])
|
||||
|
||||
# 提交当前事务点,但不结束整个事务
|
||||
# 注意:这里不能提交完整的事务,因为我们还在一个更大的事务中
|
||||
# 只是刷新当前的更改
|
||||
await db.flush()
|
||||
|
||||
# 注意:由于我们在数据库事务中,不能在这里进行外部API调用
|
||||
# 外部API调用应该在事务之外进行
|
||||
# 这个方法只处理数据库操作部分
|
||||
|
||||
# 为了保持与原流程的一致性,我们需要确保在调用此方法之前已经完成了
|
||||
# 外部API调用并将结果保存到了数据库中
|
||||
|
||||
@staticmethod
|
||||
async def get_task_status(task_id: int) -> dict:
|
||||
@@ -872,4 +1005,4 @@ class ImageService:
|
||||
return image
|
||||
|
||||
|
||||
image_service: ImageService = ImageService()
|
||||
image_service: ImageService = ImageService()
|
||||
|
||||
@@ -4,6 +4,8 @@ import logging
|
||||
import asyncio
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.app.ai.crud import recording_dao
|
||||
from backend.app.ai.crud.image_text_crud import image_text_dao
|
||||
from backend.app.ai.model.image_text import ImageText
|
||||
@@ -108,18 +110,60 @@ class ImageTextService:
|
||||
:return: 图片文本记录列表
|
||||
"""
|
||||
async with async_db_session() as db:
|
||||
# First, check if there's an ImageProcessingTask for this image_id
|
||||
# and wait until it reaches a final state (COMPLETED or FAILED)
|
||||
image_task = await image_task_dao.get_by_image_id(db, image_id)
|
||||
if image_task:
|
||||
while image_task.status not in [ImageTaskStatus.COMPLETED, ImageTaskStatus.FAILED]:
|
||||
# Wait for 1 second before checking again
|
||||
await asyncio.sleep(1)
|
||||
# Refresh the task status
|
||||
await db.refresh(image_task)
|
||||
task = await image_task_dao.get_by_image_id(db, image_id)
|
||||
if not task or task.user_id != user_id:
|
||||
raise errors.ForbiddenError(msg="Forbidden")
|
||||
|
||||
if image_task.status == ImageTaskStatus.FAILED:
|
||||
raise errors.ServerError(msg="Image task failed")
|
||||
image = await image_dao.get(db, image_id)
|
||||
if not image:
|
||||
raise ValueError(f"Image with id {image_id} not found")
|
||||
|
||||
filtered_texts = await image_text_dao.get_by_image_id_and_level(db, image_id, dict_level)
|
||||
|
||||
assessments = []
|
||||
for text in filtered_texts:
|
||||
latest_recording_details = None
|
||||
latest_recording_file = None
|
||||
latest_recording = await recording_dao.get_latest_by_text_id(db, text.id)
|
||||
if latest_recording:
|
||||
latest_recording_details = latest_recording.details
|
||||
latest_recording_file = str(latest_recording.file_id)
|
||||
assessment = ImageTextAssessmentSchema(
|
||||
id=str(text.id),
|
||||
ipa=text.ipa.replace('/', '') if text.ipa else '',
|
||||
zh=text.zh,
|
||||
content=text.content,
|
||||
file_id=latest_recording_file,
|
||||
details=latest_recording_details,
|
||||
)
|
||||
assessments.append(assessment)
|
||||
|
||||
return ImageTextInitResponseSchema(
|
||||
image_file_id=str(image.file_id),
|
||||
assessments=assessments
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def init_image_text_by_task(task_id: int) -> bool:
|
||||
"""
|
||||
初始化图片文本记录
|
||||
根据dict_level从image的recognition_result中提取文本,如果不存在则创建,如果已存在则直接返回
|
||||
|
||||
:param user_id: 用户ID
|
||||
:param image_id: 图片ID
|
||||
:param dict_level: 词典等级
|
||||
:param background_tasks: 后台任务对象,用于异步生成标准发音音频
|
||||
:return: 图片文本记录列表
|
||||
"""
|
||||
async with async_db_session() as db:
|
||||
image_task = await image_task_dao.get(db, task_id)
|
||||
if not image_task:
|
||||
logger.error(f"Task {task_id} not found")
|
||||
return False
|
||||
|
||||
image_id = image_task.image_id
|
||||
user_id = image_task.user_id
|
||||
|
||||
# 获取图片记录
|
||||
image = await image_dao.get(db, image_id)
|
||||
@@ -131,52 +175,44 @@ class ImageTextService:
|
||||
raise ValueError("Image recognition result not found")
|
||||
|
||||
recognition_result = image.details["recognition_result"]
|
||||
lower_dict_level = dict_level.lower()
|
||||
|
||||
if lower_dict_level not in recognition_result:
|
||||
raise ValueError(f"Dict level {dict_level} not found in recognition result")
|
||||
|
||||
level_data = recognition_result[lower_dict_level]
|
||||
|
||||
# 收集该等级下的所有文本及来源
|
||||
texts_with_source_and_ipa = [] # [(content, source, ipa, zh), ...]
|
||||
# 遍历 recognition_result 中的所有等级数据
|
||||
for level_key, level_data in recognition_result.items():
|
||||
if "desc_en" in level_data and "desc_ipa" in level_data:
|
||||
descriptions = level_data["desc_en"]
|
||||
desc_ipas = level_data["desc_ipa"]
|
||||
desc_zhs = level_data["desc_zh"]
|
||||
|
||||
# 提取description中的文本和对应的IPA
|
||||
if "desc_en" in level_data and "desc_ipa" in level_data:
|
||||
descriptions = level_data["desc_en"]
|
||||
desc_ipas = level_data["desc_ipa"]
|
||||
desc_zhs = level_data["desc_zh"]
|
||||
|
||||
# 如果description是列表,处理每个元素及对应的IPA
|
||||
if isinstance(descriptions, list):
|
||||
for i, desc in enumerate(descriptions):
|
||||
if isinstance(desc, str):
|
||||
# 获取对应的IPA,如果存在的话
|
||||
ipa = None
|
||||
zh = None
|
||||
if isinstance(desc_ipas, list) and i < len(desc_ipas):
|
||||
ipa_value = desc_ipas[i]
|
||||
if isinstance(ipa_value, str):
|
||||
ipa = ipa_value
|
||||
if isinstance(desc_zhs, list) and i < len(desc_zhs):
|
||||
zh_value = desc_zhs[i]
|
||||
if isinstance(zh_value, str):
|
||||
zh = zh_value
|
||||
texts_with_source_and_ipa.append((desc, "desc_en", ipa, zh))
|
||||
# 如果description是字符串,直接添加
|
||||
elif isinstance(descriptions, str):
|
||||
# 获取对应的IPA,如果存在的话
|
||||
ipa = None
|
||||
zh = None
|
||||
if isinstance(desc_ipas, list) and len(desc_ipas) > 0:
|
||||
ipa_value = desc_ipas[0]
|
||||
if isinstance(ipa_value, str):
|
||||
ipa = ipa_value
|
||||
if isinstance(desc_zhs, list) and len(desc_zhs) > 0:
|
||||
zh_value = desc_zhs[0]
|
||||
if isinstance(zh_value, str):
|
||||
zh = zh_value
|
||||
texts_with_source_and_ipa.append((descriptions, "desc_en", ipa, zh))
|
||||
# 如果description是列表,处理每个元素及对应的IPA
|
||||
if isinstance(descriptions, list):
|
||||
for i, desc in enumerate(descriptions):
|
||||
if isinstance(desc, str):
|
||||
# 获取对应的IPA,如果存在的话
|
||||
ipa = None
|
||||
zh = None
|
||||
if isinstance(desc_ipas, list) and i < len(desc_ipas):
|
||||
ipa_value = desc_ipas[i]
|
||||
if isinstance(ipa_value, str):
|
||||
ipa = ipa_value
|
||||
if isinstance(desc_zhs, list) and i < len(desc_zhs):
|
||||
zh_value = desc_zhs[i]
|
||||
if isinstance(zh_value, str):
|
||||
zh = zh_value
|
||||
texts_with_source_and_ipa.append((desc, "desc_en", ipa, zh, str(level_key).lower()))
|
||||
# 如果description是字符串,直接添加
|
||||
elif isinstance(descriptions, str):
|
||||
# 获取对应的IPA,如果存在的话
|
||||
ipa = None
|
||||
zh = None
|
||||
if isinstance(desc_ipas, list) and len(desc_ipas) > 0:
|
||||
ipa_value = desc_ipas[0]
|
||||
if isinstance(ipa_value, str):
|
||||
ipa = ipa_value
|
||||
if isinstance(desc_zhs, list) and len(desc_zhs) > 0:
|
||||
zh_value = desc_zhs[0]
|
||||
if isinstance(zh_value, str):
|
||||
zh = zh_value
|
||||
texts_with_source_and_ipa.append((descriptions, "desc_en", ipa, zh, str(level_key).lower()))
|
||||
|
||||
# 获取已存在的文本记录
|
||||
existing_texts = await image_text_dao.get_by_image_id(db, image_id)
|
||||
@@ -185,7 +221,7 @@ class ImageTextService:
|
||||
# 创建新的文本记录(如果不存在)
|
||||
created_texts = []
|
||||
newly_created_texts = []
|
||||
for text_content, source, ipa, zh in texts_with_source_and_ipa:
|
||||
for text_content, source, ipa, zh, level_key in texts_with_source_and_ipa:
|
||||
if text_content in existing_text_map:
|
||||
# 已存在的文本记录
|
||||
created_texts.append(existing_text_map[text_content])
|
||||
@@ -196,7 +232,7 @@ class ImageTextService:
|
||||
content=text_content,
|
||||
standard_audio_id=None,
|
||||
source=source,
|
||||
dict_level=dict_level,
|
||||
dict_level=level_key,
|
||||
ipa=ipa,
|
||||
zh=zh,
|
||||
)
|
||||
@@ -213,42 +249,147 @@ class ImageTextService:
|
||||
await db.refresh(text)
|
||||
|
||||
# 为新创建的文本记录生成标准发音音频(使用后台任务)
|
||||
if background_tasks and newly_created_texts:
|
||||
if newly_created_texts:
|
||||
from backend.middleware.tencent_cloud import TencentCloud
|
||||
tencent_cloud = TencentCloud()
|
||||
for text in newly_created_texts:
|
||||
# 添加后台任务来生成标准发音音频
|
||||
background_tasks.add_task(
|
||||
tencent_cloud.text_to_speak,
|
||||
await tencent_cloud.text_to_speak(
|
||||
image_id=text.image_id,
|
||||
content=text.content,
|
||||
image_text_id=text.id,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# 构造返回结构
|
||||
assessments = []
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def init_image_text_by_task_with_db(task_id: int, db: AsyncSession) -> bool:
|
||||
"""
|
||||
根据任务ID初始化图片文本记录(使用提供的数据库连接)
|
||||
|
||||
:param task_id: 任务ID
|
||||
:param db: 数据库连接
|
||||
:return: 是否成功
|
||||
"""
|
||||
try:
|
||||
image_task = await image_task_dao.get(db, task_id)
|
||||
if not image_task:
|
||||
logger.error(f"Task {task_id} not found")
|
||||
return False
|
||||
|
||||
image_id = image_task.image_id
|
||||
user_id = image_task.user_id
|
||||
|
||||
# 获取图片记录
|
||||
image = await image_dao.get(db, image_id)
|
||||
if not image:
|
||||
raise ValueError(f"Image with id {image_id} not found")
|
||||
|
||||
# Explicitly refresh the image to ensure we have the latest details
|
||||
await db.refresh(image)
|
||||
|
||||
# 检查details和recognition_result是否存在
|
||||
if not image.details or "recognition_result" not in image.details:
|
||||
raise ValueError("Image recognition result not found")
|
||||
|
||||
recognition_result = image.details["recognition_result"]
|
||||
|
||||
# Validate that recognition_result is not empty
|
||||
if not recognition_result or len(recognition_result) == 0:
|
||||
raise ValueError("Image recognition result is empty")
|
||||
|
||||
texts_with_source_and_ipa = [] # [(content, source, ipa, zh), ...]
|
||||
# 遍历 recognition_result 中的所有等级数据
|
||||
for level_key, level_data in recognition_result.items():
|
||||
if "desc_en" in level_data and "desc_ipa" in level_data:
|
||||
descriptions = level_data["desc_en"]
|
||||
desc_ipas = level_data["desc_ipa"]
|
||||
desc_zhs = level_data["desc_zh"]
|
||||
|
||||
# 如果description是列表,处理每个元素及对应的IPA
|
||||
if isinstance(descriptions, list):
|
||||
for i, desc in enumerate(descriptions):
|
||||
if isinstance(desc, str):
|
||||
# 获取对应的IPA,如果存在的话
|
||||
ipa = None
|
||||
zh = None
|
||||
if isinstance(desc_ipas, list) and i < len(desc_ipas):
|
||||
ipa_value = desc_ipas[i]
|
||||
if isinstance(ipa_value, str):
|
||||
ipa = ipa_value
|
||||
if isinstance(desc_zhs, list) and i < len(desc_zhs):
|
||||
zh_value = desc_zhs[i]
|
||||
if isinstance(zh_value, str):
|
||||
zh = zh_value
|
||||
texts_with_source_and_ipa.append((desc, "desc_en", ipa, zh, str(level_key).lower()))
|
||||
# 如果description是字符串,直接添加
|
||||
elif isinstance(descriptions, str):
|
||||
# 获取对应的IPA,如果存在的话
|
||||
ipa = None
|
||||
zh = None
|
||||
if isinstance(desc_ipas, list) and len(desc_ipas) > 0:
|
||||
ipa_value = desc_ipas[0]
|
||||
if isinstance(ipa_value, str):
|
||||
ipa = ipa_value
|
||||
if isinstance(desc_zhs, list) and len(desc_zhs) > 0:
|
||||
zh_value = desc_zhs[0]
|
||||
if isinstance(zh_value, str):
|
||||
zh = zh_value
|
||||
texts_with_source_and_ipa.append((descriptions, "desc_en", ipa, zh, str(level_key).lower()))
|
||||
|
||||
# 获取已存在的文本记录
|
||||
existing_texts = await image_text_dao.get_by_image_id(db, image_id)
|
||||
existing_text_map = {text.content: text for text in existing_texts}
|
||||
|
||||
# 创建新的文本记录(如果不存在)
|
||||
created_texts = []
|
||||
newly_created_texts = []
|
||||
for text_content, source, ipa, zh, level_key in texts_with_source_and_ipa:
|
||||
if text_content in existing_text_map:
|
||||
# 已存在的文本记录
|
||||
created_texts.append(existing_text_map[text_content])
|
||||
else:
|
||||
# 创建新的文本记录
|
||||
new_text = ImageText(
|
||||
image_id=image_id,
|
||||
content=text_content,
|
||||
standard_audio_id=None,
|
||||
source=source,
|
||||
dict_level=level_key,
|
||||
ipa=ipa,
|
||||
zh=zh,
|
||||
)
|
||||
db.add(new_text)
|
||||
created_texts.append(new_text)
|
||||
newly_created_texts.append(new_text)
|
||||
|
||||
# 提交事务
|
||||
await db.commit()
|
||||
|
||||
# 刷新创建的文本记录
|
||||
for text in created_texts:
|
||||
latest_recording_details = None
|
||||
latest_recording_file = None
|
||||
latest_recording = await recording_dao.get_latest_by_text_id(db, text.id)
|
||||
if latest_recording:
|
||||
latest_recording_details = latest_recording.details
|
||||
latest_recording_file = str(latest_recording.file_id)
|
||||
assessment = ImageTextAssessmentSchema(
|
||||
id=str(text.id),
|
||||
ipa=text.ipa.replace('/', ''),
|
||||
zh=text.zh,
|
||||
content=text.content,
|
||||
file_id=latest_recording_file,
|
||||
details=latest_recording_details,
|
||||
)
|
||||
assessments.append(assessment)
|
||||
if text.id is None: # 只刷新新创建的记录
|
||||
await db.refresh(text)
|
||||
|
||||
return ImageTextInitResponseSchema(
|
||||
image_file_id=str(image.file_id),
|
||||
assessments=assessments
|
||||
)
|
||||
# 为新创建的文本记录生成标准发音音频(使用后台任务)
|
||||
if newly_created_texts:
|
||||
from backend.middleware.tencent_cloud import TencentCloud
|
||||
tencent_cloud = TencentCloud()
|
||||
for text in newly_created_texts:
|
||||
# 添加后台任务来生成标准发音音频
|
||||
await tencent_cloud.text_to_speak(
|
||||
image_id=text.image_id,
|
||||
content=text.content,
|
||||
image_text_id=text.id,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing image text for task {task_id}: {str(e)}")
|
||||
await db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
image_text_service = ImageTextService()
|
||||
image_text_service = ImageTextService()
|
||||
|
||||
@@ -48,6 +48,13 @@ async def process_pending_tasks():
|
||||
logger.info(f"Recovering processing task {task.id}")
|
||||
# 重置为待处理状态 with retry mechanism
|
||||
await update_task_status_with_retry(db, task.id, ImageTaskStatus.PENDING)
|
||||
await db.commit()
|
||||
# 释放之前分配的任务槽位,因为任务将重新开始
|
||||
try:
|
||||
from backend.app.ai.service.rate_limit_service import rate_limit_service
|
||||
await rate_limit_service.release_task_slot(task.user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to release slot for interrupted task {task.id}: {str(e)}")
|
||||
# 创建包装任务来处理完成和异常情况
|
||||
task_wrapper = create_task_wrapper(image_service._process_image_recognition, task.id)
|
||||
asyncio.create_task(task_wrapper)
|
||||
@@ -84,6 +91,7 @@ async def create_task_wrapper(process_func, task_id):
|
||||
db, task_id, ImageTaskStatus.FAILED,
|
||||
error_message=f"Task failed with exception: {str(e)}"
|
||||
)
|
||||
await db.commit()
|
||||
except Exception as update_error:
|
||||
logger.error(f"Failed to update task {task_id} status to FAILED: {str(update_error)}")
|
||||
|
||||
@@ -166,4 +174,4 @@ async def start_task_processor():
|
||||
await asyncio.sleep(30) # 30秒检查一次
|
||||
except Exception as e:
|
||||
logger.error(f"Error in task processor: {str(e)}")
|
||||
await asyncio.sleep(30) # 出错后也等待一段时间再重试
|
||||
await asyncio.sleep(30) # 出错后也等待一段时间再重试
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user