1020 lines
45 KiB
Python
Executable File
1020 lines
45 KiB
Python
Executable File
import asyncio
|
||
import re
|
||
import base64
|
||
import hashlib
|
||
import imghdr
|
||
import io
|
||
import json
|
||
from datetime import datetime
|
||
from typing import Optional, Dict, Any, List, Set, Tuple
|
||
|
||
from PIL import Image as PILImage, ExifTags
|
||
from fastapi import UploadFile, Request
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from starlette.background import BackgroundTasks
|
||
|
||
from backend.app.ai.crud.image_curd import image_dao
|
||
from backend.app.ai.model import Image
|
||
from backend.app.ai.schema.image import ImageFormat, ImageMetadata, ColorMode, ImageRecognizeRes, UpdateImageParam, \
|
||
ProcessImageRequest
|
||
from backend.app.admin.schema.qwen import QwenRecognizeImageParams
|
||
from backend.app.admin.service.file_service import file_service
|
||
from backend.app.ai.service.rate_limit_service import rate_limit_service
|
||
from backend.common.enums import FileType
|
||
from backend.common.exception import errors
|
||
from backend.common.const import IMAGE_RECOGNITION_COST, POINTS_ACTION_IMAGE_RECOGNITION
|
||
from backend.core.conf import settings
|
||
from backend.common.log import log as logger
|
||
|
||
from backend.database.db import async_db_session, background_db_session
|
||
from backend.middleware.qwen import Qwen
|
||
|
||
# Import the new task components
|
||
from backend.app.ai.model.image_task import ImageTaskStatus, ImageProcessingTask
|
||
from backend.app.ai.crud.image_task_crud import image_task_dao
|
||
from backend.app.ai.schema.image_task import CreateImageTaskParam
|
||
|
||
# Import Youdao API and related components for dictionary lookup
|
||
from backend.middleware.youdao import YoudaoAPI, YoudaoWordAPI
|
||
from backend.app.admin.model.dict import YdDictLanguage, YdDictType, DictCategory
|
||
from backend.app.admin.service.yd_dict_service import yd_dict_service
|
||
from backend.app.admin.service.points_service import points_service
|
||
from backend.app.admin.service.dict_service import dict_service
|
||
from backend.database.redis import redis_client
|
||
|
||
|
||
DAILY_IMAGE_RECOGNITION_MAX_TIMES = 3
|
||
|
||
class ImageService:
|
||
|
||
@staticmethod
|
||
def file_verify(file: UploadFile) -> None:
|
||
"""
|
||
文件验证
|
||
|
||
:param file: FastAPI 上传文件对象
|
||
:return:
|
||
"""
|
||
filename = file.filename
|
||
file_ext = filename.split('.')[-1].lower()
|
||
if not file_ext:
|
||
raise errors.ForbiddenError(msg='未知的文件类型')
|
||
|
||
if file_ext == FileType.image:
|
||
if file_ext not in settings.UPLOAD_IMAGE_EXT_INCLUDE:
|
||
raise errors.ForbiddenError(msg=f'[{file_ext}] 此图片格式暂不支持')
|
||
if file.size > settings.UPLOAD_IMAGE_SIZE_MAX:
|
||
raise errors.ForbiddenError(msg=f'图片超出最大限制,请重新选择')
|
||
|
||
@staticmethod
|
||
async def load_image_by_id(self, id: int) -> Image:
|
||
async with async_db_session() as db:
|
||
image = await image_dao.get(db, id)
|
||
if not image:
|
||
errors.NotFoundError(msg='图片不存在')
|
||
return image
|
||
|
||
@staticmethod
|
||
def detect_image_format(image_bytes: bytes) -> ImageFormat:
|
||
"""通过二进制数据检测图片格式"""
|
||
# 使用imghdr识别基础格式
|
||
format_str = imghdr.what(None, h=image_bytes)
|
||
|
||
# 映射到枚举类型
|
||
format_mapping = {
|
||
'jpeg': ImageFormat.JPEG,
|
||
'jpg': ImageFormat.JPEG,
|
||
'png': ImageFormat.PNG,
|
||
'gif': ImageFormat.GIF,
|
||
'bmp': ImageFormat.BMP,
|
||
'webp': ImageFormat.WEBP,
|
||
'tiff': ImageFormat.TIFF,
|
||
'svg': ImageFormat.SVG
|
||
}
|
||
|
||
return format_mapping.get(format_str, ImageFormat.UNKNOWN)
|
||
|
||
@staticmethod
|
||
def extract_metadata(image_bytes: bytes, additional_info: Dict[str, Any] = None) -> ImageMetadata:
|
||
"""从图片二进制数据中提取元数据"""
|
||
try:
|
||
with PILImage.open(io.BytesIO(image_bytes)) as img:
|
||
# 获取基础信息
|
||
width, height = img.size
|
||
color_mode = ColorMode(img.mode) if img.mode in ColorMode.__members__.values() else ColorMode.UNKNOWN
|
||
|
||
# 获取EXIF数据
|
||
exif_data = {}
|
||
if hasattr(img, '_getexif') and img._getexif():
|
||
for tag, value in img._getexif().items():
|
||
decoded_tag = ExifTags.TAGS.get(tag, tag)
|
||
# 特殊处理日期时间
|
||
if decoded_tag in ['DateTime', 'DateTimeOriginal', 'DateTimeDigitized']:
|
||
try:
|
||
value = datetime.strptime(value, "%Y:%m:%d %H:%M:%S").isoformat()
|
||
except:
|
||
pass
|
||
# Convert IFDRational values to float to avoid JSON serialization issues
|
||
if hasattr(value, 'numerator') and hasattr(value, 'denominator'):
|
||
try:
|
||
value = float(value)
|
||
except:
|
||
pass
|
||
exif_data[decoded_tag] = value
|
||
|
||
# 获取颜色通道数
|
||
channels = len(img.getbands())
|
||
|
||
# 尝试获取DPI
|
||
dpi = img.info.get('dpi')
|
||
# Convert DPI IFDRational values to float tuple
|
||
if dpi:
|
||
try:
|
||
dpi = tuple(float(d) for d in dpi)
|
||
except:
|
||
pass
|
||
|
||
# 创建元数据对象
|
||
metadata = ImageMetadata(
|
||
format=ImageService.detect_image_format(image_bytes),
|
||
width=width,
|
||
height=height,
|
||
color_mode=color_mode,
|
||
file_size=len(image_bytes),
|
||
channels=channels,
|
||
dpi=dpi,
|
||
exif=exif_data
|
||
)
|
||
|
||
# 添加额外信息
|
||
if additional_info:
|
||
for key, value in additional_info.items():
|
||
if hasattr(metadata, key):
|
||
setattr(metadata, key, value)
|
||
|
||
return metadata
|
||
except Exception as e:
|
||
# 无法解析图片时返回基础元数据
|
||
return ImageMetadata(
|
||
format=ImageService.detect_image_format(image_bytes),
|
||
width=0,
|
||
height=0,
|
||
color_mode=ColorMode.UNKNOWN,
|
||
file_size=len(image_bytes),
|
||
error=f"Metadata extraction failed: {str(e)}"
|
||
)
|
||
|
||
@staticmethod
|
||
def calculate_image_hash(image_bytes: bytes) -> str:
|
||
"""计算图片的SHA256哈希值"""
|
||
return hashlib.sha256(image_bytes).hexdigest()
|
||
|
||
@staticmethod
|
||
async def generate_thumbnail(image_id: int, file_id: int) -> None:
|
||
"""生成缩略图并更新image记录"""
|
||
try:
|
||
# 下载原始图片
|
||
file_content, file_name, content_type = await file_service.download_file(file_id)
|
||
|
||
# 生成缩略图
|
||
thumbnail_content = await ImageService._create_thumbnail(file_content)
|
||
|
||
# 如果缩略图生成失败,使用原始图片作为缩略图
|
||
if not thumbnail_content:
|
||
thumbnail_content = file_content
|
||
|
||
# 上传缩略图到文件服务
|
||
# 创建一个虚拟的文件对象用于上传
|
||
class MockUploadFile:
|
||
def __init__(self, filename, content):
|
||
self.filename = filename
|
||
self._file = io.BytesIO(content)
|
||
self.size = len(content)
|
||
|
||
async def read(self):
|
||
"""读取文件内容"""
|
||
self._file.seek(0)
|
||
return self._file.read()
|
||
|
||
async def seek(self, position):
|
||
"""重置文件指针"""
|
||
self._file.seek(position)
|
||
|
||
thumbnail_file = MockUploadFile(
|
||
filename=f"thumbnail_{file_name}",
|
||
content=thumbnail_content
|
||
)
|
||
|
||
# 上传缩略图,使用新的方法并显式传递content_type
|
||
thumbnail_response = await file_service.upload_file_with_content_type(
|
||
thumbnail_file,
|
||
content_type=content_type
|
||
)
|
||
thumbnail_file_id = int(thumbnail_response.id)
|
||
|
||
# 更新image记录的thumbnail_id字段
|
||
async with async_db_session.begin() as db:
|
||
await image_dao.update(
|
||
db,
|
||
image_id,
|
||
UpdateImageParam(thumbnail_id=thumbnail_file_id)
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成缩略图失败: {str(e)}")
|
||
# 不抛出异常,避免影响主流程
|
||
|
||
@staticmethod
|
||
async def _create_thumbnail(image_bytes: bytes, size: tuple = (100, 100)) -> bytes:
|
||
"""创建缩略图"""
|
||
try:
|
||
# 检查输入是否为空
|
||
if not image_bytes:
|
||
return None
|
||
|
||
# 打开原始图片
|
||
with PILImage.open(io.BytesIO(image_bytes)) as img:
|
||
# 转换为RGB模式(如果需要)
|
||
if img.mode in ("RGBA", "LA", "P"):
|
||
# 创建白色背景
|
||
background = PILImage.new("RGB", img.size, (255, 255, 255))
|
||
if img.mode == "P":
|
||
img = img.convert("RGBA")
|
||
background.paste(img, mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None)
|
||
img = background
|
||
|
||
# 居中裁剪图片为正方形
|
||
width, height = img.size
|
||
if width > height:
|
||
# 宽度大于高度,裁剪水平中部
|
||
left = (width - height) // 2
|
||
right = left + height
|
||
top = 0
|
||
bottom = height
|
||
else:
|
||
# 高度大于宽度,裁剪垂直中部
|
||
left = 0
|
||
right = width
|
||
top = (height - width) // 2
|
||
bottom = top + width
|
||
|
||
# 执行裁剪
|
||
img = img.crop((left, top, right, bottom))
|
||
|
||
# 调整图片尺寸为指定大小
|
||
img = img.resize(size, PILImage.Resampling.LANCZOS)
|
||
|
||
# 保存缩略图到字节流
|
||
thumbnail_buffer = io.BytesIO()
|
||
img.save(thumbnail_buffer, format=img.format or "JPEG")
|
||
thumbnail_buffer.seek(0)
|
||
|
||
return thumbnail_buffer.read()
|
||
except Exception as e:
|
||
logger.error(f"创建缩略图失败: {str(e)}")
|
||
# 如果失败,返回None
|
||
return None
|
||
|
||
@staticmethod
|
||
async def process_image_from_file_async(
|
||
params: ProcessImageRequest,
|
||
background_tasks: BackgroundTasks,
|
||
request: Request
|
||
) -> dict:
|
||
"""异步处理图片识别 - 立即返回任务ID"""
|
||
|
||
current_user = request.user
|
||
file_id = params.file_id
|
||
type = params.type
|
||
dict_level = params.dict_level.name
|
||
if not dict_level:
|
||
dict_level = current_user.dict_level.name
|
||
|
||
# 检查用户积分是否足够(现在积分没有过期概念)
|
||
if not await points_service.check_sufficient_points(current_user.id, IMAGE_RECOGNITION_COST):
|
||
raise errors.ForbiddenError(
|
||
msg=f'积分不足,请充值以继续使用'
|
||
)
|
||
|
||
# 尝试获取任务槽位
|
||
slot_acquired = await rate_limit_service.acquire_task_slot(current_user.id)
|
||
if not slot_acquired:
|
||
# 如果无法获取槽位,抛出限流错误
|
||
max_tasks = await rate_limit_service.get_user_task_limit(current_user.id)
|
||
raise errors.ForbiddenError(
|
||
msg=f'用户同时最多只能运行 {max_tasks} 个任务,请等待现有任务完成后再试'
|
||
)
|
||
|
||
# 通过file_id读取文件内容
|
||
try:
|
||
file_content, file_name, content_type = await file_service.download_file(file_id)
|
||
except Exception as e:
|
||
raise errors.NotFoundError(msg=f"文件不存在或无法读取: {str(e)}")
|
||
|
||
# 提前提取图片格式
|
||
image_format = image_service.detect_image_format(file_content)
|
||
image_format_str = image_format.value
|
||
|
||
base64_image = base64.b64encode(file_content).decode('utf-8')
|
||
|
||
async with async_db_session.begin() as db:
|
||
# 检查是否在image表中已有记录(根据file_id和dict_level)
|
||
existing_image = await image_dao.get_by_file_id(db, file_id)
|
||
if existing_image:
|
||
image_id = existing_image.id
|
||
else:
|
||
# insert image
|
||
new_image = Image(
|
||
file_id=file_id,
|
||
)
|
||
|
||
await image_dao.add(db, new_image)
|
||
await db.flush() # 获取ID
|
||
image_id = new_image.id
|
||
|
||
# 生成缩略图
|
||
background_tasks.add_task(ImageService.generate_thumbnail, image_id, file_id)
|
||
|
||
# embedding
|
||
# embed_params = QwenEmbedImageParams(
|
||
# user_id=current_user.id,
|
||
# dict_level=dict_level,
|
||
# image_id=new_image.id,
|
||
# file_name=file_name,
|
||
# format=image_format_str,
|
||
# data=base64_image,
|
||
# )
|
||
# embed_response = await Qwen.embed_image(embed_params)
|
||
# if embed_response.get("error"):
|
||
# raise Exception(embed_response["error"])
|
||
#
|
||
# embedding = embed_response.get("embedding")
|
||
|
||
# 提取元数据
|
||
additional_info = {
|
||
"file_name": file_name,
|
||
"content_type": content_type,
|
||
"file_size": len(file_content),
|
||
}
|
||
metadata = file_service.extract_metadata(file_content, additional_info)
|
||
|
||
# 更新image记录
|
||
await image_dao.update(
|
||
db, new_image.id,
|
||
UpdateImageParam(
|
||
# embedding=embedding,
|
||
info=metadata or {},
|
||
)
|
||
)
|
||
|
||
# 创建异步处理任务
|
||
task_params = CreateImageTaskParam(
|
||
image_id=image_id,
|
||
file_id=file_id,
|
||
user_id=current_user.id,
|
||
status=ImageTaskStatus.PENDING,
|
||
dict_level=dict_level,
|
||
type=type
|
||
)
|
||
|
||
task = await image_task_dao.create_task(db, task_params)
|
||
await db.flush()
|
||
task_id = task.id
|
||
# Ensure all database changes are committed before starting background tasks
|
||
await db.commit()
|
||
|
||
# 添加后台任务来处理图片识别
|
||
asyncio.create_task(ImageService._process_image_with_limiting(task_id, current_user.id))
|
||
|
||
return {
|
||
"task_id": str(task_id),
|
||
"status": "accepted",
|
||
"message": "Image processing started"
|
||
}
|
||
|
||
@staticmethod
|
||
async def _process_image_with_limiting(task_id: int, user_id: int) -> None:
|
||
"""带限流控制的后台处理图片识别任务"""
|
||
task_processing_success = False
|
||
points_deducted = False
|
||
try:
|
||
# Step 1: Execute image recognition only if not already done
|
||
need_recognition = True
|
||
try:
|
||
async with background_db_session() as db_check:
|
||
task_check = await image_task_dao.get(db_check, task_id)
|
||
if task_check and task_check.result:
|
||
need_recognition = False
|
||
except Exception:
|
||
pass
|
||
if need_recognition:
|
||
await ImageService._process_image_recognition(task_id)
|
||
|
||
# Step 2: Process all remaining steps in a single database transaction for consistency
|
||
async with background_db_session() as db:
|
||
await db.begin()
|
||
try:
|
||
task_locked = await image_task_dao.get_for_update(db, task_id)
|
||
if task_locked:
|
||
await image_dao.get_for_update(db, task_locked.image_id)
|
||
# Step 2: Process lookup word
|
||
await dict_service.process_lookup_word_with_db(task_id, db)
|
||
|
||
# Step 3: Initialize image text
|
||
from backend.app.ai import ImageTextService
|
||
result = await ImageTextService.init_image_text_by_task_with_db(task_id, db)
|
||
if not result:
|
||
raise Exception("Failed to initialize image text")
|
||
|
||
# Step 4: Deduct user points
|
||
task = await image_task_dao.get(db, task_id)
|
||
if task:
|
||
image = await image_dao.get(db, task.image_id)
|
||
if image:
|
||
points_deducted = await points_service.deduct_points_with_db(
|
||
user_id=task.user_id,
|
||
amount=IMAGE_RECOGNITION_COST,
|
||
db=db,
|
||
related_id=image.id,
|
||
details={"task_id": task_id},
|
||
action=POINTS_ACTION_IMAGE_RECOGNITION
|
||
)
|
||
if not points_deducted:
|
||
logger.error(f"Failed to deduct points for user {task.user_id} for task {task_id}")
|
||
raise Exception("Failed to deduct points")
|
||
|
||
# Step 5: Update task status to completed
|
||
await ImageService._update_task_status_with_db(task_id, ImageTaskStatus.COMPLETED, db)
|
||
|
||
await db.commit()
|
||
task_processing_success = True
|
||
|
||
except Exception as e:
|
||
await db.rollback()
|
||
raise
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error processing image task {task_id}: {str(e)}")
|
||
|
||
# Try to compensate for partial completion
|
||
try:
|
||
async with background_db_session() as db:
|
||
await db.begin()
|
||
try:
|
||
# If points were deducted but task failed, try to refund points
|
||
if points_deducted and task_processing_success:
|
||
task = await image_task_dao.get(db, task_id)
|
||
if task:
|
||
image = await image_dao.get(db, task.image_id)
|
||
if image:
|
||
# Try to refund points by adding them back
|
||
refund_success = await points_service.add_points(
|
||
user_id=task.user_id,
|
||
amount=IMAGE_RECOGNITION_COST,
|
||
related_id=image.id,
|
||
details={"task_id": task_id, "refund_for": "task_failure"},
|
||
action=POINTS_ACTION_IMAGE_RECOGNITION + "_REFUND"
|
||
)
|
||
if not refund_success:
|
||
logger.error(f"Failed to refund points for user {task.user_id} for task {task_id}")
|
||
|
||
# Update task status to failed
|
||
await ImageService._update_task_status_with_db(task_id, ImageTaskStatus.FAILED, db, str(e))
|
||
await db.commit()
|
||
except Exception as compensation_error:
|
||
await db.rollback()
|
||
logger.error(f"Failed to compensate for task {task_id} failure: {str(compensation_error)}")
|
||
# Try to update task status with separate connection as last resort
|
||
try:
|
||
await ImageService._update_task_status(task_id, ImageTaskStatus.FAILED, str(e))
|
||
except Exception as final_error:
|
||
logger.error(f"Failed to update task {task_id} status to FAILED: {str(final_error)}")
|
||
except Exception as compensation_setup_error:
|
||
logger.error(f"Failed to setup compensation for task {task_id}: {str(compensation_setup_error)}")
|
||
# Try to update task status with separate connection as last resort
|
||
try:
|
||
await ImageService._update_task_status(task_id, ImageTaskStatus.FAILED, str(e))
|
||
except Exception as final_error:
|
||
logger.error(f"Failed to update task {task_id} status to FAILED: {str(final_error)}")
|
||
finally:
|
||
# Release slot
|
||
try:
|
||
await rate_limit_service.release_task_slot(user_id)
|
||
except Exception as slot_error:
|
||
logger.error(f"Failed to release task slot for user {user_id}: {str(slot_error)}")
|
||
|
||
@staticmethod
|
||
async def _update_task_status_with_db(task_id: int, status: ImageTaskStatus, db: AsyncSession, error_message: str = None) -> None:
|
||
"""使用提供的数据库连接更新任务状态"""
|
||
try:
|
||
from backend.app.ai.tasks import update_task_status_with_retry
|
||
await update_task_status_with_retry(
|
||
db, task_id, status, error_message=error_message
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Failed to update task {task_id} status to {status}: {str(e)}")
|
||
raise
|
||
|
||
@staticmethod
|
||
async def _update_task_status(task_id: int, status: ImageTaskStatus, error_message: str = None) -> None:
|
||
"""更新任务状态 - 为兼容性而创建的方法"""
|
||
try:
|
||
from backend.app.ai.tasks import update_task_status_with_retry
|
||
async with background_db_session() as db:
|
||
await update_task_status_with_retry(
|
||
db, task_id, status, error_message=error_message
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Failed to update task {task_id} status to {status}: {str(e)}")
|
||
raise
|
||
|
||
@staticmethod
|
||
async def _process_image(task_id: int) -> None:
|
||
# This method is no longer used as we've moved the logic to _process_image_with_limiting
|
||
# Keeping it for backward compatibility
|
||
pass
|
||
|
||
@staticmethod
|
||
async def _process_image_recognition(task_id: int) -> None:
|
||
"""后台处理图片识别任务 - compatible version for task processor"""
|
||
# This is maintained for backward compatibility with the task processor
|
||
# It creates its own database connection like the original implementation
|
||
from backend.app.ai.tasks import update_task_status_with_retry, increment_retry_count_with_retry
|
||
|
||
max_retries = 5
|
||
for attempt in range(max_retries):
|
||
try:
|
||
async with background_db_session() as db:
|
||
# 获取任务信息
|
||
task = await image_task_dao.get(db, task_id)
|
||
if not task:
|
||
logger.error(f"Task {task_id} not found")
|
||
return
|
||
|
||
# 获取图片信息
|
||
image = await image_dao.get(db, task.image_id)
|
||
if not image:
|
||
await update_task_status_with_retry(
|
||
db, task_id, ImageTaskStatus.FAILED,
|
||
error_message="Image not found"
|
||
)
|
||
await db.commit()
|
||
return
|
||
|
||
# 获取exclude_words
|
||
exclude_words = []
|
||
if image.details and "recognition_result" in image.details:
|
||
recognition_result = image.details["recognition_result"]
|
||
exclude_words.extend([
|
||
word for section in recognition_result.values()
|
||
for word in section.get('ref_word', [])
|
||
if isinstance(section.get('ref_word'), list)
|
||
])
|
||
|
||
# 提交当前事务,释放锁
|
||
await db.commit()
|
||
|
||
# 下载文件(在数据库事务外执行)
|
||
file_content, file_name, content_type = await file_service.download_file(task.file_id)
|
||
image_format = image_service.detect_image_format(file_content)
|
||
image_format_str = image_format.value
|
||
base64_image = base64.b64encode(file_content).decode('utf-8')
|
||
|
||
# Release database connection before making external API call
|
||
# This prevents blocking the connection pool during the long-running API call
|
||
recognize_params = QwenRecognizeImageParams(
|
||
user_id=task.user_id,
|
||
image_id=task.image_id,
|
||
file_name=file_name,
|
||
format=image_format_str,
|
||
data=base64_image,
|
||
type=task.type,
|
||
dict_level=task.dict_level,
|
||
exclude_words=list(set(exclude_words)) # 去重
|
||
)
|
||
|
||
# Make the external API call without holding a database connection
|
||
recognize_response = await Qwen.recognize_image(recognize_params)
|
||
recognition_result = recognize_response.get("result").strip().replace("```json", "").replace("```", "").strip()
|
||
|
||
# 使用新的数据库会话处理API响应
|
||
async with background_db_session() as db:
|
||
await db.begin()
|
||
|
||
if recognize_response.get("error"):
|
||
# 增加重试次数
|
||
await increment_retry_count_with_retry(db, task_id)
|
||
# 获取更新后的任务对象
|
||
updated_task = await image_task_dao.get(db, task_id)
|
||
# 如果重试次数超过限制,标记为失败
|
||
if updated_task and updated_task.retry_count >= 5:
|
||
await update_task_status_with_retry(
|
||
db, task_id, ImageTaskStatus.FAILED,
|
||
error_message=recognize_response["error"]
|
||
)
|
||
else:
|
||
# 重置为待处理状态以便重试
|
||
await update_task_status_with_retry(
|
||
db, task_id, ImageTaskStatus.PENDING
|
||
)
|
||
await db.commit()
|
||
return
|
||
|
||
# Improve JSON parsing with better error handling
|
||
try:
|
||
json_str_stripped = recognition_result.strip()
|
||
# Fixed regex pattern - removed unsupported (*SKIP)(*FAIL) syntax
|
||
# This pattern finds trailing commas before closing brackets/braces
|
||
# Handle multiple trailing commas and nested structures
|
||
pattern = r',\s*(?=[\]}])'
|
||
processed_str = json_str_stripped
|
||
# Loop to handle multiple trailing commas
|
||
while True:
|
||
new_processed = re.sub(pattern, '', processed_str)
|
||
if new_processed == processed_str: # No more changes
|
||
break
|
||
processed_str = new_processed
|
||
result = json.loads(processed_str)
|
||
except json.JSONDecodeError as e:
|
||
# Log the actual response for debugging
|
||
logger.error(f"Invalid JSON response from Qwen API: {recognition_result}")
|
||
# 增加重试次数
|
||
await increment_retry_count_with_retry(db, task_id)
|
||
# 获取更新后的任务对象
|
||
updated_task = await image_task_dao.get(db, task_id)
|
||
# 如果重试次数超过限制,标记为失败
|
||
if updated_task and updated_task.retry_count >= max_retries:
|
||
await update_task_status_with_retry(
|
||
db, task_id, ImageTaskStatus.FAILED,
|
||
error_message="Invalid JSON response from Qwen API"
|
||
)
|
||
else:
|
||
# 重置为待处理状态以便重试
|
||
await update_task_status_with_retry(
|
||
db, task_id, ImageTaskStatus.PENDING
|
||
)
|
||
await db.commit()
|
||
return
|
||
|
||
# Transform the data structure from array of objects to grouped arrays
|
||
transformed_result = {}
|
||
for level_key, level_data in result.items():
|
||
upper_level_key = str(level_key).lower()
|
||
transformed_result[upper_level_key] = level_data
|
||
|
||
# 更新image记录
|
||
if image.details:
|
||
details = image.details
|
||
else:
|
||
details = {
|
||
"embedding_model": settings.QWEN_VISION_EMBEDDING_MODEL,
|
||
"recognize_model": settings.QWEN_VISION_MODEL,
|
||
"recognition_result": {},
|
||
}
|
||
|
||
details["recognition_result"] = transformed_result
|
||
details["updated_by"] = task.user_id
|
||
|
||
await image_dao.update(
|
||
db, task.image_id,
|
||
UpdateImageParam(
|
||
details=details,
|
||
)
|
||
)
|
||
|
||
# 更新任务状态为处理中(不扣减积分)
|
||
await update_task_status_with_retry(
|
||
db, task_id, ImageTaskStatus.PROCESSING,
|
||
result=transformed_result
|
||
)
|
||
|
||
# 提交事务
|
||
await db.commit()
|
||
|
||
# If we reach here, the operation was successful
|
||
break
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error processing image recognition task {task_id}: {str(e)}")
|
||
# Handle database operations in a separate, isolated session to avoid transaction conflicts
|
||
try:
|
||
async with background_db_session() as db:
|
||
await db.begin()
|
||
# 增加重试次数
|
||
await increment_retry_count_with_retry(db, task_id)
|
||
task = await image_task_dao.get(db, task_id)
|
||
# 如果重试次数超过限制,标记为失败
|
||
if task and task.retry_count >= max_retries:
|
||
await update_task_status_with_retry(
|
||
db, task_id, ImageTaskStatus.FAILED,
|
||
error_message=str(e)
|
||
)
|
||
else:
|
||
# 重置为待处理状态以便重试
|
||
await update_task_status_with_retry(
|
||
db, task_id, ImageTaskStatus.PENDING
|
||
)
|
||
await db.commit()
|
||
except Exception as db_error:
|
||
logger.error(f"Failed to update task {task_id} after error: {str(db_error)}")
|
||
|
||
@staticmethod
|
||
async def _process_image_recognition_with_db(task_id: int, db: AsyncSession) -> None:
|
||
"""使用提供的数据库连接处理图片识别任务(不更新状态和扣减积分)"""
|
||
from backend.app.ai.tasks import update_task_status_with_retry, increment_retry_count_with_retry
|
||
|
||
# 获取任务信息
|
||
task = await image_task_dao.get(db, task_id)
|
||
if not task:
|
||
logger.error(f"Task {task_id} not found")
|
||
raise Exception(f"Task {task_id} not found")
|
||
|
||
# 获取图片信息
|
||
image = await image_dao.get(db, task.image_id)
|
||
if not image:
|
||
await update_task_status_with_retry(
|
||
db, task_id, ImageTaskStatus.FAILED,
|
||
error_message="Image not found"
|
||
)
|
||
return
|
||
|
||
# 获取exclude_words
|
||
exclude_words = []
|
||
if image.details and "recognition_result" in image.details:
|
||
recognition_result = image.details["recognition_result"]
|
||
exclude_words.extend([
|
||
word for section in recognition_result.values()
|
||
for word in section.get('ref_word', [])
|
||
if isinstance(section.get('ref_word'), list)
|
||
])
|
||
|
||
# 提交当前事务点,但不结束整个事务
|
||
# 注意:这里不能提交完整的事务,因为我们还在一个更大的事务中
|
||
# 只是刷新当前的更改
|
||
await db.flush()
|
||
|
||
# 注意:由于我们在数据库事务中,不能在这里进行外部API调用
|
||
# 外部API调用应该在事务之外进行
|
||
# 这个方法只处理数据库操作部分
|
||
|
||
# 为了保持与原流程的一致性,我们需要确保在调用此方法之前已经完成了
|
||
# 外部API调用并将结果保存到了数据库中
|
||
|
||
@staticmethod
|
||
async def get_task_status(task_id: int) -> dict:
|
||
"""获取任务状态"""
|
||
# Use a read-only session with proper isolation level to avoid locking issues
|
||
async with async_db_session() as db:
|
||
task = await image_task_dao.get(db, task_id)
|
||
if not task:
|
||
raise errors.NotFoundError(msg="Task not found")
|
||
|
||
response = {
|
||
"task_id": str(task.id),
|
||
"image_id": str(task.image_id),
|
||
"status": task.status,
|
||
"result": {},
|
||
"error_message": None
|
||
}
|
||
|
||
if task.status == ImageTaskStatus.COMPLETED:
|
||
# Handle both the new structure and potential legacy structures
|
||
if task.result:
|
||
# if isinstance(task.result, dict) and task.dict_level in task.result:
|
||
# response["result"] = task.result[task.dict_level]
|
||
# else:
|
||
# response["result"] = task.result
|
||
response["result"] = task.result
|
||
elif task.status == ImageTaskStatus.FAILED:
|
||
response["error_message"] = task.error_message
|
||
|
||
return response
|
||
|
||
# @staticmethod
|
||
# async def process_image_from_file(
|
||
# params: ProcessImageRequest,
|
||
# background_tasks: BackgroundTasks,
|
||
# request: Request
|
||
# ) -> ImageRecognizeRes:
|
||
# """通过文件ID处理图片识别"""
|
||
#
|
||
# current_user = request.user
|
||
# file_id = params.file_id
|
||
# type = params.type
|
||
# dict_level = params.dict_level.name
|
||
# if not dict_level:
|
||
# dict_level = current_user.dict_level.name
|
||
#
|
||
# # 通过file_id读取文件内容
|
||
# try:
|
||
# file_content, file_name, content_type = await file_service.download_file(file_id)
|
||
# except Exception as e:
|
||
# raise errors.NotFoundError(msg=f"文件不存在或无法读取: {str(e)}")
|
||
#
|
||
# # 提前提取图片格式
|
||
# image_format = image_service.detect_image_format(file_content)
|
||
# image_format_str = image_format.value
|
||
#
|
||
# base64_image = base64.b64encode(file_content).decode('utf-8')
|
||
#
|
||
# async with async_db_session.begin() as db:
|
||
# # 获取exclude_words
|
||
# exclude_words = []
|
||
#
|
||
# # 检查是否在image表中已有记录(根据file_id和dict_level)
|
||
# existing_image = await image_dao.get_by_file_id(db, file_id)
|
||
# if existing_image:
|
||
# recognition_result = existing_image.details["recognition_result"]
|
||
# if recognition_result and dict_level in recognition_result:
|
||
# exist_result = recognition_result.get(dict_level)
|
||
# return image_service.wrap_exist_image_with_result(existing_image, exist_result)
|
||
# else:
|
||
# exclude_words.extend([
|
||
# word for section in recognition_result.values()
|
||
# for word in section.get('ref_word', [])
|
||
# if isinstance(section.get('ref_word'), list)
|
||
# ])
|
||
# image_id = existing_image.id
|
||
# else:
|
||
# # insert image
|
||
# new_image = Image(
|
||
# file_id=file_id,
|
||
# )
|
||
#
|
||
# await image_dao.add(db, new_image)
|
||
# await db.flush() # 获取ID
|
||
# image_id = new_image.id
|
||
#
|
||
# # 生成缩略图
|
||
# background_tasks.add_task(ImageService.generate_thumbnail, image_id, file_id)
|
||
# # await image_service.generate_thumbnail(image_id, file_id)
|
||
#
|
||
# # embedding
|
||
# embed_params = QwenEmbedImageParams(
|
||
# user_id=current_user.id,
|
||
# dict_level=dict_level,
|
||
# image_id=new_image.id,
|
||
# file_name=file_name,
|
||
# format=image_format_str,
|
||
# data=base64_image,
|
||
# )
|
||
# embed_response = await Qwen.embed_image(embed_params)
|
||
# if embed_response.get("error"):
|
||
# raise Exception(embed_response["error"])
|
||
#
|
||
# embedding = embed_response.get("embedding")
|
||
#
|
||
# # 提取元数据
|
||
# additional_info = {
|
||
# "file_name": file_name,
|
||
# "content_type": content_type,
|
||
# "file_size": len(file_content),
|
||
# }
|
||
# metadata = file_service.extract_image_metadata(file_content, additional_info)
|
||
#
|
||
# # 相似图片
|
||
# similar_image_ids = await image_dao.find_similar_image_ids(db, embedding)
|
||
#
|
||
# # 2. 获取相似图片中不同dict_level的图片,从中提取ref_word
|
||
# if similar_image_ids:
|
||
# similar_image = await image_dao.get(db, similar_image_ids[0])
|
||
# if similar_image:
|
||
# recognition_result = similar_image.details["recognition_result"]
|
||
# if recognition_result and dict_level in recognition_result:
|
||
# exist_result = recognition_result.get(dict_level)
|
||
# res = image_service.wrap_exist_image_with_result(similar_image, exist_result)
|
||
# new_image.details = {
|
||
# 'created_by': current_user.id,
|
||
# 'embedding_similar': similar_image_ids,
|
||
# 'recognition_result': {dict_level: exist_result},
|
||
# }
|
||
# await image_dao.update(
|
||
# db, new_image.id,
|
||
# UpdateImageParam(
|
||
# embedding=embedding,
|
||
# info=metadata or {},
|
||
# details=new_image.details
|
||
# )
|
||
# )
|
||
# return res
|
||
# else:
|
||
# new_image.details = {
|
||
# 'created_by': current_user.id,
|
||
# 'embedding_similar': similar_image_ids,
|
||
# 'recognition_result': {},
|
||
# }
|
||
# await image_dao.update(
|
||
# db, new_image.id,
|
||
# UpdateImageParam(
|
||
# embedding=embedding,
|
||
# info=metadata or {},
|
||
# details=new_image.details
|
||
# )
|
||
# )
|
||
# exclude_words.extend([
|
||
# word for section in recognition_result.values()
|
||
# for word in section.get('ref_word', [])
|
||
# if isinstance(section.get('ref_word'), list)
|
||
# ])
|
||
#
|
||
# # recognize
|
||
# recognize_params = QwenRecognizeImageParams(
|
||
# user_id=current_user.id,
|
||
# image_id=image_id,
|
||
# file_name=file_name,
|
||
# format=image_format_str,
|
||
# data=base64_image,
|
||
# type=type,
|
||
# dict_level=dict_level,
|
||
# exclude_words=list(set(exclude_words)) # 去重
|
||
# )
|
||
#
|
||
# recognize_response = await Qwen.recognize_image(recognize_params)
|
||
# if recognize_response.get("error"):
|
||
# raise Exception(recognize_response["error"])
|
||
#
|
||
# recognition_result = recognize_response.get("result").strip().replace("```json", "").replace("```", "").strip()
|
||
#
|
||
# result = json.loads(recognition_result)
|
||
#
|
||
# # Transform the data structure from array of objects to grouped arrays
|
||
# transformed_result = {}
|
||
# for level_key, level_data in result.items():
|
||
# upper_level_key = str(level_key).upper()
|
||
# if not isinstance(level_data, list):
|
||
# # If it's not a list, keep the original structure
|
||
# transformed_result[upper_level_key] = level_data
|
||
# continue
|
||
#
|
||
# # Initialize the structure with empty lists
|
||
# transformed_result[upper_level_key] = {
|
||
# "description": [],
|
||
# "desc_ipa": [],
|
||
# "ref_word": [],
|
||
# "word_ipa": []
|
||
# }
|
||
#
|
||
# # Populate the lists while maintaining order
|
||
# for item in level_data:
|
||
# if isinstance(item, dict):
|
||
# transformed_result[upper_level_key]["description"].append(item.get("description", ""))
|
||
# transformed_result[upper_level_key]["desc_ipa"].append(item.get("desc_ipa", ""))
|
||
# transformed_result[upper_level_key]["ref_word"].append(item.get("ref_word", ""))
|
||
# transformed_result[upper_level_key]["word_ipa"].append(item.get("word_ipa", ""))
|
||
#
|
||
# processed_result = {}
|
||
# # 处理ref_word中的复数单词
|
||
# for level_key, level_data in transformed_result.items():
|
||
# upper_level_key = str(level_key).upper()
|
||
# if not isinstance(level_data, dict):
|
||
# continue
|
||
#
|
||
# if "ref_word" in level_data:
|
||
# ref_words = level_data["ref_word"]
|
||
# if isinstance(ref_words, list):
|
||
# processed_ref_words = []
|
||
# for word in ref_words:
|
||
# if isinstance(word, str):
|
||
# # 调用异步方法获取处理后的单词(如转为单数)
|
||
# processed_word = await ImageService._get_linked_word(word)
|
||
# processed_ref_words.append(processed_word)
|
||
# else:
|
||
# processed_ref_words.append(word)
|
||
# level_data["ref_word"] = processed_ref_words
|
||
# processed_result[upper_level_key] = level_data
|
||
#
|
||
# # 保留原有的值
|
||
# if existing_image:
|
||
# details = existing_image.details
|
||
# details["recognition_result"] = processed_result
|
||
# details["updated_by"] = current_user.id
|
||
# else:
|
||
# details = {
|
||
# 'created_by': current_user.id,
|
||
# "embedding_model": settings.QWEN_VISION_EMBEDDING_MODEL,
|
||
# "recognize_model": settings.QWEN_VISION_MODEL,
|
||
# "recognition_result": processed_result,
|
||
# }
|
||
#
|
||
# await image_dao.update(
|
||
# db, image_id,
|
||
# UpdateImageParam(
|
||
# details=details,
|
||
# )
|
||
# )
|
||
#
|
||
# return ImageRecognizeRes(
|
||
# id=image_id,
|
||
# res=details["recognition_result"][dict_level]
|
||
# )
|
||
|
||
@staticmethod
|
||
async def find_image(id: int) -> Image:
|
||
async with async_db_session.begin() as db:
|
||
# 检查是否在image表中已有记录
|
||
image = await image_dao.get(db, id)
|
||
return image
|
||
|
||
|
||
image_service: ImageService = ImageService()
|