add qwen tts

This commit is contained in:
Felix
2025-12-23 17:33:36 +08:00
parent 711fecbff4
commit 5e2de52da9
6 changed files with 194 additions and 30 deletions

View File

@@ -17,10 +17,12 @@ async def get_user_points_info(
"""
根据用户ID获取对应的积分和过期时间
"""
points_info = await points_service.get_user_points(request.user.id)
if points_info:
balance_info = PointsBalanceInfo(
balance=points_info.balance,
)
return response_base.success(data=balance_info)
return response_base.success(data=None)
details = await points_service.get_user_account_details(request.user.id)
balance_info = PointsBalanceInfo(
balance=int(details.get("balance") or 0),
available_balance=int(details.get("available_balance") or 0),
frozen_balance=int(details.get("frozen_balance") or 0),
total_purchased=int(details.get("total_purchased") or 0),
total_refunded=int(details.get("total_refunded") or 0),
)
return response_base.success(data=balance_info)

View File

@@ -93,11 +93,11 @@ async def get_current_user(request: Request, response: Response) -> ResponseSche
)
settings_obj = WxSettings(dict_level=dict_level)
pts = await points_service.get_user_points(user.id)
points_brief = PointsBrief(
balance=pts.balance if pts else 0,
expired_time=pts.expired_time.isoformat() if pts and getattr(pts, "expired_time", None) else None,
)
pts = await points_service.get_user_points(user.id)
points_brief = PointsBrief(
balance=(max(0, (pts.balance or 0) - (pts.frozen_balance or 0)) if pts else 0),
expired_time=pts.expired_time.isoformat() if pts and getattr(pts, "expired_time", None) else None,
)
data = GetWxLoginToken(
access_token=access.access_token,

View File

@@ -18,6 +18,10 @@ class PointsSchema(BaseModel):
class PointsBalanceInfo(BaseModel):
"""积分余额和过期时间信息"""
balance: int = Field(default=0, description="当前积分余额")
available_balance: int = Field(default=0, description="当前可用余额(余额-冻结)")
frozen_balance: int = Field(default=0, description="当前冻结积分")
total_purchased: int = Field(default=0, description="累计获得积分")
total_refunded: int = Field(default=0, description="累计退款积分")
class PointsLogSchema(BaseModel):
@@ -46,4 +50,4 @@ class DeductPointsRequest(BaseModel):
user_id: int = Field(description="用户ID")
amount: int = Field(gt=0, description="扣减的积分数量")
related_id: Optional[int] = Field(default=None, description="关联ID")
details: Optional[dict] = Field(default=None, description="附加信息")
details: Optional[dict] = Field(default=None, description="附加信息")

View File

@@ -116,15 +116,15 @@ class WxAuthService:
province=(user.profile or {}).get("province") if isinstance(user.profile, dict) else None,
city=(user.profile or {}).get("city") if isinstance(user.profile, dict) else None,
language=(user.profile or {}).get("language") if isinstance(user.profile, dict) else None,
),
settings=WxSettings(dict_level=dict_level),
points=PointsBrief(
balance=(await points_service.get_user_balance(user.id)),
expired_time=None,
),
server_time=int(timezone.to_utc(timezone.now()).timestamp()),
)
return data
),
settings=WxSettings(dict_level=dict_level),
points=PointsBrief(
balance=(await points_service.get_user_account_details(user.id)).get("available_balance", 0),
expired_time=None,
),
server_time=int(timezone.to_utc(timezone.now()).timestamp()),
)
return data
@staticmethod
async def update_user_settings(

View File

@@ -19,6 +19,8 @@ from backend.database.db import async_db_session
from backend.middleware.tencent_cloud import TencentCloud
from backend.common.const import SPEECH_ASSESSMENT_COST, POINTS_ACTION_SPEECH_ASSESSMENT
from backend.app.admin.service.points_service import points_service
from backend.core.conf import settings
from backend.middleware.qwen import Qwen
# Import the recording_dao for accessing recording CRUD methods
from backend.app.ai.crud.recording_crud import recording_dao
@@ -68,14 +70,22 @@ class RecordingService:
if not image_text:
return None
try:
from backend.middleware.tencent_cloud import TencentCloud
tts = TencentCloud()
await tts.text_to_speak(
content=image_text.content,
image_text_id=text_id,
image_id=image_text.image_id,
user_id=user_id
)
model_type = (getattr(settings, "LLM_MODEL_TYPE", "") or "").lower()
if model_type == "qwen":
await Qwen.text_to_speak(
content=image_text.content,
image_text_id=text_id,
image_id=image_text.image_id,
user_id=user_id
)
else:
tts = TencentCloud()
await tts.text_to_speak(
content=image_text.content,
image_text_id=text_id,
image_id=image_text.image_id,
user_id=user_id
)
except Exception as e:
logger.error(f"On-demand TTS generation failed for text_id={text_id}: {e}")
return None

View File

@@ -7,6 +7,7 @@ import time
import json
import dashscope
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
import httpx
from backend.app.admin.model.audit_log import AuditLog
from backend.app.admin.schema.audit_log import CreateAuditLogParam
@@ -30,6 +31,153 @@ class Qwen:
RECOGNITION_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
EMBEDDING_URL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/multimodal-embedding"
@staticmethod
async def text_to_speak(content: str, image_text_id: int | None = None, image_id: int | None = None, user_id: int | None = None) -> Dict[str, Any]:
api_key = settings.QWEN_API_KEY
model_name = "qwen3-tts-flash"
voice = "Jennifer"
language_type = "English"
stream = False
start_time = time.time()
start_at = datetime.now()
error_message = ""
status_code = 500
response_data: Dict[str, Any] = {}
try:
dashscope.base_http_api_url = "https://dashscope.aliyuncs.com/api/v1"
loop = asyncio.get_event_loop()
resp = await loop.run_in_executor(
Qwen._executor,
lambda: dashscope.audio.qwen_tts.SpeechSynthesizer.call(
model=model_name,
api_key=api_key,
text=content,
voice=voice,
language_type=language_type,
stream=stream,
)
)
status_code = getattr(resp, "status_code", getattr(resp, "code", 500))
if hasattr(resp, "output"):
response_data = {
"output": getattr(resp, "output", {}),
"usage": getattr(resp, "usage", {}),
"code": getattr(resp, "code", None),
"message": getattr(resp, "message", None),
}
else:
response_data = {}
if hasattr(resp, "__dict__"):
response_data = resp.__dict__
else:
response_data = {"output": {}, "message": str(resp)}
duration = time.time() - start_time
if status_code == 200:
audio_data: bytes | None = None
audio_url: str | None = None
# 优先从输出结构的 audio.url 下载音频,其次尝试 audio.data 的 base64
out = response_data.get("output") or {}
audio_obj = out.get("audio") or {}
if isinstance(audio_obj, dict):
audio_url = audio_obj.get("url")
audio_b64 = audio_obj.get("data")
else:
audio_url = None
audio_b64 = None
# 下载远程音频
if audio_url:
try:
async with httpx.AsyncClient(timeout=60.0) as client:
r = await client.get(audio_url)
if r.status_code == 200:
audio_data = r.content
else:
logger.warning(f"Fetch audio failed: {r.status_code}")
except Exception as _e:
logger.warning(f"Fetch audio exception: {str(_e)}")
# 若无URL或下载失败尝试解析内嵌的base64
if not audio_data and audio_b64:
try:
from base64 import b64decode
audio_data = b64decode(audio_b64)
except Exception:
audio_data = None
recording_id: int | None = None
if audio_data:
try:
from fastapi import UploadFile
import io
from backend.app.admin.service.file_service import file_service
from backend.app.ai.service.recording_service import recording_service
# 根据URL后缀判断文件类型默认mp3
ext = "mp3"
ct = "audio/mp3"
try:
if audio_url and audio_url.lower().endswith(".wav"):
ext = "wav"
ct = "audio/wav"
except Exception:
pass
file_name = f"{image_text_id}_std.{ext}"
upload_file = UploadFile(filename=file_name, file=io.BytesIO(audio_data), headers={}, size=len(audio_data))
file_response = await file_service.upload_file_with_content_type(file=upload_file, content_type=ct, metadata={"is_standard_audio": True})
recording_id = await recording_service.create_recording_record_with_details(
file_id=int(file_response.id),
ref_text=content,
image_id=image_id,
image_text_id=image_text_id,
eval_mode=1,
user_id=user_id,
details=response_data,
is_standard=True,
)
except Exception:
recording_id = None
Qwen._log_audit(
api_type="tts",
model_name=model_name,
request_data={"content": content, "image_text_id": image_text_id},
response_data=response_data,
duration=duration,
status_code=status_code,
error_message=None,
image_id=image_id or 0,
user_id=user_id or 0,
called_at=start_at,
)
return {"success": True, "output": response_data.get("output", {}), "recording_id": recording_id, "audio_url": audio_url}
else:
error_message = f"TTS error: {response_data.get('message') or 'Unknown error'}"
Qwen._log_audit(
api_type="tts",
model_name=model_name,
request_data={"content": content, "image_text_id": image_text_id},
response_data={"error": error_message},
duration=time.time() - start_time,
status_code=status_code,
error_message=error_message,
image_id=image_id or 0,
user_id=user_id or 0,
called_at=start_at,
)
return {"success": False, "error": error_message, "status_code": status_code}
except Exception as e:
error_message = str(e)
logger.error(f"TTS error: {error_message}")
Qwen._log_audit(
api_type="tts",
model_name=model_name,
request_data={"content": content, "image_text_id": image_text_id},
response_data={"error": error_message},
duration=time.time() - start_time,
status_code=status_code,
error_message=error_message,
image_id=image_id or 0,
user_id=user_id or 0,
called_at=start_at,
)
return {"success": False, "error": error_message}
@staticmethod
async def chat(messages: List[Dict[str, str]], image_id: int = 0, user_id: int = 0) -> Dict[str, Any]:
api_key = settings.QWEN_API_KEY