This commit is contained in:
Felix
2026-01-23 11:48:08 +08:00
parent 7ade571e13
commit 898a7e902b
4 changed files with 163 additions and 32 deletions

View File

@@ -19,6 +19,7 @@ from backend.app.ai.schema.qa import (
ConversationReplyRequest,
ConversationReplyResponse,
ConversationLatestResponse,
ConversationListResponse,
)
from backend.common.response.response_schema import response_base, ResponseSchemaModel
from backend.common.security.jwt import DependsJwtAuth
@@ -51,11 +52,11 @@ async def start_conversation(request: Request, obj: ConversationStartRequest) ->
res = await qa_service.start_conversation(
image_id=obj.image_id,
user_id=request.user.id,
scene=obj.scene,
event=obj.event,
style=obj.style,
user_role=obj.user_role,
assistant_role=obj.assistant_role,
scene=[item.model_dump() for item in obj.scene],
event=[item.model_dump() for item in obj.event],
style=obj.style.model_dump() if obj.style else None,
user_role=obj.user_role.model_dump() if obj.user_role else None,
assistant_role=obj.assistant_role.model_dump() if obj.assistant_role else None,
level=obj.level,
info=obj.info,
)
@@ -73,6 +74,17 @@ async def reply_conversation(request: Request, session_id: int, obj: Conversatio
return response_base.success(data=ConversationReplyResponse(**res))
@router.get('/conversations/{image_id}/list', summary='获取图片自由对话列表', dependencies=[DependsJwtAuth])
async def list_conversations(request: Request, image_id: int, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1, le=100)) -> ResponseSchemaModel[ConversationListResponse]:
res = await qa_service.list_conversations_by_image(
image_id=image_id,
user_id=request.user.id,
page=page,
page_size=page_size,
)
return response_base.success(data=ConversationListResponse(**res))
@router.get('/conversations/{session_id}/latest', summary='获取图片自由对话最新消息', dependencies=[DependsJwtAuth])
async def get_conversation_latest(request: Request, session_id: int) -> ResponseSchemaModel[ConversationLatestResponse]:
res = await qa_service.get_latest_messages(session_id=session_id, user_id=request.user.id)

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional, List
from sqlalchemy import select, and_
from typing import Optional, List, Tuple
from sqlalchemy import select, and_, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.ai.model.qa import QaExercise, QaQuestion, QaQuestionAttempt, QaPracticeSession
@@ -156,6 +156,35 @@ class QaPracticeSessionCRUD(CRUDPlus[QaPracticeSession]):
result = await db.execute(stmt)
return result.scalars().first()
async def get_list_by_image(
self,
db: AsyncSession,
image_id: int,
user_id: int,
page: int = 1,
page_size: int = 10
) -> Tuple[int, List[Tuple[QaPracticeSession, QaExercise]]]:
stmt = (
select(QaPracticeSession, QaExercise)
.join(QaExercise, QaPracticeSession.exercise_id == QaExercise.id)
.where(
and_(
QaPracticeSession.starter_user_id == user_id,
QaExercise.image_id == image_id,
QaExercise.type == 'free_conversation'
)
)
)
# Count
count_stmt = select(func.count()).select_from(stmt.subquery())
total = await db.scalar(count_stmt) or 0
# Pagination
stmt = stmt.order_by(QaPracticeSession.id.desc()).offset((page - 1) * page_size).limit(page_size)
result = await db.execute(stmt)
return total, result.all()
qa_session_dao = QaPracticeSessionCRUD(QaPracticeSession)
qa_attempt_dao = QaQuestionAttemptCRUD(QaQuestionAttempt)

View File

@@ -23,6 +23,7 @@ class QaExerciseSchema(SchemaBase):
description: Optional[str] = None
status: str
question_count: int
ext: Optional[Dict[str, Any]] = None
class QaQuestionSchema(SchemaBase):
@@ -172,13 +173,18 @@ class ImageConversationInitResponse(SchemaBase):
latest_session: Optional[Dict[str, Any]] = None
class BilingualItem(SchemaBase):
en: str
zh: str
class ConversationStartRequest(SchemaBase):
image_id: int
scene: List[str]
event: List[str]
style: Optional[str] = None
user_role: Optional[str] = None
assistant_role: Optional[str] = None
scene: List[BilingualItem]
event: List[BilingualItem]
style: Optional[BilingualItem] = None
user_role: Optional[BilingualItem] = None
assistant_role: Optional[BilingualItem] = None
level: Optional[str] = None
info: Optional[str] = None
@@ -213,6 +219,7 @@ class ConversationStartResponse(SchemaBase):
task_id: str
status: str
exercise_id: Optional[str] = None
session_id: Optional[str] = None
class ConversationReplyRequest(SchemaBase):
@@ -239,6 +246,22 @@ class ConversationSessionSchema(SchemaBase):
messages: List[ConversationMessageSchema] = []
class ConversationListItemSchema(SchemaBase):
session_id: str
scene: List[BilingualItem]
event: List[BilingualItem]
user_role: Optional[BilingualItem] = None
style: Optional[BilingualItem] = None
created_at: Optional[str] = None
class ConversationListResponse(SchemaBase):
total: int
items: List[ConversationListItemSchema]
page: int
page_size: int
CreateAttemptTaskResponse.model_rebuild()
AttemptResultResponse.model_rebuild()
QuestionEvaluationResponse.model_rebuild()

