1773 lines
78 KiB
Python
1773 lines
78 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
import asyncio
|
||
import json
|
||
import math
|
||
import aiohttp
|
||
import io
|
||
import hashlib
|
||
from fastapi import UploadFile
|
||
from backend.app.admin.service.file_service import file_service
|
||
from backend.app.admin.schema.file import AddFileParam, FileMetadata, UpdateFileParam
|
||
from backend.app.admin.crud.file_crud import file_dao
|
||
from backend.middleware.cos_client import CosClient
|
||
from typing import Optional, List, Dict, Any, Tuple
|
||
from datetime import datetime
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from backend.database.db import async_db_session, background_db_session
|
||
from backend.app.ai.crud.qa_crud import qa_exercise_dao, qa_question_dao, qa_attempt_dao, qa_session_dao
|
||
from backend.app.ai.crud.image_task_crud import image_task_dao
|
||
from backend.app.ai.crud.image_curd import image_dao
|
||
from backend.app.ai.model.image_task import ImageTaskStatus
|
||
from backend.app.ai.schema.image_task import CreateImageTaskParam
|
||
from backend.app.admin.service.points_service import points_service
|
||
from backend.app.ai.service.rate_limit_service import rate_limit_service
|
||
from backend.common.exception import errors
|
||
from backend.core.llm import LLMFactory, AuditLogCallbackHandler
|
||
from langchain_core.messages import SystemMessage, HumanMessage
|
||
from backend.core.conf import settings
|
||
from backend.app.ai.service.recording_service import recording_service
|
||
from backend.common.const import EXERCISE_TYPE_CHOICE, EXERCISE_TYPE_CLOZE, EXERCISE_TYPE_FREE_TEXT, LLM_CHAT_COST, POINTS_ACTION_SPEND, IMAGE_GENERATION_COST
|
||
from backend.app.admin.schema.wx import DictLevel
|
||
from backend.app.ai.service.image_task_service import TaskProcessor, image_task_service
|
||
from backend.app.ai.model.image_task import ImageProcessingTask
|
||
from backend.app.ai.model.qa import QaQuestion, QaPracticeSession
|
||
|
||
from backend.core.prompts.qa_exercise import get_qa_exercise_prompt
|
||
from backend.core.prompts.recognition import get_conversation_prompt_for_image_dialogue
|
||
from backend.core.prompts.free_conversation import get_free_conversation_start_prompt, get_free_conversation_reply_prompt
|
||
from backend.app.ai.tools.qa_tool import SceneVariationGenerator, Illustrator
|
||
|
||
class QaExerciseProcessor(TaskProcessor):
|
||
async def process(self, db: AsyncSession, task: ImageProcessingTask) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||
image = await image_dao.get(db, task.image_id)
|
||
exercise = await qa_exercise_dao.get(db, task.ref_id)
|
||
payload = {}
|
||
rr = (image.details or {}).get('recognition_result') or {}
|
||
description = ''
|
||
try:
|
||
d = rr.get('description')
|
||
if isinstance(d, str):
|
||
description = d
|
||
elif isinstance(d, list) and d:
|
||
description = d[0] if isinstance(d[0], str) else ''
|
||
except Exception:
|
||
description = ''
|
||
payload = {'description': description}
|
||
prompt = get_qa_exercise_prompt(payload)
|
||
res = await self._call_llm_chat(prompt=prompt, image_id=image.id, user_id=task.user_id, chat_type='qa_exercise')
|
||
if not res.get('success'):
|
||
raise Exception(res.get('error') or "LLM call failed")
|
||
|
||
token_usage = res.get('token_usage') or {}
|
||
items = []
|
||
try:
|
||
parsed = json.loads(res.get('result')) if isinstance(res.get('result'), str) else res.get('result')
|
||
if isinstance(parsed, dict):
|
||
items = parsed.get('qa_list') or []
|
||
elif isinstance(parsed, list):
|
||
items = parsed
|
||
except Exception:
|
||
items = []
|
||
|
||
created = 0
|
||
for it in items:
|
||
q = await qa_question_dao.create(db, {
|
||
'exercise_id': exercise.id,
|
||
'image_id': image.id,
|
||
'question': it.get('question') or '',
|
||
'payload': None,
|
||
'user_id': task.user_id,
|
||
'ext': {
|
||
'dimension': it.get('dimension'),
|
||
'key_pronunciation_words': it.get('key_pronunciation_words'),
|
||
'answers': it.get('answers'),
|
||
'cloze': it.get('cloze'),
|
||
'correct_options': it.get('correct_options'),
|
||
'incorrect_options': it.get('incorrect_options'),
|
||
},
|
||
})
|
||
created += 1
|
||
|
||
exercise.question_count = created
|
||
exercise.status = 'published' if created > 0 else 'draft'
|
||
await db.flush()
|
||
|
||
if created > 0:
|
||
existing_session = await qa_session_dao.get_latest_by_user_exercise(db, task.user_id, exercise.id)
|
||
if not existing_session:
|
||
prog = {'current_index': 0, 'answered': 0, 'correct': 0, 'attempts': [], 'total_questions': created}
|
||
await qa_session_dao.create(db, {
|
||
'exercise_id': exercise.id,
|
||
'starter_user_id': task.user_id,
|
||
'share_id': None,
|
||
'status': 'ongoing',
|
||
'started_at': datetime.now(),
|
||
'completed_at': None,
|
||
'progress': prog,
|
||
'score': None,
|
||
'ext': None,
|
||
})
|
||
await db.flush()
|
||
|
||
# Return result and token_usage.
|
||
# Note: image_task_service handles points deduction and final status update.
|
||
result = {'token_usage': token_usage, 'count': created}
|
||
return result, token_usage
|
||
|
||
async def _call_llm_chat(self, prompt: str, image_id: int, user_id: int, chat_type: str) -> Dict[str, Any]:
|
||
messages = [
|
||
SystemMessage(content="You are a helpful assistant."),
|
||
HumanMessage(content=prompt)
|
||
]
|
||
|
||
metadata = {
|
||
"image_id": image_id,
|
||
"user_id": user_id,
|
||
"api_type": chat_type,
|
||
"model_name": settings.LLM_MODEL_TYPE
|
||
}
|
||
|
||
try:
|
||
llm = LLMFactory.create_llm(settings.LLM_MODEL_TYPE)
|
||
res = await llm.ainvoke(
|
||
messages,
|
||
config={"callbacks": [AuditLogCallbackHandler(metadata=metadata)]}
|
||
)
|
||
|
||
content = res.content
|
||
if not isinstance(content, str):
|
||
content = str(content)
|
||
|
||
token_usage = {}
|
||
if res.response_metadata:
|
||
token_usage = res.response_metadata.get("token_usage") or res.response_metadata.get("usage") or {}
|
||
|
||
return {
|
||
"success": True,
|
||
"result": content,
|
||
"token_usage": token_usage
|
||
}
|
||
except Exception as e:
|
||
return {"success": False, "error": str(e)}
|
||
|
||
class SceneVariationProcessor(TaskProcessor):
|
||
async def process(self, db: AsyncSession, task: ImageProcessingTask) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||
count, token_usage = await qa_service.generate_scene_variations(task.ref_id, task.user_id, db=db)
|
||
|
||
# Calculate extra points for generated images
|
||
image_points = count * IMAGE_GENERATION_COST
|
||
token_usage['extra_points'] = image_points
|
||
token_usage['extra_details'] = {
|
||
'image_count': count,
|
||
'image_unit_price': IMAGE_GENERATION_COST,
|
||
'source': 'scene_variation_generation'
|
||
}
|
||
|
||
return {'count': count, 'token_usage': token_usage}, token_usage
|
||
|
||
|
||
class ConversationInitProcessor(TaskProcessor):
|
||
async def process(self, db: AsyncSession, task: ImageProcessingTask) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||
image = await image_dao.get(db, task.image_id)
|
||
if not image:
|
||
raise errors.NotFoundError(msg="Image not found")
|
||
details = dict(image.details or {})
|
||
rr = (details.get("recognition_result") or {}) if details else {}
|
||
description = ""
|
||
scene_tags: List[str] = []
|
||
try:
|
||
d = rr.get("description")
|
||
if isinstance(d, str):
|
||
description = d
|
||
elif isinstance(d, list) and d:
|
||
description = d[0] if isinstance(d[0], str) else ""
|
||
except Exception:
|
||
description = ""
|
||
try:
|
||
tags = rr.get("scene_tag")
|
||
if isinstance(tags, list):
|
||
scene_tags = [str(t) for t in tags]
|
||
elif isinstance(tags, str):
|
||
scene_tags = [tags]
|
||
except Exception:
|
||
scene_tags = []
|
||
payload = {
|
||
"description": description,
|
||
"scene_tags": scene_tags,
|
||
}
|
||
prompt = get_conversation_prompt_for_image_dialogue(payload)
|
||
res = await self._call_llm_chat(prompt=prompt, image_id=image.id, user_id=task.user_id, chat_type="image_conversation_analysis")
|
||
if not res.get("success"):
|
||
raise Exception(res.get("error") or "LLM call failed")
|
||
token_usage = res.get("token_usage") or {}
|
||
try:
|
||
parsed = json.loads(res.get("result")) if isinstance(res.get("result"), str) else res.get("result")
|
||
except Exception:
|
||
parsed = {}
|
||
image_analysis = parsed.get("image_analysis") if isinstance(parsed, dict) else None
|
||
if not isinstance(image_analysis, dict):
|
||
raise Exception("Invalid image_analysis structure")
|
||
new_details = dict(details)
|
||
new_details["conversation_analysis"] = {
|
||
"image_analysis": image_analysis,
|
||
}
|
||
image.details = new_details
|
||
try:
|
||
from sqlalchemy.orm.attributes import flag_modified
|
||
|
||
flag_modified(image, "details")
|
||
except Exception:
|
||
pass
|
||
await db.flush()
|
||
result = {"image_analysis": image_analysis, "token_usage": token_usage}
|
||
return result, token_usage
|
||
|
||
|
||
async def _call_llm_chat(self, prompt: str, image_id: int, user_id: int, chat_type: str) -> Dict[str, Any]:
|
||
messages = [
|
||
SystemMessage(content="You are a helpful assistant."),
|
||
HumanMessage(content=prompt)
|
||
]
|
||
|
||
metadata = {
|
||
"image_id": image_id,
|
||
"user_id": user_id,
|
||
"api_type": chat_type,
|
||
"model_name": settings.LLM_MODEL_TYPE
|
||
}
|
||
|
||
try:
|
||
llm = LLMFactory.create_llm(settings.LLM_MODEL_TYPE)
|
||
res = await llm.ainvoke(
|
||
messages,
|
||
config={"callbacks": [AuditLogCallbackHandler(metadata=metadata)]}
|
||
)
|
||
|
||
content = res.content
|
||
if not isinstance(content, str):
|
||
content = str(content)
|
||
|
||
token_usage = {}
|
||
if res.response_metadata:
|
||
token_usage = res.response_metadata.get("token_usage") or res.response_metadata.get("usage") or {}
|
||
|
||
return {
|
||
"success": True,
|
||
"result": content,
|
||
"token_usage": token_usage
|
||
}
|
||
except Exception as e:
|
||
return {"success": False, "error": str(e)}
|
||
|
||
class ConversationStartProcessor(TaskProcessor):
|
||
async def process(self, db: AsyncSession, task: ImageProcessingTask) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||
# task.ref_id is exercise_id
|
||
exercise_id = task.ref_id
|
||
exercise = await qa_exercise_dao.get(db, exercise_id)
|
||
if not exercise:
|
||
raise errors.NotFoundError(msg="Exercise not found")
|
||
|
||
image = await image_dao.get(db, exercise.image_id)
|
||
if not image:
|
||
raise errors.NotFoundError(msg="Image not found")
|
||
|
||
# Parse recognition result for description
|
||
rr = (image.details or {}).get('recognition_result') or {}
|
||
description = ''
|
||
try:
|
||
d = rr.get('description')
|
||
if isinstance(d, str):
|
||
description = d
|
||
elif isinstance(d, list) and d:
|
||
description = d[0] if isinstance(d[0], str) else ''
|
||
except Exception:
|
||
description = ''
|
||
|
||
params = exercise.ext or {}
|
||
|
||
# Helper to extract 'en' from bilingual items
|
||
def get_en_list(items):
|
||
if not items:
|
||
return []
|
||
return [item.get('en') for item in items if isinstance(item, dict) and item.get('en')]
|
||
|
||
def get_en_item(item):
|
||
if not item or not isinstance(item, dict):
|
||
return None
|
||
return item.get('en')
|
||
|
||
prompt = get_free_conversation_start_prompt(
|
||
scene=get_en_list(params.get('scene')),
|
||
event=get_en_list(params.get('event')),
|
||
user_role=get_en_item(params.get('user_role')),
|
||
assistant_role=get_en_item(params.get('assistant_role')),
|
||
style=get_en_item(params.get('style')),
|
||
level=params.get('level'),
|
||
info=params.get('info'),
|
||
description=description,
|
||
)
|
||
|
||
res = await self._call_llm_chat(prompt=prompt, image_id=image.id, user_id=task.user_id, chat_type='conversation_start')
|
||
if not res.get('success'):
|
||
raise Exception(res.get('error') or "LLM call failed")
|
||
|
||
token_usage = res.get('token_usage') or {}
|
||
try:
|
||
parsed = json.loads(res.get('result')) if isinstance(res.get('result'), str) else res.get('result')
|
||
except Exception:
|
||
parsed = {}
|
||
|
||
if not parsed or not isinstance(parsed, dict):
|
||
raise Exception("Invalid LLM response format")
|
||
|
||
# Update Exercise Status
|
||
exercise.status = 'published'
|
||
exercise.question_count = 1
|
||
|
||
# Get or Create Session
|
||
session = await qa_session_dao.get_latest_by_user_exercise(db, task.user_id, exercise.id)
|
||
if session:
|
||
session.status = 'ongoing'
|
||
prog = dict(session.progress or {})
|
||
prog['total_questions'] = 1
|
||
session.progress = prog
|
||
else:
|
||
prog = {'current_index': 0, 'answered': 0, 'correct': 0, 'attempts': [], 'total_questions': 1}
|
||
session = await qa_session_dao.create(db, {
|
||
'exercise_id': exercise.id,
|
||
'starter_user_id': task.user_id,
|
||
'status': 'ongoing',
|
||
'started_at': datetime.now(),
|
||
'progress': prog,
|
||
'ext': None,
|
||
})
|
||
|
||
# Create First Question (AI Message)
|
||
question_content = parsed.get('response_en') or ''
|
||
question_ext = {
|
||
'role': 'assistant',
|
||
'response_zh': parsed.get('response_zh'),
|
||
'prompt_en': parsed.get('prompt_en'),
|
||
'prompt_zh': parsed.get('prompt_zh'),
|
||
'alternative_responses': parsed.get('alternative_responses'),
|
||
'correction': parsed.get('correction'),
|
||
}
|
||
|
||
await qa_question_dao.create(db, {
|
||
'exercise_id': exercise.id,
|
||
'image_id': image.id,
|
||
'question': question_content,
|
||
'user_id': task.user_id,
|
||
'payload': None,
|
||
'ext': question_ext,
|
||
})
|
||
|
||
await db.flush()
|
||
|
||
result = {
|
||
'exercise_id': str(exercise.id),
|
||
'session_id': str(session.id),
|
||
'token_usage': token_usage
|
||
}
|
||
return result, token_usage
|
||
|
||
async def _call_llm_chat(self, prompt: str, image_id: int, user_id: int, chat_type: str) -> Dict[str, Any]:
|
||
messages = [
|
||
SystemMessage(content="You are a helpful assistant."),
|
||
HumanMessage(content=prompt)
|
||
]
|
||
metadata = {
|
||
"image_id": image_id,
|
||
"user_id": user_id,
|
||
"api_type": chat_type,
|
||
"model_name": settings.LLM_MODEL_TYPE
|
||
}
|
||
try:
|
||
llm = LLMFactory.create_llm(settings.LLM_MODEL_TYPE)
|
||
res = await llm.ainvoke(
|
||
messages,
|
||
config={"callbacks": [AuditLogCallbackHandler(metadata=metadata)]}
|
||
)
|
||
content = res.content
|
||
if not isinstance(content, str):
|
||
content = str(content)
|
||
token_usage = {}
|
||
if res.response_metadata:
|
||
token_usage = res.response_metadata.get("token_usage") or res.response_metadata.get("usage") or {}
|
||
return {"success": True, "result": content, "token_usage": token_usage}
|
||
except Exception as e:
|
||
return {"success": False, "error": str(e)}
|
||
|
||
|
||
class ConversationReplyProcessor(TaskProcessor):
|
||
async def process(self, db: AsyncSession, task: ImageProcessingTask) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||
# task.ref_id is attempt_id
|
||
attempt_id = task.ref_id
|
||
attempt = await qa_attempt_dao.get(db, attempt_id)
|
||
if not attempt:
|
||
raise errors.NotFoundError(msg="Attempt not found")
|
||
|
||
exercise = await qa_exercise_dao.get(db, attempt.exercise_id)
|
||
if not exercise:
|
||
raise errors.NotFoundError(msg="Exercise not found")
|
||
|
||
image = await image_dao.get(db, exercise.image_id)
|
||
|
||
# Get Session (to update timestamp/progress)
|
||
session = await qa_session_dao.get_latest_by_user_exercise(db, task.user_id, exercise.id)
|
||
if not session:
|
||
# Should not happen in normal flow, but maybe session closed?
|
||
# Or creating a new session?
|
||
# We assume session exists.
|
||
pass
|
||
|
||
# Parse recognition result for description
|
||
rr = (image.details or {}).get('recognition_result') or {}
|
||
description = ''
|
||
try:
|
||
d = rr.get('description')
|
||
if isinstance(d, str):
|
||
description = d
|
||
elif isinstance(d, list) and d:
|
||
description = d[0] if isinstance(d[0], str) else ''
|
||
except Exception:
|
||
description = ''
|
||
|
||
# Get history
|
||
questions = await qa_question_dao.get_by_exercise_id(db, exercise.id)
|
||
|
||
history = []
|
||
# We need to build history excluding the current attempt (which is linked to the LAST question)
|
||
# The current attempt is for the LAST question in `questions`.
|
||
# So we iterate up to the second to last question for pairs,
|
||
# and then handle the last question.
|
||
|
||
# Actually, `questions` includes ALL questions.
|
||
# The attempt we just created is for `questions[-1]`.
|
||
# So `questions[-1]` is the AI message we are replying to.
|
||
# We need to construct history for the prompt.
|
||
|
||
# The prompt expects:
|
||
# History:
|
||
# AI: ...
|
||
# User: ...
|
||
# AI: ...
|
||
# (Current AI message is NOT in history of prompt?
|
||
# Wait, prompt says "History" then "User's New Input".
|
||
# If I am replying to AI's "Hello", "Hello" should be in history?
|
||
# Usually yes.
|
||
# The `get_free_conversation_reply_prompt` implementation:
|
||
# history_str += f"{role}: {content}\n"
|
||
# User's New Input is appended.
|
||
# So `history` should contain everything BEFORE User's New Input.
|
||
# That includes the AI message user is replying to.
|
||
|
||
for i, q in enumerate(questions):
|
||
# AI Message
|
||
history.append({
|
||
'role': 'assistant',
|
||
'content': q.question
|
||
})
|
||
|
||
# If this is the last question, it's the one we are replying to.
|
||
# We don't add a user attempt for it in `history` list,
|
||
# because the "User's New Input" IS the attempt.
|
||
if i == len(questions) - 1:
|
||
break
|
||
|
||
# For previous questions, find their attempt
|
||
prev_attempt = await qa_attempt_dao.get_latest_completed_by_user_question(db, task.user_id, q.id)
|
||
if prev_attempt:
|
||
history.append({
|
||
'role': 'user',
|
||
'content': prev_attempt.input_text
|
||
})
|
||
|
||
user_input = attempt.input_text or ''
|
||
|
||
params = exercise.ext or {}
|
||
|
||
# Helper to extract 'en' from bilingual items
|
||
def get_en_list(items):
|
||
if not items:
|
||
return []
|
||
return [item.get('en') for item in items if isinstance(item, dict) and item.get('en')]
|
||
|
||
def get_en_item(item):
|
||
if not item or not isinstance(item, dict):
|
||
return None
|
||
return item.get('en')
|
||
|
||
prompt = get_free_conversation_reply_prompt(
|
||
history=history,
|
||
user_input=user_input,
|
||
scene=get_en_list(params.get('scene')),
|
||
event=get_en_list(params.get('event')),
|
||
user_role=get_en_item(params.get('user_role')),
|
||
assistant_role=get_en_item(params.get('assistant_role')),
|
||
style=get_en_item(params.get('style')),
|
||
level=params.get('level'),
|
||
info=params.get('info'),
|
||
description=description,
|
||
)
|
||
|
||
res = await self._call_llm_chat(prompt=prompt, image_id=image.id, user_id=task.user_id, chat_type='conversation_reply')
|
||
if not res.get('success'):
|
||
raise Exception(res.get('error') or "LLM call failed")
|
||
|
||
token_usage = res.get('token_usage') or {}
|
||
try:
|
||
parsed = json.loads(res.get('result')) if isinstance(res.get('result'), str) else res.get('result')
|
||
except Exception:
|
||
parsed = {}
|
||
|
||
if not parsed or not isinstance(parsed, dict):
|
||
raise Exception("Invalid LLM response format")
|
||
|
||
# Update attempt with correction
|
||
correction = parsed.get('correction')
|
||
if correction:
|
||
new_ext = dict(attempt.ext or {})
|
||
new_ext['correction'] = correction
|
||
attempt.ext = new_ext
|
||
attempt.evaluation = {'correction': correction}
|
||
|
||
attempt.status = 'completed' # It was pending
|
||
|
||
# Create New Question (AI Response)
|
||
question_content = parsed.get('response_en') or ''
|
||
question_ext = {
|
||
'role': 'assistant',
|
||
'response_zh': parsed.get('response_zh'),
|
||
'prompt_en': parsed.get('prompt_en'),
|
||
'prompt_zh': parsed.get('prompt_zh'),
|
||
'alternative_responses': parsed.get('alternative_responses'),
|
||
'correction': correction,
|
||
}
|
||
|
||
new_question = await qa_question_dao.create(db, {
|
||
'exercise_id': exercise.id,
|
||
'image_id': image.id,
|
||
'question': question_content,
|
||
'user_id': task.user_id,
|
||
'payload': None,
|
||
'ext': question_ext,
|
||
})
|
||
|
||
# Update Session
|
||
if session:
|
||
session.updated_at = datetime.now()
|
||
prog = dict(session.progress or {})
|
||
prog['total_questions'] = (prog.get('total_questions') or 0) + 1
|
||
session.progress = prog
|
||
|
||
await db.flush()
|
||
|
||
result = {
|
||
'session_id': str(session.id) if session else '',
|
||
'new_question_id': str(new_question.id),
|
||
'token_usage': token_usage
|
||
}
|
||
return result, token_usage
|
||
|
||
async def _call_llm_chat(self, prompt: str, image_id: int, user_id: int, chat_type: str) -> Dict[str, Any]:
|
||
messages = [
|
||
SystemMessage(content="You are a helpful assistant."),
|
||
HumanMessage(content=prompt)
|
||
]
|
||
metadata = {
|
||
"image_id": image_id,
|
||
"user_id": user_id,
|
||
"api_type": chat_type,
|
||
"model_name": settings.LLM_MODEL_TYPE
|
||
}
|
||
try:
|
||
llm = LLMFactory.create_llm(settings.LLM_MODEL_TYPE)
|
||
res = await llm.ainvoke(
|
||
messages,
|
||
config={"callbacks": [AuditLogCallbackHandler(metadata=metadata)]}
|
||
)
|
||
content = res.content
|
||
if not isinstance(content, str):
|
||
content = str(content)
|
||
token_usage = {}
|
||
if res.response_metadata:
|
||
token_usage = res.response_metadata.get("token_usage") or res.response_metadata.get("usage") or {}
|
||
return {"success": True, "result": content, "token_usage": token_usage}
|
||
except Exception as e:
|
||
return {"success": False, "error": str(e)}
|
||
|
||
class QaService:
|
||
async def get_conversation_setting(self, image_id: int, user_id: int) -> Optional[Dict[str, Any]]:
|
||
async with async_db_session() as db:
|
||
task = await image_task_dao.get_by_image_id(db, image_id)
|
||
if not task or task.user_id != user_id:
|
||
raise errors.ForbiddenError(msg="Forbidden")
|
||
image = await image_dao.get(db, image_id)
|
||
if not image:
|
||
raise errors.NotFoundError(msg="Image not found")
|
||
details = dict(image.details or {})
|
||
existing = details.get("conversation_analysis") or {}
|
||
existing_analysis = existing.get("image_analysis")
|
||
if not isinstance(existing_analysis, dict):
|
||
return None
|
||
|
||
# Find latest conversation session
|
||
latest_session_info = None
|
||
session = await qa_session_dao.get_latest_session_by_image_user(db, user_id, image_id, exercise_type='free_conversation')
|
||
if session:
|
||
latest_session_info = {
|
||
'session_id': str(session.id),
|
||
'status': session.status,
|
||
'updated_at': session.completed_at.isoformat() if session.completed_at else (session.started_at.isoformat() if session.started_at else None),
|
||
'exercise_id': str(session.exercise_id),
|
||
}
|
||
|
||
return {
|
||
"image_id": image_id,
|
||
"setting": existing_analysis,
|
||
"latest_session": latest_session_info,
|
||
}
|
||
|
||
async def start_conversation(
|
||
self,
|
||
image_id: int,
|
||
user_id: int,
|
||
scene: List[Dict[str, str]],
|
||
event: List[Dict[str, str]],
|
||
style: Optional[Dict[str, str]] = None,
|
||
user_role: Optional[Dict[str, str]] = None,
|
||
assistant_role: Optional[Dict[str, str]] = None,
|
||
level: Optional[str] = None,
|
||
info: Optional[str] = None,
|
||
) -> Dict[str, Any]:
|
||
# Check points and rate limit
|
||
if not await points_service.check_sufficient_points(user_id, LLM_CHAT_COST):
|
||
raise errors.ForbiddenError(msg='积分不足,请获取积分后继续使用')
|
||
|
||
slot_acquired = await rate_limit_service.acquire_task_slot(user_id)
|
||
if not slot_acquired:
|
||
max_tasks = await rate_limit_service.get_user_task_limit(user_id)
|
||
raise errors.ForbiddenError(msg=f'用户同时最多只能运行 {max_tasks} 个任务,请等待现有任务完成后再试')
|
||
|
||
async with async_db_session.begin() as db:
|
||
image = await image_dao.get(db, image_id)
|
||
if not image:
|
||
raise errors.NotFoundError(msg="Image not found")
|
||
|
||
# Create Exercise
|
||
exercise = await qa_exercise_dao.create(db, {
|
||
"image_id": image_id,
|
||
"created_by": user_id,
|
||
"type": "free_conversation",
|
||
"description": None,
|
||
"status": "ongoing",
|
||
"ext": {
|
||
"scene": scene,
|
||
"event": event,
|
||
"user_role": user_role,
|
||
"assistant_role": assistant_role,
|
||
"style": style,
|
||
"level": level,
|
||
"info": info,
|
||
},
|
||
})
|
||
await db.flush()
|
||
|
||
# Create Session (Pre-create to return session_id immediately)
|
||
prog = {'current_index': 0, 'answered': 0, 'correct': 0, 'attempts': [], 'total_questions': 0}
|
||
session = await qa_session_dao.create(db, {
|
||
'exercise_id': exercise.id,
|
||
'starter_user_id': user_id,
|
||
'status': 'initializing',
|
||
'started_at': datetime.now(),
|
||
'progress': prog,
|
||
'ext': None,
|
||
})
|
||
await db.flush()
|
||
|
||
# Create Task
|
||
task = await image_task_dao.create_task(db, CreateImageTaskParam(
|
||
image_id=image_id,
|
||
user_id=user_id,
|
||
dict_level=(getattr(getattr(image, 'dict_level', None), 'name', None) or 'LEVEL1'),
|
||
ref_type='qa_exercise',
|
||
ref_id=exercise.id,
|
||
status=ImageTaskStatus.PENDING,
|
||
))
|
||
await db.flush()
|
||
task_id = task.id
|
||
|
||
# Dispatch Task
|
||
asyncio.create_task(image_task_service.process_task(task_id, user_id, ConversationStartProcessor()))
|
||
|
||
return {
|
||
"task_id": str(task_id),
|
||
"status": "processing",
|
||
"exercise_id": str(exercise.id),
|
||
"session_id": str(session.id)
|
||
}
|
||
|
||
async def reply_conversation(
|
||
self,
|
||
session_id: int,
|
||
user_id: int,
|
||
input_text: str,
|
||
) -> Dict[str, Any]:
|
||
# Check points and rate limit
|
||
if not await points_service.check_sufficient_points(user_id, LLM_CHAT_COST):
|
||
raise errors.ForbiddenError(msg='积分不足,请获取积分后继续使用')
|
||
|
||
slot_acquired = await rate_limit_service.acquire_task_slot(user_id)
|
||
if not slot_acquired:
|
||
max_tasks = await rate_limit_service.get_user_task_limit(user_id)
|
||
raise errors.ForbiddenError(msg=f'用户同时最多只能运行 {max_tasks} 个任务,请等待现有任务完成后再试')
|
||
|
||
async with async_db_session.begin() as db:
|
||
session = await qa_session_dao.get(db, session_id)
|
||
if not session:
|
||
raise errors.NotFoundError(msg="Session not found")
|
||
if session.starter_user_id != user_id:
|
||
raise errors.ForbiddenError(msg="Forbidden")
|
||
|
||
exercise = await qa_exercise_dao.get(db, session.exercise_id)
|
||
|
||
# Create Attempt (User Input)
|
||
# Link to the last question (the one AI asked)
|
||
last_question = await qa_question_dao.get_latest_by_exercise_id(db, exercise.id)
|
||
if not last_question:
|
||
raise errors.ServerError(msg="No question to reply to")
|
||
|
||
attempt = await qa_attempt_dao.create(db, {
|
||
"user_id": user_id,
|
||
"question_id": last_question.id,
|
||
"exercise_id": exercise.id,
|
||
"input_text": input_text,
|
||
"status": "pending",
|
||
"evaluation": None,
|
||
"ext": None
|
||
})
|
||
await db.flush()
|
||
|
||
# Create Task
|
||
task = await image_task_dao.create_task(db, CreateImageTaskParam(
|
||
image_id=exercise.image_id,
|
||
user_id=user_id,
|
||
dict_level='LEVEL1', # Default or fetch from image
|
||
ref_type='qa_attempt',
|
||
ref_id=attempt.id,
|
||
status=ImageTaskStatus.PENDING,
|
||
))
|
||
await db.flush()
|
||
task_id = task.id
|
||
|
||
asyncio.create_task(image_task_service.process_task(task_id, user_id, ConversationReplyProcessor()))
|
||
|
||
return {
|
||
"task_id": str(task_id),
|
||
"status": "processing",
|
||
"session_id": str(session.id)
|
||
}
|
||
|
||
async def _get_messages_for_session(self, db: AsyncSession, exercise_id: int, user_id: int) -> List[Dict[str, Any]]:
|
||
questions = await qa_question_dao.get_by_exercise_id(db, exercise_id)
|
||
|
||
messages = []
|
||
total_questions = len(questions)
|
||
for idx, q in enumerate(questions):
|
||
is_last = (idx == total_questions - 1)
|
||
ext = q.ext or {}
|
||
|
||
# AI Message
|
||
if is_last:
|
||
# Full content for the last message
|
||
content = {
|
||
"response_en": q.question,
|
||
"response_zh": ext.get("response_zh"),
|
||
"prompt_en": ext.get("prompt_en"),
|
||
"prompt_zh": ext.get("prompt_zh"),
|
||
"alternative_responses": ext.get("alternative_responses"),
|
||
"correction": ext.get("correction"),
|
||
}
|
||
else:
|
||
# Simplified content for historical messages
|
||
content = {
|
||
"response_en": q.question,
|
||
"response_zh": ext.get("response_zh"),
|
||
}
|
||
|
||
messages.append({
|
||
"role": "assistant",
|
||
"content": content
|
||
})
|
||
|
||
# User Reply (Attempt)
|
||
attempt = await qa_attempt_dao.get_latest_completed_by_user_question(db, user_id, q.id)
|
||
if attempt:
|
||
messages.append({
|
||
"role": "user",
|
||
"content": {
|
||
"text": attempt.input_text,
|
||
"correction": (attempt.evaluation or {}).get("correction")
|
||
}
|
||
})
|
||
return messages
|
||
|
||
async def get_latest_messages(self, session_id: int, user_id: int) -> Dict[str, Any]:
|
||
async with async_db_session() as db:
|
||
session = await qa_session_dao.get(db, session_id)
|
||
if not session or session.starter_user_id != user_id:
|
||
raise errors.NotFoundError(msg="Session not found")
|
||
|
||
# Optimization: Directly fetch the latest question from DB to avoid loading full history
|
||
latest_q = await qa_question_dao.get_latest_by_exercise_id(db, session.exercise_id)
|
||
|
||
latest_messages = []
|
||
if latest_q:
|
||
ext = latest_q.ext or {}
|
||
latest_messages.append({
|
||
"role": "assistant",
|
||
"content": {
|
||
"response_en": latest_q.question,
|
||
"response_zh": ext.get("response_zh"),
|
||
"prompt_en": ext.get("prompt_en"),
|
||
"prompt_zh": ext.get("prompt_zh"),
|
||
"alternative_responses": ext.get("alternative_responses"),
|
||
"correction": ext.get("correction"),
|
||
}
|
||
})
|
||
|
||
return {
|
||
"session_id": str(session_id),
|
||
"messages": latest_messages
|
||
}
|
||
|
||
async def get_conversation_session(self, session_id: int, user_id: int) -> Dict[str, Any]:
|
||
async with async_db_session() as db:
|
||
session = await qa_session_dao.get(db, session_id)
|
||
if not session or session.starter_user_id != user_id:
|
||
raise errors.NotFoundError(msg="Session not found")
|
||
|
||
exercise = await qa_exercise_dao.get(db, session.exercise_id)
|
||
messages = await self._get_messages_for_session(db, session.exercise_id, user_id)
|
||
|
||
return {
|
||
"exercise_id": str(exercise.id),
|
||
"session_id": str(session.id),
|
||
"status": session.status,
|
||
"updated_at": (exercise.updated_time.isoformat() if getattr(exercise, "updated_time", None) else None),
|
||
"messages": messages,
|
||
}
|
||
|
||
async def list_conversations_by_image(
|
||
self,
|
||
image_id: int,
|
||
user_id: int,
|
||
page: int = 1,
|
||
page_size: int = 10
|
||
) -> Dict[str, Any]:
|
||
async with async_db_session() as db:
|
||
total, results = await qa_session_dao.get_list_by_image(
|
||
db, image_id, user_id, page, page_size
|
||
)
|
||
|
||
items = []
|
||
for session, exercise in results:
|
||
ext = exercise.ext or {}
|
||
items.append({
|
||
"session_id": str(session.id),
|
||
"scene": ext.get("scene") or [],
|
||
"event": ext.get("event") or [],
|
||
"user_role": ext.get("user_role") or {},
|
||
"style": ext.get("style") or {},
|
||
"created_at": (session.created_time.strftime("%Y-%m-%d %H:%M:%S") if session.created_time else None),
|
||
})
|
||
|
||
return {
|
||
"total": total,
|
||
"items": items,
|
||
"page": page,
|
||
"page_size": page_size,
|
||
}
|
||
|
||
async def recognize_audio(self, file_id: int, user_id: int, session_id: int) -> Dict[str, Any]:
|
||
"""识别音频内容并关联到会话"""
|
||
from backend.middleware.qwen import Qwen
|
||
|
||
# 1. 验证会话
|
||
exercise_id = None
|
||
image_id = None
|
||
async with async_db_session.begin() as db:
|
||
session = await qa_session_dao.get(db, session_id)
|
||
if not session:
|
||
raise errors.NotFoundError(msg="Session not found")
|
||
if session.starter_user_id != user_id:
|
||
raise errors.ForbiddenError(msg="Forbidden")
|
||
exercise_id = session.exercise_id
|
||
|
||
exercise = await qa_exercise_dao.get(db, exercise_id)
|
||
if exercise:
|
||
image_id = exercise.image_id
|
||
|
||
# 2. 获取文件信息
|
||
file_obj = await file_service.get_file(file_id)
|
||
if not file_obj:
|
||
raise errors.NotFoundError(msg="文件不存在")
|
||
|
||
# 获取文件访问路径或URL
|
||
audio_url = ""
|
||
temp_file_path = None
|
||
|
||
try:
|
||
if file_obj.storage_type == "cos":
|
||
# 获取预签名下载URL
|
||
audio_url = await file_service.get_presigned_download_url(file_id, user_id)
|
||
else:
|
||
# 数据库存储,写入临时文件
|
||
content, _, _ = await file_service.download_file(file_id)
|
||
import tempfile
|
||
import os
|
||
# 创建临时文件,注意需要保留后缀以便Qwen识别格式
|
||
ext = "mp3"
|
||
if file_obj.content_type:
|
||
ext = file_service._mime_to_ext(file_obj.content_type, file_obj.file_name)
|
||
|
||
tf = tempfile.NamedTemporaryFile(delete=False, suffix=f".{ext}")
|
||
tf.write(content)
|
||
tf.close()
|
||
temp_file_path = tf.name
|
||
audio_url = temp_file_path
|
||
|
||
# 调用Qwen ASR
|
||
res = await Qwen.recognize_speech(audio_url, user_id=user_id, image_id=image_id)
|
||
|
||
if not res.get("success"):
|
||
raise errors.ServerError(msg=res.get("error") or "ASR failed")
|
||
|
||
text = res.get("text", "")
|
||
|
||
# 3. 更新会话最新消息
|
||
async with async_db_session.begin() as db:
|
||
# Find latest question (LLM response)
|
||
last_question = await qa_question_dao.get_latest_by_exercise_id(db, exercise_id)
|
||
if last_question:
|
||
ext = dict(last_question.ext or {})
|
||
audio_recognitions = ext.get('audio_recognitions')
|
||
if not isinstance(audio_recognitions, list):
|
||
audio_recognitions = []
|
||
|
||
audio_recognitions.append({
|
||
'audio_id': str(file_id),
|
||
'text': text
|
||
})
|
||
ext['audio_recognitions'] = audio_recognitions
|
||
last_question.ext = ext
|
||
await db.flush()
|
||
|
||
return {"text": text}
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
if temp_file_path:
|
||
import os
|
||
if os.path.exists(temp_file_path):
|
||
try:
|
||
os.remove(temp_file_path)
|
||
except Exception:
|
||
pass
|
||
|
||
async def create_exercise_task(self, image_id: int, user_id: int, type: Optional[str] = "scene_basic") -> Dict[str, Any]:
|
||
|
||
is_conversation_init = type == 'init_conversion'
|
||
|
||
async with async_db_session.begin() as db:
|
||
# Check for existing active task
|
||
ref_type_for_lookup = 'image_conversation_analysis' if is_conversation_init else 'qa_exercise'
|
||
latest_task = await image_task_dao.get_latest_active_task(db, user_id, image_id, ref_type_for_lookup)
|
||
if latest_task:
|
||
return {'task_id': str(latest_task.id), 'status': latest_task.status}
|
||
|
||
if not await points_service.check_sufficient_points(user_id, LLM_CHAT_COST):
|
||
raise errors.ForbiddenError(msg='积分不足,请获取积分后继续使用')
|
||
slot_acquired = await rate_limit_service.acquire_task_slot(user_id)
|
||
if not slot_acquired:
|
||
max_tasks = await rate_limit_service.get_user_task_limit(user_id)
|
||
raise errors.ForbiddenError(msg=f'用户同时最多只能运行 {max_tasks} 个任务,请等待现有任务完成后再试')
|
||
|
||
async with async_db_session.begin() as db:
|
||
image = await image_dao.get(db, image_id)
|
||
if not image:
|
||
raise errors.NotFoundError(msg='Image not found')
|
||
if is_conversation_init:
|
||
ref_type = 'image_conversation_analysis'
|
||
ref_id = image_id
|
||
else:
|
||
exercise = await qa_exercise_dao.create(db, {
|
||
'image_id': image_id,
|
||
'created_by': user_id,
|
||
'type': type,
|
||
'description': None,
|
||
'status': 'draft',
|
||
'ext': None
|
||
})
|
||
await db.flush()
|
||
ref_type = 'qa_exercise'
|
||
ref_id = exercise.id
|
||
task = await image_task_dao.create_task(db, CreateImageTaskParam(
|
||
image_id=image_id,
|
||
user_id=user_id,
|
||
dict_level=(getattr(getattr(image, 'dict_level', None), 'name', None) or 'LEVEL1'),
|
||
ref_type=ref_type,
|
||
ref_id=ref_id,
|
||
status=ImageTaskStatus.PENDING,
|
||
))
|
||
await db.flush()
|
||
task_id = task.id
|
||
await db.commit()
|
||
|
||
if type == 'scene_variation':
|
||
processor = SceneVariationProcessor()
|
||
elif is_conversation_init:
|
||
processor = ConversationInitProcessor()
|
||
else:
|
||
processor = QaExerciseProcessor()
|
||
|
||
asyncio.create_task(image_task_service.process_task(task_id, user_id, processor))
|
||
return {'task_id': str(task_id), 'status': 'accepted'}
|
||
|
||
async def get_task_status(self, task_id: int) -> Dict[str, Any]:
|
||
async with async_db_session() as db:
|
||
task = await image_task_dao.get(db, task_id)
|
||
if not task:
|
||
raise errors.NotFoundError(msg='Task not found')
|
||
return {
|
||
'task_id': str(task.id),
|
||
'image_id': str(task.image_id),
|
||
'ref_type': task.ref_type,
|
||
'ref_id': str(task.ref_id),
|
||
'status': task.status,
|
||
'error_message': task.error_message,
|
||
}
|
||
|
||
async def list_exercises_by_image(self, image_id: int, user_id: Optional[int] = None, type: Optional[str] = "scene_basic") -> Optional[Dict[str, Any]]:
|
||
async with async_db_session() as db:
|
||
image = await image_dao.get(db, image_id)
|
||
if not image:
|
||
return None
|
||
i = await qa_exercise_dao.get_latest_by_image_id(db, image_id, type=type)
|
||
if not i:
|
||
return None
|
||
qs = await qa_question_dao.get_by_exercise_id(db, i.id)
|
||
session = None
|
||
if user_id:
|
||
s = await qa_session_dao.get_latest_by_user_exercise(db, user_id, i.id)
|
||
if s:
|
||
session = {
|
||
'id': str(s.id),
|
||
'started_at': s.started_at.isoformat() if s.started_at else None,
|
||
'progress': s.progress,
|
||
}
|
||
ret = {
|
||
'exercise': {
|
||
'id': str(i.id),
|
||
'image_id': str(i.image_id),
|
||
'type': i.type,
|
||
'description': i.description,
|
||
'status': i.status,
|
||
'question_count': i.question_count,
|
||
},
|
||
'session': session,
|
||
'questions': [
|
||
{
|
||
'id': str(q.id),
|
||
'exercise_id': str(q.exercise_id),
|
||
'image_id': str(q.image_id),
|
||
'question': q.question,
|
||
'ext': q.ext,
|
||
} for q in qs
|
||
]
|
||
}
|
||
return ret
|
||
|
||
def _evaluate_choice(self, q: QaQuestion, selected_options: List[str]) -> Tuple[Dict[str, Any], str, List[str]]:
|
||
ext = q.ext or {}
|
||
raw_correct = ext.get('correct_options') or []
|
||
raw_incorrect = ext.get('incorrect_options') or []
|
||
def _norm(v):
|
||
try:
|
||
return str(v).strip().lower()
|
||
except Exception:
|
||
return str(v)
|
||
correct_set = set(_norm(o.get('content') if isinstance(o, dict) else o) for o in raw_correct)
|
||
incorrect_map = {}
|
||
for o in raw_incorrect:
|
||
c = _norm(o.get('content') if isinstance(o, dict) else o)
|
||
if isinstance(o, dict):
|
||
incorrect_map[c] = {
|
||
'content': o.get('content'),
|
||
'error_type': o.get('error_type'),
|
||
'error_reason': o.get('error_reason')
|
||
}
|
||
else:
|
||
incorrect_map[c] = {'content': o, 'error_type': None, 'error_reason': None}
|
||
selected_list = list(selected_options or [])
|
||
selected = set(_norm(s) for s in selected_list)
|
||
if not selected:
|
||
is_correct = 'incorrect'
|
||
result_text = '完全错误'
|
||
evaluation = {'type': 'choice', 'result': result_text, 'detail': 'no selection', 'selected': {'correct': [], 'incorrect': []}, 'missing_correct': [o.get('content') if isinstance(o, dict) else o for o in raw_correct]}
|
||
else:
|
||
selected_correct = []
|
||
for o in raw_correct:
|
||
c = _norm(o.get('content') if isinstance(o, dict) else o)
|
||
if c in selected:
|
||
selected_correct.append(o.get('content') if isinstance(o, dict) else o)
|
||
selected_incorrect = []
|
||
for s in selected_list:
|
||
ns = _norm(s)
|
||
if ns not in correct_set:
|
||
detail = incorrect_map.get(ns)
|
||
if detail:
|
||
selected_incorrect.append(detail)
|
||
else:
|
||
selected_incorrect.append({'content': s, 'error_type': 'unknown', 'error_reason': None})
|
||
missing_correct = []
|
||
for o in raw_correct:
|
||
c = _norm(o.get('content') if isinstance(o, dict) else o)
|
||
if c not in selected:
|
||
missing_correct.append(o.get('content') if isinstance(o, dict) else o)
|
||
if selected == correct_set and not selected_incorrect:
|
||
is_correct = 'correct'
|
||
result_text = '完全匹配'
|
||
evaluation = {'type': 'choice', 'result': result_text, 'detail': is_correct, 'selected': {'correct': selected_correct, 'incorrect': []}, 'missing_correct': []}
|
||
elif selected_correct:
|
||
is_correct = 'partial'
|
||
result_text = '部分匹配'
|
||
evaluation = {'type': 'choice', 'result': result_text, 'detail': is_correct, 'selected': {'correct': selected_correct, 'incorrect': selected_incorrect}, 'missing_correct': missing_correct}
|
||
else:
|
||
is_correct = 'incorrect'
|
||
result_text = '完全错误'
|
||
evaluation = {'type': 'choice', 'result': result_text, 'detail': is_correct, 'selected': {'correct': [], 'incorrect': selected_incorrect}, 'missing_correct': [o.get('content') if isinstance(o, dict) else o for o in raw_correct]}
|
||
return evaluation, is_correct, selected_list
|
||
|
||
def _evaluate_cloze(self, q: QaQuestion, cloze_options: Optional[List[str]]) -> Tuple[Dict[str, Any], str, str]:
|
||
ext = q.ext or {}
|
||
cloze = ext.get('cloze') or {}
|
||
correct_word = cloze.get('correct_word')
|
||
# Support multiple selections: treat as correct if any selected matches a correct answer
|
||
selection_list = [s for s in (cloze_options or []) if isinstance(s, str) and s.strip()]
|
||
input_str = selection_list[0] if selection_list else ''
|
||
def _norm(v):
|
||
try:
|
||
return str(v).strip().lower()
|
||
except Exception:
|
||
return str(v)
|
||
# correct answers may be a single string or a list
|
||
correct_candidates = []
|
||
if isinstance(correct_word, list):
|
||
correct_candidates = [cw for cw in correct_word if isinstance(cw, str) and cw.strip()]
|
||
elif isinstance(correct_word, str) and correct_word.strip():
|
||
correct_candidates = [correct_word]
|
||
correct_set = set(_norm(cw) for cw in correct_candidates)
|
||
|
||
user_correct = []
|
||
user_incorrect = []
|
||
for s in selection_list:
|
||
if _norm(s) in correct_set:
|
||
user_correct.append(s)
|
||
else:
|
||
user_incorrect.append({'content': s, 'error_type': None, 'error_reason': None})
|
||
|
||
if user_correct and not user_incorrect:
|
||
is_correct = 'correct'
|
||
result_text = '完全匹配'
|
||
evaluation = {'type': 'cloze', 'result': result_text, 'detail': is_correct, 'selected': {'correct': user_correct, 'incorrect': []}, 'missing_correct': []}
|
||
elif user_correct:
|
||
is_correct = 'partial'
|
||
result_text = '部分匹配'
|
||
evaluation = {'type': 'cloze', 'result': result_text, 'detail': is_correct, 'selected': {'correct': user_correct, 'incorrect': user_incorrect}, 'missing_correct': []}
|
||
else:
|
||
is_correct = 'incorrect'
|
||
result_text = '完全错误'
|
||
evaluation = {'type': 'cloze', 'result': result_text, 'detail': is_correct, 'selected': {'correct': [], 'incorrect': user_incorrect}, 'missing_correct': [cw for cw in correct_candidates]}
|
||
return evaluation, is_correct, input_str
|
||
|
||
async def submit_attempt(self, question_id: int, exercise_id: int, user_id: int, mode: str, selected_options: Optional[List[str]] = None, input_text: Optional[str] = None, cloze_options: Optional[List[str]] = None, session_id: Optional[int] = None, is_trial: bool = False) -> Dict[str, Any]:
|
||
async with async_db_session.begin() as db:
|
||
q = await qa_question_dao.get(db, question_id)
|
||
if not q or q.exercise_id != exercise_id:
|
||
raise errors.NotFoundError(msg='Question not found')
|
||
|
||
# Optimization: If trial mode and synchronous evaluation (Choice/Cloze), skip DB persistence
|
||
if is_trial:
|
||
if mode == EXERCISE_TYPE_CHOICE:
|
||
evaluation, _, selected_list = self._evaluate_choice(q, selected_options)
|
||
return {
|
||
'session_id': None,
|
||
'type': 'choice',
|
||
'choice': {
|
||
'options': selected_list,
|
||
'evaluation': evaluation
|
||
}
|
||
}
|
||
elif mode == EXERCISE_TYPE_CLOZE:
|
||
c_opts = cloze_options
|
||
if not c_opts and input_text:
|
||
c_opts = [input_text]
|
||
evaluation, _, input_str = self._evaluate_cloze(q, c_opts)
|
||
return {
|
||
'session_id': None,
|
||
'type': 'cloze',
|
||
'cloze': {
|
||
'input': input_str,
|
||
'evaluation': evaluation
|
||
}
|
||
}
|
||
|
||
recording_id = None
|
||
attempt = await qa_attempt_dao.get_latest_by_user_question(db, user_id=user_id, question_id=question_id)
|
||
if attempt:
|
||
attempt.task_id = None
|
||
attempt.choice_options = selected_options if mode == EXERCISE_TYPE_CHOICE else attempt.choice_options
|
||
if mode == EXERCISE_TYPE_CLOZE:
|
||
if isinstance(cloze_options, list) and cloze_options:
|
||
attempt.cloze_options = cloze_options[0]
|
||
elif input_text:
|
||
attempt.cloze_options = input_text
|
||
attempt.input_text = input_text if mode == EXERCISE_TYPE_FREE_TEXT else attempt.input_text
|
||
attempt.status = 'pending'
|
||
ext0 = attempt.ext or {}
|
||
if session_id:
|
||
ext0['session_id'] = session_id
|
||
if is_trial:
|
||
ext0['is_trial'] = True
|
||
elif 'is_trial' in ext0:
|
||
del ext0['is_trial']
|
||
attempt.ext = ext0
|
||
await db.flush()
|
||
else:
|
||
attempt = await qa_attempt_dao.create(db, {
|
||
'question_id': question_id,
|
||
'exercise_id': exercise_id,
|
||
'user_id': user_id,
|
||
'task_id': None,
|
||
'recording_id': recording_id,
|
||
'choice_options': selected_options if mode == EXERCISE_TYPE_CHOICE else None,
|
||
'cloze_options': ((cloze_options[0] if isinstance(cloze_options, list) and cloze_options else (input_text if input_text else None)) if mode == EXERCISE_TYPE_CLOZE else None),
|
||
'input_text': input_text if mode == EXERCISE_TYPE_FREE_TEXT else None,
|
||
'status': 'pending',
|
||
'evaluation': None,
|
||
'ext': {'is_trial': True} if is_trial else None,
|
||
})
|
||
|
||
if not is_trial:
|
||
s = await qa_session_dao.get_latest_by_user_exercise(db, user_id, exercise_id)
|
||
if s and s.exercise_id == exercise_id:
|
||
prog = dict(s.progress or {})
|
||
attempts = list(prog.get('attempts') or [])
|
||
replaced = False
|
||
for idx, a in enumerate(attempts):
|
||
if a.get('question_id') == question_id and a.get('mode') == mode:
|
||
attempts[idx] = {
|
||
'attempt_id': attempt.id,
|
||
'question_id': str(question_id),
|
||
'mode': mode,
|
||
'created_at': datetime.now().isoformat(),
|
||
'is_correct': a.get('is_correct'),
|
||
}
|
||
replaced = True
|
||
break
|
||
if not replaced:
|
||
attempts.append({
|
||
'attempt_id': attempt.id,
|
||
'question_id': str(question_id),
|
||
'mode': mode,
|
||
'created_at': datetime.now().isoformat(),
|
||
'is_correct': None,
|
||
})
|
||
prog['answered'] = int(prog.get('answered') or 0) + 1
|
||
prog['attempts'] = attempts
|
||
s.progress = prog
|
||
attempt.ext = {**(attempt.ext or {}), 'session_id': s.id}
|
||
await db.flush()
|
||
|
||
if mode == EXERCISE_TYPE_FREE_TEXT:
|
||
attempt.ext = {**(attempt.ext or {}), 'type': 'free_text', 'free_text': {'text': attempt.input_text or '', 'evaluation': None}}
|
||
await db.flush()
|
||
async with async_db_session.begin() as db2:
|
||
task = await image_task_dao.create_task(db2, CreateImageTaskParam(
|
||
image_id=q.image_id,
|
||
user_id=user_id,
|
||
dict_level=DictLevel.LEVEL1.value,
|
||
ref_type='qa_question_attempt',
|
||
ref_id=attempt.id,
|
||
status=ImageTaskStatus.PENDING,
|
||
))
|
||
await db2.flush()
|
||
asyncio.create_task(self._process_attempt_evaluation(task.id, user_id))
|
||
session_id_val = (attempt.ext or {}).get('session_id')
|
||
return {
|
||
'session_id': str(session_id_val) if session_id_val is not None else None,
|
||
'type': 'free_text',
|
||
'free_text': {
|
||
'text': attempt.input_text or '',
|
||
'evaluation': None
|
||
}
|
||
}
|
||
# Synchronous evaluation for choice/cloze/variation
|
||
if mode == EXERCISE_TYPE_CHOICE:
|
||
evaluation, is_correct, selected_list = self._evaluate_choice(q, attempt.choice_options)
|
||
attempt.ext = {**(attempt.ext or {}), 'type': 'choice', 'choice': {'options': selected_list, 'evaluation': evaluation}}
|
||
await db.flush()
|
||
merged_eval = dict(attempt.evaluation or {})
|
||
merged_eval['choice'] = {'options': selected_list, 'evaluation': evaluation}
|
||
await qa_attempt_dao.update_status(db, attempt.id, 'completed', merged_eval)
|
||
|
||
if not is_trial:
|
||
s = await qa_session_dao.get_latest_by_user_exercise(db, user_id, exercise_id)
|
||
if s and s.exercise_id == attempt.exercise_id:
|
||
prog = dict(s.progress or {})
|
||
attempts = list(prog.get('attempts') or [])
|
||
prev = None
|
||
for a in attempts:
|
||
if a.get('attempt_id') == attempt.id:
|
||
prev = a.get('is_correct')
|
||
a['is_correct'] = is_correct
|
||
break
|
||
prev_correct = 1 if prev == 'correct' else 0
|
||
new_correct = 1 if is_correct == 'correct' else 0
|
||
correct_inc = new_correct - prev_correct
|
||
prog['attempts'] = attempts
|
||
prog['correct'] = int(prog.get('correct') or 0) + correct_inc
|
||
s.progress = prog
|
||
await db.flush()
|
||
await db.commit()
|
||
session_id_val = (attempt.ext or {}).get('session_id')
|
||
return {
|
||
'session_id': str(session_id_val) if session_id_val is not None else None,
|
||
'type': 'choice',
|
||
'choice': {
|
||
'options': selected_list,
|
||
'evaluation': evaluation
|
||
}
|
||
}
|
||
|
||
if mode == EXERCISE_TYPE_CLOZE:
|
||
c_opts: List[str] = []
|
||
if isinstance(cloze_options, list) and cloze_options:
|
||
c_opts = cloze_options
|
||
elif input_text:
|
||
c_opts = [input_text]
|
||
elif attempt.cloze_options:
|
||
c_opts = [attempt.cloze_options]
|
||
if cloze_options:
|
||
c_opts = cloze_options
|
||
|
||
evaluation, is_correct, input_str = self._evaluate_cloze(q, c_opts)
|
||
attempt.ext = {**(attempt.ext or {}), 'type': 'cloze', 'cloze': {'input': input_str, 'evaluation': evaluation}}
|
||
await db.flush()
|
||
merged_eval = dict(attempt.evaluation or {})
|
||
merged_eval['cloze'] = {'input': input_str, 'evaluation': evaluation}
|
||
await qa_attempt_dao.update_status(db, attempt.id, 'completed', merged_eval)
|
||
|
||
if not is_trial:
|
||
s = await qa_session_dao.get_latest_by_user_exercise(db, user_id, exercise_id)
|
||
if s and s.exercise_id == attempt.exercise_id:
|
||
prog = dict(s.progress or {})
|
||
attempts = list(prog.get('attempts') or [])
|
||
prev = None
|
||
for a in attempts:
|
||
if a.get('attempt_id') == attempt.id:
|
||
prev = a.get('is_correct')
|
||
a['is_correct'] = is_correct
|
||
break
|
||
prev_correct = 1 if prev == 'correct' else 0
|
||
new_correct = 1 if is_correct == 'correct' else 0
|
||
correct_inc = new_correct - prev_correct
|
||
prog['attempts'] = attempts
|
||
prog['correct'] = int(prog.get('correct') or 0) + correct_inc
|
||
s.progress = prog
|
||
await db.flush()
|
||
await db.commit()
|
||
session_id_val = (attempt.ext or {}).get('session_id')
|
||
return {
|
||
'session_id': str(session_id_val) if session_id_val is not None else None,
|
||
'type': 'cloze',
|
||
'cloze': {
|
||
'input': input_str,
|
||
'evaluation': evaluation
|
||
}
|
||
}
|
||
|
||
if mode == 'variation':
|
||
ext_q = q.ext or {}
|
||
correct_file_id = ext_q.get('file_id')
|
||
|
||
# Get user selected file_id from selected_options
|
||
user_file_id = None
|
||
if selected_options and len(selected_options) > 0:
|
||
try:
|
||
user_file_id = selected_options[0]
|
||
except (ValueError, TypeError):
|
||
user_file_id = None
|
||
|
||
is_correct = 'incorrect'
|
||
if user_file_id is not None and correct_file_id is not None and int(user_file_id) == int(correct_file_id):
|
||
is_correct = 'correct'
|
||
|
||
evaluation = {'type': 'variation', 'detail':is_correct, 'result': is_correct, 'correct_file_id': correct_file_id, 'user_file_id': user_file_id}
|
||
attempt.ext = {**(attempt.ext or {}), 'type': 'variation', 'variation': {'file_id': user_file_id, 'evaluation': evaluation}}
|
||
await db.flush()
|
||
merged_eval = dict(attempt.evaluation or {})
|
||
merged_eval['variation'] = {'file_id': user_file_id, 'evaluation': evaluation}
|
||
await qa_attempt_dao.update_status(db, attempt.id, 'completed', merged_eval)
|
||
|
||
if not is_trial:
|
||
s = await qa_session_dao.get_latest_by_user_exercise(db, user_id, exercise_id)
|
||
if s and s.exercise_id == attempt.exercise_id:
|
||
prog = dict(s.progress or {})
|
||
attempts = list(prog.get('attempts') or [])
|
||
prev = None
|
||
for a in attempts:
|
||
if a.get('attempt_id') == attempt.id:
|
||
prev = a.get('is_correct')
|
||
a['is_correct'] = is_correct
|
||
break
|
||
prev_correct = 1 if prev == 'correct' else 0
|
||
new_correct = 1 if is_correct == 'correct' else 0
|
||
correct_inc = new_correct - prev_correct
|
||
prog['attempts'] = attempts
|
||
prog['correct'] = int(prog.get('correct') or 0) + correct_inc
|
||
s.progress = prog
|
||
await db.flush()
|
||
await db.commit()
|
||
session_id_val = (attempt.ext or {}).get('session_id')
|
||
return {
|
||
'session_id': str(session_id_val) if session_id_val is not None else None,
|
||
'type': 'variation',
|
||
'variation': {
|
||
'file_id': user_file_id,
|
||
'evaluation': evaluation
|
||
}
|
||
}
|
||
|
||
async def _process_attempt_evaluation(self, task_id: int, user_id: int):
|
||
async with background_db_session() as db:
|
||
task = await image_task_dao.get(db, task_id)
|
||
if not task:
|
||
return
|
||
await image_task_dao.update_task_status(db, task_id, ImageTaskStatus.PROCESSING)
|
||
attempt = await qa_attempt_dao.get(db, task.ref_id)
|
||
if not attempt:
|
||
await image_task_dao.update_task_status(db, task_id, ImageTaskStatus.FAILED, error_message='Attempt not found')
|
||
await db.commit()
|
||
return
|
||
|
||
is_trial = (attempt.ext or {}).get('is_trial', False)
|
||
|
||
# Only async evaluation for free_text/audio attempts
|
||
q = await qa_question_dao.get(db, attempt.question_id)
|
||
user_text = attempt.input_text or ''
|
||
answers = (q.ext or {}).get('answers') or {}
|
||
prompt = (
|
||
'根据给定标准答案,判断用户回答是否正确,输出JSON:{is_correct: correct|partial|incorrect, feedback: string}。'
|
||
'标准答案:' + json.dumps(answers, ensure_ascii=False) +
|
||
'用户回答:' + user_text
|
||
)
|
||
res = await self._call_llm_chat(prompt=prompt, image_id=q.image_id, user_id=user_id, chat_type='qa_attempt')
|
||
if not res.get('success'):
|
||
await image_task_dao.update_task_status(db, task_id, ImageTaskStatus.FAILED, error_message=res.get('error'))
|
||
await db.commit()
|
||
return
|
||
try:
|
||
parsed = json.loads(res.get('result')) if isinstance(res.get('result'), str) else res.get('result')
|
||
except Exception:
|
||
parsed = {}
|
||
evaluation = {'type': 'free_text', 'result': parsed.get('is_correct'), 'feedback': parsed.get('feedback')}
|
||
# update ext with free_text details
|
||
attempt.ext = {**(attempt.ext or {}), 'type': 'free_text', 'free_text': {'text': attempt.input_text or '', 'evaluation': evaluation}}
|
||
await db.flush()
|
||
merged_eval = dict(attempt.evaluation or {})
|
||
merged_eval['free_text'] = {'text': attempt.input_text or '', 'evaluation': evaluation}
|
||
await qa_attempt_dao.update_status(db, attempt.id, 'completed', merged_eval)
|
||
await image_task_dao.update_task_status(db, task_id, ImageTaskStatus.COMPLETED, result={'mode': 'free_text', 'token_usage': res.get('token_usage') or {}})
|
||
|
||
if not is_trial:
|
||
s = await qa_session_dao.get_latest_by_user_exercise(db, user_id, attempt.exercise_id)
|
||
if s and s.exercise_id == attempt.exercise_id:
|
||
prog = dict(s.progress or {})
|
||
attempts = list(prog.get('attempts') or [])
|
||
prev = None
|
||
for a in attempts:
|
||
if a.get('attempt_id') == attempt.id:
|
||
prev = a.get('is_correct')
|
||
a['is_correct'] = parsed.get('is_correct')
|
||
break
|
||
prev_correct = 1 if prev == 'correct' else 0
|
||
new_correct = 1 if parsed.get('is_correct') == 'correct' else 0
|
||
correct_inc = new_correct - prev_correct
|
||
prog['attempts'] = attempts
|
||
prog['correct'] = int(prog.get('correct') or 0) + correct_inc
|
||
s.progress = prog
|
||
await db.flush()
|
||
await db.commit()
|
||
|
||
async def _call_llm_chat(self, prompt: str, image_id: int, user_id: int, chat_type: str) -> Dict[str, Any]:
|
||
model_type = (settings.LLM_MODEL_TYPE or "").lower()
|
||
messages = [{"role": "system", "content": "You are a helpful assistant."}, {'role': 'user', 'content': prompt}]
|
||
if model_type == 'qwen':
|
||
try:
|
||
qres = await Qwen.chat(messages=[{'role': 'user', 'content': prompt}], image_id=image_id, user_id=user_id, api_type=chat_type)
|
||
if qres and qres.get('success'):
|
||
return {"success": True, "result": qres.get("result"), "token_usage": qres.get("token_usage") or {}}
|
||
except Exception as e:
|
||
return {"success": False, "error": str(e)}
|
||
return {"success": False, "error": "LLM call failed"}
|
||
else:
|
||
try:
|
||
res = await Hunyuan.chat(messages=messages, image_id=image_id, user_id=user_id, system_prompt=None, chat_type=chat_type)
|
||
if res and res.get('success'):
|
||
return res
|
||
except Exception as e:
|
||
return {"success": False, "error": str(e)}
|
||
return {"success": False, "error": "LLM call failed"}
|
||
|
||
async def get_attempt_task_status(self, task_id: int) -> Dict[str, Any]:
|
||
return await self.get_task_status(task_id)
|
||
|
||
|
||
async def get_question_evaluation(self, question_id: int, user_id: int) -> Dict[str, Any]:
|
||
async with async_db_session() as db:
|
||
# Exclude trial attempts by default so they don't pollute normal mode history
|
||
latest = await qa_attempt_dao.get_latest_completed_by_user_question(db, user_id=user_id, question_id=question_id, exclude_trial=True)
|
||
if not latest:
|
||
latest = await qa_attempt_dao.get_latest_valid_by_user_question(db, user_id=user_id, question_id=question_id, exclude_trial=True)
|
||
if not latest:
|
||
return {}
|
||
evalution = latest.evaluation or {}
|
||
session_id = evalution.get('session_id')
|
||
ret = {
|
||
'session_id': str(session_id) if session_id is not None else None,
|
||
'type': evalution.get('type'),
|
||
}
|
||
if 'choice' in evalution:
|
||
ch = evalution.get('choice') or {}
|
||
ret['choice'] = {
|
||
'options': ch.get('options') or [],
|
||
'evaluation': ch.get('evaluation') or None,
|
||
}
|
||
if 'cloze' in evalution:
|
||
cz = evalution.get('cloze') or {}
|
||
ret['cloze'] = {
|
||
'input': cz.get('input') or '',
|
||
'evaluation': cz.get('evaluation') or None,
|
||
}
|
||
if 'free_text' in evalution:
|
||
ft = evalution.get('free_text') or {}
|
||
ret['free_text'] = {
|
||
'text': ft.get('text') or '',
|
||
'evaluation': ft.get('evaluation') or None,
|
||
}
|
||
if 'variation' in evalution:
|
||
va = evalution.get('variation') or {}
|
||
ret['variation'] = {
|
||
'file_id': va.get('file_id'),
|
||
'evaluation': va.get('evaluation') or None,
|
||
}
|
||
return ret
|
||
|
||
|
||
async def persist_image_from_url(self, image_url: str, user_id: int, filename: str = "generated_variation.png") -> int:
|
||
"""Download image from URL and persist to system file storage"""
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(image_url) as response:
|
||
if response.status != 200:
|
||
raise Exception(f"Failed to download image: {response.status}")
|
||
content = await response.read()
|
||
|
||
file_hash = hashlib.sha256(content).hexdigest()
|
||
content_type = "image/png" # Default to png as per filename default
|
||
|
||
# 1. Create DB record first (Pending state)
|
||
async with async_db_session.begin() as db:
|
||
meta_init = FileMetadata(
|
||
file_name=filename,
|
||
content_type=content_type,
|
||
file_size=0,
|
||
extra=None,
|
||
)
|
||
t_params = AddFileParam(
|
||
file_hash=file_hash,
|
||
file_name=filename,
|
||
content_type=content_type,
|
||
file_size=0,
|
||
storage_type="cos",
|
||
storage_path=None,
|
||
metadata_info=meta_init,
|
||
)
|
||
t_file = await file_dao.create(db, t_params)
|
||
await db.flush()
|
||
# Capture ID for use outside transaction
|
||
file_id = t_file.id
|
||
|
||
# 2. Upload to COS
|
||
# Note: We download the image because COS standard PutObject requires a body (bytes/stream).
|
||
# Direct fetch from URL (AsyncFetch) is asynchronous and not suitable for this synchronous flow.
|
||
cos_client = CosClient()
|
||
key = f"{file_id}_{filename}"
|
||
cos_client.upload_object(key, content)
|
||
|
||
# 3. Update DB record (Completed state)
|
||
async with async_db_session.begin() as db:
|
||
meta = FileMetadata(
|
||
file_name=filename,
|
||
content_type=content_type,
|
||
file_size=len(content),
|
||
extra=None,
|
||
)
|
||
|
||
update_params = UpdateFileParam(
|
||
file_hash=file_hash,
|
||
storage_path=key,
|
||
metadata_info=meta,
|
||
details={
|
||
"key": key,
|
||
"source": "ai_generation",
|
||
"user_id": user_id
|
||
}
|
||
)
|
||
await file_dao.update(db, file_id, update_params)
|
||
|
||
return int(file_id)
|
||
|
||
async def generate_scene_variations(self, exercise_id: int, user_id: int, db: AsyncSession = None) -> Tuple[int, Dict[str, Any]]:
|
||
"""
|
||
Execute the advanced workflow:
|
||
1. Generate variations text
|
||
2. Generate images
|
||
3. Persist images
|
||
4. Update exercise
|
||
"""
|
||
# If db is provided, use it (assumed to be in a transaction).
|
||
# Otherwise create a new transaction.
|
||
# However, to avoid code duplication, we'll implement a context manager helper or just branching logic.
|
||
|
||
# Helper to get DB session
|
||
from contextlib import asynccontextmanager
|
||
|
||
@asynccontextmanager
|
||
async def get_db():
|
||
if db:
|
||
yield db
|
||
else:
|
||
async with async_db_session.begin() as new_db:
|
||
yield new_db
|
||
|
||
async with get_db() as session:
|
||
exercise = await qa_exercise_dao.get(session, exercise_id)
|
||
if not exercise:
|
||
raise errors.NotFoundError(msg='Exercise not found')
|
||
|
||
image = await image_dao.get(session, exercise.image_id)
|
||
if not image:
|
||
raise errors.NotFoundError(msg='Image not found')
|
||
|
||
# Prepare payload from image details
|
||
rr = (image.details or {}).get('recognition_result') or {}
|
||
payload = {
|
||
'description': rr.get('description'),
|
||
'core_vocab': rr.get('core_vocab'),
|
||
'collocations': rr.get('collocations'),
|
||
'scene_tag': rr.get('scene_tag')
|
||
}
|
||
|
||
# Run AI tasks outside transaction (to avoid long holding of DB connection if db was created here)
|
||
# Note: If db was passed in from ImageTaskService, this is technically inside the outer transaction scope,
|
||
# but since we are not executing SQL here, it's just holding the session object.
|
||
gen_res = await SceneVariationGenerator.generate(payload, image.id, user_id)
|
||
# print(gen_res)
|
||
if not gen_res.get('success'):
|
||
raise Exception(f"Variation generation failed: {gen_res.get('error')}")
|
||
|
||
variations = gen_res.get('result', {}).get('new_descriptions', [])
|
||
token_usage = gen_res.get('token_usage', {})
|
||
|
||
if not variations:
|
||
raise Exception("No variations generated")
|
||
|
||
# Step 2: Generate images (Parallel)
|
||
variations_with_images = await Illustrator.process_variations(image.file_id, user_id, variations)
|
||
|
||
# Step 3: Persist images and update data
|
||
for i, v in enumerate(variations_with_images):
|
||
if v.get('success') and v.get('generated_image_url'):
|
||
try:
|
||
# Construct filename: exercise_{exercise_id}_variation_{image_id}.png
|
||
img_id = v.get('image_id', i + 1)
|
||
filename = f"exercise_{exercise_id}_variation_{img_id}.png"
|
||
|
||
file_id = await self.persist_image_from_url(v['generated_image_url'], user_id, filename=filename)
|
||
v['file_id'] = file_id
|
||
except Exception as e:
|
||
v['persist_error'] = str(e)
|
||
|
||
# Step 4: Update exercise
|
||
async with get_db() as session:
|
||
exercise = await qa_exercise_dao.get(session, exercise_id)
|
||
if not exercise:
|
||
# Should not happen given previous check, but good for safety
|
||
raise errors.NotFoundError(msg='Exercise not found')
|
||
|
||
# Create questions from variations
|
||
created = 0
|
||
for v in variations_with_images:
|
||
if v.get('success') and v.get('file_id'):
|
||
await qa_question_dao.create(session, {
|
||
'exercise_id': exercise.id,
|
||
'image_id': exercise.image_id,
|
||
'question': v.get('desc_en') or '',
|
||
'user_id': user_id,
|
||
'ext': {
|
||
'file_id': str(v.get('file_id')),
|
||
'desc_zh': v.get('desc_zh'),
|
||
'modification_type': v.get('modification_type'),
|
||
'modification_point': v.get('modification_point'),
|
||
'core_vocab': v.get('core_vocab'),
|
||
'collocation': v.get('collocation'),
|
||
'learning_note': v.get('learning_note'),
|
||
},
|
||
})
|
||
created += 1
|
||
|
||
ext = dict(exercise.ext or {})
|
||
ext['new_descriptions'] = variations_with_images
|
||
exercise.ext = ext
|
||
from sqlalchemy.orm.attributes import flag_modified
|
||
flag_modified(exercise, "ext")
|
||
|
||
exercise.question_count = created
|
||
exercise.status = 'published' if created > 0 else 'draft'
|
||
await session.flush()
|
||
|
||
if created > 0:
|
||
existing_session = await qa_session_dao.get_latest_by_user_exercise(session, user_id, exercise.id)
|
||
if not existing_session:
|
||
prog = {'current_index': 0, 'answered': 0, 'correct': 0, 'attempts': [], 'total_questions': created}
|
||
await qa_session_dao.create(session, {
|
||
'exercise_id': exercise.id,
|
||
'starter_user_id': user_id,
|
||
'share_id': None,
|
||
'status': 'ongoing',
|
||
'started_at': datetime.now(),
|
||
'completed_at': None,
|
||
'progress': prog,
|
||
'score': None,
|
||
'ext': None,
|
||
})
|
||
await session.flush()
|
||
|
||
return len(variations_with_images), token_usage
|
||
|
||
qa_service = QaService()
|