134 lines
6.4 KiB
Python
134 lines
6.4 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
from fastapi import APIRouter, Request, Query
|
|
from backend.app.ai.schema.qa import (
|
|
CreateQaExerciseRequest,
|
|
CreateQaExerciseTaskResponse,
|
|
QaExerciseSchema,
|
|
QaExerciseWithQuestionsSchema,
|
|
QaQuestionSchema,
|
|
QaSessionSchema,
|
|
CreateAttemptRequest,
|
|
TaskStatusResponse,
|
|
QuestionLatestResultResponse,
|
|
ImageConversationInitRequest,
|
|
ImageConversationInitResponse,
|
|
ConversationStartRequest,
|
|
ConversationStartResponse,
|
|
ConversationSessionSchema,
|
|
ConversationReplyRequest,
|
|
ConversationReplyResponse,
|
|
ConversationLatestResponse,
|
|
)
|
|
from backend.common.response.response_schema import response_base, ResponseSchemaModel
|
|
from backend.common.security.jwt import DependsJwtAuth
|
|
from backend.app.ai.service.qa_service import qa_service
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post('/exercises/tasks', summary='创建练习任务', dependencies=[DependsJwtAuth])
|
|
async def create_exercise_task(request: Request, obj: CreateQaExerciseRequest) -> ResponseSchemaModel[CreateQaExerciseTaskResponse]:
|
|
res = await qa_service.create_exercise_task(image_id=obj.image_id, user_id=request.user.id, type=obj.type)
|
|
return response_base.success(data=CreateQaExerciseTaskResponse(**res))
|
|
|
|
|
|
@router.post('/conversations/setting', summary='获取图片自由对话配置', dependencies=[DependsJwtAuth])
|
|
async def get_conversation_setting(request: Request, obj: ImageConversationInitRequest) -> ResponseSchemaModel[ImageConversationInitResponse | None]:
|
|
res = await qa_service.get_conversation_setting(image_id=obj.image_id, user_id=request.user.id)
|
|
if not res:
|
|
return response_base.success(data=None)
|
|
data = ImageConversationInitResponse(
|
|
image_id=res["image_id"],
|
|
setting=res["setting"],
|
|
latest_session=res.get("latest_session"),
|
|
)
|
|
return response_base.success(data=data)
|
|
|
|
|
|
@router.post('/conversations/start', summary='启动图片自由对话', dependencies=[DependsJwtAuth])
|
|
async def start_conversation(request: Request, obj: ConversationStartRequest) -> ResponseSchemaModel[ConversationStartResponse]:
|
|
res = await qa_service.start_conversation(
|
|
image_id=obj.image_id,
|
|
user_id=request.user.id,
|
|
scene=obj.scene,
|
|
event=obj.event,
|
|
style=obj.style,
|
|
user_role=obj.user_role,
|
|
assistant_role=obj.assistant_role,
|
|
level=obj.level,
|
|
info=obj.info,
|
|
)
|
|
data = ConversationStartResponse(**res)
|
|
return response_base.success(data=data)
|
|
|
|
|
|
@router.post('/conversations/{session_id}/reply', summary='回复图片自由对话', dependencies=[DependsJwtAuth])
|
|
async def reply_conversation(request: Request, session_id: int, obj: ConversationReplyRequest) -> ResponseSchemaModel[ConversationReplyResponse]:
|
|
res = await qa_service.reply_conversation(
|
|
session_id=session_id,
|
|
user_id=request.user.id,
|
|
input_text=obj.content,
|
|
)
|
|
return response_base.success(data=ConversationReplyResponse(**res))
|
|
|
|
|
|
@router.get('/conversations/{session_id}/latest', summary='获取图片自由对话最新消息', dependencies=[DependsJwtAuth])
|
|
async def get_conversation_latest(request: Request, session_id: int) -> ResponseSchemaModel[ConversationLatestResponse]:
|
|
res = await qa_service.get_latest_messages(session_id=session_id, user_id=request.user.id)
|
|
return response_base.success(data=ConversationLatestResponse(**res))
|
|
|
|
|
|
@router.get('/conversations/{session_id}', summary='获取图片自由对话会话信息', dependencies=[DependsJwtAuth])
|
|
async def get_conversation_session(request: Request, session_id: int) -> ResponseSchemaModel[ConversationSessionSchema]:
|
|
res = await qa_service.get_conversation_session(session_id=session_id, user_id=request.user.id)
|
|
data = ConversationSessionSchema(**res)
|
|
return response_base.success(data=data)
|
|
|
|
|
|
@router.get('/exercises/tasks/{task_id}/status', summary='查询练习任务状态', dependencies=[DependsJwtAuth])
|
|
async def get_exercise_task_status(task_id: int) -> ResponseSchemaModel[TaskStatusResponse]:
|
|
res = await qa_service.get_task_status(task_id)
|
|
return response_base.success(data=TaskStatusResponse(**res))
|
|
|
|
|
|
@router.get('/{image_id}/exercises', summary='根据图片获取练习', dependencies=[DependsJwtAuth])
|
|
async def list_exercises(request: Request, image_id: int, type: str = Query(None)) -> ResponseSchemaModel[QaExerciseWithQuestionsSchema | None]:
|
|
item = await qa_service.list_exercises_by_image(image_id, user_id=request.user.id, type=type)
|
|
data = None if not item else QaExerciseWithQuestionsSchema(**item)
|
|
return response_base.success(data=data)
|
|
|
|
|
|
@router.post('/questions/{question_id}/attempts', summary='提交题目练习', dependencies=[DependsJwtAuth])
|
|
async def submit_attempt(request: Request, question_id: int, obj: CreateAttemptRequest) -> ResponseSchemaModel[QuestionLatestResultResponse]:
|
|
res = await qa_service.submit_attempt(
|
|
question_id=question_id,
|
|
exercise_id=obj.exercise_id,
|
|
user_id=request.user.id,
|
|
mode=obj.mode,
|
|
selected_options=obj.selected_options,
|
|
input_text=obj.input_text,
|
|
cloze_options=obj.cloze_options,
|
|
session_id=obj.session_id,
|
|
is_trial=obj.is_trial,
|
|
)
|
|
return response_base.success(data=QuestionLatestResultResponse(**res))
|
|
|
|
|
|
@router.get('/question-tasks/{task_id}/status', summary='获取题目任务状态', dependencies=[DependsJwtAuth])
|
|
async def get_question_task_status(task_id: int) -> ResponseSchemaModel[TaskStatusResponse]:
|
|
res = await qa_service.get_attempt_task_status(task_id)
|
|
return response_base.success(data=TaskStatusResponse(**res))
|
|
|
|
|
|
@router.get('/questions/{question_id}/result', summary='获取题目最新结果', dependencies=[DependsJwtAuth])
|
|
async def get_question_latest_result(request: Request, question_id: int) -> ResponseSchemaModel[QuestionLatestResultResponse]:
|
|
res = await qa_service.get_question_evaluation(question_id, user_id=request.user.id)
|
|
return response_base.success(data=QuestionLatestResultResponse(**res))
|
|
|
|
@router.get('/questions/{question_id}/audio', summary='获取题目标准音频', dependencies=[DependsJwtAuth])
|
|
async def get_question_audio(request: Request, question_id: int) -> ResponseSchemaModel[dict]:
|
|
from backend.app.ai.service.recording_service import RecordingService
|
|
file_id = await RecordingService.get_question_audio_file_id(question_id=question_id, user_id=request.user.id)
|
|
return response_base.success(data={'file_id': str(file_id)})
|