fix code
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -6,5 +6,10 @@ backend/.env
|
||||
backend/alembic/versions/
|
||||
.venv/
|
||||
.vscode/
|
||||
.cloudbase/
|
||||
.history/
|
||||
.trae/
|
||||
*.log
|
||||
.ruff_cache/
|
||||
backend/.DS_Store
|
||||
assets/
|
||||
@@ -3,5 +3,4 @@
|
||||
|
||||
from backend.app.admin.crud.file_crud import file_dao
|
||||
from backend.app.admin.crud.daily_summary_crud import daily_summary_dao
|
||||
from backend.app.admin.crud.points_crud import points_dao, points_log_dao
|
||||
from backend.app.admin.crud.free_usage_crud import sharing_event_dao, usage_reset_log_dao
|
||||
from backend.app.admin.crud.points_crud import points_dao, points_log_dao
|
||||
@@ -1,48 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import select, update, delete, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy_crud_plus import CRUDPlus
|
||||
|
||||
from backend.app.admin.model.order import SharingEvent, UsageResetLog
|
||||
|
||||
|
||||
class SharingEventDao(CRUDPlus[SharingEvent]):
|
||||
async def record_sharing_event(
|
||||
self, db: AsyncSession, user_id: int, platform: str, success: bool = True
|
||||
) -> SharingEvent:
|
||||
"""
|
||||
记录社交分享事件
|
||||
"""
|
||||
sharing_event = self.model(
|
||||
user_id=user_id,
|
||||
platform=platform,
|
||||
success=success,
|
||||
reset_usage=success
|
||||
)
|
||||
db.add(sharing_event)
|
||||
return sharing_event
|
||||
|
||||
|
||||
class UsageResetLogDao(CRUDPlus[UsageResetLog]):
|
||||
async def record_usage_reset(
|
||||
self, db: AsyncSession, user_id: int, reset_type: str, service_types: List[str], previous_counts: dict
|
||||
) -> UsageResetLog:
|
||||
"""
|
||||
记录使用次数重置事件
|
||||
"""
|
||||
reset_log = self.model(
|
||||
user_id=user_id,
|
||||
reset_type=reset_type,
|
||||
service_types=service_types,
|
||||
previous_counts=previous_counts
|
||||
)
|
||||
db.add(reset_log)
|
||||
return reset_log
|
||||
|
||||
|
||||
# 创建DAO实例
|
||||
sharing_event_dao = SharingEventDao(SharingEvent)
|
||||
usage_reset_log_dao = UsageResetLogDao(UsageResetLog)
|
||||
@@ -1,38 +0,0 @@
|
||||
from typing import Optional
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy_crud_plus import CRUDPlus
|
||||
from backend.app.admin.model.order import FreezeLog
|
||||
|
||||
|
||||
class FreezeLogDao(CRUDPlus[FreezeLog]):
|
||||
|
||||
async def get_by_id(self, db: AsyncSession, freeze_id: int) -> Optional[FreezeLog]:
|
||||
"""
|
||||
根据ID获取冻结记录
|
||||
"""
|
||||
stmt = select(FreezeLog).where(FreezeLog.id == freeze_id)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_order_id(self, db: AsyncSession, order_id: int) -> Optional[FreezeLog]:
|
||||
"""
|
||||
根据订单ID获取冻结记录
|
||||
"""
|
||||
stmt = select(FreezeLog).where(FreezeLog.order_id == order_id)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_pending_by_user(self, db: AsyncSession, user_id: int) -> list[FreezeLog]:
|
||||
"""
|
||||
获取用户所有待处理的冻结记录
|
||||
"""
|
||||
stmt = select(FreezeLog).where(
|
||||
FreezeLog.user_id == user_id,
|
||||
FreezeLog.status == "pending"
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
freeze_log_dao = FreezeLogDao(FreezeLog)
|
||||
@@ -1,61 +0,0 @@
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy import select, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from backend.app.admin.model.order import UsageLog
|
||||
from sqlalchemy_crud_plus import CRUDPlus
|
||||
|
||||
|
||||
class UsageLogDao(CRUDPlus[UsageLog]):
|
||||
|
||||
async def get_by_id(self, db: AsyncSession, log_id: int) -> Optional[UsageLog]:
|
||||
"""
|
||||
根据ID获取使用日志
|
||||
"""
|
||||
stmt = select(UsageLog).where(UsageLog.id == log_id)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_user_id(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[UsageLog]:
|
||||
"""
|
||||
根据用户ID获取使用日志列表
|
||||
"""
|
||||
stmt = select(UsageLog).where(
|
||||
UsageLog.user_id == user_id
|
||||
).order_by(UsageLog.created_at.desc()).limit(limit)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_by_action(self, db: AsyncSession, user_id: int, action: str, limit: int = 50) -> List[UsageLog]:
|
||||
"""
|
||||
根据动作类型获取使用日志
|
||||
"""
|
||||
stmt = select(UsageLog).where(
|
||||
and_(
|
||||
UsageLog.user_id == user_id,
|
||||
UsageLog.action == action
|
||||
)
|
||||
).order_by(UsageLog.created_at.desc()).limit(limit)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_balance_history(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[UsageLog]:
|
||||
"""
|
||||
获取用户余额变动历史
|
||||
"""
|
||||
stmt = select(UsageLog).where(
|
||||
UsageLog.user_id == user_id
|
||||
).order_by(UsageLog.created_at.desc()).limit(limit)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def add_log(self, db: AsyncSession, log_data: Dict[str, Any]) -> UsageLog:
|
||||
"""
|
||||
添加使用日志
|
||||
"""
|
||||
log = UsageLog(**log_data)
|
||||
db.add(log)
|
||||
await db.flush()
|
||||
return log
|
||||
|
||||
|
||||
usage_log_dao = UsageLogDao(UsageLog)
|
||||
@@ -4,9 +4,6 @@ from backend.common.model import MappedBase # noqa: I
|
||||
from backend.app.admin.model.wx_user import WxUser
|
||||
from backend.app.admin.model.audit_log import AuditLog, DailySummary
|
||||
from backend.app.admin.model.file import File
|
||||
# from backend.app.admin.model.dict import DictionaryEntry, DictionaryMedia
|
||||
from backend.app.admin.model.order import UsageLog, SharingEvent, UsageResetLog
|
||||
from backend.app.admin.model.coupon import Coupon, CouponUsage
|
||||
# from backend.app.admin.model.notification import Notification, UserNotification
|
||||
from backend.app.admin.model.points import Points, PointsLog
|
||||
from backend.app.ai.model import Image
|
||||
@@ -2,4 +2,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from backend.app.admin.service.points_service import PointsService
|
||||
from backend.app.admin.service.sharing_service import sharing_service
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
import aiofiles
|
||||
@@ -11,7 +12,7 @@ from backend.core.conf import settings
|
||||
|
||||
class StorageProvider(ABC):
|
||||
"""存储提供者抽象基类"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def save(self, file_id: int, content: bytes, file_name: str) -> str:
|
||||
"""保存文件"""
|
||||
@@ -27,10 +28,20 @@ class StorageProvider(ABC):
|
||||
"""删除文件"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def compress_image(self, file_id: int) -> str:
|
||||
"""压缩图片并返回压缩后路径或键"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def object_url(self, cos_key: str) -> str:
|
||||
"""根据存储键返回可访问的URL或路径"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseStorage(StorageProvider):
|
||||
"""数据库存储提供者"""
|
||||
|
||||
|
||||
async def save(self, file_id: int, content: bytes, file_name: str) -> str:
|
||||
"""数据库存储不需要实际保存文件,直接返回空字符串"""
|
||||
return ""
|
||||
@@ -43,10 +54,18 @@ class DatabaseStorage(StorageProvider):
|
||||
"""数据库存储不需要删除文件"""
|
||||
return True
|
||||
|
||||
async def compress_image(self, file_id: int) -> str:
|
||||
"""数据库存储不涉及文件压缩,返回空字符串"""
|
||||
return ""
|
||||
|
||||
async def object_url(self, cos_key: str) -> str:
|
||||
"""数据库存储不提供URL,返回空字符串"""
|
||||
return ""
|
||||
|
||||
|
||||
class LocalStorage(StorageProvider):
|
||||
"""本地文件系统存储提供者"""
|
||||
|
||||
|
||||
def __init__(self, base_path: str = settings.STORAGE_PATH):
|
||||
self.base_path = base_path
|
||||
|
||||
@@ -78,16 +97,27 @@ class LocalStorage(StorageProvider):
|
||||
except:
|
||||
return False
|
||||
|
||||
async def compress_image(self, file_id: int) -> str:
|
||||
"""本地压缩占位实现,返回预期压缩文件路径"""
|
||||
dir_name = str(file_id // 1000)
|
||||
file_dir = os.path.join(self.base_path, dir_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
return os.path.join(file_dir, f"{file_id}.heif")
|
||||
|
||||
async def object_url(self, cos_key: str) -> str:
|
||||
"""本地存储返回本地路径作为访问地址"""
|
||||
return cos_key
|
||||
|
||||
|
||||
class CosStorage(StorageProvider):
|
||||
"""腾讯云COS存储提供者"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""初始化COS客户端"""
|
||||
import logging
|
||||
# Reduce verbosity of COS client logging
|
||||
logging.getLogger('qcloud_cos').setLevel(logging.WARNING)
|
||||
|
||||
|
||||
secret_id = settings.COS_SECRET_ID
|
||||
secret_key = settings.COS_SECRET_KEY
|
||||
region = settings.COS_REGION
|
||||
@@ -104,15 +134,31 @@ class CosStorage(StorageProvider):
|
||||
async def save(self, file_id: int, content: bytes, file_name: str) -> str:
|
||||
"""保存文件到COS"""
|
||||
key = self._get_key(file_id)
|
||||
|
||||
# res = compress_image(2106103703009361920)
|
||||
# 上传文件到COS
|
||||
response = self.client.put_object(
|
||||
Bucket=self.bucket,
|
||||
Body=content,
|
||||
Key=key,
|
||||
StorageClass='STANDARD',
|
||||
EnableMD5=False
|
||||
)
|
||||
pic_ops = {
|
||||
"is_pic_info": 1,
|
||||
"rules": [
|
||||
{
|
||||
"bucket": self.bucket,
|
||||
"fileid": f"{key}",
|
||||
"rule": "imageMogr2/format/avif",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.client.put_object(
|
||||
Bucket=self.bucket,
|
||||
Body=content,
|
||||
Key=key,
|
||||
StorageClass='STANDARD',
|
||||
EnableMD5=False,
|
||||
PicOperations=json.dumps(pic_ops)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"详细错误: {e}")
|
||||
print(f"错误类型: {type(e)}")
|
||||
|
||||
# 返回存储路径(即对象键)
|
||||
return key
|
||||
@@ -139,6 +185,24 @@ class CosStorage(StorageProvider):
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def compress_image(self, file_id: int) -> str:
|
||||
key = self._get_key(file_id)
|
||||
res_path=f"{key}.heif"
|
||||
response = self.client.ci_download_compress_image(
|
||||
Bucket=self.bucket,
|
||||
Key=key,
|
||||
DestImagePath=res_path,
|
||||
CompressType='avif'
|
||||
)
|
||||
return res_path
|
||||
|
||||
async def object_url(self, cos_key: str) -> str:
|
||||
url = self.client.get_object_url(
|
||||
Bucket=self.bucket,
|
||||
Key=cos_key,
|
||||
)
|
||||
return url
|
||||
|
||||
|
||||
def get_storage_provider(provider_type: str) -> StorageProvider:
|
||||
"""根据配置获取存储提供者"""
|
||||
@@ -152,4 +216,4 @@ def get_storage_provider(provider_type: str) -> StorageProvider:
|
||||
|
||||
def calculate_file_hash(content: bytes) -> str:
|
||||
"""计算文件的SHA256哈希值"""
|
||||
return hashlib.sha256(content).hexdigest()
|
||||
return hashlib.sha256(content).hexdigest()
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
|
||||
from backend.app.admin.crud.free_usage_crud import sharing_event_dao, usage_reset_log_dao
|
||||
from backend.database.db import async_db_session
|
||||
from backend.app.ai.service.rate_limit_service import RateLimitService
|
||||
|
||||
|
||||
class SharingService:
|
||||
"""社交分享服务类"""
|
||||
|
||||
@staticmethod
|
||||
async def grant_times_by_sharing(user_id: int, platform: str) -> bool:
|
||||
"""
|
||||
通过社交分享授予使用次数
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
platform: 分享平台
|
||||
|
||||
Returns:
|
||||
是否成功授予
|
||||
"""
|
||||
# 记录分享事件
|
||||
await RateLimitService.record_sharing_event(user_id, platform, success=True)
|
||||
|
||||
# 重置免费使用次数
|
||||
await RateLimitService.reset_free_usage_by_sharing(user_id)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def get_user_sharing_history(user_id: int, limit: int = 10) -> List[dict]:
|
||||
"""
|
||||
获取用户分享历史
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
limit: 返回记录数量限制
|
||||
|
||||
Returns:
|
||||
分享历史记录列表
|
||||
"""
|
||||
async with async_db_session() as db:
|
||||
# 获取用户的分享事件记录
|
||||
sharing_events = await sharing_event_dao.get_list(
|
||||
db,
|
||||
limit=limit,
|
||||
user_id=user_id,
|
||||
order_by="created_at DESC"
|
||||
)
|
||||
|
||||
history = []
|
||||
for event in sharing_events:
|
||||
history.append({
|
||||
"id": event.id,
|
||||
"platform": event.platform,
|
||||
"success": event.success,
|
||||
"reset_usage": event.reset_usage,
|
||||
"created_at": event.created_at.isoformat() if event.created_at else None
|
||||
})
|
||||
|
||||
return history
|
||||
|
||||
@staticmethod
|
||||
async def get_user_reset_history(user_id: int, limit: int = 10) -> List[dict]:
|
||||
"""
|
||||
获取用户使用次数重置历史
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
limit: 返回记录数量限制
|
||||
|
||||
Returns:
|
||||
重置历史记录列表
|
||||
"""
|
||||
async with async_db_session() as db:
|
||||
# 获取用户的重置记录
|
||||
reset_logs = await usage_reset_log_dao.get_list(
|
||||
db,
|
||||
limit=limit,
|
||||
user_id=user_id,
|
||||
order_by="reset_at DESC"
|
||||
)
|
||||
|
||||
history = []
|
||||
for log in reset_logs:
|
||||
history.append({
|
||||
"id": log.id,
|
||||
"reset_type": log.reset_type,
|
||||
"service_types": log.service_types,
|
||||
"previous_counts": log.previous_counts,
|
||||
"reset_at": log.reset_at.isoformat() if log.reset_at else None
|
||||
})
|
||||
|
||||
return history
|
||||
|
||||
|
||||
# 创建单例实例
|
||||
sharing_service = SharingService()
|
||||
@@ -1,166 +0,0 @@
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
from backend.app.admin.crud.usage_log_crud import usage_log_dao
|
||||
from backend.app.admin.crud.order_crud import order_dao
|
||||
from backend.app.admin.model.order import Order
|
||||
from backend.app.admin.model.points import Points
|
||||
from backend.app.admin.schema.usage import PurchaseRequest
|
||||
from backend.app.admin.service.points_service import PointsService
|
||||
from backend.common.exception import errors
|
||||
from backend.database.db import async_db_session
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import func, select, update
|
||||
|
||||
|
||||
class UsageService:
|
||||
|
||||
@staticmethod
|
||||
async def get_user_account(user_id: int) -> dict:
|
||||
return await PointsService.get_user_account_details(user_id)
|
||||
|
||||
@staticmethod
|
||||
def calculate_purchase_times_safe(amount_cents: int) -> int:
|
||||
"""
|
||||
安全的充值次数计算(使用Decimal避免浮点数精度问题)
|
||||
"""
|
||||
if amount_cents <= 0:
|
||||
raise ValueError("充值金额必须大于0")
|
||||
|
||||
# 限制最大充值金额(防止溢出)
|
||||
if amount_cents > 10000000: # 10万元
|
||||
raise ValueError("单次充值金额不能超过10万元")
|
||||
|
||||
amount_yuan = Decimal(amount_cents) / Decimal(100)
|
||||
base_times = amount_yuan * Decimal(10)
|
||||
|
||||
# 计算优惠比例(每10元增加10%,最多100%)
|
||||
tens = (amount_yuan // Decimal(10))
|
||||
bonus_percent = min(tens * Decimal('0.1'), Decimal('1.0'))
|
||||
|
||||
total_times = base_times * (Decimal('1') + bonus_percent)
|
||||
return int(total_times.quantize(Decimal('1'), rounding=ROUND_HALF_UP))
|
||||
|
||||
@staticmethod
|
||||
async def purchase_times(user_id: int, request: PurchaseRequest):
|
||||
# 输入验证
|
||||
if request.amount_cents < 100: # 最少1元
|
||||
raise errors.RequestError(msg="充值金额不能少于1元")
|
||||
if request.amount_cents > 10000000: # 最多10万元
|
||||
raise errors.RequestError(msg="单次充值金额不能超过10万元")
|
||||
|
||||
async with async_db_session.begin() as db:
|
||||
# For now, we'll need to get the points account directly since we don't have a full replacement
|
||||
# This is a temporary solution until we can fully refactor this method
|
||||
from backend.app.admin.crud.points_crud import points_dao
|
||||
points_account = await points_dao.get_by_user_id(db, user_id)
|
||||
if not points_account:
|
||||
points_account = await points_dao.create_user_points(db, user_id)
|
||||
|
||||
account = await UsageService.get_user_account(user_id)
|
||||
times = UsageService.calculate_purchase_times_safe(request.amount_cents)
|
||||
|
||||
order = Order(
|
||||
user_id=user_id,
|
||||
order_type="purchase",
|
||||
amount_cents=request.amount_cents,
|
||||
amount_times=times,
|
||||
status="pending"
|
||||
)
|
||||
await order_dao.add(db, order)
|
||||
await db.flush() # 获取order.id
|
||||
|
||||
# 原子性更新账户(防止并发问题)
|
||||
# 如果用户之前不是订阅用户,现在购买了次数,将其标记为订阅用户
|
||||
# We need to update the points account directly
|
||||
result = await db.execute(
|
||||
update(Points)
|
||||
.where(Points.id == points_account.id)
|
||||
.values(
|
||||
balance=Points.balance + times,
|
||||
total_earned=Points.total_earned + times,
|
||||
is_subscribed=True # Always set to subscribed when purchasing
|
||||
)
|
||||
)
|
||||
|
||||
if result.rowcount == 0:
|
||||
raise errors.ServerError(msg="账户更新失败")
|
||||
|
||||
# 更新订单状态
|
||||
order.status = "completed"
|
||||
order.processed_at = datetime.now()
|
||||
await order_dao.update(db, order.id, order)
|
||||
|
||||
await usage_log_dao.add(db, {
|
||||
"user_id": user_id,
|
||||
"action": "purchase",
|
||||
"amount": times,
|
||||
"balance_after": points_account.balance + times,
|
||||
"related_id": order.id,
|
||||
"details": {"amount_cents": request.amount_cents}
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def update_subscription_status(user_id: int) -> bool:
|
||||
"""
|
||||
更新用户的订阅状态
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
bool: 是否为订阅用户
|
||||
"""
|
||||
return await PointsService.update_subscription_status(user_id)
|
||||
|
||||
@staticmethod
|
||||
async def use_times_atomic(user_id: int, count: int = 1):
|
||||
"""
|
||||
原子性扣减次数,支持免费试用优先使用
|
||||
"""
|
||||
if count <= 0:
|
||||
raise ValueError("扣减次数必须大于0")
|
||||
|
||||
# Use the points service to deduct points/balance
|
||||
# This is a simplified implementation - in a full refactor, we would need to implement
|
||||
# the free trial logic in the points service as well
|
||||
from backend.app.admin.crud.points_crud import points_dao
|
||||
async with async_db_session.begin() as db:
|
||||
points_account = await points_dao.get_by_user_id(db, user_id)
|
||||
if not points_account:
|
||||
raise errors.ForbiddenError(msg="用户账户不存在")
|
||||
|
||||
# For now, we'll just deduct from the balance directly
|
||||
# A full implementation would need to handle free trial logic as well
|
||||
if points_account.balance < count:
|
||||
raise errors.ForbiddenError(msg="余额不足")
|
||||
|
||||
result = await db.execute(
|
||||
update(Points)
|
||||
.where(Points.id == points_account.id)
|
||||
.values(balance=Points.balance - count)
|
||||
)
|
||||
|
||||
if result.rowcount == 0:
|
||||
raise errors.ForbiddenError(msg="余额不足")
|
||||
|
||||
# Update the points account object for logging
|
||||
points_account.balance -= count
|
||||
|
||||
# 记录使用日志
|
||||
await usage_log_dao.add(db, {
|
||||
"user_id": user_id,
|
||||
"action": "use",
|
||||
"amount": -count,
|
||||
"balance_after": points_account.balance,
|
||||
"metadata_": {
|
||||
"used_at": datetime.now().isoformat(),
|
||||
"is_free_trial": False # Simplified implementation
|
||||
}
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def get_account_info(user_id: int) -> dict:
|
||||
"""
|
||||
获取用户账户详细信息
|
||||
"""
|
||||
# Use the points service method which already provides this functionality
|
||||
return await PointsService.get_user_account_details(user_id)
|
||||
@@ -10,7 +10,6 @@ from backend.app.ai.crud.image_task_crud import image_task_dao
|
||||
from backend.database.db import async_db_session
|
||||
from backend.app.ai.model.image_task import ImageTaskStatus
|
||||
from backend.app.admin.service.points_service import PointsService
|
||||
from backend.app.admin.crud.free_usage_crud import sharing_event_dao, usage_reset_log_dao
|
||||
from datetime import date
|
||||
|
||||
|
||||
@@ -197,178 +196,6 @@ class RateLimitService:
|
||||
"""
|
||||
await RateLimitService.decrement_user_task_count(user_id)
|
||||
|
||||
@staticmethod
|
||||
async def reset_free_usage_by_sharing(user_id: int, service_types: list = None) -> bool:
|
||||
"""
|
||||
通过社交分享重置免费使用次数
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
service_types: 要重置的服务类型列表,None表示重置所有服务
|
||||
|
||||
Returns:
|
||||
是否成功重置
|
||||
"""
|
||||
if service_types is None:
|
||||
service_types = [IMAGE_RECOGNITION_SERVICE, SPEECH_ASSESSMENT_SERVICE]
|
||||
|
||||
# 记录重置前的状态
|
||||
previous_counts = {}
|
||||
today = date.today().isoformat()
|
||||
|
||||
async with async_db_session.begin() as db:
|
||||
for service_type in service_types:
|
||||
usage_key = RateLimitService._get_daily_usage_key(user_id, service_type)
|
||||
current_usage = await redis_client.get(usage_key)
|
||||
previous_counts[service_type] = int(current_usage) if current_usage else 0
|
||||
|
||||
# 重置Redis中的使用计数
|
||||
for service_type in service_types:
|
||||
usage_key = RateLimitService._get_daily_usage_key(user_id, service_type)
|
||||
await redis_client.delete(usage_key)
|
||||
|
||||
# 记录重置事件
|
||||
async with async_db_session.begin() as db:
|
||||
reset_log = await usage_reset_log_dao.record_usage_reset(
|
||||
db, user_id, "sharing", service_types, previous_counts
|
||||
)
|
||||
await db.flush()
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def record_sharing_event(user_id: int, platform: str, success: bool = True) -> bool:
|
||||
"""
|
||||
记录社交分享事件
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
platform: 分享平台
|
||||
success: 分享是否成功
|
||||
|
||||
Returns:
|
||||
是否成功记录
|
||||
"""
|
||||
async with async_db_session.begin() as db:
|
||||
sharing_event = await sharing_event_dao.record_sharing_event(
|
||||
db, user_id, platform, success
|
||||
)
|
||||
await db.flush()
|
||||
|
||||
# 如果分享成功,重置使用次数
|
||||
if success:
|
||||
await RateLimitService.reset_free_usage_by_sharing(user_id)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def increment_free_usage_count(user_id: int, service_type: str) -> bool:
|
||||
"""
|
||||
增加用户免费使用次数
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
service_type: 服务类型
|
||||
|
||||
Returns:
|
||||
是否成功增加
|
||||
"""
|
||||
today = date.today().isoformat()
|
||||
|
||||
# 增加Redis中的使用计数
|
||||
usage_key = RateLimitService._get_daily_usage_key(user_id, service_type)
|
||||
await redis_client.incr(usage_key)
|
||||
|
||||
# 设置过期时间到今天结束
|
||||
now = datetime.now()
|
||||
end_of_day = datetime(now.year, now.month, now.day, 23, 59, 59)
|
||||
expire_seconds = int((end_of_day - now).total_seconds())
|
||||
await redis_client.expire(usage_key, expire_seconds)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def get_user_free_usage_status(user_id: int) -> dict:
|
||||
"""
|
||||
获取用户免费使用状态
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
用户免费使用状态信息
|
||||
"""
|
||||
services = [IMAGE_RECOGNITION_SERVICE, SPEECH_ASSESSMENT_SERVICE]
|
||||
status = {}
|
||||
|
||||
for service in services:
|
||||
usage_key = RateLimitService._get_daily_usage_key(user_id, service)
|
||||
current_usage = await redis_client.get(usage_key)
|
||||
status[service] = {
|
||||
"current_usage": int(current_usage) if current_usage else 0,
|
||||
"max_usage": 3 # 或从配置中获取
|
||||
}
|
||||
|
||||
return status
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def increment_daily_usage(user_id: int, service_type: str, max_usage: int = 3) -> bool:
|
||||
"""
|
||||
增加用户每日使用次数
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
service_type: 服务类型
|
||||
|
||||
Returns:
|
||||
是否成功增加(未超过限制)
|
||||
"""
|
||||
|
||||
# 使用新的免费使用次数跟踪系统
|
||||
await RateLimitService.increment_free_usage_count(user_id, service_type)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def reset_daily_usage(user_id: int) -> bool:
|
||||
"""
|
||||
重置用户每日使用次数
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
recording_id: 录音ID(可选,用于重置特定录音的语音评测次数)
|
||||
|
||||
Returns:
|
||||
是否成功重置
|
||||
"""
|
||||
# 检查今日重置次数限制
|
||||
reset_key = RateLimitService._get_daily_reset_key(user_id)
|
||||
current_resets = await redis_client.get(reset_key)
|
||||
|
||||
if current_resets is not None and int(current_resets) >= 1:
|
||||
# 今日已达到重置次数限制
|
||||
return False
|
||||
|
||||
# 增加重置次数
|
||||
await redis_client.incr(reset_key)
|
||||
|
||||
# 设置过期时间到今天结束
|
||||
now = datetime.now()
|
||||
end_of_day = datetime(now.year, now.month, now.day, 23, 59, 59)
|
||||
expire_seconds = int((end_of_day - now).total_seconds())
|
||||
await redis_client.expire(reset_key, expire_seconds)
|
||||
|
||||
# 重置图像识别使用次数
|
||||
image_usage_key = RateLimitService._get_daily_usage_key(user_id, IMAGE_RECOGNITION_SERVICE)
|
||||
await redis_client.delete(image_usage_key)
|
||||
# 重置所有语音评测次数(匹配模式)
|
||||
pattern = RateLimitService._get_daily_usage_key(user_id, f"{SPEECH_ASSESSMENT_SERVICE}_*")
|
||||
async for key in redis_client.scan_iter(match=pattern):
|
||||
await redis_client.delete(key)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# 创建单例实例
|
||||
rate_limit_service = RateLimitService()
|
||||
|
||||
Reference in New Issue
Block a user