Files
backend/backend/app/ai/service/qa_service.py
2026-01-27 19:58:37 +08:00

1773 lines
78 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import asyncio
import json
import math
import aiohttp
import io
import hashlib
from fastapi import UploadFile
from backend.app.admin.service.file_service import file_service
from backend.app.admin.schema.file import AddFileParam, FileMetadata, UpdateFileParam
from backend.app.admin.crud.file_crud import file_dao
from backend.middleware.cos_client import CosClient
from typing import Optional, List, Dict, Any, Tuple
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from backend.database.db import async_db_session, background_db_session
from backend.app.ai.crud.qa_crud import qa_exercise_dao, qa_question_dao, qa_attempt_dao, qa_session_dao
from backend.app.ai.crud.image_task_crud import image_task_dao
from backend.app.ai.crud.image_curd import image_dao
from backend.app.ai.model.image_task import ImageTaskStatus
from backend.app.ai.schema.image_task import CreateImageTaskParam
from backend.app.admin.service.points_service import points_service
from backend.app.ai.service.rate_limit_service import rate_limit_service
from backend.common.exception import errors
from backend.core.llm import LLMFactory, AuditLogCallbackHandler
from langchain_core.messages import SystemMessage, HumanMessage
from backend.core.conf import settings
from backend.app.ai.service.recording_service import recording_service
from backend.common.const import EXERCISE_TYPE_CHOICE, EXERCISE_TYPE_CLOZE, EXERCISE_TYPE_FREE_TEXT, LLM_CHAT_COST, POINTS_ACTION_SPEND, IMAGE_GENERATION_COST
from backend.app.admin.schema.wx import DictLevel
from backend.app.ai.service.image_task_service import TaskProcessor, image_task_service
from backend.app.ai.model.image_task import ImageProcessingTask
from backend.app.ai.model.qa import QaQuestion, QaPracticeSession
from backend.core.prompts.qa_exercise import get_qa_exercise_prompt
from backend.core.prompts.recognition import get_conversation_prompt_for_image_dialogue
from backend.core.prompts.free_conversation import get_free_conversation_start_prompt, get_free_conversation_reply_prompt
from backend.app.ai.tools.qa_tool import SceneVariationGenerator, Illustrator
class QaExerciseProcessor(TaskProcessor):
async def process(self, db: AsyncSession, task: ImageProcessingTask) -> Tuple[Dict[str, Any], Dict[str, Any]]:
image = await image_dao.get(db, task.image_id)
exercise = await qa_exercise_dao.get(db, task.ref_id)
payload = {}
rr = (image.details or {}).get('recognition_result') or {}
description = ''
try:
d = rr.get('description')
if isinstance(d, str):
description = d
elif isinstance(d, list) and d:
description = d[0] if isinstance(d[0], str) else ''
except Exception:
description = ''
payload = {'description': description}
prompt = get_qa_exercise_prompt(payload)
res = await self._call_llm_chat(prompt=prompt, image_id=image.id, user_id=task.user_id, chat_type='qa_exercise')
if not res.get('success'):
raise Exception(res.get('error') or "LLM call failed")
token_usage = res.get('token_usage') or {}
items = []
try:
parsed = json.loads(res.get('result')) if isinstance(res.get('result'), str) else res.get('result')
if isinstance(parsed, dict):
items = parsed.get('qa_list') or []
elif isinstance(parsed, list):
items = parsed
except Exception:
items = []
created = 0
for it in items:
q = await qa_question_dao.create(db, {
'exercise_id': exercise.id,
'image_id': image.id,
'question': it.get('question') or '',
'payload': None,
'user_id': task.user_id,
'ext': {
'dimension': it.get('dimension'),
'key_pronunciation_words': it.get('key_pronunciation_words'),
'answers': it.get('answers'),
'cloze': it.get('cloze'),
'correct_options': it.get('correct_options'),
'incorrect_options': it.get('incorrect_options'),
},
})
created += 1
exercise.question_count = created
exercise.status = 'published' if created > 0 else 'draft'
await db.flush()
if created > 0:
existing_session = await qa_session_dao.get_latest_by_user_exercise(db, task.user_id, exercise.id)
if not existing_session:
prog = {'current_index': 0, 'answered': 0, 'correct': 0, 'attempts': [], 'total_questions': created}
await qa_session_dao.create(db, {
'exercise_id': exercise.id,
'starter_user_id': task.user_id,
'share_id': None,
'status': 'ongoing',
'started_at': datetime.now(),
'completed_at': None,
'progress': prog,
'score': None,
'ext': None,
})
await db.flush()
# Return result and token_usage.
# Note: image_task_service handles points deduction and final status update.
result = {'token_usage': token_usage, 'count': created}
return result, token_usage
async def _call_llm_chat(self, prompt: str, image_id: int, user_id: int, chat_type: str) -> Dict[str, Any]:
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content=prompt)
]
metadata = {
"image_id": image_id,
"user_id": user_id,
"api_type": chat_type,
"model_name": settings.LLM_MODEL_TYPE
}
try:
llm = LLMFactory.create_llm(settings.LLM_MODEL_TYPE)
res = await llm.ainvoke(
messages,
config={"callbacks": [AuditLogCallbackHandler(metadata=metadata)]}
)
content = res.content
if not isinstance(content, str):
content = str(content)
token_usage = {}
if res.response_metadata:
token_usage = res.response_metadata.get("token_usage") or res.response_metadata.get("usage") or {}
return {
"success": True,
"result": content,
"token_usage": token_usage
}
except Exception as e:
return {"success": False, "error": str(e)}
class SceneVariationProcessor(TaskProcessor):
async def process(self, db: AsyncSession, task: ImageProcessingTask) -> Tuple[Dict[str, Any], Dict[str, Any]]:
count, token_usage = await qa_service.generate_scene_variations(task.ref_id, task.user_id, db=db)
# Calculate extra points for generated images
image_points = count * IMAGE_GENERATION_COST
token_usage['extra_points'] = image_points
token_usage['extra_details'] = {
'image_count': count,
'image_unit_price': IMAGE_GENERATION_COST,
'source': 'scene_variation_generation'
}
return {'count': count, 'token_usage': token_usage}, token_usage
class ConversationInitProcessor(TaskProcessor):
async def process(self, db: AsyncSession, task: ImageProcessingTask) -> Tuple[Dict[str, Any], Dict[str, Any]]:
image = await image_dao.get(db, task.image_id)
if not image:
raise errors.NotFoundError(msg="Image not found")
details = dict(image.details or {})
rr = (details.get("recognition_result") or {}) if details else {}
description = ""
scene_tags: List[str] = []
try:
d = rr.get("description")
if isinstance(d, str):
description = d
elif isinstance(d, list) and d:
description = d[0] if isinstance(d[0], str) else ""
except Exception:
description = ""
try:
tags = rr.get("scene_tag")
if isinstance(tags, list):
scene_tags = [str(t) for t in tags]
elif isinstance(tags, str):
scene_tags = [tags]
except Exception:
scene_tags = []
payload = {
"description": description,
"scene_tags": scene_tags,
}
prompt = get_conversation_prompt_for_image_dialogue(payload)
res = await self._call_llm_chat(prompt=prompt, image_id=image.id, user_id=task.user_id, chat_type="image_conversation_analysis")
if not res.get("success"):
raise Exception(res.get("error") or "LLM call failed")
token_usage = res.get("token_usage") or {}
try:
parsed = json.loads(res.get("result")) if isinstance(res.get("result"), str) else res.get("result")
except Exception:
parsed = {}
image_analysis = parsed.get("image_analysis") if isinstance(parsed, dict) else None
if not isinstance(image_analysis, dict):
raise Exception("Invalid image_analysis structure")
new_details = dict(details)
new_details["conversation_analysis"] = {
"image_analysis": image_analysis,
}
image.details = new_details
try:
from sqlalchemy.orm.attributes import flag_modified
flag_modified(image, "details")
except Exception:
pass
await db.flush()
result = {"image_analysis": image_analysis, "token_usage": token_usage}
return result, token_usage
async def _call_llm_chat(self, prompt: str, image_id: int, user_id: int, chat_type: str) -> Dict[str, Any]:
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content=prompt)
]
metadata = {
"image_id": image_id,
"user_id": user_id,
"api_type": chat_type,
"model_name": settings.LLM_MODEL_TYPE
}
try:
llm = LLMFactory.create_llm(settings.LLM_MODEL_TYPE)
res = await llm.ainvoke(
messages,
config={"callbacks": [AuditLogCallbackHandler(metadata=metadata)]}
)
content = res.content
if not isinstance(content, str):
content = str(content)
token_usage = {}
if res.response_metadata:
token_usage = res.response_metadata.get("token_usage") or res.response_metadata.get("usage") or {}
return {
"success": True,
"result": content,
"token_usage": token_usage
}
except Exception as e:
return {"success": False, "error": str(e)}
class ConversationStartProcessor(TaskProcessor):
async def process(self, db: AsyncSession, task: ImageProcessingTask) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# task.ref_id is exercise_id
exercise_id = task.ref_id
exercise = await qa_exercise_dao.get(db, exercise_id)
if not exercise:
raise errors.NotFoundError(msg="Exercise not found")
image = await image_dao.get(db, exercise.image_id)
if not image:
raise errors.NotFoundError(msg="Image not found")
# Parse recognition result for description
rr = (image.details or {}).get('recognition_result') or {}
description = ''
try:
d = rr.get('description')
if isinstance(d, str):
description = d
elif isinstance(d, list) and d:
description = d[0] if isinstance(d[0], str) else ''
except Exception:
description = ''
params = exercise.ext or {}
# Helper to extract 'en' from bilingual items
def get_en_list(items):
if not items:
return []
return [item.get('en') for item in items if isinstance(item, dict) and item.get('en')]
def get_en_item(item):
if not item or not isinstance(item, dict):
return None
return item.get('en')
prompt = get_free_conversation_start_prompt(
scene=get_en_list(params.get('scene')),
event=get_en_list(params.get('event')),
user_role=get_en_item(params.get('user_role')),
assistant_role=get_en_item(params.get('assistant_role')),
style=get_en_item(params.get('style')),
level=params.get('level'),
info=params.get('info'),
description=description,
)
res = await self._call_llm_chat(prompt=prompt, image_id=image.id, user_id=task.user_id, chat_type='conversation_start')
if not res.get('success'):
raise Exception(res.get('error') or "LLM call failed")
token_usage = res.get('token_usage') or {}
try:
parsed = json.loads(res.get('result')) if isinstance(res.get('result'), str) else res.get('result')
except Exception:
parsed = {}
if not parsed or not isinstance(parsed, dict):
raise Exception("Invalid LLM response format")
# Update Exercise Status
exercise.status = 'published'
exercise.question_count = 1
# Get or Create Session
session = await qa_session_dao.get_latest_by_user_exercise(db, task.user_id, exercise.id)
if session:
session.status = 'ongoing'
prog = dict(session.progress or {})
prog['total_questions'] = 1
session.progress = prog
else:
prog = {'current_index': 0, 'answered': 0, 'correct': 0, 'attempts': [], 'total_questions': 1}
session = await qa_session_dao.create(db, {
'exercise_id': exercise.id,
'starter_user_id': task.user_id,
'status': 'ongoing',
'started_at': datetime.now(),
'progress': prog,
'ext': None,
})
# Create First Question (AI Message)
question_content = parsed.get('response_en') or ''
question_ext = {
'role': 'assistant',
'response_zh': parsed.get('response_zh'),
'prompt_en': parsed.get('prompt_en'),
'prompt_zh': parsed.get('prompt_zh'),
'alternative_responses': parsed.get('alternative_responses'),
'correction': parsed.get('correction'),
}
await qa_question_dao.create(db, {
'exercise_id': exercise.id,
'image_id': image.id,
'question': question_content,
'user_id': task.user_id,
'payload': None,
'ext': question_ext,
})
await db.flush()
result = {
'exercise_id': str(exercise.id),
'session_id': str(session.id),
'token_usage': token_usage
}
return result, token_usage
async def _call_llm_chat(self, prompt: str, image_id: int, user_id: int, chat_type: str) -> Dict[str, Any]:
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content=prompt)
]
metadata = {
"image_id": image_id,
"user_id": user_id,
"api_type": chat_type,
"model_name": settings.LLM_MODEL_TYPE
}
try:
llm = LLMFactory.create_llm(settings.LLM_MODEL_TYPE)
res = await llm.ainvoke(
messages,
config={"callbacks": [AuditLogCallbackHandler(metadata=metadata)]}
)
content = res.content
if not isinstance(content, str):
content = str(content)
token_usage = {}
if res.response_metadata:
token_usage = res.response_metadata.get("token_usage") or res.response_metadata.get("usage") or {}
return {"success": True, "result": content, "token_usage": token_usage}
except Exception as e:
return {"success": False, "error": str(e)}
class ConversationReplyProcessor(TaskProcessor):
async def process(self, db: AsyncSession, task: ImageProcessingTask) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# task.ref_id is attempt_id
attempt_id = task.ref_id
attempt = await qa_attempt_dao.get(db, attempt_id)
if not attempt:
raise errors.NotFoundError(msg="Attempt not found")
exercise = await qa_exercise_dao.get(db, attempt.exercise_id)
if not exercise:
raise errors.NotFoundError(msg="Exercise not found")
image = await image_dao.get(db, exercise.image_id)
# Get Session (to update timestamp/progress)
session = await qa_session_dao.get_latest_by_user_exercise(db, task.user_id, exercise.id)
if not session:
# Should not happen in normal flow, but maybe session closed?
# Or creating a new session?
# We assume session exists.
pass
# Parse recognition result for description
rr = (image.details or {}).get('recognition_result') or {}
description = ''
try:
d = rr.get('description')
if isinstance(d, str):
description = d
elif isinstance(d, list) and d:
description = d[0] if isinstance(d[0], str) else ''
except Exception:
description = ''
# Get history
questions = await qa_question_dao.get_by_exercise_id(db, exercise.id)
history = []
# We need to build history excluding the current attempt (which is linked to the LAST question)
# The current attempt is for the LAST question in `questions`.
# So we iterate up to the second to last question for pairs,
# and then handle the last question.
# Actually, `questions` includes ALL questions.
# The attempt we just created is for `questions[-1]`.
# So `questions[-1]` is the AI message we are replying to.
# We need to construct history for the prompt.
# The prompt expects:
# History:
# AI: ...
# User: ...
# AI: ...
# (Current AI message is NOT in history of prompt?
# Wait, prompt says "History" then "User's New Input".
# If I am replying to AI's "Hello", "Hello" should be in history?
# Usually yes.
# The `get_free_conversation_reply_prompt` implementation:
# history_str += f"{role}: {content}\n"
# User's New Input is appended.
# So `history` should contain everything BEFORE User's New Input.
# That includes the AI message user is replying to.
for i, q in enumerate(questions):
# AI Message
history.append({
'role': 'assistant',
'content': q.question
})
# If this is the last question, it's the one we are replying to.
# We don't add a user attempt for it in `history` list,
# because the "User's New Input" IS the attempt.
if i == len(questions) - 1:
break
# For previous questions, find their attempt
prev_attempt = await qa_attempt_dao.get_latest_completed_by_user_question(db, task.user_id, q.id)
if prev_attempt:
history.append({
'role': 'user',
'content': prev_attempt.input_text
})
user_input = attempt.input_text or ''
params = exercise.ext or {}
# Helper to extract 'en' from bilingual items
def get_en_list(items):
if not items:
return []
return [item.get('en') for item in items if isinstance(item, dict) and item.get('en')]
def get_en_item(item):
if not item or not isinstance(item, dict):
return None
return item.get('en')
prompt = get_free_conversation_reply_prompt(
history=history,
user_input=user_input,
scene=get_en_list(params.get('scene')),
event=get_en_list(params.get('event')),
user_role=get_en_item(params.get('user_role')),
assistant_role=get_en_item(params.get('assistant_role')),
style=get_en_item(params.get('style')),
level=params.get('level'),
info=params.get('info'),
description=description,
)
res = await self._call_llm_chat(prompt=prompt, image_id=image.id, user_id=task.user_id, chat_type='conversation_reply')
if not res.get('success'):
raise Exception(res.get('error') or "LLM call failed")
token_usage = res.get('token_usage') or {}
try:
parsed = json.loads(res.get('result')) if isinstance(res.get('result'), str) else res.get('result')
except Exception:
parsed = {}
if not parsed or not isinstance(parsed, dict):
raise Exception("Invalid LLM response format")
# Update attempt with correction
correction = parsed.get('correction')
if correction:
new_ext = dict(attempt.ext or {})
new_ext['correction'] = correction
attempt.ext = new_ext
attempt.evaluation = {'correction': correction}
attempt.status = 'completed' # It was pending
# Create New Question (AI Response)
question_content = parsed.get('response_en') or ''
question_ext = {
'role': 'assistant',
'response_zh': parsed.get('response_zh'),
'prompt_en': parsed.get('prompt_en'),
'prompt_zh': parsed.get('prompt_zh'),
'alternative_responses': parsed.get('alternative_responses'),
'correction': correction,
}
new_question = await qa_question_dao.create(db, {
'exercise_id': exercise.id,
'image_id': image.id,
'question': question_content,
'user_id': task.user_id,
'payload': None,
'ext': question_ext,
})
# Update Session
if session:
session.updated_at = datetime.now()
prog = dict(session.progress or {})
prog['total_questions'] = (prog.get('total_questions') or 0) + 1
session.progress = prog
await db.flush()
result = {
'session_id': str(session.id) if session else '',
'new_question_id': str(new_question.id),
'token_usage': token_usage
}
return result, token_usage
async def _call_llm_chat(self, prompt: str, image_id: int, user_id: int, chat_type: str) -> Dict[str, Any]:
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content=prompt)
]
metadata = {
"image_id": image_id,
"user_id": user_id,
"api_type": chat_type,
"model_name": settings.LLM_MODEL_TYPE
}
try:
llm = LLMFactory.create_llm(settings.LLM_MODEL_TYPE)
res = await llm.ainvoke(
messages,
config={"callbacks": [AuditLogCallbackHandler(metadata=metadata)]}
)
content = res.content
if not isinstance(content, str):
content = str(content)
token_usage = {}
if res.response_metadata:
token_usage = res.response_metadata.get("token_usage") or res.response_metadata.get("usage") or {}
return {"success": True, "result": content, "token_usage": token_usage}
except Exception as e:
return {"success": False, "error": str(e)}
class QaService:
async def get_conversation_setting(self, image_id: int, user_id: int) -> Optional[Dict[str, Any]]:
async with async_db_session() as db:
task = await image_task_dao.get_by_image_id(db, image_id)
if not task or task.user_id != user_id:
raise errors.ForbiddenError(msg="Forbidden")
image = await image_dao.get(db, image_id)
if not image:
raise errors.NotFoundError(msg="Image not found")
details = dict(image.details or {})
existing = details.get("conversation_analysis") or {}
existing_analysis = existing.get("image_analysis")
if not isinstance(existing_analysis, dict):
return None
# Find latest conversation session
latest_session_info = None
session = await qa_session_dao.get_latest_session_by_image_user(db, user_id, image_id, exercise_type='free_conversation')
if session:
latest_session_info = {
'session_id': str(session.id),
'status': session.status,
'updated_at': session.completed_at.isoformat() if session.completed_at else (session.started_at.isoformat() if session.started_at else None),
'exercise_id': str(session.exercise_id),
}
return {
"image_id": image_id,
"setting": existing_analysis,
"latest_session": latest_session_info,
}
async def start_conversation(
self,
image_id: int,
user_id: int,
scene: List[Dict[str, str]],
event: List[Dict[str, str]],
style: Optional[Dict[str, str]] = None,
user_role: Optional[Dict[str, str]] = None,
assistant_role: Optional[Dict[str, str]] = None,
level: Optional[str] = None,
info: Optional[str] = None,
) -> Dict[str, Any]:
# Check points and rate limit
if not await points_service.check_sufficient_points(user_id, LLM_CHAT_COST):
raise errors.ForbiddenError(msg='积分不足,请获取积分后继续使用')
slot_acquired = await rate_limit_service.acquire_task_slot(user_id)
if not slot_acquired:
max_tasks = await rate_limit_service.get_user_task_limit(user_id)
raise errors.ForbiddenError(msg=f'用户同时最多只能运行 {max_tasks} 个任务,请等待现有任务完成后再试')
async with async_db_session.begin() as db:
image = await image_dao.get(db, image_id)
if not image:
raise errors.NotFoundError(msg="Image not found")
# Create Exercise
exercise = await qa_exercise_dao.create(db, {
"image_id": image_id,
"created_by": user_id,
"type": "free_conversation",
"description": None,
"status": "ongoing",
"ext": {
"scene": scene,
"event": event,
"user_role": user_role,
"assistant_role": assistant_role,
"style": style,
"level": level,
"info": info,
},
})
await db.flush()
# Create Session (Pre-create to return session_id immediately)
prog = {'current_index': 0, 'answered': 0, 'correct': 0, 'attempts': [], 'total_questions': 0}
session = await qa_session_dao.create(db, {
'exercise_id': exercise.id,
'starter_user_id': user_id,
'status': 'initializing',
'started_at': datetime.now(),
'progress': prog,
'ext': None,
})
await db.flush()
# Create Task
task = await image_task_dao.create_task(db, CreateImageTaskParam(
image_id=image_id,
user_id=user_id,
dict_level=(getattr(getattr(image, 'dict_level', None), 'name', None) or 'LEVEL1'),
ref_type='qa_exercise',
ref_id=exercise.id,
status=ImageTaskStatus.PENDING,
))
await db.flush()
task_id = task.id
# Dispatch Task
asyncio.create_task(image_task_service.process_task(task_id, user_id, ConversationStartProcessor()))
return {
"task_id": str(task_id),
"status": "processing",
"exercise_id": str(exercise.id),
"session_id": str(session.id)
}
async def reply_conversation(
self,
session_id: int,
user_id: int,
input_text: str,
) -> Dict[str, Any]:
# Check points and rate limit
if not await points_service.check_sufficient_points(user_id, LLM_CHAT_COST):
raise errors.ForbiddenError(msg='积分不足,请获取积分后继续使用')
slot_acquired = await rate_limit_service.acquire_task_slot(user_id)
if not slot_acquired:
max_tasks = await rate_limit_service.get_user_task_limit(user_id)
raise errors.ForbiddenError(msg=f'用户同时最多只能运行 {max_tasks} 个任务,请等待现有任务完成后再试')
async with async_db_session.begin() as db:
session = await qa_session_dao.get(db, session_id)
if not session:
raise errors.NotFoundError(msg="Session not found")
if session.starter_user_id != user_id:
raise errors.ForbiddenError(msg="Forbidden")
exercise = await qa_exercise_dao.get(db, session.exercise_id)
# Create Attempt (User Input)
# Link to the last question (the one AI asked)
last_question = await qa_question_dao.get_latest_by_exercise_id(db, exercise.id)
if not last_question:
raise errors.ServerError(msg="No question to reply to")
attempt = await qa_attempt_dao.create(db, {
"user_id": user_id,
"question_id": last_question.id,
"exercise_id": exercise.id,
"input_text": input_text,
"status": "pending",
"evaluation": None,
"ext": None
})
await db.flush()
# Create Task
task = await image_task_dao.create_task(db, CreateImageTaskParam(
image_id=exercise.image_id,
user_id=user_id,
dict_level='LEVEL1', # Default or fetch from image
ref_type='qa_attempt',
ref_id=attempt.id,
status=ImageTaskStatus.PENDING,
))
await db.flush()
task_id = task.id
asyncio.create_task(image_task_service.process_task(task_id, user_id, ConversationReplyProcessor()))
return {
"task_id": str(task_id),
"status": "processing",
"session_id": str(session.id)
}
async def _get_messages_for_session(self, db: AsyncSession, exercise_id: int, user_id: int) -> List[Dict[str, Any]]:
questions = await qa_question_dao.get_by_exercise_id(db, exercise_id)
messages = []
total_questions = len(questions)
for idx, q in enumerate(questions):
is_last = (idx == total_questions - 1)
ext = q.ext or {}
# AI Message
if is_last:
# Full content for the last message
content = {
"response_en": q.question,
"response_zh": ext.get("response_zh"),
"prompt_en": ext.get("prompt_en"),
"prompt_zh": ext.get("prompt_zh"),
"alternative_responses": ext.get("alternative_responses"),
"correction": ext.get("correction"),
}
else:
# Simplified content for historical messages
content = {
"response_en": q.question,
"response_zh": ext.get("response_zh"),
}
messages.append({
"role": "assistant",
"content": content
})
# User Reply (Attempt)
attempt = await qa_attempt_dao.get_latest_completed_by_user_question(db, user_id, q.id)
if attempt:
messages.append({
"role": "user",
"content": {
"text": attempt.input_text,
"correction": (attempt.evaluation or {}).get("correction")
}
})
return messages
async def get_latest_messages(self, session_id: int, user_id: int) -> Dict[str, Any]:
async with async_db_session() as db:
session = await qa_session_dao.get(db, session_id)
if not session or session.starter_user_id != user_id:
raise errors.NotFoundError(msg="Session not found")
# Optimization: Directly fetch the latest question from DB to avoid loading full history
latest_q = await qa_question_dao.get_latest_by_exercise_id(db, session.exercise_id)
latest_messages = []
if latest_q:
ext = latest_q.ext or {}
latest_messages.append({
"role": "assistant",
"content": {
"response_en": latest_q.question,
"response_zh": ext.get("response_zh"),
"prompt_en": ext.get("prompt_en"),
"prompt_zh": ext.get("prompt_zh"),
"alternative_responses": ext.get("alternative_responses"),
"correction": ext.get("correction"),
}
})
return {
"session_id": str(session_id),
"messages": latest_messages
}
async def get_conversation_session(self, session_id: int, user_id: int) -> Dict[str, Any]:
async with async_db_session() as db:
session = await qa_session_dao.get(db, session_id)
if not session or session.starter_user_id != user_id:
raise errors.NotFoundError(msg="Session not found")
exercise = await qa_exercise_dao.get(db, session.exercise_id)
messages = await self._get_messages_for_session(db, session.exercise_id, user_id)
return {
"exercise_id": str(exercise.id),
"session_id": str(session.id),
"status": session.status,
"updated_at": (exercise.updated_time.isoformat() if getattr(exercise, "updated_time", None) else None),
"messages": messages,
}
async def list_conversations_by_image(
self,
image_id: int,
user_id: int,
page: int = 1,
page_size: int = 10
) -> Dict[str, Any]:
async with async_db_session() as db:
total, results = await qa_session_dao.get_list_by_image(
db, image_id, user_id, page, page_size
)
items = []
for session, exercise in results:
ext = exercise.ext or {}
items.append({
"session_id": str(session.id),
"scene": ext.get("scene") or [],
"event": ext.get("event") or [],
"user_role": ext.get("user_role") or {},
"style": ext.get("style") or {},
"created_at": (session.created_time.strftime("%Y-%m-%d %H:%M:%S") if session.created_time else None),
})
return {
"total": total,
"items": items,
"page": page,
"page_size": page_size,
}
async def recognize_audio(self, file_id: int, user_id: int, session_id: int) -> Dict[str, Any]:
"""识别音频内容并关联到会话"""
from backend.middleware.qwen import Qwen
# 1. 验证会话
exercise_id = None
image_id = None
async with async_db_session.begin() as db:
session = await qa_session_dao.get(db, session_id)
if not session:
raise errors.NotFoundError(msg="Session not found")
if session.starter_user_id != user_id:
raise errors.ForbiddenError(msg="Forbidden")
exercise_id = session.exercise_id
exercise = await qa_exercise_dao.get(db, exercise_id)
if exercise:
image_id = exercise.image_id
# 2. 获取文件信息
file_obj = await file_service.get_file(file_id)
if not file_obj:
raise errors.NotFoundError(msg="文件不存在")
# 获取文件访问路径或URL
audio_url = ""
temp_file_path = None
try:
if file_obj.storage_type == "cos":
# 获取预签名下载URL
audio_url = await file_service.get_presigned_download_url(file_id, user_id)
else:
# 数据库存储,写入临时文件
content, _, _ = await file_service.download_file(file_id)
import tempfile
import os
# 创建临时文件注意需要保留后缀以便Qwen识别格式
ext = "mp3"
if file_obj.content_type:
ext = file_service._mime_to_ext(file_obj.content_type, file_obj.file_name)
tf = tempfile.NamedTemporaryFile(delete=False, suffix=f".{ext}")
tf.write(content)
tf.close()
temp_file_path = tf.name
audio_url = temp_file_path
# 调用Qwen ASR
res = await Qwen.recognize_speech(audio_url, user_id=user_id, image_id=image_id)
if not res.get("success"):
raise errors.ServerError(msg=res.get("error") or "ASR failed")
text = res.get("text", "")
# 3. 更新会话最新消息
async with async_db_session.begin() as db:
# Find latest question (LLM response)
last_question = await qa_question_dao.get_latest_by_exercise_id(db, exercise_id)
if last_question:
ext = dict(last_question.ext or {})
audio_recognitions = ext.get('audio_recognitions')
if not isinstance(audio_recognitions, list):
audio_recognitions = []
audio_recognitions.append({
'audio_id': str(file_id),
'text': text
})
ext['audio_recognitions'] = audio_recognitions
last_question.ext = ext
await db.flush()
return {"text": text}
finally:
# 清理临时文件
if temp_file_path:
import os
if os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
except Exception:
pass
async def create_exercise_task(self, image_id: int, user_id: int, type: Optional[str] = "scene_basic") -> Dict[str, Any]:
is_conversation_init = type == 'init_conversion'
async with async_db_session.begin() as db:
# Check for existing active task
ref_type_for_lookup = 'image_conversation_analysis' if is_conversation_init else 'qa_exercise'
latest_task = await image_task_dao.get_latest_active_task(db, user_id, image_id, ref_type_for_lookup)
if latest_task:
return {'task_id': str(latest_task.id), 'status': latest_task.status}
if not await points_service.check_sufficient_points(user_id, LLM_CHAT_COST):
raise errors.ForbiddenError(msg='积分不足,请获取积分后继续使用')
slot_acquired = await rate_limit_service.acquire_task_slot(user_id)
if not slot_acquired:
max_tasks = await rate_limit_service.get_user_task_limit(user_id)
raise errors.ForbiddenError(msg=f'用户同时最多只能运行 {max_tasks} 个任务,请等待现有任务完成后再试')
async with async_db_session.begin() as db:
image = await image_dao.get(db, image_id)
if not image:
raise errors.NotFoundError(msg='Image not found')
if is_conversation_init:
ref_type = 'image_conversation_analysis'
ref_id = image_id
else:
exercise = await qa_exercise_dao.create(db, {
'image_id': image_id,
'created_by': user_id,
'type': type,
'description': None,
'status': 'draft',
'ext': None
})
await db.flush()
ref_type = 'qa_exercise'
ref_id = exercise.id
task = await image_task_dao.create_task(db, CreateImageTaskParam(
image_id=image_id,
user_id=user_id,
dict_level=(getattr(getattr(image, 'dict_level', None), 'name', None) or 'LEVEL1'),
ref_type=ref_type,
ref_id=ref_id,
status=ImageTaskStatus.PENDING,
))
await db.flush()
task_id = task.id
await db.commit()
if type == 'scene_variation':
processor = SceneVariationProcessor()
elif is_conversation_init:
processor = ConversationInitProcessor()
else:
processor = QaExerciseProcessor()
asyncio.create_task(image_task_service.process_task(task_id, user_id, processor))
return {'task_id': str(task_id), 'status': 'accepted'}
async def get_task_status(self, task_id: int) -> Dict[str, Any]:
async with async_db_session() as db:
task = await image_task_dao.get(db, task_id)
if not task:
raise errors.NotFoundError(msg='Task not found')
return {
'task_id': str(task.id),
'image_id': str(task.image_id),
'ref_type': task.ref_type,
'ref_id': str(task.ref_id),
'status': task.status,
'error_message': task.error_message,
}
async def list_exercises_by_image(self, image_id: int, user_id: Optional[int] = None, type: Optional[str] = "scene_basic") -> Optional[Dict[str, Any]]:
async with async_db_session() as db:
image = await image_dao.get(db, image_id)
if not image:
return None
i = await qa_exercise_dao.get_latest_by_image_id(db, image_id, type=type)
if not i:
return None
qs = await qa_question_dao.get_by_exercise_id(db, i.id)
session = None
if user_id:
s = await qa_session_dao.get_latest_by_user_exercise(db, user_id, i.id)
if s:
session = {
'id': str(s.id),
'started_at': s.started_at.isoformat() if s.started_at else None,
'progress': s.progress,
}
ret = {
'exercise': {
'id': str(i.id),
'image_id': str(i.image_id),
'type': i.type,
'description': i.description,
'status': i.status,
'question_count': i.question_count,
},
'session': session,
'questions': [
{
'id': str(q.id),
'exercise_id': str(q.exercise_id),
'image_id': str(q.image_id),
'question': q.question,
'ext': q.ext,
} for q in qs
]
}
return ret
def _evaluate_choice(self, q: QaQuestion, selected_options: List[str]) -> Tuple[Dict[str, Any], str, List[str]]:
ext = q.ext or {}
raw_correct = ext.get('correct_options') or []
raw_incorrect = ext.get('incorrect_options') or []
def _norm(v):
try:
return str(v).strip().lower()
except Exception:
return str(v)
correct_set = set(_norm(o.get('content') if isinstance(o, dict) else o) for o in raw_correct)
incorrect_map = {}
for o in raw_incorrect:
c = _norm(o.get('content') if isinstance(o, dict) else o)
if isinstance(o, dict):
incorrect_map[c] = {
'content': o.get('content'),
'error_type': o.get('error_type'),
'error_reason': o.get('error_reason')
}
else:
incorrect_map[c] = {'content': o, 'error_type': None, 'error_reason': None}
selected_list = list(selected_options or [])
selected = set(_norm(s) for s in selected_list)
if not selected:
is_correct = 'incorrect'
result_text = '完全错误'
evaluation = {'type': 'choice', 'result': result_text, 'detail': 'no selection', 'selected': {'correct': [], 'incorrect': []}, 'missing_correct': [o.get('content') if isinstance(o, dict) else o for o in raw_correct]}
else:
selected_correct = []
for o in raw_correct:
c = _norm(o.get('content') if isinstance(o, dict) else o)
if c in selected:
selected_correct.append(o.get('content') if isinstance(o, dict) else o)
selected_incorrect = []
for s in selected_list:
ns = _norm(s)
if ns not in correct_set:
detail = incorrect_map.get(ns)
if detail:
selected_incorrect.append(detail)
else:
selected_incorrect.append({'content': s, 'error_type': 'unknown', 'error_reason': None})
missing_correct = []
for o in raw_correct:
c = _norm(o.get('content') if isinstance(o, dict) else o)
if c not in selected:
missing_correct.append(o.get('content') if isinstance(o, dict) else o)
if selected == correct_set and not selected_incorrect:
is_correct = 'correct'
result_text = '完全匹配'
evaluation = {'type': 'choice', 'result': result_text, 'detail': is_correct, 'selected': {'correct': selected_correct, 'incorrect': []}, 'missing_correct': []}
elif selected_correct:
is_correct = 'partial'
result_text = '部分匹配'
evaluation = {'type': 'choice', 'result': result_text, 'detail': is_correct, 'selected': {'correct': selected_correct, 'incorrect': selected_incorrect}, 'missing_correct': missing_correct}
else:
is_correct = 'incorrect'
result_text = '完全错误'
evaluation = {'type': 'choice', 'result': result_text, 'detail': is_correct, 'selected': {'correct': [], 'incorrect': selected_incorrect}, 'missing_correct': [o.get('content') if isinstance(o, dict) else o for o in raw_correct]}
return evaluation, is_correct, selected_list
def _evaluate_cloze(self, q: QaQuestion, cloze_options: Optional[List[str]]) -> Tuple[Dict[str, Any], str, str]:
ext = q.ext or {}
cloze = ext.get('cloze') or {}
correct_word = cloze.get('correct_word')
# Support multiple selections: treat as correct if any selected matches a correct answer
selection_list = [s for s in (cloze_options or []) if isinstance(s, str) and s.strip()]
input_str = selection_list[0] if selection_list else ''
def _norm(v):
try:
return str(v).strip().lower()
except Exception:
return str(v)
# correct answers may be a single string or a list
correct_candidates = []
if isinstance(correct_word, list):
correct_candidates = [cw for cw in correct_word if isinstance(cw, str) and cw.strip()]
elif isinstance(correct_word, str) and correct_word.strip():
correct_candidates = [correct_word]
correct_set = set(_norm(cw) for cw in correct_candidates)
user_correct = []
user_incorrect = []
for s in selection_list:
if _norm(s) in correct_set:
user_correct.append(s)
else:
user_incorrect.append({'content': s, 'error_type': None, 'error_reason': None})
if user_correct and not user_incorrect:
is_correct = 'correct'
result_text = '完全匹配'
evaluation = {'type': 'cloze', 'result': result_text, 'detail': is_correct, 'selected': {'correct': user_correct, 'incorrect': []}, 'missing_correct': []}
elif user_correct:
is_correct = 'partial'
result_text = '部分匹配'
evaluation = {'type': 'cloze', 'result': result_text, 'detail': is_correct, 'selected': {'correct': user_correct, 'incorrect': user_incorrect}, 'missing_correct': []}
else:
is_correct = 'incorrect'
result_text = '完全错误'
evaluation = {'type': 'cloze', 'result': result_text, 'detail': is_correct, 'selected': {'correct': [], 'incorrect': user_incorrect}, 'missing_correct': [cw for cw in correct_candidates]}
return evaluation, is_correct, input_str
async def submit_attempt(self, question_id: int, exercise_id: int, user_id: int, mode: str, selected_options: Optional[List[str]] = None, input_text: Optional[str] = None, cloze_options: Optional[List[str]] = None, session_id: Optional[int] = None, is_trial: bool = False) -> Dict[str, Any]:
async with async_db_session.begin() as db:
q = await qa_question_dao.get(db, question_id)
if not q or q.exercise_id != exercise_id:
raise errors.NotFoundError(msg='Question not found')
# Optimization: If trial mode and synchronous evaluation (Choice/Cloze), skip DB persistence
if is_trial:
if mode == EXERCISE_TYPE_CHOICE:
evaluation, _, selected_list = self._evaluate_choice(q, selected_options)
return {
'session_id': None,
'type': 'choice',
'choice': {
'options': selected_list,
'evaluation': evaluation
}
}
elif mode == EXERCISE_TYPE_CLOZE:
c_opts = cloze_options
if not c_opts and input_text:
c_opts = [input_text]
evaluation, _, input_str = self._evaluate_cloze(q, c_opts)
return {
'session_id': None,
'type': 'cloze',
'cloze': {
'input': input_str,
'evaluation': evaluation
}
}
recording_id = None
attempt = await qa_attempt_dao.get_latest_by_user_question(db, user_id=user_id, question_id=question_id)
if attempt:
attempt.task_id = None
attempt.choice_options = selected_options if mode == EXERCISE_TYPE_CHOICE else attempt.choice_options
if mode == EXERCISE_TYPE_CLOZE:
if isinstance(cloze_options, list) and cloze_options:
attempt.cloze_options = cloze_options[0]
elif input_text:
attempt.cloze_options = input_text
attempt.input_text = input_text if mode == EXERCISE_TYPE_FREE_TEXT else attempt.input_text
attempt.status = 'pending'
ext0 = attempt.ext or {}
if session_id:
ext0['session_id'] = session_id
if is_trial:
ext0['is_trial'] = True
elif 'is_trial' in ext0:
del ext0['is_trial']
attempt.ext = ext0
await db.flush()
else:
attempt = await qa_attempt_dao.create(db, {
'question_id': question_id,
'exercise_id': exercise_id,
'user_id': user_id,
'task_id': None,
'recording_id': recording_id,
'choice_options': selected_options if mode == EXERCISE_TYPE_CHOICE else None,
'cloze_options': ((cloze_options[0] if isinstance(cloze_options, list) and cloze_options else (input_text if input_text else None)) if mode == EXERCISE_TYPE_CLOZE else None),
'input_text': input_text if mode == EXERCISE_TYPE_FREE_TEXT else None,
'status': 'pending',
'evaluation': None,
'ext': {'is_trial': True} if is_trial else None,
})
if not is_trial:
s = await qa_session_dao.get_latest_by_user_exercise(db, user_id, exercise_id)
if s and s.exercise_id == exercise_id:
prog = dict(s.progress or {})
attempts = list(prog.get('attempts') or [])
replaced = False
for idx, a in enumerate(attempts):
if a.get('question_id') == question_id and a.get('mode') == mode:
attempts[idx] = {
'attempt_id': attempt.id,
'question_id': str(question_id),
'mode': mode,
'created_at': datetime.now().isoformat(),
'is_correct': a.get('is_correct'),
}
replaced = True
break
if not replaced:
attempts.append({
'attempt_id': attempt.id,
'question_id': str(question_id),
'mode': mode,
'created_at': datetime.now().isoformat(),
'is_correct': None,
})
prog['answered'] = int(prog.get('answered') or 0) + 1
prog['attempts'] = attempts
s.progress = prog
attempt.ext = {**(attempt.ext or {}), 'session_id': s.id}
await db.flush()
if mode == EXERCISE_TYPE_FREE_TEXT:
attempt.ext = {**(attempt.ext or {}), 'type': 'free_text', 'free_text': {'text': attempt.input_text or '', 'evaluation': None}}
await db.flush()
async with async_db_session.begin() as db2:
task = await image_task_dao.create_task(db2, CreateImageTaskParam(
image_id=q.image_id,
user_id=user_id,
dict_level=DictLevel.LEVEL1.value,
ref_type='qa_question_attempt',
ref_id=attempt.id,
status=ImageTaskStatus.PENDING,
))
await db2.flush()
asyncio.create_task(self._process_attempt_evaluation(task.id, user_id))
session_id_val = (attempt.ext or {}).get('session_id')
return {
'session_id': str(session_id_val) if session_id_val is not None else None,
'type': 'free_text',
'free_text': {
'text': attempt.input_text or '',
'evaluation': None
}
}
# Synchronous evaluation for choice/cloze/variation
if mode == EXERCISE_TYPE_CHOICE:
evaluation, is_correct, selected_list = self._evaluate_choice(q, attempt.choice_options)
attempt.ext = {**(attempt.ext or {}), 'type': 'choice', 'choice': {'options': selected_list, 'evaluation': evaluation}}
await db.flush()
merged_eval = dict(attempt.evaluation or {})
merged_eval['choice'] = {'options': selected_list, 'evaluation': evaluation}
await qa_attempt_dao.update_status(db, attempt.id, 'completed', merged_eval)
if not is_trial:
s = await qa_session_dao.get_latest_by_user_exercise(db, user_id, exercise_id)
if s and s.exercise_id == attempt.exercise_id:
prog = dict(s.progress or {})
attempts = list(prog.get('attempts') or [])
prev = None
for a in attempts:
if a.get('attempt_id') == attempt.id:
prev = a.get('is_correct')
a['is_correct'] = is_correct
break
prev_correct = 1 if prev == 'correct' else 0
new_correct = 1 if is_correct == 'correct' else 0
correct_inc = new_correct - prev_correct
prog['attempts'] = attempts
prog['correct'] = int(prog.get('correct') or 0) + correct_inc
s.progress = prog
await db.flush()
await db.commit()
session_id_val = (attempt.ext or {}).get('session_id')
return {
'session_id': str(session_id_val) if session_id_val is not None else None,
'type': 'choice',
'choice': {
'options': selected_list,
'evaluation': evaluation
}
}
if mode == EXERCISE_TYPE_CLOZE:
c_opts: List[str] = []
if isinstance(cloze_options, list) and cloze_options:
c_opts = cloze_options
elif input_text:
c_opts = [input_text]
elif attempt.cloze_options:
c_opts = [attempt.cloze_options]
if cloze_options:
c_opts = cloze_options
evaluation, is_correct, input_str = self._evaluate_cloze(q, c_opts)
attempt.ext = {**(attempt.ext or {}), 'type': 'cloze', 'cloze': {'input': input_str, 'evaluation': evaluation}}
await db.flush()
merged_eval = dict(attempt.evaluation or {})
merged_eval['cloze'] = {'input': input_str, 'evaluation': evaluation}
await qa_attempt_dao.update_status(db, attempt.id, 'completed', merged_eval)
if not is_trial:
s = await qa_session_dao.get_latest_by_user_exercise(db, user_id, exercise_id)
if s and s.exercise_id == attempt.exercise_id:
prog = dict(s.progress or {})
attempts = list(prog.get('attempts') or [])
prev = None
for a in attempts:
if a.get('attempt_id') == attempt.id:
prev = a.get('is_correct')
a['is_correct'] = is_correct
break
prev_correct = 1 if prev == 'correct' else 0
new_correct = 1 if is_correct == 'correct' else 0
correct_inc = new_correct - prev_correct
prog['attempts'] = attempts
prog['correct'] = int(prog.get('correct') or 0) + correct_inc
s.progress = prog
await db.flush()
await db.commit()
session_id_val = (attempt.ext or {}).get('session_id')
return {
'session_id': str(session_id_val) if session_id_val is not None else None,
'type': 'cloze',
'cloze': {
'input': input_str,
'evaluation': evaluation
}
}
if mode == 'variation':
ext_q = q.ext or {}
correct_file_id = ext_q.get('file_id')
# Get user selected file_id from selected_options
user_file_id = None
if selected_options and len(selected_options) > 0:
try:
user_file_id = selected_options[0]
except (ValueError, TypeError):
user_file_id = None
is_correct = 'incorrect'
if user_file_id is not None and correct_file_id is not None and int(user_file_id) == int(correct_file_id):
is_correct = 'correct'
evaluation = {'type': 'variation', 'detail':is_correct, 'result': is_correct, 'correct_file_id': correct_file_id, 'user_file_id': user_file_id}
attempt.ext = {**(attempt.ext or {}), 'type': 'variation', 'variation': {'file_id': user_file_id, 'evaluation': evaluation}}
await db.flush()
merged_eval = dict(attempt.evaluation or {})
merged_eval['variation'] = {'file_id': user_file_id, 'evaluation': evaluation}
await qa_attempt_dao.update_status(db, attempt.id, 'completed', merged_eval)
if not is_trial:
s = await qa_session_dao.get_latest_by_user_exercise(db, user_id, exercise_id)
if s and s.exercise_id == attempt.exercise_id:
prog = dict(s.progress or {})
attempts = list(prog.get('attempts') or [])
prev = None
for a in attempts:
if a.get('attempt_id') == attempt.id:
prev = a.get('is_correct')
a['is_correct'] = is_correct
break
prev_correct = 1 if prev == 'correct' else 0
new_correct = 1 if is_correct == 'correct' else 0
correct_inc = new_correct - prev_correct
prog['attempts'] = attempts
prog['correct'] = int(prog.get('correct') or 0) + correct_inc
s.progress = prog
await db.flush()
await db.commit()
session_id_val = (attempt.ext or {}).get('session_id')
return {
'session_id': str(session_id_val) if session_id_val is not None else None,
'type': 'variation',
'variation': {
'file_id': user_file_id,
'evaluation': evaluation
}
}
async def _process_attempt_evaluation(self, task_id: int, user_id: int):
async with background_db_session() as db:
task = await image_task_dao.get(db, task_id)
if not task:
return
await image_task_dao.update_task_status(db, task_id, ImageTaskStatus.PROCESSING)
attempt = await qa_attempt_dao.get(db, task.ref_id)
if not attempt:
await image_task_dao.update_task_status(db, task_id, ImageTaskStatus.FAILED, error_message='Attempt not found')
await db.commit()
return
is_trial = (attempt.ext or {}).get('is_trial', False)
# Only async evaluation for free_text/audio attempts
q = await qa_question_dao.get(db, attempt.question_id)
user_text = attempt.input_text or ''
answers = (q.ext or {}).get('answers') or {}
prompt = (
'根据给定标准答案判断用户回答是否正确输出JSON{is_correct: correct|partial|incorrect, feedback: string}。'
'标准答案:' + json.dumps(answers, ensure_ascii=False) +
'用户回答:' + user_text
)
res = await self._call_llm_chat(prompt=prompt, image_id=q.image_id, user_id=user_id, chat_type='qa_attempt')
if not res.get('success'):
await image_task_dao.update_task_status(db, task_id, ImageTaskStatus.FAILED, error_message=res.get('error'))
await db.commit()
return
try:
parsed = json.loads(res.get('result')) if isinstance(res.get('result'), str) else res.get('result')
except Exception:
parsed = {}
evaluation = {'type': 'free_text', 'result': parsed.get('is_correct'), 'feedback': parsed.get('feedback')}
# update ext with free_text details
attempt.ext = {**(attempt.ext or {}), 'type': 'free_text', 'free_text': {'text': attempt.input_text or '', 'evaluation': evaluation}}
await db.flush()
merged_eval = dict(attempt.evaluation or {})
merged_eval['free_text'] = {'text': attempt.input_text or '', 'evaluation': evaluation}
await qa_attempt_dao.update_status(db, attempt.id, 'completed', merged_eval)
await image_task_dao.update_task_status(db, task_id, ImageTaskStatus.COMPLETED, result={'mode': 'free_text', 'token_usage': res.get('token_usage') or {}})
if not is_trial:
s = await qa_session_dao.get_latest_by_user_exercise(db, user_id, attempt.exercise_id)
if s and s.exercise_id == attempt.exercise_id:
prog = dict(s.progress or {})
attempts = list(prog.get('attempts') or [])
prev = None
for a in attempts:
if a.get('attempt_id') == attempt.id:
prev = a.get('is_correct')
a['is_correct'] = parsed.get('is_correct')
break
prev_correct = 1 if prev == 'correct' else 0
new_correct = 1 if parsed.get('is_correct') == 'correct' else 0
correct_inc = new_correct - prev_correct
prog['attempts'] = attempts
prog['correct'] = int(prog.get('correct') or 0) + correct_inc
s.progress = prog
await db.flush()
await db.commit()
async def _call_llm_chat(self, prompt: str, image_id: int, user_id: int, chat_type: str) -> Dict[str, Any]:
model_type = (settings.LLM_MODEL_TYPE or "").lower()
messages = [{"role": "system", "content": "You are a helpful assistant."}, {'role': 'user', 'content': prompt}]
if model_type == 'qwen':
try:
qres = await Qwen.chat(messages=[{'role': 'user', 'content': prompt}], image_id=image_id, user_id=user_id, api_type=chat_type)
if qres and qres.get('success'):
return {"success": True, "result": qres.get("result"), "token_usage": qres.get("token_usage") or {}}
except Exception as e:
return {"success": False, "error": str(e)}
return {"success": False, "error": "LLM call failed"}
else:
try:
res = await Hunyuan.chat(messages=messages, image_id=image_id, user_id=user_id, system_prompt=None, chat_type=chat_type)
if res and res.get('success'):
return res
except Exception as e:
return {"success": False, "error": str(e)}
return {"success": False, "error": "LLM call failed"}
async def get_attempt_task_status(self, task_id: int) -> Dict[str, Any]:
return await self.get_task_status(task_id)
async def get_question_evaluation(self, question_id: int, user_id: int) -> Dict[str, Any]:
async with async_db_session() as db:
# Exclude trial attempts by default so they don't pollute normal mode history
latest = await qa_attempt_dao.get_latest_completed_by_user_question(db, user_id=user_id, question_id=question_id, exclude_trial=True)
if not latest:
latest = await qa_attempt_dao.get_latest_valid_by_user_question(db, user_id=user_id, question_id=question_id, exclude_trial=True)
if not latest:
return {}
evalution = latest.evaluation or {}
session_id = evalution.get('session_id')
ret = {
'session_id': str(session_id) if session_id is not None else None,
'type': evalution.get('type'),
}
if 'choice' in evalution:
ch = evalution.get('choice') or {}
ret['choice'] = {
'options': ch.get('options') or [],
'evaluation': ch.get('evaluation') or None,
}
if 'cloze' in evalution:
cz = evalution.get('cloze') or {}
ret['cloze'] = {
'input': cz.get('input') or '',
'evaluation': cz.get('evaluation') or None,
}
if 'free_text' in evalution:
ft = evalution.get('free_text') or {}
ret['free_text'] = {
'text': ft.get('text') or '',
'evaluation': ft.get('evaluation') or None,
}
if 'variation' in evalution:
va = evalution.get('variation') or {}
ret['variation'] = {
'file_id': va.get('file_id'),
'evaluation': va.get('evaluation') or None,
}
return ret
async def persist_image_from_url(self, image_url: str, user_id: int, filename: str = "generated_variation.png") -> int:
"""Download image from URL and persist to system file storage"""
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:
if response.status != 200:
raise Exception(f"Failed to download image: {response.status}")
content = await response.read()
file_hash = hashlib.sha256(content).hexdigest()
content_type = "image/png" # Default to png as per filename default
# 1. Create DB record first (Pending state)
async with async_db_session.begin() as db:
meta_init = FileMetadata(
file_name=filename,
content_type=content_type,
file_size=0,
extra=None,
)
t_params = AddFileParam(
file_hash=file_hash,
file_name=filename,
content_type=content_type,
file_size=0,
storage_type="cos",
storage_path=None,
metadata_info=meta_init,
)
t_file = await file_dao.create(db, t_params)
await db.flush()
# Capture ID for use outside transaction
file_id = t_file.id
# 2. Upload to COS
# Note: We download the image because COS standard PutObject requires a body (bytes/stream).
# Direct fetch from URL (AsyncFetch) is asynchronous and not suitable for this synchronous flow.
cos_client = CosClient()
key = f"{file_id}_{filename}"
cos_client.upload_object(key, content)
# 3. Update DB record (Completed state)
async with async_db_session.begin() as db:
meta = FileMetadata(
file_name=filename,
content_type=content_type,
file_size=len(content),
extra=None,
)
update_params = UpdateFileParam(
file_hash=file_hash,
storage_path=key,
metadata_info=meta,
details={
"key": key,
"source": "ai_generation",
"user_id": user_id
}
)
await file_dao.update(db, file_id, update_params)
return int(file_id)
async def generate_scene_variations(self, exercise_id: int, user_id: int, db: AsyncSession = None) -> Tuple[int, Dict[str, Any]]:
"""
Execute the advanced workflow:
1. Generate variations text
2. Generate images
3. Persist images
4. Update exercise
"""
# If db is provided, use it (assumed to be in a transaction).
# Otherwise create a new transaction.
# However, to avoid code duplication, we'll implement a context manager helper or just branching logic.
# Helper to get DB session
from contextlib import asynccontextmanager
@asynccontextmanager
async def get_db():
if db:
yield db
else:
async with async_db_session.begin() as new_db:
yield new_db
async with get_db() as session:
exercise = await qa_exercise_dao.get(session, exercise_id)
if not exercise:
raise errors.NotFoundError(msg='Exercise not found')
image = await image_dao.get(session, exercise.image_id)
if not image:
raise errors.NotFoundError(msg='Image not found')
# Prepare payload from image details
rr = (image.details or {}).get('recognition_result') or {}
payload = {
'description': rr.get('description'),
'core_vocab': rr.get('core_vocab'),
'collocations': rr.get('collocations'),
'scene_tag': rr.get('scene_tag')
}
# Run AI tasks outside transaction (to avoid long holding of DB connection if db was created here)
# Note: If db was passed in from ImageTaskService, this is technically inside the outer transaction scope,
# but since we are not executing SQL here, it's just holding the session object.
gen_res = await SceneVariationGenerator.generate(payload, image.id, user_id)
# print(gen_res)
if not gen_res.get('success'):
raise Exception(f"Variation generation failed: {gen_res.get('error')}")
variations = gen_res.get('result', {}).get('new_descriptions', [])
token_usage = gen_res.get('token_usage', {})
if not variations:
raise Exception("No variations generated")
# Step 2: Generate images (Parallel)
variations_with_images = await Illustrator.process_variations(image.file_id, user_id, variations)
# Step 3: Persist images and update data
for i, v in enumerate(variations_with_images):
if v.get('success') and v.get('generated_image_url'):
try:
# Construct filename: exercise_{exercise_id}_variation_{image_id}.png
img_id = v.get('image_id', i + 1)
filename = f"exercise_{exercise_id}_variation_{img_id}.png"
file_id = await self.persist_image_from_url(v['generated_image_url'], user_id, filename=filename)
v['file_id'] = file_id
except Exception as e:
v['persist_error'] = str(e)
# Step 4: Update exercise
async with get_db() as session:
exercise = await qa_exercise_dao.get(session, exercise_id)
if not exercise:
# Should not happen given previous check, but good for safety
raise errors.NotFoundError(msg='Exercise not found')
# Create questions from variations
created = 0
for v in variations_with_images:
if v.get('success') and v.get('file_id'):
await qa_question_dao.create(session, {
'exercise_id': exercise.id,
'image_id': exercise.image_id,
'question': v.get('desc_en') or '',
'user_id': user_id,
'ext': {
'file_id': str(v.get('file_id')),
'desc_zh': v.get('desc_zh'),
'modification_type': v.get('modification_type'),
'modification_point': v.get('modification_point'),
'core_vocab': v.get('core_vocab'),
'collocation': v.get('collocation'),
'learning_note': v.get('learning_note'),
},
})
created += 1
ext = dict(exercise.ext or {})
ext['new_descriptions'] = variations_with_images
exercise.ext = ext
from sqlalchemy.orm.attributes import flag_modified
flag_modified(exercise, "ext")
exercise.question_count = created
exercise.status = 'published' if created > 0 else 'draft'
await session.flush()
if created > 0:
existing_session = await qa_session_dao.get_latest_by_user_exercise(session, user_id, exercise.id)
if not existing_session:
prog = {'current_index': 0, 'answered': 0, 'correct': 0, 'attempts': [], 'total_questions': created}
await qa_session_dao.create(session, {
'exercise_id': exercise.id,
'starter_user_id': user_id,
'share_id': None,
'status': 'ongoing',
'started_at': datetime.now(),
'completed_at': None,
'progress': prog,
'score': None,
'ext': None,
})
await session.flush()
return len(variations_with_images), token_usage
qa_service = QaService()