add recognize_audio

This commit is contained in:
Felix
2026-01-27 10:21:18 +08:00
parent 898a7e902b
commit 2a7e07c2dc
4 changed files with 216 additions and 2 deletions

View File

@@ -20,6 +20,8 @@ from backend.app.ai.schema.qa import (
ConversationReplyResponse,
ConversationLatestResponse,
ConversationListResponse,
ConversationRecognitionRequest,
ConversationRecognitionResponse,
)
from backend.common.response.response_schema import response_base, ResponseSchemaModel
from backend.common.security.jwt import DependsJwtAuth
@@ -74,6 +76,12 @@ async def reply_conversation(request: Request, session_id: int, obj: Conversatio
return response_base.success(data=ConversationReplyResponse(**res))
@router.post('/conversations/{session_id}/recognize_audio', summary='识别音频内容', dependencies=[DependsJwtAuth])
async def recognize_audio(request: Request, obj: ConversationRecognitionRequest) -> ResponseSchemaModel[ConversationRecognitionResponse]:
res = await qa_service.recognize_audio(file_id=int(obj.file_id), user_id=request.user.id, session_id=session_id)
return response_base.success(data=ConversationRecognitionResponse(**res))
@router.get('/conversations/{image_id}/list', summary='获取图片自由对话列表', dependencies=[DependsJwtAuth])
async def list_conversations(request: Request, image_id: int, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1, le=100)) -> ResponseSchemaModel[ConversationListResponse]:
res = await qa_service.list_conversations_by_image(

View File

@@ -180,8 +180,8 @@ class BilingualItem(SchemaBase):
class ConversationStartRequest(SchemaBase):
image_id: int
scene: List[BilingualItem]
event: List[BilingualItem]
scene: Optional[List[BilingualItem]] = None
event: Optional[List[BilingualItem]] = None
style: Optional[BilingualItem] = None
user_role: Optional[BilingualItem] = None
assistant_role: Optional[BilingualItem] = None
@@ -233,6 +233,14 @@ class ConversationReplyResponse(SchemaBase):
session_id: Optional[str] = None
class ConversationRecognitionRequest(SchemaBase):
file_id: str
class ConversationRecognitionResponse(SchemaBase):
text: str
class ConversationLatestResponse(SchemaBase):
session_id: str
messages: List[ConversationMessageSchema]

View File

@@ -891,6 +891,87 @@ class QaService:
"page_size": page_size,
}
async def recognize_audio(self, file_id: int, user_id: int, session_id: int) -> Dict[str, Any]:
"""识别音频内容并关联到会话"""
from backend.middleware.qwen import Qwen
# 1. 验证会话
exercise_id = None
async with async_db_session.begin() as db:
session = await qa_session_dao.get(db, session_id)
if not session:
raise errors.NotFoundError(msg="Session not found")
if session.starter_user_id != user_id:
raise errors.ForbiddenError(msg="Forbidden")
exercise_id = session.exercise_id
# 2. 获取文件信息
file_obj = await file_service.get_file(file_id)
if not file_obj:
raise errors.NotFoundError(msg="文件不存在")
# 获取文件访问路径或URL
audio_url = ""
temp_file_path = None
try:
if file_obj.storage_type == "cos":
# 获取预签名下载URL
audio_url = await file_service.get_presigned_download_url(file_id, user_id)
else:
# 数据库存储,写入临时文件
content, _, _ = await file_service.download_file(file_id)
import tempfile
import os
# 创建临时文件注意需要保留后缀以便Qwen识别格式
ext = "mp3"
if file_obj.content_type:
ext = file_service._mime_to_ext(file_obj.content_type, file_obj.file_name)
tf = tempfile.NamedTemporaryFile(delete=False, suffix=f".{ext}")
tf.write(content)
tf.close()
temp_file_path = tf.name
audio_url = temp_file_path
# 调用Qwen ASR
res = await Qwen.recognize_speech(audio_url, user_id=user_id)
if not res.get("success"):
raise errors.ServerError(msg=res.get("error") or "ASR failed")
text = res.get("text", "")
# 3. 更新会话最新消息
async with async_db_session.begin() as db:
# Find latest question (LLM response)
last_question = await qa_question_dao.get_latest_by_exercise_id(db, exercise_id)
if last_question:
ext = dict(last_question.ext or {})
audio_recognitions = ext.get('audio_recognitions')
if not isinstance(audio_recognitions, list):
audio_recognitions = []
audio_recognitions.append({
'audio_id': str(file_id),
'text': text
})
ext['audio_recognitions'] = audio_recognitions
last_question.ext = ext
await db.flush()
return {"text": text}
finally:
# 清理临时文件
if temp_file_path:
import os
if os.path.exists(temp_file_path):
try:
os.remove(temp_file_path)
except Exception:
pass
async def create_exercise_task(self, image_id: int, user_id: int, type: Optional[str] = "scene_basic") -> Dict[str, Any]:
is_conversation_init = type == 'init_conversion'

