add recognize_audio
This commit is contained in:
@@ -20,6 +20,8 @@ from backend.app.ai.schema.qa import (
|
||||
ConversationReplyResponse,
|
||||
ConversationLatestResponse,
|
||||
ConversationListResponse,
|
||||
ConversationRecognitionRequest,
|
||||
ConversationRecognitionResponse,
|
||||
)
|
||||
from backend.common.response.response_schema import response_base, ResponseSchemaModel
|
||||
from backend.common.security.jwt import DependsJwtAuth
|
||||
@@ -74,6 +76,12 @@ async def reply_conversation(request: Request, session_id: int, obj: Conversatio
|
||||
return response_base.success(data=ConversationReplyResponse(**res))
|
||||
|
||||
|
||||
@router.post('/conversations/{session_id}/recognize_audio', summary='识别音频内容', dependencies=[DependsJwtAuth])
|
||||
async def recognize_audio(request: Request, obj: ConversationRecognitionRequest) -> ResponseSchemaModel[ConversationRecognitionResponse]:
|
||||
res = await qa_service.recognize_audio(file_id=int(obj.file_id), user_id=request.user.id, session_id=session_id)
|
||||
return response_base.success(data=ConversationRecognitionResponse(**res))
|
||||
|
||||
|
||||
@router.get('/conversations/{image_id}/list', summary='获取图片自由对话列表', dependencies=[DependsJwtAuth])
|
||||
async def list_conversations(request: Request, image_id: int, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1, le=100)) -> ResponseSchemaModel[ConversationListResponse]:
|
||||
res = await qa_service.list_conversations_by_image(
|
||||
|
||||
@@ -180,8 +180,8 @@ class BilingualItem(SchemaBase):
|
||||
|
||||
class ConversationStartRequest(SchemaBase):
|
||||
image_id: int
|
||||
scene: List[BilingualItem]
|
||||
event: List[BilingualItem]
|
||||
scene: Optional[List[BilingualItem]] = None
|
||||
event: Optional[List[BilingualItem]] = None
|
||||
style: Optional[BilingualItem] = None
|
||||
user_role: Optional[BilingualItem] = None
|
||||
assistant_role: Optional[BilingualItem] = None
|
||||
@@ -233,6 +233,14 @@ class ConversationReplyResponse(SchemaBase):
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
class ConversationRecognitionRequest(SchemaBase):
|
||||
file_id: str
|
||||
|
||||
|
||||
class ConversationRecognitionResponse(SchemaBase):
|
||||
text: str
|
||||
|
||||
|
||||
class ConversationLatestResponse(SchemaBase):
|
||||
session_id: str
|
||||
messages: List[ConversationMessageSchema]
|
||||
|
||||
@@ -891,6 +891,87 @@ class QaService:
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
async def recognize_audio(self, file_id: int, user_id: int, session_id: int) -> Dict[str, Any]:
|
||||
"""识别音频内容并关联到会话"""
|
||||
from backend.middleware.qwen import Qwen
|
||||
|
||||
# 1. 验证会话
|
||||
exercise_id = None
|
||||
async with async_db_session.begin() as db:
|
||||
session = await qa_session_dao.get(db, session_id)
|
||||
if not session:
|
||||
raise errors.NotFoundError(msg="Session not found")
|
||||
if session.starter_user_id != user_id:
|
||||
raise errors.ForbiddenError(msg="Forbidden")
|
||||
exercise_id = session.exercise_id
|
||||
|
||||
# 2. 获取文件信息
|
||||
file_obj = await file_service.get_file(file_id)
|
||||
if not file_obj:
|
||||
raise errors.NotFoundError(msg="文件不存在")
|
||||
|
||||
# 获取文件访问路径或URL
|
||||
audio_url = ""
|
||||
temp_file_path = None
|
||||
|
||||
try:
|
||||
if file_obj.storage_type == "cos":
|
||||
# 获取预签名下载URL
|
||||
audio_url = await file_service.get_presigned_download_url(file_id, user_id)
|
||||
else:
|
||||
# 数据库存储,写入临时文件
|
||||
content, _, _ = await file_service.download_file(file_id)
|
||||
import tempfile
|
||||
import os
|
||||
# 创建临时文件,注意需要保留后缀以便Qwen识别格式
|
||||
ext = "mp3"
|
||||
if file_obj.content_type:
|
||||
ext = file_service._mime_to_ext(file_obj.content_type, file_obj.file_name)
|
||||
|
||||
tf = tempfile.NamedTemporaryFile(delete=False, suffix=f".{ext}")
|
||||
tf.write(content)
|
||||
tf.close()
|
||||
temp_file_path = tf.name
|
||||
audio_url = temp_file_path
|
||||
|
||||
# 调用Qwen ASR
|
||||
res = await Qwen.recognize_speech(audio_url, user_id=user_id)
|
||||
|
||||
if not res.get("success"):
|
||||
raise errors.ServerError(msg=res.get("error") or "ASR failed")
|
||||
|
||||
text = res.get("text", "")
|
||||
|
||||
# 3. 更新会话最新消息
|
||||
async with async_db_session.begin() as db:
|
||||
# Find latest question (LLM response)
|
||||
last_question = await qa_question_dao.get_latest_by_exercise_id(db, exercise_id)
|
||||
if last_question:
|
||||
ext = dict(last_question.ext or {})
|
||||
audio_recognitions = ext.get('audio_recognitions')
|
||||
if not isinstance(audio_recognitions, list):
|
||||
audio_recognitions = []
|
||||
|
||||
audio_recognitions.append({
|
||||
'audio_id': str(file_id),
|
||||
'text': text
|
||||
})
|
||||
ext['audio_recognitions'] = audio_recognitions
|
||||
last_question.ext = ext
|
||||
await db.flush()
|
||||
|
||||
return {"text": text}
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if temp_file_path:
|
||||
import os
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def create_exercise_task(self, image_id: int, user_id: int, type: Optional[str] = "scene_basic") -> Dict[str, Any]:
|
||||
|
||||
is_conversation_init = type == 'init_conversion'
|
||||
|
||||
@@ -316,6 +316,123 @@ class Qwen:
|
||||
input=[{'image': image_data}]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def recognize_speech(file_path: str, user_id: int = 0) -> Dict[str, Any]:
|
||||
"""
|
||||
调用通义千问API识别语音内容
|
||||
:param file_path: 音频文件路径 (本地路径)
|
||||
:param user_id: 用户ID
|
||||
:return: 识别结果
|
||||
"""
|
||||
api_key = _get_primary_qwen_api_key()
|
||||
model_name = "qwen3-asr-flash"
|
||||
start_time = time.time()
|
||||
start_at = datetime.now()
|
||||
error_message = ""
|
||||
status_code = 500
|
||||
response_data: Dict[str, Any] = {}
|
||||
|
||||
try:
|
||||
# 使用 file:// 协议传递本地文件路径
|
||||
# 注意:DashScope SDK 会处理 file:// 路径
|
||||
audio_url = file_path if file_path.startswith("http") else f"file://{file_path}"
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": [{"text": ""}]},
|
||||
{"role": "user", "content": [{"audio": audio_url}]}
|
||||
]
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(
|
||||
Qwen._executor,
|
||||
lambda: dashscope.MultiModalConversation.call(
|
||||
api_key=api_key,
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
result_format="message",
|
||||
asr_options={
|
||||
"language": "en",
|
||||
"enable_itn": False
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
status_code = getattr(response, 'status_code', getattr(response, 'code', 500))
|
||||
|
||||
response_data = {
|
||||
"output": getattr(response, 'output', None),
|
||||
"usage": getattr(response, 'usage', {}),
|
||||
"code": getattr(response, 'code', None),
|
||||
"message": getattr(response, 'message', None)
|
||||
}
|
||||
duration = time.time() - start_time
|
||||
|
||||
audit_log = CreateAuditLogParam(
|
||||
api_type="asr",
|
||||
model_name=model_name,
|
||||
response_data=response_data,
|
||||
request_data={"audio": audio_url},
|
||||
token_usage=response_data.get("usage", {}),
|
||||
duration=duration,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
called_at=start_at,
|
||||
image_id=0,
|
||||
user_id=user_id,
|
||||
cost=0,
|
||||
api_version=settings.FASTAPI_API_V1_PATH,
|
||||
dict_level=None,
|
||||
)
|
||||
Qwen._audit_log("asr", audit_log)
|
||||
|
||||
if status_code == 200:
|
||||
content = ""
|
||||
try:
|
||||
content = response.output.choices[0].message.content[0]["text"]
|
||||
except Exception:
|
||||
# 尝试其他可能的响应结构
|
||||
try:
|
||||
content = response.output.choices[0].message.content
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"text": content,
|
||||
"token_usage": response_data.get("usage", {})
|
||||
}
|
||||
else:
|
||||
error_message = response_data.get("message") or "API error"
|
||||
logger.error(f"ASR API error: {status_code} - {error_message}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_message,
|
||||
"status_code": status_code
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.exception(f"ASR API exception: {error_message}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_message
|
||||
}
|
||||
finally:
|
||||
if error_message:
|
||||
Qwen._log_audit(
|
||||
api_type="asr",
|
||||
dict_level=None,
|
||||
model_name=model_name,
|
||||
request_data={"audio": file_path},
|
||||
response_data={"error": error_message},
|
||||
duration=time.time() - start_time,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
image_id=0,
|
||||
user_id=user_id,
|
||||
called_at=start_at
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _call_api(api_type: str, image_id: int, user_id: int, input: Union[Dict, List], dict_level: str | None = None) -> \
|
||||
Dict[str, Any]:
|
||||
|
||||
Reference in New Issue
Block a user