fix code
This commit is contained in:
@@ -19,6 +19,7 @@ from backend.app.ai.schema.qa import (
|
||||
ConversationReplyRequest,
|
||||
ConversationReplyResponse,
|
||||
ConversationLatestResponse,
|
||||
ConversationListResponse,
|
||||
)
|
||||
from backend.common.response.response_schema import response_base, ResponseSchemaModel
|
||||
from backend.common.security.jwt import DependsJwtAuth
|
||||
@@ -51,11 +52,11 @@ async def start_conversation(request: Request, obj: ConversationStartRequest) ->
|
||||
res = await qa_service.start_conversation(
|
||||
image_id=obj.image_id,
|
||||
user_id=request.user.id,
|
||||
scene=obj.scene,
|
||||
event=obj.event,
|
||||
style=obj.style,
|
||||
user_role=obj.user_role,
|
||||
assistant_role=obj.assistant_role,
|
||||
scene=[item.model_dump() for item in obj.scene],
|
||||
event=[item.model_dump() for item in obj.event],
|
||||
style=obj.style.model_dump() if obj.style else None,
|
||||
user_role=obj.user_role.model_dump() if obj.user_role else None,
|
||||
assistant_role=obj.assistant_role.model_dump() if obj.assistant_role else None,
|
||||
level=obj.level,
|
||||
info=obj.info,
|
||||
)
|
||||
@@ -73,6 +74,17 @@ async def reply_conversation(request: Request, session_id: int, obj: Conversatio
|
||||
return response_base.success(data=ConversationReplyResponse(**res))
|
||||
|
||||
|
||||
@router.get('/conversations/{image_id}/list', summary='获取图片自由对话列表', dependencies=[DependsJwtAuth])
|
||||
async def list_conversations(request: Request, image_id: int, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1, le=100)) -> ResponseSchemaModel[ConversationListResponse]:
|
||||
res = await qa_service.list_conversations_by_image(
|
||||
image_id=image_id,
|
||||
user_id=request.user.id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return response_base.success(data=ConversationListResponse(**res))
|
||||
|
||||
|
||||
@router.get('/conversations/{session_id}/latest', summary='获取图片自由对话最新消息', dependencies=[DependsJwtAuth])
|
||||
async def get_conversation_latest(request: Request, session_id: int) -> ResponseSchemaModel[ConversationLatestResponse]:
|
||||
res = await qa_service.get_latest_messages(session_id=session_id, user_id=request.user.id)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import select, and_
|
||||
from typing import Optional, List, Tuple
|
||||
from sqlalchemy import select, and_, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy_crud_plus import CRUDPlus
|
||||
from backend.app.ai.model.qa import QaExercise, QaQuestion, QaQuestionAttempt, QaPracticeSession
|
||||
@@ -156,6 +156,35 @@ class QaPracticeSessionCRUD(CRUDPlus[QaPracticeSession]):
|
||||
result = await db.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_list_by_image(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
image_id: int,
|
||||
user_id: int,
|
||||
page: int = 1,
|
||||
page_size: int = 10
|
||||
) -> Tuple[int, List[Tuple[QaPracticeSession, QaExercise]]]:
|
||||
stmt = (
|
||||
select(QaPracticeSession, QaExercise)
|
||||
.join(QaExercise, QaPracticeSession.exercise_id == QaExercise.id)
|
||||
.where(
|
||||
and_(
|
||||
QaPracticeSession.starter_user_id == user_id,
|
||||
QaExercise.image_id == image_id,
|
||||
QaExercise.type == 'free_conversation'
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Count
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
total = await db.scalar(count_stmt) or 0
|
||||
|
||||
# Pagination
|
||||
stmt = stmt.order_by(QaPracticeSession.id.desc()).offset((page - 1) * page_size).limit(page_size)
|
||||
result = await db.execute(stmt)
|
||||
return total, result.all()
|
||||
|
||||
|
||||
qa_session_dao = QaPracticeSessionCRUD(QaPracticeSession)
|
||||
qa_attempt_dao = QaQuestionAttemptCRUD(QaQuestionAttempt)
|
||||
|
||||
@@ -23,6 +23,7 @@ class QaExerciseSchema(SchemaBase):
|
||||
description: Optional[str] = None
|
||||
status: str
|
||||
question_count: int
|
||||
ext: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class QaQuestionSchema(SchemaBase):
|
||||
@@ -172,13 +173,18 @@ class ImageConversationInitResponse(SchemaBase):
|
||||
latest_session: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class BilingualItem(SchemaBase):
|
||||
en: str
|
||||
zh: str
|
||||
|
||||
|
||||
class ConversationStartRequest(SchemaBase):
|
||||
image_id: int
|
||||
scene: List[str]
|
||||
event: List[str]
|
||||
style: Optional[str] = None
|
||||
user_role: Optional[str] = None
|
||||
assistant_role: Optional[str] = None
|
||||
scene: List[BilingualItem]
|
||||
event: List[BilingualItem]
|
||||
style: Optional[BilingualItem] = None
|
||||
user_role: Optional[BilingualItem] = None
|
||||
assistant_role: Optional[BilingualItem] = None
|
||||
level: Optional[str] = None
|
||||
info: Optional[str] = None
|
||||
|
||||
@@ -213,6 +219,7 @@ class ConversationStartResponse(SchemaBase):
|
||||
task_id: str
|
||||
status: str
|
||||
exercise_id: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
class ConversationReplyRequest(SchemaBase):
|
||||
@@ -239,6 +246,22 @@ class ConversationSessionSchema(SchemaBase):
|
||||
messages: List[ConversationMessageSchema] = []
|
||||
|
||||
|
||||
class ConversationListItemSchema(SchemaBase):
|
||||
session_id: str
|
||||
scene: List[BilingualItem]
|
||||
event: List[BilingualItem]
|
||||
user_role: Optional[BilingualItem] = None
|
||||
style: Optional[BilingualItem] = None
|
||||
created_at: Optional[str] = None
|
||||
|
||||
|
||||
class ConversationListResponse(SchemaBase):
|
||||
total: int
|
||||
items: List[ConversationListItemSchema]
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
CreateAttemptTaskResponse.model_rebuild()
|
||||
AttemptResultResponse.model_rebuild()
|
||||
QuestionEvaluationResponse.model_rebuild()
|
||||
|
||||
@@ -285,12 +285,24 @@ class ConversationStartProcessor(TaskProcessor):
|
||||
description = ''
|
||||
|
||||
params = exercise.ext or {}
|
||||
|
||||
# Helper to extract 'en' from bilingual items
|
||||
def get_en_list(items):
|
||||
if not items:
|
||||
return []
|
||||
return [item.get('en') for item in items if isinstance(item, dict) and item.get('en')]
|
||||
|
||||
def get_en_item(item):
|
||||
if not item or not isinstance(item, dict):
|
||||
return None
|
||||
return item.get('en')
|
||||
|
||||
prompt = get_free_conversation_start_prompt(
|
||||
scene=params.get('scene'),
|
||||
event=params.get('event'),
|
||||
user_role=params.get('user_role'),
|
||||
assistant_role=params.get('assistant_role'),
|
||||
style=params.get('style'),
|
||||
scene=get_en_list(params.get('scene')),
|
||||
event=get_en_list(params.get('event')),
|
||||
user_role=get_en_item(params.get('user_role')),
|
||||
assistant_role=get_en_item(params.get('assistant_role')),
|
||||
style=get_en_item(params.get('style')),
|
||||
level=params.get('level'),
|
||||
info=params.get('info'),
|
||||
description=description,
|
||||
@@ -476,14 +488,25 @@ class ConversationReplyProcessor(TaskProcessor):
|
||||
|
||||
params = exercise.ext or {}
|
||||
|
||||
# Helper to extract 'en' from bilingual items
|
||||
def get_en_list(items):
|
||||
if not items:
|
||||
return []
|
||||
return [item.get('en') for item in items if isinstance(item, dict) and item.get('en')]
|
||||
|
||||
def get_en_item(item):
|
||||
if not item or not isinstance(item, dict):
|
||||
return None
|
||||
return item.get('en')
|
||||
|
||||
prompt = get_free_conversation_reply_prompt(
|
||||
history=history,
|
||||
user_input=user_input,
|
||||
scene=params.get('scene'),
|
||||
event=params.get('event'),
|
||||
user_role=params.get('user_role'),
|
||||
assistant_role=params.get('assistant_role'),
|
||||
style=params.get('style'),
|
||||
scene=get_en_list(params.get('scene')),
|
||||
event=get_en_list(params.get('event')),
|
||||
user_role=get_en_item(params.get('user_role')),
|
||||
assistant_role=get_en_item(params.get('assistant_role')),
|
||||
style=get_en_item(params.get('style')),
|
||||
level=params.get('level'),
|
||||
info=params.get('info'),
|
||||
description=description,
|
||||
@@ -611,11 +634,11 @@ class QaService:
|
||||
self,
|
||||
image_id: int,
|
||||
user_id: int,
|
||||
scene: List[str],
|
||||
event: List[str],
|
||||
style: Optional[str] = None,
|
||||
user_role: Optional[str] = None,
|
||||
assistant_role: Optional[str] = None,
|
||||
scene: List[Dict[str, str]],
|
||||
event: List[Dict[str, str]],
|
||||
style: Optional[Dict[str, str]] = None,
|
||||
user_role: Optional[Dict[str, str]] = None,
|
||||
assistant_role: Optional[Dict[str, str]] = None,
|
||||
level: Optional[str] = None,
|
||||
info: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
@@ -751,12 +774,15 @@ class QaService:
|
||||
questions = await qa_question_dao.get_by_exercise_id(db, exercise_id)
|
||||
|
||||
messages = []
|
||||
for q in questions:
|
||||
# AI Message
|
||||
total_questions = len(questions)
|
||||
for idx, q in enumerate(questions):
|
||||
is_last = (idx == total_questions - 1)
|
||||
ext = q.ext or {}
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": {
|
||||
|
||||
# AI Message
|
||||
if is_last:
|
||||
# Full content for the last message
|
||||
content = {
|
||||
"response_en": q.question,
|
||||
"response_zh": ext.get("response_zh"),
|
||||
"prompt_en": ext.get("prompt_en"),
|
||||
@@ -764,6 +790,16 @@ class QaService:
|
||||
"alternative_responses": ext.get("alternative_responses"),
|
||||
"correction": ext.get("correction"),
|
||||
}
|
||||
else:
|
||||
# Simplified content for historical messages
|
||||
content = {
|
||||
"response_en": q.question,
|
||||
"response_zh": ext.get("response_zh"),
|
||||
}
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
})
|
||||
|
||||
# User Reply (Attempt)
|
||||
@@ -824,6 +860,37 @@ class QaService:
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
async def list_conversations_by_image(
|
||||
self,
|
||||
image_id: int,
|
||||
user_id: int,
|
||||
page: int = 1,
|
||||
page_size: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
async with async_db_session() as db:
|
||||
total, results = await qa_session_dao.get_list_by_image(
|
||||
db, image_id, user_id, page, page_size
|
||||
)
|
||||
|
||||
items = []
|
||||
for session, exercise in results:
|
||||
ext = exercise.ext or {}
|
||||
items.append({
|
||||
"session_id": str(session.id),
|
||||
"scene": ext.get("scene") or [],
|
||||
"event": ext.get("event") or [],
|
||||
"user_role": ext.get("user_role") or {},
|
||||
"style": ext.get("style") or {},
|
||||
"created_at": (session.created_time.strftime("%Y-%m-%d %H:%M:%S") if session.created_time else None),
|
||||
})
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"items": items,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
async def create_exercise_task(self, image_id: int, user_id: int, type: Optional[str] = "scene_basic") -> Dict[str, Any]:
|
||||
|
||||
is_conversation_init = type == 'init_conversion'
|
||||
|
||||
Reference in New Issue
Block a user