change points model and service
This commit is contained in:
@@ -5,6 +5,7 @@ from sqlalchemy import select, update, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy_crud_plus import CRUDPlus
|
||||
from backend.app.admin.model.points import Points, PointsLog
|
||||
from backend.common.const import FREE_TRIAL_BALANCE, POINTS_ACTION_SYSTEM_GIFT
|
||||
|
||||
|
||||
class PointsDao(CRUDPlus[Points]):
|
||||
@@ -25,7 +26,7 @@ class PointsDao(CRUDPlus[Points]):
|
||||
await db.flush()
|
||||
return points
|
||||
|
||||
async def add_points_atomic(self, db: AsyncSession, user_id: int, amount: int, extend_expiration: bool = False) -> bool:
|
||||
async def add_points_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
|
||||
"""
|
||||
原子性增加用户积分
|
||||
"""
|
||||
@@ -40,30 +41,12 @@ class PointsDao(CRUDPlus[Points]):
|
||||
"total_earned": Points.total_earned + amount
|
||||
}
|
||||
|
||||
# 如果需要延期,则更新过期时间
|
||||
if extend_expiration:
|
||||
update_values["expired_time"] = datetime.now() + timedelta(days=30)
|
||||
|
||||
stmt = update(Points).where(
|
||||
Points.user_id == user_id
|
||||
).values(**update_values)
|
||||
result = await db.execute(stmt)
|
||||
return result.rowcount > 0
|
||||
|
||||
async def deduct_points_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
|
||||
"""
|
||||
原子性扣减用户积分(确保不超扣)
|
||||
"""
|
||||
stmt = update(Points).where(
|
||||
Points.user_id == user_id,
|
||||
Points.balance >= amount
|
||||
).values(
|
||||
balance=Points.balance - amount,
|
||||
total_spent=Points.total_spent + amount
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.rowcount > 0
|
||||
|
||||
async def get_balance(self, db: AsyncSession, user_id: int) -> int:
|
||||
"""
|
||||
获取用户积分余额
|
||||
@@ -73,50 +56,33 @@ class PointsDao(CRUDPlus[Points]):
|
||||
return 0
|
||||
return points_account.balance
|
||||
|
||||
async def check_and_clear_expired_points(self, db: AsyncSession, user_id: int) -> bool:
|
||||
"""
|
||||
检查并清空过期积分
|
||||
"""
|
||||
stmt = update(Points).where(
|
||||
Points.user_id == user_id,
|
||||
Points.expired_time < datetime.now(),
|
||||
Points.balance > 0
|
||||
).values(
|
||||
balance=0,
|
||||
total_spent=Points.total_spent + Points.balance
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.rowcount > 0
|
||||
|
||||
async def create_new_user_account(self, db: AsyncSession, user_id: int) -> Points:
|
||||
"""
|
||||
为新用户创建账户(包含免费试用)
|
||||
为新用户创建账户(根据新需求,直接在balance中增加初始积分)
|
||||
"""
|
||||
# 设置免费试用期(3天)
|
||||
trial_expires_at = datetime.now() + timedelta(days=3)
|
||||
|
||||
# 为新用户创建积分账户,初始赠送30积分
|
||||
initial_points = FREE_TRIAL_BALANCE # 使用原来的常量作为初始积分
|
||||
|
||||
points = Points(
|
||||
user_id=user_id,
|
||||
balance=30, # 初始30次免费次数
|
||||
free_trial_balance=30,
|
||||
free_trial_expires_at=trial_expires_at,
|
||||
free_trial_used=True # 标记为已使用(因为已经给了30次)
|
||||
balance=initial_points, # 直接在balance中增加初始积分
|
||||
total_earned=initial_points # 同时更新累计获得积分
|
||||
)
|
||||
db.add(points)
|
||||
await db.flush()
|
||||
return points
|
||||
|
||||
async def update_balance_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
|
||||
"""
|
||||
原子性更新用户余额
|
||||
"""
|
||||
stmt = update(Points).where(
|
||||
Points.user_id == user_id
|
||||
).values(
|
||||
balance=Points.balance + amount
|
||||
|
||||
# 创建积分变动日志,记录新用户获赠积分
|
||||
log = PointsLog(
|
||||
user_id=user_id,
|
||||
action=POINTS_ACTION_SYSTEM_GIFT,
|
||||
amount=initial_points,
|
||||
balance_after=initial_points,
|
||||
details={"message": "新用户注册赠送积分"}
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.rowcount > 0
|
||||
db.add(log)
|
||||
await db.flush()
|
||||
|
||||
return points
|
||||
|
||||
async def deduct_balance_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
|
||||
"""
|
||||
@@ -126,49 +92,12 @@ class PointsDao(CRUDPlus[Points]):
|
||||
Points.user_id == user_id,
|
||||
Points.balance >= amount
|
||||
).values(
|
||||
balance=Points.balance - amount
|
||||
balance=Points.balance - amount,
|
||||
total_spent=Points.total_spent + amount
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.rowcount > 0
|
||||
|
||||
async def get_frozen_balance(self, db: AsyncSession, user_id: int) -> int:
|
||||
"""
|
||||
获取用户被冻结的次数
|
||||
"""
|
||||
# This is a placeholder since we don't have a FreezeLog model in the points context yet
|
||||
# In a full implementation, this would query a freeze log table
|
||||
return 0
|
||||
|
||||
async def get_available_balance(self, db: AsyncSession, user_id: int) -> int:
|
||||
"""
|
||||
获取用户可用余额(总余额减去冻结余额)
|
||||
"""
|
||||
points_account = await self.get_by_user_id(db, user_id)
|
||||
if not points_account:
|
||||
return 0
|
||||
|
||||
frozen_balance = await self.get_frozen_balance(db, user_id)
|
||||
return max(0, points_account.balance - frozen_balance)
|
||||
|
||||
async def check_free_trial_valid(self, db: AsyncSession, user_id: int) -> bool:
|
||||
"""
|
||||
检查用户免费试用是否仍然有效
|
||||
"""
|
||||
points_account = await self.get_by_user_id(db, user_id)
|
||||
if not points_account or not points_account.free_trial_expires_at:
|
||||
return False
|
||||
|
||||
return (points_account.free_trial_expires_at > datetime.now() and
|
||||
points_account.free_trial_balance > 0)
|
||||
|
||||
async def update(self, db: AsyncSession, points_id: int, update_data: dict) -> bool:
|
||||
"""
|
||||
更新积分账户信息
|
||||
"""
|
||||
stmt = update(Points).where(Points.id == points_id).values(**update_data)
|
||||
result = await db.execute(stmt)
|
||||
return result.rowcount > 0
|
||||
|
||||
|
||||
class PointsLogDao(CRUDPlus[PointsLog]):
|
||||
async def add_log(self, db: AsyncSession, log_data: Dict[str, Any]) -> PointsLog:
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import String, BigInteger, DateTime, func, Index, Boolean
|
||||
from sqlalchemy import String, BigInteger, DateTime, func, Index, Boolean, ForeignKey
|
||||
from sqlalchemy.dialects.mysql import JSON as MySQLJSON # Changed from postgresql.JSONB to mysql.JSON
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
@@ -14,21 +12,10 @@ class Points(Base):
|
||||
__tablename__ = 'points'
|
||||
|
||||
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, unique=True, nullable=False, comment='关联的用户ID')
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='关联的用户ID')
|
||||
balance: Mapped[int] = mapped_column(BigInteger, default=0, comment='当前积分余额')
|
||||
total_earned: Mapped[int] = mapped_column(BigInteger, default=0, comment='累计获得积分')
|
||||
total_spent: Mapped[int] = mapped_column(BigInteger, default=0, comment='累计消费积分')
|
||||
expired_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now() + timedelta(days=30), comment="过期时间")
|
||||
|
||||
# Subscription and free trial fields (moved from UserAccount)
|
||||
subscription_type: Mapped[Optional[str]] = mapped_column(String(32), default=None, comment='订阅类型:monthly/quarterly/half_yearly/yearly')
|
||||
subscription_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='订阅到期时间')
|
||||
is_subscribed: Mapped[bool] = mapped_column(Boolean, default=False, comment='是否为订阅用户')
|
||||
carryover_balance: Mapped[int] = mapped_column(BigInteger, default=0, comment='上期未使用的次数')
|
||||
# 新用户免费次数相关
|
||||
free_trial_balance: Mapped[int] = mapped_column(BigInteger, default=30, comment='新用户免费试用次数')
|
||||
free_trial_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='免费试用期结束时间')
|
||||
free_trial_used: Mapped[bool] = mapped_column(Boolean, default=False, comment='是否已使用免费试用')
|
||||
|
||||
# 索引优化
|
||||
__table_args__ = (
|
||||
@@ -42,15 +29,16 @@ class PointsLog(Base):
|
||||
|
||||
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False, comment='用户ID')
|
||||
action: Mapped[str] = mapped_column(String(32), comment='动作:earn/spend')
|
||||
action: Mapped[str] = mapped_column(String(32), comment='积分变更类型:system_gift/recharge/purchase/etc')
|
||||
amount: Mapped[int] = mapped_column(BigInteger, comment='变动数量')
|
||||
balance_after: Mapped[int] = mapped_column(BigInteger, comment='变动后余额')
|
||||
related_id: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, comment='关联ID')
|
||||
details: Mapped[Optional[dict]] = mapped_column(MySQLJSON, default=None, comment='附加信息')
|
||||
related_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, default=None, comment='关联ID')
|
||||
details: Mapped[Optional[dict]] = mapped_column(MySQLJSON, nullable=True, default=None, comment='附加信息')
|
||||
|
||||
# 索引优化
|
||||
__table_args__ = (
|
||||
Index('idx_points_log_user_action', 'user_id', 'action'),
|
||||
Index('idx_points_log_related', 'related_id'),
|
||||
Index('idx_points_log_user_time', 'user_id', 'created_time'),
|
||||
{'comment': '积分变动日志表'}
|
||||
)
|
||||
@@ -1,247 +0,0 @@
|
||||
from backend.app.admin.crud.usage_log_crud import usage_log_dao
|
||||
from backend.app.admin.crud.points_crud import points_dao
|
||||
from backend.database.db import async_db_session
|
||||
from backend.common.exception import errors
|
||||
from datetime import datetime, timedelta
|
||||
from backend.common.log import log as logger
|
||||
|
||||
# 导入 Redis 客户端
|
||||
from backend.database.redis import redis_client
|
||||
|
||||
|
||||
class AdShareService:
|
||||
# 每日限制配置
|
||||
DAILY_AD_LIMIT = 5 # 每日广告观看限制
|
||||
DAILY_SHARE_LIMIT = 3 # 每日分享限制
|
||||
AD_REWARD_TIMES = 3 # 每次广告奖励次数
|
||||
SHARE_REWARD_TIMES = 3 # 每次分享奖励次数
|
||||
|
||||
@staticmethod
|
||||
def _get_redis_key(user_id: int, action_type: str, date_str: str = None) -> str:
|
||||
"""
|
||||
生成 Redis key
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
action_type: 动作类型 (ad/share)
|
||||
date_str: 日期字符串 (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
str: Redis key
|
||||
"""
|
||||
if date_str is None:
|
||||
date_str = datetime.now().strftime('%Y-%m-%d')
|
||||
return f"user:{user_id}:{action_type}:count:{date_str}"
|
||||
|
||||
@staticmethod
|
||||
async def _check_daily_limit(user_id: int, action_type: str, limit: int) -> bool:
|
||||
"""
|
||||
检查每日限制
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
action_type: 动作类型
|
||||
limit: 限制次数
|
||||
|
||||
Returns:
|
||||
bool: 是否超过限制
|
||||
"""
|
||||
try:
|
||||
# 确保 Redis 连接
|
||||
await redis_client.ping()
|
||||
|
||||
# 获取今天的计数
|
||||
key = AdShareService._get_redis_key(user_id, action_type)
|
||||
current_count = await redis_client.get(key)
|
||||
|
||||
if current_count is None:
|
||||
current_count = 0
|
||||
else:
|
||||
current_count = int(current_count)
|
||||
|
||||
return current_count >= limit
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查每日限制失败: {str(e)}")
|
||||
# Redis 异常时允许继续操作(避免服务中断)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _increment_daily_count(user_id: int, action_type: str) -> int:
|
||||
"""
|
||||
增加每日计数
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
action_type: 动作类型
|
||||
|
||||
Returns:
|
||||
int: 当前计数
|
||||
"""
|
||||
try:
|
||||
# 确保 Redis 连接
|
||||
await redis_client.ping()
|
||||
|
||||
key = AdShareService._get_redis_key(user_id, action_type)
|
||||
|
||||
# 增加计数
|
||||
current_count = await redis_client.incr(key)
|
||||
|
||||
# 设置过期时间(第二天自动清除)
|
||||
tomorrow = datetime.now() + timedelta(days=1)
|
||||
tomorrow_midnight = tomorrow.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
expire_seconds = int((tomorrow_midnight - datetime.now()).total_seconds())
|
||||
|
||||
if expire_seconds > 0:
|
||||
await redis_client.expire(key, expire_seconds)
|
||||
|
||||
return current_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"增加每日计数失败: {str(e)}")
|
||||
raise errors.ServerError(msg="系统繁忙,请稍后重试")
|
||||
|
||||
@staticmethod
|
||||
async def _get_daily_count(user_id: int, action_type: str) -> int:
|
||||
"""
|
||||
获取今日计数
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
action_type: 动作类型
|
||||
|
||||
Returns:
|
||||
int: 当前计数
|
||||
"""
|
||||
try:
|
||||
# 确保 Redis 连接
|
||||
await redis_client.ping()
|
||||
|
||||
key = AdShareService._get_redis_key(user_id, action_type)
|
||||
count = await redis_client.get(key)
|
||||
|
||||
return int(count) if count is not None else 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取每日计数失败: {str(e)}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def grant_times_by_ad(user_id: int):
|
||||
"""
|
||||
通过观看广告获得次数
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
"""
|
||||
# 检查每日限制
|
||||
if await AdShareService._check_daily_limit(user_id, "ad", AdShareService.DAILY_AD_LIMIT):
|
||||
raise errors.ForbiddenError(msg=f"今日广告观看次数已达上限({AdShareService.DAILY_AD_LIMIT}次)")
|
||||
|
||||
async with async_db_session.begin() as db:
|
||||
try:
|
||||
# 增加计数
|
||||
current_count = await AdShareService._increment_daily_count(user_id, "ad")
|
||||
|
||||
# 增加用户余额
|
||||
result = await points_dao.update_balance_atomic(db, user_id, AdShareService.AD_REWARD_TIMES)
|
||||
if not result:
|
||||
# 回滚计数
|
||||
key = AdShareService._get_redis_key(user_id, "ad")
|
||||
await redis_client.decr(key)
|
||||
raise errors.ServerError(msg="账户更新失败")
|
||||
|
||||
# 记录使用日志
|
||||
account = await points_dao.get_by_user_id(db, user_id)
|
||||
await usage_log_dao.add(db, {
|
||||
"user_id": user_id,
|
||||
"action": "ad",
|
||||
"amount": AdShareService.AD_REWARD_TIMES,
|
||||
"balance_after": account.balance if account else AdShareService.AD_REWARD_TIMES,
|
||||
"metadata_": {
|
||||
"daily_count": current_count,
|
||||
"max_limit": AdShareService.DAILY_AD_LIMIT
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if not isinstance(e, errors.ForbiddenError):
|
||||
logger.error(f"广告奖励处理失败: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def grant_times_by_share(user_id: int):
|
||||
"""
|
||||
通过分享获得次数
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
"""
|
||||
# 检查每日限制
|
||||
if await AdShareService._check_daily_limit(user_id, "share", AdShareService.DAILY_SHARE_LIMIT):
|
||||
raise errors.ForbiddenError(msg=f"今日分享次数已达上限({AdShareService.DAILY_SHARE_LIMIT}次)")
|
||||
|
||||
async with async_db_session.begin() as db:
|
||||
try:
|
||||
# 增加计数
|
||||
current_count = await AdShareService._increment_daily_count(user_id, "share")
|
||||
|
||||
# 增加用户余额
|
||||
result = await points_dao.update_balance_atomic(db, user_id, AdShareService.SHARE_REWARD_TIMES)
|
||||
if not result:
|
||||
# 回滚计数
|
||||
key = AdShareService._get_redis_key(user_id, "share")
|
||||
await redis_client.decr(key)
|
||||
raise errors.ServerError(msg="账户更新失败")
|
||||
|
||||
# 记录使用日志
|
||||
account = await points_dao.get_by_user_id(db, user_id)
|
||||
await usage_log_dao.add(db, {
|
||||
"user_id": user_id,
|
||||
"action": "share",
|
||||
"amount": AdShareService.SHARE_REWARD_TIMES,
|
||||
"balance_after": account.balance if account else AdShareService.SHARE_REWARD_TIMES,
|
||||
"metadata_": {
|
||||
"daily_count": current_count,
|
||||
"max_limit": AdShareService.DAILY_SHARE_LIMIT
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if not isinstance(e, errors.ForbiddenError):
|
||||
logger.error(f"分享奖励处理失败: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def get_daily_stats(user_id: int) -> dict:
|
||||
"""
|
||||
获取用户今日统计信息
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
dict: 统计信息
|
||||
"""
|
||||
try:
|
||||
ad_count = await AdShareService._get_daily_count(user_id, "ad")
|
||||
share_count = await AdShareService._get_daily_count(user_id, "share")
|
||||
|
||||
return {
|
||||
"ad_count": ad_count,
|
||||
"ad_limit": AdShareService.DAILY_AD_LIMIT,
|
||||
"share_count": share_count,
|
||||
"share_limit": AdShareService.DAILY_SHARE_LIMIT,
|
||||
"can_watch_ad": ad_count < AdShareService.DAILY_AD_LIMIT,
|
||||
"can_share": share_count < AdShareService.DAILY_SHARE_LIMIT
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取每日统计失败: {str(e)}")
|
||||
return {
|
||||
"ad_count": 0,
|
||||
"ad_limit": AdShareService.DAILY_AD_LIMIT,
|
||||
"share_count": 0,
|
||||
"share_limit": AdShareService.DAILY_SHARE_LIMIT,
|
||||
"can_watch_ad": True,
|
||||
"can_share": True
|
||||
}
|
||||
@@ -4,6 +4,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from backend.app.admin.crud.points_crud import points_dao, points_log_dao
|
||||
from backend.app.admin.model.points import Points
|
||||
from backend.database.db import async_db_session
|
||||
from backend.common.const import (
|
||||
POINTS_ACTION_RECHARGE,
|
||||
POINTS_ACTION_SPEND
|
||||
)
|
||||
|
||||
|
||||
class PointsService:
|
||||
@@ -17,9 +21,6 @@ class PointsService:
|
||||
points_account_before = await points_dao.get_by_user_id(db, user_id)
|
||||
balance_before = points_account_before.balance if points_account_before else 0
|
||||
|
||||
# 检查并清空过期积分
|
||||
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
|
||||
|
||||
# 如果清空了过期积分,记录日志
|
||||
if expired_cleared and balance_before > 0:
|
||||
await points_log_dao.add_log(db, {
|
||||
@@ -35,30 +36,13 @@ class PointsService:
|
||||
@staticmethod
|
||||
async def get_user_balance(user_id: int) -> int:
|
||||
"""
|
||||
获取用户积分余额(会检查并清空过期积分)
|
||||
获取用户积分余额
|
||||
"""
|
||||
async with async_db_session.begin() as db:
|
||||
# 获取当前积分余额(清空前)
|
||||
points_account_before = await points_dao.get_by_user_id(db, user_id)
|
||||
balance_before = points_account_before.balance if points_account_before else 0
|
||||
|
||||
# 检查并清空过期积分
|
||||
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
|
||||
|
||||
# 如果清空了过期积分,记录日志
|
||||
if expired_cleared and balance_before > 0:
|
||||
await points_log_dao.add_log(db, {
|
||||
"user_id": user_id,
|
||||
"action": "expire_clear",
|
||||
"balance_before": balance_before,
|
||||
"balance_after": 0,
|
||||
"details": {"message": "过期积分已清空", "cleared_amount": balance_before}
|
||||
})
|
||||
|
||||
return await points_dao.get_balance(db, user_id)
|
||||
|
||||
@staticmethod
|
||||
async def add_points(user_id: int, amount: int, extend_expiration: bool = False, related_id: Optional[int] = None, details: Optional[dict] = None) -> bool:
|
||||
async def add_points(user_id: int, amount: int, related_id: Optional[int] = None, details: Optional[dict] = None, action: Optional[str] = None) -> bool:
|
||||
"""
|
||||
为用户增加积分
|
||||
|
||||
@@ -84,21 +68,22 @@ class PointsService:
|
||||
current_balance = points_account.balance
|
||||
|
||||
# 原子性增加积分(可能延期过期时间)
|
||||
result = await points_dao.add_points_atomic(db, user_id, amount, extend_expiration)
|
||||
result = await points_dao.add_points_atomic(db, user_id, amount)
|
||||
if not result:
|
||||
return False
|
||||
|
||||
# 准备日志详情
|
||||
log_details = details or {}
|
||||
if extend_expiration:
|
||||
log_details["expiration_extended"] = True
|
||||
log_details["extension_days"] = 30
|
||||
|
||||
# 记录积分变动日志
|
||||
new_balance = current_balance + amount
|
||||
# 根据是否延期过期时间来确定action类型
|
||||
if not action:
|
||||
action = POINTS_ACTION_RECHARGE
|
||||
|
||||
await points_log_dao.add_log(db, {
|
||||
"user_id": user_id,
|
||||
"action": "earn",
|
||||
"action": action,
|
||||
"amount": amount,
|
||||
"balance_after": new_balance,
|
||||
"related_id": related_id,
|
||||
@@ -110,7 +95,7 @@ class PointsService:
|
||||
@staticmethod
|
||||
async def add_points_from_coupon(user_id: int, amount: int, coupon_id: int) -> bool:
|
||||
"""
|
||||
为用户增加积分(来自兑换券),积分有效期从兑换时间开始加30天
|
||||
为用户增加积分(来自兑换券)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
@@ -138,15 +123,14 @@ class PointsService:
|
||||
|
||||
# 记录积分变动日志,包含coupon id和区分action为coupon
|
||||
log_details = {
|
||||
"coupon_id": coupon_id,
|
||||
"expiration_days": 30
|
||||
"coupon_id": coupon_id
|
||||
}
|
||||
|
||||
# 记录积分变动日志
|
||||
new_balance = current_balance + amount
|
||||
await points_log_dao.add_log(db, {
|
||||
"user_id": user_id,
|
||||
"action": "coupon", # 区分action为coupon
|
||||
"action": POINTS_ACTION_COUPON, # 使用充值action类型
|
||||
"amount": amount,
|
||||
"balance_after": new_balance,
|
||||
"related_id": coupon_id, # 关联coupon id
|
||||
@@ -155,73 +139,10 @@ class PointsService:
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def deduct_points(user_id: int, amount: int, related_id: Optional[int] = None, details: Optional[dict] = None) -> bool:
|
||||
"""
|
||||
扣减用户积分(会检查并清空过期积分)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
amount: 扣减的积分数量
|
||||
related_id: 关联ID(可选)
|
||||
details: 附加信息(可选)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
if amount <= 0:
|
||||
raise ValueError("积分数量必须大于0")
|
||||
|
||||
async with async_db_session.begin() as db:
|
||||
# 获取当前积分余额(清空前)
|
||||
points_account_before = await points_dao.get_by_user_id(db, user_id)
|
||||
if not points_account_before:
|
||||
return False
|
||||
|
||||
balance_before = points_account_before.balance
|
||||
|
||||
# 检查并清空过期积分
|
||||
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
|
||||
|
||||
# 如果清空了过期积分,记录日志
|
||||
if expired_cleared and balance_before > 0:
|
||||
await points_log_dao.add_log(db, {
|
||||
"user_id": user_id,
|
||||
"action": "expire_clear",
|
||||
"amount": balance_before, # 记录清空前的积分数量
|
||||
"balance_after": 0,
|
||||
"details": {"message": "过期积分已清空", "cleared_amount": balance_before}
|
||||
})
|
||||
|
||||
# 重新获取账户信息(可能已被清空)
|
||||
points_account = await points_dao.get_by_user_id(db, user_id)
|
||||
if not points_account or points_account.balance < amount:
|
||||
return False
|
||||
|
||||
current_balance = points_account.balance
|
||||
|
||||
# 原子性扣减积分
|
||||
result = await points_dao.deduct_points_atomic(db, user_id, amount)
|
||||
if not result:
|
||||
return False
|
||||
|
||||
# 记录积分变动日志
|
||||
new_balance = current_balance - amount
|
||||
await points_log_dao.add_log(db, {
|
||||
"user_id": user_id,
|
||||
"action": "spend",
|
||||
"amount": amount,
|
||||
"balance_after": new_balance,
|
||||
"related_id": related_id,
|
||||
"details": details
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def check_sufficient_points(user_id: int, required_amount: int) -> bool:
|
||||
"""
|
||||
检查用户是否有足够的积分
|
||||
检查用户是否有足够的积分(简化逻辑,移除过期积分检查)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
@@ -233,25 +154,8 @@ class PointsService:
|
||||
if required_amount <= 0:
|
||||
return True
|
||||
|
||||
async with async_db_session.begin() as db:
|
||||
# 获取当前积分余额(清空前)
|
||||
points_account_before = await points_dao.get_by_user_id(db, user_id)
|
||||
balance_before = points_account_before.balance if points_account_before else 0
|
||||
|
||||
# 检查并清空过期积分
|
||||
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
|
||||
|
||||
# 如果清空了过期积分,记录日志
|
||||
if expired_cleared and balance_before > 0:
|
||||
await points_log_dao.add_log(db, {
|
||||
"user_id": user_id,
|
||||
"action": "expire_clear",
|
||||
"amount": balance_before, # 记录清空前的积分数量
|
||||
"balance_after": 0,
|
||||
"details": {"message": "过期积分已清空", "cleared_amount": balance_before}
|
||||
})
|
||||
|
||||
# 重新获取账户信息(可能已被清空)
|
||||
async with async_db_session() as db:
|
||||
# 直接获取用户积分余额(不再检查过期积分)
|
||||
points_account = await points_dao.get_by_user_id(db, user_id)
|
||||
if not points_account:
|
||||
return False
|
||||
@@ -259,9 +163,9 @@ class PointsService:
|
||||
return points_account.balance >= required_amount
|
||||
|
||||
@staticmethod
|
||||
async def deduct_points_with_db(user_id: int, amount: int, db: AsyncSession, related_id: Optional[int] = None, details: Optional[dict] = None) -> bool:
|
||||
async def deduct_points_with_db(user_id: int, amount: int, db: AsyncSession, related_id: Optional[int] = None, details: Optional[dict] = None, action: Optional[str] = None) -> bool:
|
||||
"""
|
||||
扣减用户积分(会检查并清空过期积分)- 使用提供的数据库连接
|
||||
扣减用户积分(简化逻辑,移除过期积分检查)- 使用提供的数据库连接
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
@@ -269,6 +173,7 @@ class PointsService:
|
||||
db: 数据库连接
|
||||
related_id: 关联ID(可选)
|
||||
details: 附加信息(可选)
|
||||
action: 积分操作类型(可选),默认为POINTS_ACTION_SPEND
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
@@ -276,27 +181,7 @@ class PointsService:
|
||||
if amount <= 0:
|
||||
raise ValueError("积分数量必须大于0")
|
||||
|
||||
# 获取当前积分余额(清空前)
|
||||
points_account_before = await points_dao.get_by_user_id(db, user_id)
|
||||
if not points_account_before:
|
||||
return False
|
||||
|
||||
balance_before = points_account_before.balance
|
||||
|
||||
# 检查并清空过期积分
|
||||
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
|
||||
|
||||
# 如果清空了过期积分,记录日志
|
||||
if expired_cleared and balance_before > 0:
|
||||
await points_log_dao.add_log(db, {
|
||||
"user_id": user_id,
|
||||
"action": "expire_clear",
|
||||
"amount": balance_before, # 记录清空前的积分数量
|
||||
"balance_after": 0,
|
||||
"details": {"message": "过期积分已清空", "cleared_amount": balance_before}
|
||||
})
|
||||
|
||||
# 重新获取账户信息(可能已被清空)
|
||||
# 直接获取用户积分账户(不再检查过期积分)
|
||||
points_account = await points_dao.get_by_user_id(db, user_id)
|
||||
if not points_account or points_account.balance < amount:
|
||||
return False
|
||||
@@ -304,15 +189,19 @@ class PointsService:
|
||||
current_balance = points_account.balance
|
||||
|
||||
# 原子性扣减积分
|
||||
result = await points_dao.deduct_points_atomic(db, user_id, amount)
|
||||
result = await points_dao.deduct_balance_atomic(db, user_id, amount)
|
||||
if not result:
|
||||
return False
|
||||
|
||||
# 记录积分变动日志
|
||||
new_balance = current_balance - amount
|
||||
# 如果action参数为空,则默认使用POINTS_ACTION_SPEND
|
||||
if action is None:
|
||||
action = POINTS_ACTION_SPEND
|
||||
|
||||
await points_log_dao.add_log(db, {
|
||||
"user_id": user_id,
|
||||
"action": "spend",
|
||||
"action": action,
|
||||
"amount": amount,
|
||||
"balance_after": new_balance,
|
||||
"related_id": related_id,
|
||||
@@ -324,58 +213,7 @@ class PointsService:
|
||||
@staticmethod
|
||||
async def initialize_user_points(user_id: int, db: AsyncSession = None) -> Points:
|
||||
"""
|
||||
为新用户初始化积分账户
|
||||
"""
|
||||
if db is not None:
|
||||
# Use the provided session (for nested transactions)
|
||||
points_account = await points_dao.get_by_user_id(db, user_id)
|
||||
if not points_account:
|
||||
points_account = await points_dao.create_user_points(db, user_id)
|
||||
return points_account
|
||||
else:
|
||||
# Create a new transaction (standalone usage)
|
||||
async with async_db_session.begin() as db:
|
||||
points_account = await points_dao.get_by_user_id(db, user_id)
|
||||
if not points_account:
|
||||
points_account = await points_dao.create_user_points(db, user_id)
|
||||
return points_account
|
||||
|
||||
@staticmethod
|
||||
async def is_subscribed_user(user_id: int) -> bool:
|
||||
"""
|
||||
检查用户是否为订阅用户(使用显式的订阅标志)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否为订阅用户
|
||||
"""
|
||||
try:
|
||||
async with async_db_session() as db:
|
||||
points_account = await points_dao.get_by_user_id(db, user_id)
|
||||
if points_account and points_account.is_subscribed:
|
||||
return True
|
||||
|
||||
# 如果没有显式的订阅标志,回退到旧的检查方式
|
||||
if points_account:
|
||||
# 检查是否有有效的订阅
|
||||
if points_account.subscription_type and points_account.subscription_expires_at:
|
||||
if points_account.subscription_expires_at > datetime.now():
|
||||
return True
|
||||
|
||||
# 检查是否有积分余额
|
||||
return points_account.balance > 0
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
# 如果出现异常,默认认为不是订阅用户
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def create_new_user_account(user_id: int, db: AsyncSession = None) -> Points:
|
||||
"""
|
||||
为新用户创建账户(包含免费试用)
|
||||
为新用户初始化积分账户(根据新需求,直接在balance中增加初始积分)
|
||||
"""
|
||||
if db is not None:
|
||||
# Use the provided session (for nested transactions)
|
||||
@@ -391,39 +229,6 @@ class PointsService:
|
||||
points_account = await points_dao.create_new_user_account(db, user_id)
|
||||
return points_account
|
||||
|
||||
@staticmethod
|
||||
async def update_subscription_status(user_id: int) -> bool:
|
||||
"""
|
||||
更新用户的订阅状态
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
bool: 是否为订阅用户
|
||||
"""
|
||||
async with async_db_session.begin() as db:
|
||||
points_account = await points_dao.get_by_user_id(db, user_id)
|
||||
if not points_account:
|
||||
return False
|
||||
|
||||
# 检查是否有有效的订阅
|
||||
is_currently_subscribed = False
|
||||
if points_account.subscription_type and points_account.subscription_expires_at:
|
||||
if points_account.subscription_expires_at > datetime.now():
|
||||
is_currently_subscribed = True
|
||||
|
||||
# 检查是否有余额
|
||||
if points_account.balance > 0:
|
||||
is_currently_subscribed = True
|
||||
|
||||
# 更新订阅标志
|
||||
if points_account.is_subscribed != is_currently_subscribed:
|
||||
points_account.is_subscribed = is_currently_subscribed
|
||||
await points_dao.update(db, points_account.id, {"is_subscribed": is_currently_subscribed})
|
||||
|
||||
return is_currently_subscribed
|
||||
|
||||
@staticmethod
|
||||
async def get_user_account_details(user_id: int) -> dict:
|
||||
"""
|
||||
@@ -434,44 +239,13 @@ class PointsService:
|
||||
if not points_account:
|
||||
return {}
|
||||
|
||||
frozen_balance = await points_dao.get_frozen_balance(db, user_id)
|
||||
available_balance = max(0, points_account.balance - frozen_balance)
|
||||
|
||||
# 检查免费试用状态
|
||||
is_free_trial_active = (
|
||||
points_account.free_trial_expires_at and
|
||||
points_account.free_trial_expires_at > datetime.now()
|
||||
)
|
||||
|
||||
return {
|
||||
"balance": points_account.balance,
|
||||
"available_balance": available_balance,
|
||||
"frozen_balance": frozen_balance,
|
||||
"total_purchased": points_account.total_earned, # Using total_earned as equivalent to total_purchased
|
||||
"subscription_type": points_account.subscription_type,
|
||||
"subscription_expires_at": points_account.subscription_expires_at,
|
||||
"carryover_balance": points_account.carryover_balance,
|
||||
|
||||
# 免费试用信息
|
||||
"free_trial_balance": points_account.free_trial_balance,
|
||||
"free_trial_expires_at": points_account.free_trial_expires_at,
|
||||
"free_trial_active": is_free_trial_active,
|
||||
"free_trial_used": points_account.free_trial_used,
|
||||
|
||||
# 计算剩余天数
|
||||
"free_trial_days_left": (
|
||||
max(0, (points_account.free_trial_expires_at - datetime.now()).days)
|
||||
if points_account.free_trial_expires_at else 0
|
||||
)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def check_free_trial_valid(user_id: int) -> bool:
|
||||
"""
|
||||
检查用户免费试用是否仍然有效
|
||||
"""
|
||||
async with async_db_session() as db:
|
||||
return await points_dao.check_free_trial_valid(db, user_id)
|
||||
|
||||
|
||||
points_service:PointsService = PointsService()
|
||||
@@ -52,10 +52,8 @@ class WxAuthService:
|
||||
await wx_user_dao.add(db, user)
|
||||
await db.flush()
|
||||
await db.refresh(user)
|
||||
# initialize user points
|
||||
# initialize user points with initial gift points (according to new requirements)
|
||||
await points_service.initialize_user_points(user_id=user.id, db=db)
|
||||
# initialize user account
|
||||
await points_service.create_new_user_account(user.id, db=db)
|
||||
else:
|
||||
await wx_user_dao.update_session_key(db, user.id, session_key)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from backend.app.admin.service.file_service import file_service
|
||||
from backend.app.ai.service.rate_limit_service import rate_limit_service, IMAGE_RECOGNITION_SERVICE
|
||||
from backend.common.enums import FileType
|
||||
from backend.common.exception import errors
|
||||
from backend.common.const import IMAGE_RECOGNITION_COST
|
||||
from backend.common.const import IMAGE_RECOGNITION_COST, POINTS_ACTION_IMAGE_RECOGNITION
|
||||
from backend.core.conf import settings
|
||||
from backend.common.log import log as logger
|
||||
|
||||
@@ -287,22 +287,11 @@ class ImageService:
|
||||
if not dict_level:
|
||||
dict_level = current_user.dict_level.name
|
||||
|
||||
# 优先检查是否订阅用户
|
||||
is_subscribed = await rate_limit_service.is_subscribed_user(current_user.id)
|
||||
|
||||
if is_subscribed:
|
||||
# 订阅用户检查积分是否足够
|
||||
if not await points_service.check_sufficient_points(current_user.id, IMAGE_RECOGNITION_COST):
|
||||
raise errors.ForbiddenError(
|
||||
msg=f'积分不足,请充值以继续使用'
|
||||
)
|
||||
else:
|
||||
# 非订阅用户检查是否有免费使用次数
|
||||
usage_allowed = await rate_limit_service.increment_daily_usage(current_user.id, IMAGE_RECOGNITION_SERVICE, DAILY_IMAGE_RECOGNITION_MAX_TIMES)
|
||||
if not usage_allowed:
|
||||
raise errors.ForbiddenError(
|
||||
msg=f'未订阅用户每天只能使用 {DAILY_IMAGE_RECOGNITION_MAX_TIMES} 次图像识别服务,请升级为订阅用户以解除限制'
|
||||
)
|
||||
# 检查用户积分是否足够(现在积分没有过期概念)
|
||||
if not await points_service.check_sufficient_points(current_user.id, IMAGE_RECOGNITION_COST):
|
||||
raise errors.ForbiddenError(
|
||||
msg=f'积分不足,请充值以继续使用'
|
||||
)
|
||||
|
||||
# 尝试获取任务槽位
|
||||
slot_acquired = await rate_limit_service.acquire_task_slot(current_user.id)
|
||||
@@ -1046,7 +1035,8 @@ class ImageService:
|
||||
amount=IMAGE_RECOGNITION_COST, # 扣减图片识别费用
|
||||
db=db,
|
||||
related_id=image.id,
|
||||
details={"action": "image_recognition", "task_id": task_id}
|
||||
details={"task_id": task_id},
|
||||
action=POINTS_ACTION_IMAGE_RECOGNITION
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
@@ -47,15 +47,16 @@ class RateLimitService:
|
||||
@staticmethod
|
||||
async def is_subscribed_user(user_id: int) -> bool:
|
||||
"""
|
||||
检查用户是否为订阅用户(使用显式的订阅标志)
|
||||
检查用户是否为订阅用户(此功能已废弃,根据新需求统一使用积分系统)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否为订阅用户
|
||||
是否为订阅用户(始终返回False,因为订阅概念已废弃)
|
||||
"""
|
||||
return await PointsService.is_subscribed_user(user_id)
|
||||
# 根据新需求,订阅概念已废弃,统一使用积分系统
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def get_user_task_limit(user_id: int) -> int:
|
||||
@@ -310,41 +311,7 @@ class RateLimitService:
|
||||
|
||||
return status
|
||||
|
||||
@staticmethod
|
||||
async def check_daily_usage_limit(user_id: int, service_type: str, max_usage: int = 3) -> tuple[bool, int, int]:
|
||||
"""
|
||||
检查用户每日使用次数限制
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
service_type: 服务类型(如"image_recognition", "speech_assessment")
|
||||
max_usage: 最大使用次数(默认3次)
|
||||
|
||||
Returns:
|
||||
(是否超过限制, 当前使用次数, 最大使用次数)
|
||||
"""
|
||||
# 检查用户是否为订阅用户
|
||||
is_subscribed = await RateLimitService.is_subscribed_user(user_id)
|
||||
|
||||
# 订阅用户不受限制
|
||||
if is_subscribed:
|
||||
return False, 0, max_usage
|
||||
|
||||
# 获取每日使用次数键
|
||||
usage_key = RateLimitService._get_daily_usage_key(user_id, service_type)
|
||||
|
||||
# 获取当前使用次数
|
||||
current_usage = await redis_client.get(usage_key)
|
||||
if current_usage is None:
|
||||
current_usage = 0
|
||||
else:
|
||||
current_usage = int(current_usage)
|
||||
|
||||
# 检查是否超过限制
|
||||
is_limited = current_usage >= max_usage
|
||||
|
||||
return is_limited, current_usage, max_usage
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def increment_daily_usage(user_id: int, service_type: str, max_usage: int = 3) -> bool:
|
||||
"""
|
||||
@@ -357,12 +324,7 @@ class RateLimitService:
|
||||
Returns:
|
||||
是否成功增加(未超过限制)
|
||||
"""
|
||||
# 检查是否超过每日使用限制
|
||||
is_limited, current_usage, _max_usage = await RateLimitService.check_daily_usage_limit(user_id, service_type, max_usage)
|
||||
|
||||
if is_limited:
|
||||
return False
|
||||
|
||||
|
||||
# 使用新的免费使用次数跟踪系统
|
||||
await RateLimitService.increment_free_usage_count(user_id, service_type)
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ from backend.app.admin.service.audit_log_service import audit_log_service
|
||||
from backend.app.ai.service.rate_limit_service import rate_limit_service, SPEECH_ASSESSMENT_SERVICE
|
||||
from backend.database.db import async_db_session
|
||||
from backend.middleware.tencent_cloud import TencentCloud
|
||||
from backend.common.const import SPEECH_ASSESSMENT_COST, POINTS_ACTION_SPEECH_ASSESSMENT
|
||||
from backend.app.admin.service.points_service import points_service
|
||||
|
||||
# Import the recording_dao for accessing recording CRUD methods
|
||||
from backend.app.ai.crud.recording_crud import recording_dao
|
||||
@@ -337,10 +339,9 @@ class RecordingService:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to create recording record for file_id {file_id}: {str(e)}")
|
||||
|
||||
# 检查每日使用次数限制(针对每条recording记录)
|
||||
usage_allowed = await rate_limit_service.increment_daily_usage(user_id, f"{SPEECH_ASSESSMENT_SERVICE}_{recording.id}", 1)
|
||||
if not usage_allowed:
|
||||
raise RuntimeError('免费用户每天每条语音只能使用1次语音评测服务,请升级为订阅用户以解除限制')
|
||||
# 检查用户积分是否足够(现在积分没有过期概念)
|
||||
if not await points_service.check_sufficient_points(user_id, SPEECH_ASSESSMENT_COST):
|
||||
raise RuntimeError('积分不足,请充值以继续使用')
|
||||
|
||||
try:
|
||||
# 调用腾讯云SOE API进行语音评估
|
||||
@@ -354,6 +355,19 @@ class RecordingService:
|
||||
if not success:
|
||||
raise RuntimeError(f"Failed to update recording details for file_id {file_id}")
|
||||
|
||||
# 扣减用户积分
|
||||
async with async_db_session.begin() as db:
|
||||
points_deducted = await points_service.deduct_points_with_db(
|
||||
user_id=user_id,
|
||||
amount=SPEECH_ASSESSMENT_COST,
|
||||
db=db,
|
||||
related_id=recording.id,
|
||||
details={"recording_id": recording.id},
|
||||
action=POINTS_ACTION_SPEECH_ASSESSMENT
|
||||
)
|
||||
if not points_deducted:
|
||||
logger.warning(f"Failed to deduct points for user {user_id} for speech assessment")
|
||||
|
||||
# 计算耗时
|
||||
duration = time.time() - start_time
|
||||
|
||||
|
||||
@@ -44,14 +44,14 @@ class TestImageServiceRateLimit:
|
||||
dict_level=DictLevel.LEVEL1
|
||||
)
|
||||
|
||||
@patch('backend.app.ai.service.rate_limit_service.rate_limit_service.increment_daily_usage')
|
||||
@patch('backend.app.admin.service.points_service.points_service.check_sufficient_points')
|
||||
@patch('backend.app.admin.service.file_service.file_service.download_file')
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_image_from_file_async_daily_usage_limit_exceeded(self, mock_download_file, mock_increment_daily_usage):
|
||||
"""Test that process_image_from_file_async raises ForbiddenError when daily usage limit is exceeded"""
|
||||
async def test_process_image_from_file_async_daily_usage_limit_exceeded(self, mock_download_file, mock_check_sufficient_points):
|
||||
"""Test that process_image_from_file_async raises ForbiddenError when user has insufficient points"""
|
||||
# Arrange
|
||||
# Mock rate_limit_service.increment_daily_usage to return False (limit exceeded)
|
||||
mock_increment_daily_usage.return_value = False
|
||||
# Mock points_service.check_sufficient_points to return False (insufficient points)
|
||||
mock_check_sufficient_points.return_value = False
|
||||
|
||||
# Mock file_service.download_file to return dummy data
|
||||
mock_download_file.return_value = (b"dummy_image_data", "test.jpg", "image/jpeg")
|
||||
@@ -65,25 +65,24 @@ class TestImageServiceRateLimit:
|
||||
)
|
||||
|
||||
# Verify the error message
|
||||
assert "免费用户每天只能使用3次图像识别服务,请升级为订阅用户以解除限制" in str(exc_info.value.msg)
|
||||
assert "积分不足,请充值以继续使用" in str(exc_info.value.msg)
|
||||
|
||||
# Verify that increment_daily_usage was called with correct parameters
|
||||
mock_increment_daily_usage.assert_awaited_once_with(
|
||||
# Verify that check_sufficient_points was called with correct parameters
|
||||
mock_check_sufficient_points.assert_awaited_once_with(
|
||||
self.user_id,
|
||||
IMAGE_RECOGNITION_SERVICE,
|
||||
3 # DAILY_IMAGE_RECOGNITION_MAX_TIMES
|
||||
IMAGE_RECOGNITION_COST
|
||||
)
|
||||
|
||||
@patch('backend.app.ai.service.rate_limit_service.rate_limit_service.increment_daily_usage')
|
||||
@patch('backend.app.admin.service.points_service.points_service.check_sufficient_points')
|
||||
@patch('backend.app.ai.service.rate_limit_service.rate_limit_service.acquire_task_slot')
|
||||
@patch('backend.app.ai.service.rate_limit_service.rate_limit_service.get_user_task_limit')
|
||||
@patch('backend.app.admin.service.file_service.file_service.download_file')
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_image_from_file_async_task_slot_limit_exceeded(self, mock_download_file, mock_get_user_task_limit, mock_acquire_task_slot, mock_increment_daily_usage):
|
||||
async def test_process_image_from_file_async_task_slot_limit_exceeded(self, mock_download_file, mock_get_user_task_limit, mock_acquire_task_slot, mock_check_sufficient_points):
|
||||
"""Test that process_image_from_file_async raises ForbiddenError when task slot limit is exceeded"""
|
||||
# Arrange
|
||||
# Mock rate_limit_service.increment_daily_usage to return True (within limit)
|
||||
mock_increment_daily_usage.return_value = True
|
||||
# Mock points_service.check_sufficient_points to return True (sufficient points)
|
||||
mock_check_sufficient_points.return_value = True
|
||||
|
||||
# Mock rate_limit_service.acquire_task_slot to return False (no slot available)
|
||||
mock_acquire_task_slot.return_value = False
|
||||
@@ -106,11 +105,10 @@ class TestImageServiceRateLimit:
|
||||
assert "用户同时最多只能运行" in str(exc_info.value.msg)
|
||||
assert "个任务,请等待现有任务完成后再试" in str(exc_info.value.msg)
|
||||
|
||||
# Verify that increment_daily_usage was called
|
||||
mock_increment_daily_usage.assert_awaited_once_with(
|
||||
# Verify that check_sufficient_points was called
|
||||
mock_check_sufficient_points.assert_awaited_once_with(
|
||||
self.user_id,
|
||||
IMAGE_RECOGNITION_SERVICE,
|
||||
3 # DAILY_IMAGE_RECOGNITION_MAX_TIMES
|
||||
IMAGE_RECOGNITION_COST
|
||||
)
|
||||
|
||||
# Verify that acquire_task_slot was called
|
||||
|
||||
Reference in New Issue
Block a user