View File

@@ -285,12 +285,24 @@ class ConversationStartProcessor(TaskProcessor):
description = ''
params = exercise.ext or {}
# Helper to extract 'en' from bilingual items
def get_en_list(items):
if not items:
return []
return [item.get('en') for item in items if isinstance(item, dict) and item.get('en')]
def get_en_item(item):
if not item or not isinstance(item, dict):
return None
return item.get('en')
prompt = get_free_conversation_start_prompt(
scene=params.get('scene'),
event=params.get('event'),
user_role=params.get('user_role'),
assistant_role=params.get('assistant_role'),
style=params.get('style'),
scene=get_en_list(params.get('scene')),
event=get_en_list(params.get('event')),
user_role=get_en_item(params.get('user_role')),
assistant_role=get_en_item(params.get('assistant_role')),
style=get_en_item(params.get('style')),
level=params.get('level'),
info=params.get('info'),
description=description,
@@ -476,14 +488,25 @@ class ConversationReplyProcessor(TaskProcessor):
params = exercise.ext or {}
# Helper to extract 'en' from bilingual items
def get_en_list(items):
if not items:
return []
return [item.get('en') for item in items if isinstance(item, dict) and item.get('en')]
def get_en_item(item):
if not item or not isinstance(item, dict):
return None
return item.get('en')
prompt = get_free_conversation_reply_prompt(
history=history,
user_input=user_input,
scene=params.get('scene'),
event=params.get('event'),
user_role=params.get('user_role'),
assistant_role=params.get('assistant_role'),
style=params.get('style'),
scene=get_en_list(params.get('scene')),
event=get_en_list(params.get('event')),
user_role=get_en_item(params.get('user_role')),
assistant_role=get_en_item(params.get('assistant_role')),
style=get_en_item(params.get('style')),
level=params.get('level'),
info=params.get('info'),
description=description,
@@ -611,11 +634,11 @@ class QaService:
self,
image_id: int,
user_id: int,
scene: List[str],
event: List[str],
style: Optional[str] = None,
user_role: Optional[str] = None,
assistant_role: Optional[str] = None,
scene: List[Dict[str, str]],
event: List[Dict[str, str]],
style: Optional[Dict[str, str]] = None,
user_role: Optional[Dict[str, str]] = None,
assistant_role: Optional[Dict[str, str]] = None,
level: Optional[str] = None,
info: Optional[str] = None,
) -> Dict[str, Any]:
@@ -751,12 +774,15 @@ class QaService:
questions = await qa_question_dao.get_by_exercise_id(db, exercise_id)
messages = []
for q in questions:
# AI Message
total_questions = len(questions)
for idx, q in enumerate(questions):
is_last = (idx == total_questions - 1)
ext = q.ext or {}
messages.append({
"role": "assistant",
"content": {
# AI Message
if is_last:
# Full content for the last message
content = {
"response_en": q.question,
"response_zh": ext.get("response_zh"),
"prompt_en": ext.get("prompt_en"),
@@ -764,6 +790,16 @@ class QaService:
"alternative_responses": ext.get("alternative_responses"),
"correction": ext.get("correction"),
}
else:
# Simplified content for historical messages
content = {
"response_en": q.question,
"response_zh": ext.get("response_zh"),
}
messages.append({
"role": "assistant",
"content": content
})
# User Reply (Attempt)
@@ -824,6 +860,37 @@ class QaService:
"messages": messages,
}
async def list_conversations_by_image(
self,
image_id: int,
user_id: int,
page: int = 1,
page_size: int = 10
) -> Dict[str, Any]:
async with async_db_session() as db:
total, results = await qa_session_dao.get_list_by_image(
db, image_id, user_id, page, page_size
)
items = []
for session, exercise in results:
ext = exercise.ext or {}
items.append({
"session_id": str(session.id),
"scene": ext.get("scene") or [],
"event": ext.get("event") or [],
"user_role": ext.get("user_role") or {},
"style": ext.get("style") or {},
"created_at": (session.created_time.strftime("%Y-%m-%d %H:%M:%S") if session.created_time else None),
})
return {
"total": total,
"items": items,
"page": page,
"page_size": page_size,
}
async def create_exercise_task(self, image_id: int, user_id: int, type: Optional[str] = "scene_basic") -> Dict[str, Any]:
is_conversation_init = type == 'init_conversion'