fix code
This commit is contained in:
@@ -21,7 +21,6 @@ async def get_user_points_info(
|
||||
if points_info:
|
||||
balance_info = PointsBalanceInfo(
|
||||
balance=points_info.balance,
|
||||
expired_time=points_info.expired_time
|
||||
)
|
||||
return response_base.success(data=balance_info)
|
||||
return response_base.success(data=None)
|
||||
@@ -19,7 +19,6 @@ class PointsSchema(BaseModel):
|
||||
class PointsBalanceInfo(BaseModel):
|
||||
"""积分余额和过期时间信息"""
|
||||
balance: int = Field(default=0, description="当前积分余额")
|
||||
expired_time: datetime = Field(description="过期时间")
|
||||
|
||||
|
||||
class PointsLogSchema(BaseModel):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import base64
|
||||
import io
|
||||
import imghdr
|
||||
from datetime import datetime
|
||||
@@ -335,7 +336,7 @@ class FileService:
|
||||
return format_mapping.get(format_str, ImageFormat.UNKNOWN)
|
||||
|
||||
@staticmethod
|
||||
def extract_image_metadata(image_bytes: bytes, additional_info: Dict[str, Any] = None) -> ImageMetadata:
|
||||
def extract_metadata(image_bytes: bytes, additional_info: Dict[str, Any] = None) -> ImageMetadata:
|
||||
"""从图片二进制数据中提取元数据"""
|
||||
try:
|
||||
with PILImage.open(io.BytesIO(image_bytes)) as img:
|
||||
@@ -360,6 +361,12 @@ class FileService:
|
||||
value = float(value)
|
||||
except:
|
||||
pass
|
||||
# 确保值是可序列化的
|
||||
if isinstance(value, bytes):
|
||||
try:
|
||||
value = value.decode('utf-8')
|
||||
except:
|
||||
value = base64.b64encode(value).decode('utf-8')
|
||||
exif_data[decoded_tag] = value
|
||||
|
||||
# 获取颜色通道数
|
||||
@@ -374,9 +381,9 @@ class FileService:
|
||||
except:
|
||||
pass
|
||||
|
||||
# 创建元数据对象
|
||||
# 创建元数据对象,确保format是字符串值而不是枚举
|
||||
metadata = ImageMetadata(
|
||||
format=file_service.detect_image_format(image_bytes),
|
||||
format=file_service.detect_image_format(image_bytes), # 使用.value确保是字符串
|
||||
width=width,
|
||||
height=height,
|
||||
color_mode=color_mode,
|
||||
@@ -390,13 +397,19 @@ class FileService:
|
||||
if additional_info:
|
||||
for key, value in additional_info.items():
|
||||
if hasattr(metadata, key):
|
||||
# 确保设置的值是可序列化的
|
||||
if isinstance(value, bytes):
|
||||
try:
|
||||
value = value.decode('utf-8')
|
||||
except:
|
||||
value = base64.b64encode(value).decode('utf-8')
|
||||
setattr(metadata, key, value)
|
||||
|
||||
return metadata
|
||||
except Exception as e:
|
||||
# 无法解析图片时返回基础元数据
|
||||
return ImageMetadata(
|
||||
format=file_service.detect_image_format(image_bytes),
|
||||
format=file_service.detect_image_format(image_bytes), # 确保使用字符串值
|
||||
width=0,
|
||||
height=0,
|
||||
color_mode=ColorMode.UNKNOWN,
|
||||
|
||||
@@ -17,20 +17,6 @@ class PointsService:
|
||||
获取用户积分账户信息(会检查并清空过期积分)
|
||||
"""
|
||||
async with async_db_session.begin() as db:
|
||||
# 获取当前积分余额(清空前)
|
||||
points_account_before = await points_dao.get_by_user_id(db, user_id)
|
||||
balance_before = points_account_before.balance if points_account_before else 0
|
||||
|
||||
# 如果清空了过期积分,记录日志
|
||||
if expired_cleared and balance_before > 0:
|
||||
await points_log_dao.add_log(db, {
|
||||
"user_id": user_id,
|
||||
"action": "expire_clear",
|
||||
"amount": balance_before, # 记录清空前的积分数量
|
||||
"balance_after": 0,
|
||||
"details": {"message": "过期积分已清空", "cleared_amount": balance_before}
|
||||
})
|
||||
|
||||
return await points_dao.get_by_user_id(db, user_id)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -167,5 +167,32 @@ class ImageTaskCRUD(CRUDPlus[ImageProcessingTask]):
|
||||
result = await db.execute(stmt)
|
||||
return result.rowcount
|
||||
|
||||
async def update(self, db: AsyncSession, task_id: int, obj_in: dict) -> None:
|
||||
"""
|
||||
更新图像处理任务
|
||||
|
||||
:param db: 数据库会话
|
||||
:param task_id: 任务ID
|
||||
:param obj_in: 更新数据字典
|
||||
"""
|
||||
try:
|
||||
# 查询任务记录
|
||||
task = await db.get(self.model, task_id)
|
||||
if not task:
|
||||
raise ValueError(f"Image processing task with id {task_id} not found")
|
||||
|
||||
# 更新字段
|
||||
for field, value in obj_in.items():
|
||||
if hasattr(task, field):
|
||||
setattr(task, field, value)
|
||||
|
||||
# 提交更改
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
raise e
|
||||
|
||||
|
||||
image_task_dao: ImageTaskCRUD = ImageTaskCRUD(ImageProcessingTask)
|
||||
@@ -10,6 +10,7 @@ from typing import Optional, Dict, Any, List, Set, Tuple
|
||||
|
||||
from PIL import Image as PILImage, ExifTags
|
||||
from fastapi import UploadFile, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
from backend.app.ai.crud.image_curd import image_dao
|
||||
@@ -355,7 +356,7 @@ class ImageService:
|
||||
"content_type": content_type,
|
||||
"file_size": len(file_content),
|
||||
}
|
||||
metadata = file_service.extract_image_metadata(file_content, additional_info)
|
||||
metadata = file_service.extract_metadata(file_content, additional_info)
|
||||
|
||||
# 更新image记录
|
||||
await image_dao.update(
|
||||
@@ -393,24 +394,43 @@ class ImageService:
|
||||
|
||||
@staticmethod
|
||||
async def _process_image_with_limiting(task_id: int, user_id: int) -> None:
|
||||
|
||||
from backend.app.ai.tasks import update_task_status_with_retry
|
||||
|
||||
"""带限流控制的后台处理图片识别任务"""
|
||||
try:
|
||||
# 执行图片处理任务
|
||||
await ImageService._process_image(task_id)
|
||||
# 任务成功完成后更新状态
|
||||
await ImageService._update_task_status(task_id, ImageTaskStatus.COMPLETED)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image task {task_id}: {str(e)}")
|
||||
# 任务失败时更新状态
|
||||
await ImageService._update_task_status(task_id, ImageTaskStatus.FAILED, str(e))
|
||||
finally:
|
||||
# 任务完成后释放槽位
|
||||
await rate_limit_service.release_task_slot(user_id)
|
||||
# 释放槽位
|
||||
try:
|
||||
await rate_limit_service.release_task_slot(user_id)
|
||||
except Exception as slot_error:
|
||||
logger.error(f"Failed to release task slot for user {user_id}: {str(slot_error)}")
|
||||
|
||||
@staticmethod
|
||||
async def _update_task_status(task_id: int, status: ImageTaskStatus, error_message: str = None) -> None:
|
||||
"""更新任务状态"""
|
||||
try:
|
||||
async with background_db_session.begin() as db:
|
||||
task = await image_task_dao.get(db, task_id)
|
||||
if task.status != ImageTaskStatus.FAILED:
|
||||
await update_task_status_with_retry(
|
||||
db, task_id, ImageTaskStatus.COMPLETED
|
||||
)
|
||||
if not task:
|
||||
logger.warning(f"Task {task_id} not found when updating status")
|
||||
return
|
||||
|
||||
# 只有当任务不是最终状态时才更新
|
||||
if task.status not in [ImageTaskStatus.COMPLETED, ImageTaskStatus.FAILED]:
|
||||
update_data = {"status": status}
|
||||
if error_message:
|
||||
update_data["error_message"] = error_message
|
||||
|
||||
await image_task_dao.update(db, task_id, update_data)
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update task {task_id} status to {status}: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def _process_image(task_id: int) -> None:
|
||||
@@ -498,12 +518,15 @@ class ImageService:
|
||||
# Improve JSON parsing with better error handling
|
||||
try:
|
||||
json_str_stripped = recognition_result.strip()
|
||||
pattern = r'(?<!\\)"(?:\\.|[^"\\])*"(*SKIP)(*FAIL)|,\s*([\]}])'
|
||||
# 循环替换:处理多层嵌套/连续多余逗号(如 [1,,2,,,])
|
||||
# Fixed regex pattern - removed unsupported (*SKIP)(*FAIL) syntax
|
||||
# This pattern finds trailing commas before closing brackets/braces
|
||||
# Handle multiple trailing commas and nested structures
|
||||
pattern = r',\s*(?=[\]}])'
|
||||
processed_str = json_str_stripped
|
||||
# Loop to handle multiple trailing commas
|
||||
while True:
|
||||
new_processed = re.sub(pattern, r'\1', processed_str, flags=re.DOTALL)
|
||||
if new_processed == processed_str: # 无更多替换则退出
|
||||
new_processed = re.sub(pattern, '', processed_str)
|
||||
if new_processed == processed_str: # No more changes
|
||||
break
|
||||
processed_str = new_processed
|
||||
result = json.loads(processed_str)
|
||||
|
||||
@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
# Store the startup time to filter tasks
|
||||
startup_time = datetime.now()
|
||||
|
||||
|
||||
async def process_pending_tasks():
|
||||
"""
|
||||
处理待处理的任务(用于服务器重启后的恢复)
|
||||
@@ -27,31 +28,33 @@ async def process_pending_tasks():
|
||||
|
||||
max_retries = 3
|
||||
retry_delay = 1 # seconds
|
||||
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with background_db_session() as db:
|
||||
# 获取所有启动前创建的待处理的任务
|
||||
pending_tasks = await image_task_dao.get_pending_tasks_before_time(db, startup_time, limit=100)
|
||||
processing_tasks = await image_task_dao.get_processing_tasks(db, limit=100)
|
||||
|
||||
|
||||
# 处理待处理的任务
|
||||
for task in pending_tasks:
|
||||
logger.info(f"Processing pending task {task.id} (created at {task.created_time})")
|
||||
# 在后台处理任务
|
||||
asyncio.create_task(image_service._process_image_recognition(task.id))
|
||||
|
||||
# 创建包装任务来处理完成和异常情况
|
||||
task_wrapper = create_task_wrapper(image_service._process_image_recognition, task.id)
|
||||
asyncio.create_task(task_wrapper)
|
||||
|
||||
# 处理之前标记为处理中但可能因服务器重启而中断的任务
|
||||
for task in processing_tasks:
|
||||
logger.info(f"Recovering processing task {task.id}")
|
||||
# 重置为待处理状态 with retry mechanism
|
||||
await update_task_status_with_retry(db, task.id, ImageTaskStatus.PENDING)
|
||||
# 在后台处理任务
|
||||
asyncio.create_task(image_service._process_image_recognition(task.id))
|
||||
|
||||
# 创建包装任务来处理完成和异常情况
|
||||
task_wrapper = create_task_wrapper(image_service._process_image_recognition, task.id)
|
||||
asyncio.create_task(task_wrapper)
|
||||
|
||||
# If we reach here, the operation was successful
|
||||
break
|
||||
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error processing pending tasks (attempt {attempt + 1}): {str(e)}")
|
||||
if attempt < max_retries - 1:
|
||||
@@ -66,6 +69,25 @@ async def process_pending_tasks():
|
||||
raise
|
||||
|
||||
|
||||
async def create_task_wrapper(process_func, task_id):
|
||||
"""
|
||||
创建一个包装任务,确保即使发生异常也能正确更新任务状态
|
||||
"""
|
||||
try:
|
||||
await process_func(task_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing task {task_id}: {str(e)}")
|
||||
# 尝试更新任务状态为失败
|
||||
try:
|
||||
async with background_db_session() as db:
|
||||
await update_task_status_with_retry(
|
||||
db, task_id, ImageTaskStatus.FAILED,
|
||||
error_message=f"Task failed with exception: {str(e)}"
|
||||
)
|
||||
except Exception as update_error:
|
||||
logger.error(f"Failed to update task {task_id} status to FAILED: {str(update_error)}")
|
||||
|
||||
|
||||
async def update_task_status_with_retry(db, task_id, status, result=None, error_message=None, max_retries=3):
|
||||
"""
|
||||
更新任务状态,带重试机制以处理序列化错误
|
||||
|
||||
Reference in New Issue
Block a user