View File

@@ -316,6 +316,123 @@ class Qwen:
input=[{'image': image_data}]
)
@staticmethod
async def recognize_speech(file_path: str, user_id: int = 0) -> Dict[str, Any]:
"""
调用通义千问API识别语音内容
:param file_path: 音频文件路径 (本地路径)
:param user_id: 用户ID
:return: 识别结果
"""
api_key = _get_primary_qwen_api_key()
model_name = "qwen3-asr-flash"
start_time = time.time()
start_at = datetime.now()
error_message = ""
status_code = 500
response_data: Dict[str, Any] = {}
try:
# 使用 file:// 协议传递本地文件路径
# 注意DashScope SDK 会处理 file:// 路径
audio_url = file_path if file_path.startswith("http") else f"file://{file_path}"
messages = [
{"role": "system", "content": [{"text": ""}]},
{"role": "user", "content": [{"audio": audio_url}]}
]
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
Qwen._executor,
lambda: dashscope.MultiModalConversation.call(
api_key=api_key,
model=model_name,
messages=messages,
result_format="message",
asr_options={
"language": "en",
"enable_itn": False
}
)
)
status_code = getattr(response, 'status_code', getattr(response, 'code', 500))
response_data = {
"output": getattr(response, 'output', None),
"usage": getattr(response, 'usage', {}),
"code": getattr(response, 'code', None),
"message": getattr(response, 'message', None)
}
duration = time.time() - start_time
audit_log = CreateAuditLogParam(
api_type="asr",
model_name=model_name,
response_data=response_data,
request_data={"audio": audio_url},
token_usage=response_data.get("usage", {}),
duration=duration,
status_code=status_code,
error_message=None,
called_at=start_at,
image_id=0,
user_id=user_id,
cost=0,
api_version=settings.FASTAPI_API_V1_PATH,
dict_level=None,
)
Qwen._audit_log("asr", audit_log)
if status_code == 200:
content = ""
try:
content = response.output.choices[0].message.content[0]["text"]
except Exception:
# 尝试其他可能的响应结构
try:
content = response.output.choices[0].message.content
except Exception:
pass
return {
"success": True,
"text": content,
"token_usage": response_data.get("usage", {})
}
else:
error_message = response_data.get("message") or "API error"
logger.error(f"ASR API error: {status_code} - {error_message}")
return {
"success": False,
"error": error_message,
"status_code": status_code
}
except Exception as e:
error_message = str(e)
logger.exception(f"ASR API exception: {error_message}")
return {
"success": False,
"error": error_message
}
finally:
if error_message:
Qwen._log_audit(
api_type="asr",
dict_level=None,
model_name=model_name,
request_data={"audio": file_path},
response_data={"error": error_message},
duration=time.time() - start_time,
status_code=status_code,
error_message=error_message,
image_id=0,
user_id=user_id,
called_at=start_at
)
@staticmethod
async def _call_api(api_type: str, image_id: int, user_id: int, input: Union[Dict, List], dict_level: str | None = None) -> \
Dict[str, Any]: