This commit is contained in:
Felix
2025-11-30 08:45:42 +08:00
parent 1fc981d4a2
commit 5f356648b1
11 changed files with 84 additions and 608 deletions

5
.gitignore vendored
View File

@@ -6,5 +6,10 @@ backend/.env
backend/alembic/versions/
.venv/
.vscode/
.cloudbase/
.history/
.trae/
*.log
.ruff_cache/
backend/.DS_Store
assets/

View File

@@ -3,5 +3,4 @@
from backend.app.admin.crud.file_crud import file_dao
from backend.app.admin.crud.daily_summary_crud import daily_summary_dao
from backend.app.admin.crud.points_crud import points_dao, points_log_dao
from backend.app.admin.crud.free_usage_crud import sharing_event_dao, usage_reset_log_dao
from backend.app.admin.crud.points_crud import points_dao, points_log_dao

View File

@@ -1,48 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional, List
from sqlalchemy import select, update, delete, and_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.admin.model.order import SharingEvent, UsageResetLog
class SharingEventDao(CRUDPlus[SharingEvent]):
async def record_sharing_event(
self, db: AsyncSession, user_id: int, platform: str, success: bool = True
) -> SharingEvent:
"""
记录社交分享事件
"""
sharing_event = self.model(
user_id=user_id,
platform=platform,
success=success,
reset_usage=success
)
db.add(sharing_event)
return sharing_event
class UsageResetLogDao(CRUDPlus[UsageResetLog]):
async def record_usage_reset(
self, db: AsyncSession, user_id: int, reset_type: str, service_types: List[str], previous_counts: dict
) -> UsageResetLog:
"""
记录使用次数重置事件
"""
reset_log = self.model(
user_id=user_id,
reset_type=reset_type,
service_types=service_types,
previous_counts=previous_counts
)
db.add(reset_log)
return reset_log
# 创建DAO实例
sharing_event_dao = SharingEventDao(SharingEvent)
usage_reset_log_dao = UsageResetLogDao(UsageResetLog)

View File

