add qwen tts
This commit is contained in:
@@ -17,10 +17,12 @@ async def get_user_points_info(
|
||||
"""
|
||||
根据用户ID获取对应的积分和过期时间
|
||||
"""
|
||||
points_info = await points_service.get_user_points(request.user.id)
|
||||
if points_info:
|
||||
balance_info = PointsBalanceInfo(
|
||||
balance=points_info.balance,
|
||||
)
|
||||
return response_base.success(data=balance_info)
|
||||
return response_base.success(data=None)
|
||||
details = await points_service.get_user_account_details(request.user.id)
|
||||
balance_info = PointsBalanceInfo(
|
||||
balance=int(details.get("balance") or 0),
|
||||
available_balance=int(details.get("available_balance") or 0),
|
||||
frozen_balance=int(details.get("frozen_balance") or 0),
|
||||
total_purchased=int(details.get("total_purchased") or 0),
|
||||
total_refunded=int(details.get("total_refunded") or 0),
|
||||
)
|
||||
return response_base.success(data=balance_info)
|
||||
|
||||
@@ -93,11 +93,11 @@ async def get_current_user(request: Request, response: Response) -> ResponseSche
|
||||
)
|
||||
settings_obj = WxSettings(dict_level=dict_level)
|
||||
|
||||
pts = await points_service.get_user_points(user.id)
|
||||
points_brief = PointsBrief(
|
||||
balance=pts.balance if pts else 0,
|
||||
expired_time=pts.expired_time.isoformat() if pts and getattr(pts, "expired_time", None) else None,
|
||||
)
|
||||
pts = await points_service.get_user_points(user.id)
|
||||
points_brief = PointsBrief(
|
||||
balance=(max(0, (pts.balance or 0) - (pts.frozen_balance or 0)) if pts else 0),
|
||||
expired_time=pts.expired_time.isoformat() if pts and getattr(pts, "expired_time", None) else None,
|
||||
)
|
||||
|
||||
data = GetWxLoginToken(
|
||||
access_token=access.access_token,
|
||||
|
||||
@@ -18,6 +18,10 @@ class PointsSchema(BaseModel):
|
||||
class PointsBalanceInfo(BaseModel):
|
||||
"""积分余额和过期时间信息"""
|
||||
balance: int = Field(default=0, description="当前积分余额")
|
||||
available_balance: int = Field(default=0, description="当前可用余额(余额-冻结)")
|
||||
frozen_balance: int = Field(default=0, description="当前冻结积分")
|
||||
total_purchased: int = Field(default=0, description="累计获得积分")
|
||||
total_refunded: int = Field(default=0, description="累计退款积分")
|
||||
|
||||
|
||||
class PointsLogSchema(BaseModel):
|
||||
@@ -46,4 +50,4 @@ class DeductPointsRequest(BaseModel):
|
||||
user_id: int = Field(description="用户ID")
|
||||
amount: int = Field(gt=0, description="扣减的积分数量")
|
||||
related_id: Optional[int] = Field(default=None, description="关联ID")
|
||||
details: Optional[dict] = Field(default=None, description="附加信息")
|
||||
details: Optional[dict] = Field(default=None, description="附加信息")
|
||||
|
||||
@@ -116,15 +116,15 @@ class WxAuthService:
|
||||
province=(user.profile or {}).get("province") if isinstance(user.profile, dict) else None,
|
||||
city=(user.profile or {}).get("city") if isinstance(user.profile, dict) else None,
|
||||
language=(user.profile or {}).get("language") if isinstance(user.profile, dict) else None,
|
||||
),
|
||||
settings=WxSettings(dict_level=dict_level),
|
||||
points=PointsBrief(
|
||||
balance=(await points_service.get_user_balance(user.id)),
|
||||
expired_time=None,
|
||||
),
|
||||
server_time=int(timezone.to_utc(timezone.now()).timestamp()),
|
||||
)
|
||||
return data
|
||||
),
|
||||
settings=WxSettings(dict_level=dict_level),
|
||||
points=PointsBrief(
|
||||
balance=(await points_service.get_user_account_details(user.id)).get("available_balance", 0),
|
||||
expired_time=None,
|
||||
),
|
||||
server_time=int(timezone.to_utc(timezone.now()).timestamp()),
|
||||
)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
async def update_user_settings(
|
||||
|
||||
@@ -19,6 +19,8 @@ from backend.database.db import async_db_session
|
||||
from backend.middleware.tencent_cloud import TencentCloud
|
||||
from backend.common.const import SPEECH_ASSESSMENT_COST, POINTS_ACTION_SPEECH_ASSESSMENT
|
||||
from backend.app.admin.service.points_service import points_service
|
||||
from backend.core.conf import settings
|
||||
from backend.middleware.qwen import Qwen
|
||||
|
||||
# Import the recording_dao for accessing recording CRUD methods
|
||||
from backend.app.ai.crud.recording_crud import recording_dao
|
||||
@@ -68,14 +70,22 @@ class RecordingService:
|
||||
if not image_text:
|
||||
return None
|
||||
try:
|
||||
from backend.middleware.tencent_cloud import TencentCloud
|
||||
tts = TencentCloud()
|
||||
await tts.text_to_speak(
|
||||
content=image_text.content,
|
||||
image_text_id=text_id,
|
||||
image_id=image_text.image_id,
|
||||
user_id=user_id
|
||||
)
|
||||
model_type = (getattr(settings, "LLM_MODEL_TYPE", "") or "").lower()
|
||||
if model_type == "qwen":
|
||||
await Qwen.text_to_speak(
|
||||
content=image_text.content,
|
||||
image_text_id=text_id,
|
||||
image_id=image_text.image_id,
|
||||
user_id=user_id
|
||||
)
|
||||
else:
|
||||
tts = TencentCloud()
|
||||
await tts.text_to_speak(
|
||||
content=image_text.content,
|
||||
image_text_id=text_id,
|
||||
image_id=image_text.image_id,
|
||||
user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"On-demand TTS generation failed for text_id={text_id}: {e}")
|
||||
return None
|
||||
|
||||
@@ -7,6 +7,7 @@ import time
|
||||
import json
|
||||
import dashscope
|
||||
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
|
||||
import httpx
|
||||
|
||||
from backend.app.admin.model.audit_log import AuditLog
|
||||
from backend.app.admin.schema.audit_log import CreateAuditLogParam
|
||||
@@ -30,6 +31,153 @@ class Qwen:
|
||||
RECOGNITION_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
|
||||
EMBEDDING_URL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/multimodal-embedding"
|
||||
|
||||
@staticmethod
|
||||
async def text_to_speak(content: str, image_text_id: int | None = None, image_id: int | None = None, user_id: int | None = None) -> Dict[str, Any]:
|
||||
api_key = settings.QWEN_API_KEY
|
||||
model_name = "qwen3-tts-flash"
|
||||
voice = "Jennifer"
|
||||
language_type = "English"
|
||||
stream = False
|
||||
start_time = time.time()
|
||||
start_at = datetime.now()
|
||||
error_message = ""
|
||||
status_code = 500
|
||||
response_data: Dict[str, Any] = {}
|
||||
try:
|
||||
dashscope.base_http_api_url = "https://dashscope.aliyuncs.com/api/v1"
|
||||
loop = asyncio.get_event_loop()
|
||||
resp = await loop.run_in_executor(
|
||||
Qwen._executor,
|
||||
lambda: dashscope.audio.qwen_tts.SpeechSynthesizer.call(
|
||||
model=model_name,
|
||||
api_key=api_key,
|
||||
text=content,
|
||||
voice=voice,
|
||||
language_type=language_type,
|
||||
stream=stream,
|
||||
)
|
||||
)
|
||||
status_code = getattr(resp, "status_code", getattr(resp, "code", 500))
|
||||
if hasattr(resp, "output"):
|
||||
response_data = {
|
||||
"output": getattr(resp, "output", {}),
|
||||
"usage": getattr(resp, "usage", {}),
|
||||
"code": getattr(resp, "code", None),
|
||||
"message": getattr(resp, "message", None),
|
||||
}
|
||||
else:
|
||||
response_data = {}
|
||||
if hasattr(resp, "__dict__"):
|
||||
response_data = resp.__dict__
|
||||
else:
|
||||
response_data = {"output": {}, "message": str(resp)}
|
||||
duration = time.time() - start_time
|
||||
if status_code == 200:
|
||||
audio_data: bytes | None = None
|
||||
audio_url: str | None = None
|
||||
# 优先从输出结构的 audio.url 下载音频,其次尝试 audio.data 的 base64
|
||||
out = response_data.get("output") or {}
|
||||
audio_obj = out.get("audio") or {}
|
||||
if isinstance(audio_obj, dict):
|
||||
audio_url = audio_obj.get("url")
|
||||
audio_b64 = audio_obj.get("data")
|
||||
else:
|
||||
audio_url = None
|
||||
audio_b64 = None
|
||||
# 下载远程音频
|
||||
if audio_url:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
r = await client.get(audio_url)
|
||||
if r.status_code == 200:
|
||||
audio_data = r.content
|
||||
else:
|
||||
logger.warning(f"Fetch audio failed: {r.status_code}")
|
||||
except Exception as _e:
|
||||
logger.warning(f"Fetch audio exception: {str(_e)}")
|
||||
# 若无URL或下载失败,尝试解析内嵌的base64
|
||||
if not audio_data and audio_b64:
|
||||
try:
|
||||
from base64 import b64decode
|
||||
audio_data = b64decode(audio_b64)
|
||||
except Exception:
|
||||
audio_data = None
|
||||
recording_id: int | None = None
|
||||
if audio_data:
|
||||
try:
|
||||
from fastapi import UploadFile
|
||||
import io
|
||||
from backend.app.admin.service.file_service import file_service
|
||||
from backend.app.ai.service.recording_service import recording_service
|
||||
# 根据URL后缀判断文件类型(默认mp3)
|
||||
ext = "mp3"
|
||||
ct = "audio/mp3"
|
||||
try:
|
||||
if audio_url and audio_url.lower().endswith(".wav"):
|
||||
ext = "wav"
|
||||
ct = "audio/wav"
|
||||
except Exception:
|
||||
pass
|
||||
file_name = f"{image_text_id}_std.{ext}"
|
||||
upload_file = UploadFile(filename=file_name, file=io.BytesIO(audio_data), headers={}, size=len(audio_data))
|
||||
file_response = await file_service.upload_file_with_content_type(file=upload_file, content_type=ct, metadata={"is_standard_audio": True})
|
||||
recording_id = await recording_service.create_recording_record_with_details(
|
||||
file_id=int(file_response.id),
|
||||
ref_text=content,
|
||||
image_id=image_id,
|
||||
image_text_id=image_text_id,
|
||||
eval_mode=1,
|
||||
user_id=user_id,
|
||||
details=response_data,
|
||||
is_standard=True,
|
||||
)
|
||||
except Exception:
|
||||
recording_id = None
|
||||
Qwen._log_audit(
|
||||
api_type="tts",
|
||||
model_name=model_name,
|
||||
request_data={"content": content, "image_text_id": image_text_id},
|
||||
response_data=response_data,
|
||||
duration=duration,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
image_id=image_id or 0,
|
||||
user_id=user_id or 0,
|
||||
called_at=start_at,
|
||||
)
|
||||
return {"success": True, "output": response_data.get("output", {}), "recording_id": recording_id, "audio_url": audio_url}
|
||||
else:
|
||||
error_message = f"TTS error: {response_data.get('message') or 'Unknown error'}"
|
||||
Qwen._log_audit(
|
||||
api_type="tts",
|
||||
model_name=model_name,
|
||||
request_data={"content": content, "image_text_id": image_text_id},
|
||||
response_data={"error": error_message},
|
||||
duration=time.time() - start_time,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
image_id=image_id or 0,
|
||||
user_id=user_id or 0,
|
||||
called_at=start_at,
|
||||
)
|
||||
return {"success": False, "error": error_message, "status_code": status_code}
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"TTS error: {error_message}")
|
||||
Qwen._log_audit(
|
||||
api_type="tts",
|
||||
model_name=model_name,
|
||||
request_data={"content": content, "image_text_id": image_text_id},
|
||||
response_data={"error": error_message},
|
||||
duration=time.time() - start_time,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
image_id=image_id or 0,
|
||||
user_id=user_id or 0,
|
||||
called_at=start_at,
|
||||
)
|
||||
return {"success": False, "error": error_message}
|
||||
|
||||
@staticmethod
|
||||
async def chat(messages: List[Dict[str, str]], image_id: int = 0, user_id: int = 0) -> Dict[str, Any]:
|
||||
api_key = settings.QWEN_API_KEY
|
||||
|
||||
Reference in New Issue
Block a user