@@ -1,38 +0,0 @@
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.admin.model.order import FreezeLog
class FreezeLogDao(CRUDPlus[FreezeLog]):
async def get_by_id(self, db: AsyncSession, freeze_id: int) -> Optional[FreezeLog]:
"""
根据ID获取冻结记录
"""
stmt = select(FreezeLog).where(FreezeLog.id == freeze_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_by_order_id(self, db: AsyncSession, order_id: int) -> Optional[FreezeLog]:
"""
根据订单ID获取冻结记录
"""
stmt = select(FreezeLog).where(FreezeLog.order_id == order_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_pending_by_user(self, db: AsyncSession, user_id: int) -> list[FreezeLog]:
"""
获取用户所有待处理的冻结记录
"""
stmt = select(FreezeLog).where(
FreezeLog.user_id == user_id,
FreezeLog.status == "pending"
)
result = await db.execute(stmt)
return result.scalars().all()
freeze_log_dao = FreezeLogDao(FreezeLog)

View File

@@ -1,61 +0,0 @@
from typing import Optional, List, Dict, Any
from sqlalchemy import select, and_
from sqlalchemy.ext.asyncio import AsyncSession
from backend.app.admin.model.order import UsageLog
from sqlalchemy_crud_plus import CRUDPlus
class UsageLogDao(CRUDPlus[UsageLog]):
async def get_by_id(self, db: AsyncSession, log_id: int) -> Optional[UsageLog]:
"""
根据ID获取使用日志
"""
stmt = select(UsageLog).where(UsageLog.id == log_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_by_user_id(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[UsageLog]:
"""
根据用户ID获取使用日志列表
"""
stmt = select(UsageLog).where(
UsageLog.user_id == user_id
).order_by(UsageLog.created_at.desc()).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def get_by_action(self, db: AsyncSession, user_id: int, action: str, limit: int = 50) -> List[UsageLog]:
"""
根据动作类型获取使用日志
"""
stmt = select(UsageLog).where(
and_(
UsageLog.user_id == user_id,
UsageLog.action == action
)
).order_by(UsageLog.created_at.desc()).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def get_balance_history(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[UsageLog]:
"""
获取用户余额变动历史
"""
stmt = select(UsageLog).where(
UsageLog.user_id == user_id
).order_by(UsageLog.created_at.desc()).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def add_log(self, db: AsyncSession, log_data: Dict[str, Any]) -> UsageLog:
"""
添加使用日志
"""
log = UsageLog(**log_data)
db.add(log)
await db.flush()
return log
usage_log_dao = UsageLogDao(UsageLog)

View File

@@ -4,9 +4,6 @@ from backend.common.model import MappedBase # noqa: I
from backend.app.admin.model.wx_user import WxUser
from backend.app.admin.model.audit_log import AuditLog, DailySummary
from backend.app.admin.model.file import File
# from backend.app.admin.model.dict import DictionaryEntry, DictionaryMedia
from backend.app.admin.model.order import UsageLog, SharingEvent, UsageResetLog
from backend.app.admin.model.coupon import Coupon, CouponUsage
# from backend.app.admin.model.notification import Notification, UserNotification
from backend.app.admin.model.points import Points, PointsLog
from backend.app.ai.model import Image

View File

@@ -2,4 +2,3 @@
# -*- coding: utf-8 -*-
from backend.app.admin.service.points_service import PointsService
from backend.app.admin.service.sharing_service import sharing_service

View File

@@ -1,4 +1,5 @@
import hashlib
import json
import os
from abc import ABC, abstractmethod
import aiofiles
@@ -11,7 +12,7 @@ from backend.core.conf import settings
class StorageProvider(ABC):
"""存储提供者抽象基类"""
@abstractmethod
async def save(self, file_id: int, content: bytes, file_name: str) -> str:
"""保存文件"""
@@ -27,10 +28,20 @@ class StorageProvider(ABC):
"""删除文件"""
pass
@abstractmethod
async def compress_image(self, file_id: int) -> str:
"""压缩图片并返回压缩后路径或键"""
pass
@abstractmethod
async def object_url(self, cos_key: str) -> str:
"""根据存储键返回可访问的URL或路径"""
pass
class DatabaseStorage(StorageProvider):
"""数据库存储提供者"""
async def save(self, file_id: int, content: bytes, file_name: str) -> str:
"""数据库存储不需要实际保存文件,直接返回空字符串"""
return ""
@@ -43,10 +54,18 @@ class DatabaseStorage(StorageProvider):
"""数据库存储不需要删除文件"""
return True
async def compress_image(self, file_id: int) -> str:
"""数据库存储不涉及文件压缩,返回空字符串"""
return ""
async def object_url(self, cos_key: str) -> str:
"""数据库存储不提供URL返回空字符串"""
return ""
class LocalStorage(StorageProvider):
"""本地文件系统存储提供者"""
def __init__(self, base_path: str = settings.STORAGE_PATH):
self.base_path = base_path
@@ -78,16 +97,27 @@ class LocalStorage(StorageProvider):
except:
return False
async def compress_image(self, file_id: int) -> str:
"""本地压缩占位实现,返回预期压缩文件路径"""
dir_name = str(file_id // 1000)
file_dir = os.path.join(self.base_path, dir_name)
os.makedirs(file_dir, exist_ok=True)
return os.path.join(file_dir, f"{file_id}.heif")
async def object_url(self, cos_key: str) -> str:
"""本地存储返回本地路径作为访问地址"""
return cos_key
class CosStorage(StorageProvider):
"""腾讯云COS存储提供者"""
def __init__(self):
"""初始化COS客户端"""
import logging
# Reduce verbosity of COS client logging
logging.getLogger('qcloud_cos').setLevel(logging.WARNING)
secret_id = settings.COS_SECRET_ID
secret_key = settings.COS_SECRET_KEY
region = settings.COS_REGION
@@ -104,15 +134,31 @@ class CosStorage(StorageProvider):
async def save(self, file_id: int, content: bytes, file_name: str) -> str:
"""保存文件到COS"""
key = self._get_key(file_id)
# res = compress_image(2106103703009361920)
# 上传文件到COS
response = self.client.put_object(
Bucket=self.bucket,
Body=content,
Key=key,
StorageClass='STANDARD',
EnableMD5=False
)
pic_ops = {
"is_pic_info": 1,
"rules": [
{
"bucket": self.bucket,
"fileid": f"{key}",
"rule": "imageMogr2/format/avif",
}
],
}
try:
response = self.client.put_object(
Bucket=self.bucket,
Body=content,
Key=key,
StorageClass='STANDARD',
EnableMD5=False,
PicOperations=json.dumps(pic_ops)
)
except Exception as e:
print(f"详细错误: {e}")
print(f"错误类型: {type(e)}")
# 返回存储路径(即对象键)
return key
@@ -139,6 +185,24 @@ class CosStorage(StorageProvider):
except Exception:
return False
async def compress_image(self, file_id: int) -> str:
key = self._get_key(file_id)
res_path=f"{key}.heif"
response = self.client.ci_download_compress_image(
Bucket=self.bucket,
Key=key,
DestImagePath=res_path,
CompressType='avif'
)
return res_path
async def object_url(self, cos_key: str) -> str:
url = self.client.get_object_url(
Bucket=self.bucket,
Key=cos_key,
)
return url
def get_storage_provider(provider_type: str) -> StorageProvider:
"""根据配置获取存储提供者"""
@@ -152,4 +216,4 @@ def get_storage_provider(provider_type: str) -> StorageProvider:
def calculate_file_hash(content: bytes) -> str:
"""计算文件的SHA256哈希值"""
return hashlib.sha256(content).hexdigest()
return hashlib.sha256(content).hexdigest()

View File

@@ -1,102 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List
from datetime import datetime
from backend.app.admin.crud.free_usage_crud import sharing_event_dao, usage_reset_log_dao
from backend.database.db import async_db_session
from backend.app.ai.service.rate_limit_service import RateLimitService
class SharingService:
"""社交分享服务类"""
@staticmethod
async def grant_times_by_sharing(user_id: int, platform: str) -> bool:
"""
通过社交分享授予使用次数
Args:
user_id: 用户ID
platform: 分享平台
Returns:
是否成功授予
"""
# 记录分享事件
await RateLimitService.record_sharing_event(user_id, platform, success=True)
# 重置免费使用次数
await RateLimitService.reset_free_usage_by_sharing(user_id)
return True
@staticmethod
async def get_user_sharing_history(user_id: int, limit: int = 10) -> List[dict]:
"""
获取用户分享历史
Args:
user_id: 用户ID
limit: 返回记录数量限制
Returns:
分享历史记录列表
"""
async with async_db_session() as db:
# 获取用户的分享事件记录
sharing_events = await sharing_event_dao.get_list(
db,
limit=limit,
user_id=user_id,
order_by="created_at DESC"
)
history = []
for event in sharing_events:
history.append({
"id": event.id,
"platform": event.platform,
"success": event.success,
"reset_usage": event.reset_usage,
"created_at": event.created_at.isoformat() if event.created_at else None
})
return history
@staticmethod
async def get_user_reset_history(user_id: int, limit: int = 10) -> List[dict]:
"""
获取用户使用次数重置历史
Args:
user_id: 用户ID
limit: 返回记录数量限制
Returns:
重置历史记录列表
"""
async with async_db_session() as db:
# 获取用户的重置记录
reset_logs = await usage_reset_log_dao.get_list(
db,
limit=limit,
user_id=user_id,
order_by="reset_at DESC"
)
history = []
for log in reset_logs:
history.append({
"id": log.id,
"reset_type": log.reset_type,
"service_types": log.service_types,
"previous_counts": log.previous_counts,
"reset_at": log.reset_at.isoformat() if log.reset_at else None
})
return history
# 创建单例实例
sharing_service = SharingService()

View File

@@ -1,166 +0,0 @@
from decimal import Decimal, ROUND_HALF_UP
from backend.app.admin.crud.usage_log_crud import usage_log_dao
from backend.app.admin.crud.order_crud import order_dao
from backend.app.admin.model.order import Order
from backend.app.admin.model.points import Points
from backend.app.admin.schema.usage import PurchaseRequest
from backend.app.admin.service.points_service import PointsService
from backend.common.exception import errors
from backend.database.db import async_db_session
from datetime import datetime, timedelta
from sqlalchemy import func, select, update
class UsageService:
@staticmethod
async def get_user_account(user_id: int) -> dict:
return await PointsService.get_user_account_details(user_id)
@staticmethod
def calculate_purchase_times_safe(amount_cents: int) -> int:
"""
安全的充值次数计算使用Decimal避免浮点数精度问题
"""
if amount_cents <= 0:
raise ValueError("充值金额必须大于0")
# 限制最大充值金额(防止溢出)
if amount_cents > 10000000: # 10万元
raise ValueError("单次充值金额不能超过10万元")
amount_yuan = Decimal(amount_cents) / Decimal(100)
base_times = amount_yuan * Decimal(10)
# 计算优惠比例每10元增加10%最多100%
tens = (amount_yuan // Decimal(10))
bonus_percent = min(tens * Decimal('0.1'), Decimal('1.0'))
total_times = base_times * (Decimal('1') + bonus_percent)
return int(total_times.quantize(Decimal('1'), rounding=ROUND_HALF_UP))
@staticmethod
async def purchase_times(user_id: int, request: PurchaseRequest):
# 输入验证
if request.amount_cents < 100: # 最少1元
raise errors.RequestError(msg="充值金额不能少于1元")
if request.amount_cents > 10000000: # 最多10万元
raise errors.RequestError(msg="单次充值金额不能超过10万元")
async with async_db_session.begin() as db:
# For now, we'll need to get the points account directly since we don't have a full replacement
# This is a temporary solution until we can fully refactor this method
from backend.app.admin.crud.points_crud import points_dao
points_account = await points_dao.get_by_user_id(db, user_id)
if not points_account:
points_account = await points_dao.create_user_points(db, user_id)
account = await UsageService.get_user_account(user_id)
times = UsageService.calculate_purchase_times_safe(request.amount_cents)
order = Order(
user_id=user_id,
order_type="purchase",
amount_cents=request.amount_cents,
amount_times=times,
status="pending"
)
await order_dao.add(db, order)
await db.flush() # 获取order.id
# 原子性更新账户(防止并发问题)
# 如果用户之前不是订阅用户,现在购买了次数,将其标记为订阅用户
# We need to update the points account directly
result = await db.execute(
update(Points)
.where(Points.id == points_account.id)
.values(
balance=Points.balance + times,
total_earned=Points.total_earned + times,
is_subscribed=True # Always set to subscribed when purchasing
)
)
if result.rowcount == 0:
raise errors.ServerError(msg="账户更新失败")
# 更新订单状态
order.status = "completed"
order.processed_at = datetime.now()
await order_dao.update(db, order.id, order)
await usage_log_dao.add(db, {
"user_id": user_id,
"action": "purchase",
"amount": times,
"balance_after": points_account.balance + times,
"related_id": order.id,
"details": {"amount_cents": request.amount_cents}
})
@staticmethod
async def update_subscription_status(user_id: int) -> bool:
"""
更新用户的订阅状态
Args:
user_id: 用户ID
Returns:
bool: 是否为订阅用户
"""
return await PointsService.update_subscription_status(user_id)
@staticmethod
async def use_times_atomic(user_id: int, count: int = 1):
"""
原子性扣减次数,支持免费试用优先使用
"""
if count <= 0:
raise ValueError("扣减次数必须大于0")
# Use the points service to deduct points/balance
# This is a simplified implementation - in a full refactor, we would need to implement
# the free trial logic in the points service as well
from backend.app.admin.crud.points_crud import points_dao
async with async_db_session.begin() as db:
points_account = await points_dao.get_by_user_id(db, user_id)
if not points_account:
raise errors.ForbiddenError(msg="用户账户不存在")
# For now, we'll just deduct from the balance directly
# A full implementation would need to handle free trial logic as well
if points_account.balance < count:
raise errors.ForbiddenError(msg="余额不足")
result = await db.execute(
update(Points)
.where(Points.id == points_account.id)
.values(balance=Points.balance - count)
)
if result.rowcount == 0:
raise errors.ForbiddenError(msg="余额不足")
# Update the points account object for logging
points_account.balance -= count
# 记录使用日志
await usage_log_dao.add(db, {
"user_id": user_id,
"action": "use",
"amount": -count,
"balance_after": points_account.balance,
"metadata_": {
"used_at": datetime.now().isoformat(),
"is_free_trial": False # Simplified implementation
}
})
@staticmethod
async def get_account_info(user_id: int) -> dict:
"""
获取用户账户详细信息
"""
# Use the points service method which already provides this functionality
return await PointsService.get_user_account_details(user_id)

View File

@@ -10,7 +10,6 @@ from backend.app.ai.crud.image_task_crud import image_task_dao
from backend.database.db import async_db_session
from backend.app.ai.model.image_task import ImageTaskStatus
from backend.app.admin.service.points_service import PointsService
from backend.app.admin.crud.free_usage_crud import sharing_event_dao, usage_reset_log_dao
from datetime import date
@@ -197,178 +196,6 @@ class RateLimitService:
"""
await RateLimitService.decrement_user_task_count(user_id)
@staticmethod
async def reset_free_usage_by_sharing(user_id: int, service_types: list = None) -> bool:
"""
通过社交分享重置免费使用次数
Args:
user_id: 用户ID
service_types: 要重置的服务类型列表None表示重置所有服务
Returns:
是否成功重置
"""
if service_types is None:
service_types = [IMAGE_RECOGNITION_SERVICE, SPEECH_ASSESSMENT_SERVICE]
# 记录重置前的状态
previous_counts = {}
today = date.today().isoformat()
async with async_db_session.begin() as db:
for service_type in service_types:
usage_key = RateLimitService._get_daily_usage_key(user_id, service_type)
current_usage = await redis_client.get(usage_key)
previous_counts[service_type] = int(current_usage) if current_usage else 0
# 重置Redis中的使用计数
for service_type in service_types:
usage_key = RateLimitService._get_daily_usage_key(user_id, service_type)
await redis_client.delete(usage_key)
# 记录重置事件
async with async_db_session.begin() as db:
reset_log = await usage_reset_log_dao.record_usage_reset(
db, user_id, "sharing", service_types, previous_counts
)
await db.flush()
return True
@staticmethod
async def record_sharing_event(user_id: int, platform: str, success: bool = True) -> bool:
"""
记录社交分享事件
Args:
user_id: 用户ID
platform: 分享平台
success: 分享是否成功
Returns:
是否成功记录
"""
async with async_db_session.begin() as db:
sharing_event = await sharing_event_dao.record_sharing_event(
db, user_id, platform, success
)
await db.flush()
# 如果分享成功,重置使用次数
if success:
await RateLimitService.reset_free_usage_by_sharing(user_id)
return True
@staticmethod
async def increment_free_usage_count(user_id: int, service_type: str) -> bool:
"""
增加用户免费使用次数
Args:
user_id: 用户ID
service_type: 服务类型
Returns:
是否成功增加
"""
today = date.today().isoformat()
# 增加Redis中的使用计数
usage_key = RateLimitService._get_daily_usage_key(user_id, service_type)
await redis_client.incr(usage_key)
# 设置过期时间到今天结束
now = datetime.now()
end_of_day = datetime(now.year, now.month, now.day, 23, 59, 59)
expire_seconds = int((end_of_day - now).total_seconds())
await redis_client.expire(usage_key, expire_seconds)
return True
@staticmethod
async def get_user_free_usage_status(user_id: int) -> dict:
"""
获取用户免费使用状态
Args:
user_id: 用户ID
Returns:
用户免费使用状态信息
"""
services = [IMAGE_RECOGNITION_SERVICE, SPEECH_ASSESSMENT_SERVICE]
status = {}
for service in services:
usage_key = RateLimitService._get_daily_usage_key(user_id, service)
current_usage = await redis_client.get(usage_key)
status[service] = {
"current_usage": int(current_usage) if current_usage else 0,
"max_usage": 3 # 或从配置中获取
}
return status
@staticmethod
async def increment_daily_usage(user_id: int, service_type: str, max_usage: int = 3) -> bool:
"""
增加用户每日使用次数
Args:
user_id: 用户ID
service_type: 服务类型
Returns:
是否成功增加(未超过限制)
"""
# 使用新的免费使用次数跟踪系统
await RateLimitService.increment_free_usage_count(user_id, service_type)
return True
@staticmethod
async def reset_daily_usage(user_id: int) -> bool:
"""
重置用户每日使用次数
Args:
user_id: 用户ID
recording_id: 录音ID可选用于重置特定录音的语音评测次数
Returns:
是否成功重置
"""
# 检查今日重置次数限制
reset_key = RateLimitService._get_daily_reset_key(user_id)
current_resets = await redis_client.get(reset_key)
if current_resets is not None and int(current_resets) >= 1:
# 今日已达到重置次数限制
return False
# 增加重置次数
await redis_client.incr(reset_key)
# 设置过期时间到今天结束
now = datetime.now()
end_of_day = datetime(now.year, now.month, now.day, 23, 59, 59)
expire_seconds = int((end_of_day - now).total_seconds())
await redis_client.expire(reset_key, expire_seconds)
# 重置图像识别使用次数
image_usage_key = RateLimitService._get_daily_usage_key(user_id, IMAGE_RECOGNITION_SERVICE)
await redis_client.delete(image_usage_key)
# 重置所有语音评测次数(匹配模式)
pattern = RateLimitService._get_daily_usage_key(user_id, f"{SPEECH_ASSESSMENT_SERVICE}_*")
async for key in redis_client.scan_iter(match=pattern):
await redis_client.delete(key)
return True
# 创建单例实例
rate_limit_service = RateLimitService()