change points model and service

This commit is contained in:
felix
2025-11-27 13:38:45 +08:00
parent 3d96f3d777
commit 2abbe776f8
9 changed files with 106 additions and 700 deletions

View File

@@ -5,6 +5,7 @@ from sqlalchemy import select, update, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.admin.model.points import Points, PointsLog
from backend.common.const import FREE_TRIAL_BALANCE, POINTS_ACTION_SYSTEM_GIFT
class PointsDao(CRUDPlus[Points]):
@@ -25,7 +26,7 @@ class PointsDao(CRUDPlus[Points]):
await db.flush()
return points
async def add_points_atomic(self, db: AsyncSession, user_id: int, amount: int, extend_expiration: bool = False) -> bool:
async def add_points_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
"""
原子性增加用户积分
"""
@@ -40,30 +41,12 @@ class PointsDao(CRUDPlus[Points]):
"total_earned": Points.total_earned + amount
}
# 如果需要延期,则更新过期时间
if extend_expiration:
update_values["expired_time"] = datetime.now() + timedelta(days=30)
stmt = update(Points).where(
Points.user_id == user_id
).values(**update_values)
result = await db.execute(stmt)
return result.rowcount > 0
async def deduct_points_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
"""
原子性扣减用户积分(确保不超扣)
"""
stmt = update(Points).where(
Points.user_id == user_id,
Points.balance >= amount
).values(
balance=Points.balance - amount,
total_spent=Points.total_spent + amount
)
result = await db.execute(stmt)
return result.rowcount > 0
async def get_balance(self, db: AsyncSession, user_id: int) -> int:
"""
获取用户积分余额
@@ -73,50 +56,33 @@ class PointsDao(CRUDPlus[Points]):
return 0
return points_account.balance
async def check_and_clear_expired_points(self, db: AsyncSession, user_id: int) -> bool:
"""
检查并清空过期积分
"""
stmt = update(Points).where(
Points.user_id == user_id,
Points.expired_time < datetime.now(),
Points.balance > 0
).values(
balance=0,
total_spent=Points.total_spent + Points.balance
)
result = await db.execute(stmt)
return result.rowcount > 0
async def create_new_user_account(self, db: AsyncSession, user_id: int) -> Points:
"""
为新用户创建账户(包含免费试用
为新用户创建账户(根据新需求直接在balance中增加初始积分
"""
# 设置免费试用期3天
trial_expires_at = datetime.now() + timedelta(days=3)
# 为新用户创建积分账户初始赠送30积分
initial_points = FREE_TRIAL_BALANCE # 使用原来的常量作为初始积分
points = Points(
user_id=user_id,
balance=30, # 初始30次免费次数
free_trial_balance=30,
free_trial_expires_at=trial_expires_at,
free_trial_used=True # 标记为已使用因为已经给了30次
balance=initial_points, # 直接在balance中增加初始积分
total_earned=initial_points # 同时更新累计获得积分
)
db.add(points)
await db.flush()
return points
async def update_balance_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
"""
原子性更新用户余额
"""
stmt = update(Points).where(
Points.user_id == user_id
).values(
balance=Points.balance + amount
# 创建积分变动日志,记录新用户获赠积分
log = PointsLog(
user_id=user_id,
action=POINTS_ACTION_SYSTEM_GIFT,
amount=initial_points,
balance_after=initial_points,
details={"message": "新用户注册赠送积分"}
)
result = await db.execute(stmt)
return result.rowcount > 0
db.add(log)
await db.flush()
return points
async def deduct_balance_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
"""
@@ -126,49 +92,12 @@ class PointsDao(CRUDPlus[Points]):
Points.user_id == user_id,
Points.balance >= amount
).values(
balance=Points.balance - amount
balance=Points.balance - amount,
total_spent=Points.total_spent + amount
)
result = await db.execute(stmt)
return result.rowcount > 0
async def get_frozen_balance(self, db: AsyncSession, user_id: int) -> int:
"""
获取用户被冻结的次数
"""
# This is a placeholder since we don't have a FreezeLog model in the points context yet
# In a full implementation, this would query a freeze log table
return 0
async def get_available_balance(self, db: AsyncSession, user_id: int) -> int:
"""
获取用户可用余额(总余额减去冻结余额)
"""
points_account = await self.get_by_user_id(db, user_id)
if not points_account:
return 0
frozen_balance = await self.get_frozen_balance(db, user_id)
return max(0, points_account.balance - frozen_balance)
async def check_free_trial_valid(self, db: AsyncSession, user_id: int) -> bool:
"""
检查用户免费试用是否仍然有效
"""
points_account = await self.get_by_user_id(db, user_id)
if not points_account or not points_account.free_trial_expires_at:
return False
return (points_account.free_trial_expires_at > datetime.now() and
points_account.free_trial_balance > 0)
async def update(self, db: AsyncSession, points_id: int, update_data: dict) -> bool:
"""
更新积分账户信息
"""
stmt = update(Points).where(Points.id == points_id).values(**update_data)
result = await db.execute(stmt)
return result.rowcount > 0
class PointsLogDao(CRUDPlus[PointsLog]):
async def add_log(self, db: AsyncSession, log_data: Dict[str, Any]) -> PointsLog:

View File

@@ -1,9 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime, timedelta
from typing import Optional
from sqlalchemy import String, BigInteger, DateTime, func, Index, Boolean
from sqlalchemy import String, BigInteger, DateTime, func, Index, Boolean, ForeignKey
from sqlalchemy.dialects.mysql import JSON as MySQLJSON # Changed from postgresql.JSONB to mysql.JSON
from sqlalchemy.orm import Mapped, mapped_column
@@ -14,21 +12,10 @@ class Points(Base):
__tablename__ = 'points'
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, unique=True, nullable=False, comment='关联的用户ID')
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='关联的用户ID')
balance: Mapped[int] = mapped_column(BigInteger, default=0, comment='当前积分余额')
total_earned: Mapped[int] = mapped_column(BigInteger, default=0, comment='累计获得积分')
total_spent: Mapped[int] = mapped_column(BigInteger, default=0, comment='累计消费积分')
expired_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now() + timedelta(days=30), comment="过期时间")
# Subscription and free trial fields (moved from UserAccount)
subscription_type: Mapped[Optional[str]] = mapped_column(String(32), default=None, comment='订阅类型monthly/quarterly/half_yearly/yearly')
subscription_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='订阅到期时间')
is_subscribed: Mapped[bool] = mapped_column(Boolean, default=False, comment='是否为订阅用户')
carryover_balance: Mapped[int] = mapped_column(BigInteger, default=0, comment='上期未使用的次数')
# 新用户免费次数相关
free_trial_balance: Mapped[int] = mapped_column(BigInteger, default=30, comment='新用户免费试用次数')
free_trial_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='免费试用期结束时间')
free_trial_used: Mapped[bool] = mapped_column(Boolean, default=False, comment='是否已使用免费试用')
# 索引优化
__table_args__ = (
@@ -42,15 +29,16 @@ class PointsLog(Base):
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, nullable=False, comment='用户ID')
action: Mapped[str] = mapped_column(String(32), comment='动作earn/spend')
action: Mapped[str] = mapped_column(String(32), comment='积分变更类型system_gift/recharge/purchase/etc')
amount: Mapped[int] = mapped_column(BigInteger, comment='变动数量')
balance_after: Mapped[int] = mapped_column(BigInteger, comment='变动后余额')
related_id: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, comment='关联ID')
details: Mapped[Optional[dict]] = mapped_column(MySQLJSON, default=None, comment='附加信息')
related_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, default=None, comment='关联ID')
details: Mapped[Optional[dict]] = mapped_column(MySQLJSON, nullable=True, default=None, comment='附加信息')
# 索引优化
__table_args__ = (
Index('idx_points_log_user_action', 'user_id', 'action'),
Index('idx_points_log_related', 'related_id'),
Index('idx_points_log_user_time', 'user_id', 'created_time'),
{'comment': '积分变动日志表'}
)

View File

@@ -1,247 +0,0 @@
from backend.app.admin.crud.usage_log_crud import usage_log_dao
from backend.app.admin.crud.points_crud import points_dao
from backend.database.db import async_db_session
from backend.common.exception import errors
from datetime import datetime, timedelta
from backend.common.log import log as logger
# 导入 Redis 客户端
from backend.database.redis import redis_client
class AdShareService:
# 每日限制配置
DAILY_AD_LIMIT = 5 # 每日广告观看限制
DAILY_SHARE_LIMIT = 3 # 每日分享限制
AD_REWARD_TIMES = 3 # 每次广告奖励次数
SHARE_REWARD_TIMES = 3 # 每次分享奖励次数
@staticmethod
def _get_redis_key(user_id: int, action_type: str, date_str: str = None) -> str:
"""
生成 Redis key
Args:
user_id: 用户ID
action_type: 动作类型 (ad/share)
date_str: 日期字符串 (YYYY-MM-DD)
Returns:
str: Redis key
"""
if date_str is None:
date_str = datetime.now().strftime('%Y-%m-%d')
return f"user:{user_id}:{action_type}:count:{date_str}"
@staticmethod
async def _check_daily_limit(user_id: int, action_type: str, limit: int) -> bool:
"""
检查每日限制
Args:
user_id: 用户ID
action_type: 动作类型
limit: 限制次数
Returns:
bool: 是否超过限制
"""
try:
# 确保 Redis 连接
await redis_client.ping()
# 获取今天的计数
key = AdShareService._get_redis_key(user_id, action_type)
current_count = await redis_client.get(key)
if current_count is None:
current_count = 0
else:
current_count = int(current_count)
return current_count >= limit
except Exception as e:
logger.error(f"检查每日限制失败: {str(e)}")
# Redis 异常时允许继续操作(避免服务中断)
return False
@staticmethod
async def _increment_daily_count(user_id: int, action_type: str) -> int:
"""
增加每日计数
Args:
user_id: 用户ID
action_type: 动作类型
Returns:
int: 当前计数
"""
try:
# 确保 Redis 连接
await redis_client.ping()
key = AdShareService._get_redis_key(user_id, action_type)
# 增加计数
current_count = await redis_client.incr(key)
# 设置过期时间(第二天自动清除)
tomorrow = datetime.now() + timedelta(days=1)
tomorrow_midnight = tomorrow.replace(hour=0, minute=0, second=0, microsecond=0)
expire_seconds = int((tomorrow_midnight - datetime.now()).total_seconds())
if expire_seconds > 0:
await redis_client.expire(key, expire_seconds)
return current_count
except Exception as e:
logger.error(f"增加每日计数失败: {str(e)}")
raise errors.ServerError(msg="系统繁忙,请稍后重试")
@staticmethod
async def _get_daily_count(user_id: int, action_type: str) -> int:
"""
获取今日计数
Args:
user_id: 用户ID
action_type: 动作类型
Returns:
int: 当前计数
"""
try:
# 确保 Redis 连接
await redis_client.ping()
key = AdShareService._get_redis_key(user_id, action_type)
count = await redis_client.get(key)
return int(count) if count is not None else 0
except Exception as e:
logger.error(f"获取每日计数失败: {str(e)}")
return 0
@staticmethod
async def grant_times_by_ad(user_id: int):
"""
通过观看广告获得次数
Args:
user_id: 用户ID
"""
# 检查每日限制
if await AdShareService._check_daily_limit(user_id, "ad", AdShareService.DAILY_AD_LIMIT):
raise errors.ForbiddenError(msg=f"今日广告观看次数已达上限({AdShareService.DAILY_AD_LIMIT}次)")
async with async_db_session.begin() as db:
try:
# 增加计数
current_count = await AdShareService._increment_daily_count(user_id, "ad")
# 增加用户余额
result = await points_dao.update_balance_atomic(db, user_id, AdShareService.AD_REWARD_TIMES)
if not result:
# 回滚计数
key = AdShareService._get_redis_key(user_id, "ad")
await redis_client.decr(key)
raise errors.ServerError(msg="账户更新失败")
# 记录使用日志
account = await points_dao.get_by_user_id(db, user_id)
await usage_log_dao.add(db, {
"user_id": user_id,
"action": "ad",
"amount": AdShareService.AD_REWARD_TIMES,
"balance_after": account.balance if account else AdShareService.AD_REWARD_TIMES,
"metadata_": {
"daily_count": current_count,
"max_limit": AdShareService.DAILY_AD_LIMIT
}
})
except Exception as e:
if not isinstance(e, errors.ForbiddenError):
logger.error(f"广告奖励处理失败: {str(e)}")
raise
@staticmethod
async def grant_times_by_share(user_id: int):
"""
通过分享获得次数
Args:
user_id: 用户ID
"""
# 检查每日限制
if await AdShareService._check_daily_limit(user_id, "share", AdShareService.DAILY_SHARE_LIMIT):
raise errors.ForbiddenError(msg=f"今日分享次数已达上限({AdShareService.DAILY_SHARE_LIMIT}次)")
async with async_db_session.begin() as db:
try:
# 增加计数
current_count = await AdShareService._increment_daily_count(user_id, "share")
# 增加用户余额
result = await points_dao.update_balance_atomic(db, user_id, AdShareService.SHARE_REWARD_TIMES)
if not result:
# 回滚计数
key = AdShareService._get_redis_key(user_id, "share")
await redis_client.decr(key)
raise errors.ServerError(msg="账户更新失败")
# 记录使用日志
account = await points_dao.get_by_user_id(db, user_id)
await usage_log_dao.add(db, {
"user_id": user_id,
"action": "share",
"amount": AdShareService.SHARE_REWARD_TIMES,
"balance_after": account.balance if account else AdShareService.SHARE_REWARD_TIMES,
"metadata_": {
"daily_count": current_count,
"max_limit": AdShareService.DAILY_SHARE_LIMIT
}
})
except Exception as e:
if not isinstance(e, errors.ForbiddenError):
logger.error(f"分享奖励处理失败: {str(e)}")
raise
@staticmethod
async def get_daily_stats(user_id: int) -> dict:
"""
获取用户今日统计信息
Args:
user_id: 用户ID
Returns:
dict: 统计信息
"""
try:
ad_count = await AdShareService._get_daily_count(user_id, "ad")
share_count = await AdShareService._get_daily_count(user_id, "share")
return {
"ad_count": ad_count,
"ad_limit": AdShareService.DAILY_AD_LIMIT,
"share_count": share_count,
"share_limit": AdShareService.DAILY_SHARE_LIMIT,
"can_watch_ad": ad_count < AdShareService.DAILY_AD_LIMIT,
"can_share": share_count < AdShareService.DAILY_SHARE_LIMIT
}
except Exception as e:
logger.error(f"获取每日统计失败: {str(e)}")
return {
"ad_count": 0,
"ad_limit": AdShareService.DAILY_AD_LIMIT,
"share_count": 0,
"share_limit": AdShareService.DAILY_SHARE_LIMIT,
"can_watch_ad": True,
"can_share": True
}

View File

@@ -4,6 +4,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
from backend.app.admin.crud.points_crud import points_dao, points_log_dao
from backend.app.admin.model.points import Points
from backend.database.db import async_db_session
from backend.common.const import (
POINTS_ACTION_RECHARGE,
POINTS_ACTION_SPEND
)
class PointsService:
@@ -17,9 +21,6 @@ class PointsService:
points_account_before = await points_dao.get_by_user_id(db, user_id)
balance_before = points_account_before.balance if points_account_before else 0
# 检查并清空过期积分
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
# 如果清空了过期积分,记录日志
if expired_cleared and balance_before > 0:
await points_log_dao.add_log(db, {
@@ -35,30 +36,13 @@ class PointsService:
@staticmethod
async def get_user_balance(user_id: int) -> int:
"""
获取用户积分余额(会检查并清空过期积分)
获取用户积分余额
"""
async with async_db_session.begin() as db:
# 获取当前积分余额(清空前)
points_account_before = await points_dao.get_by_user_id(db, user_id)
balance_before = points_account_before.balance if points_account_before else 0
# 检查并清空过期积分
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
# 如果清空了过期积分,记录日志
if expired_cleared and balance_before > 0:
await points_log_dao.add_log(db, {
"user_id": user_id,
"action": "expire_clear",
"balance_before": balance_before,
"balance_after": 0,
"details": {"message": "过期积分已清空", "cleared_amount": balance_before}
})
return await points_dao.get_balance(db, user_id)
@staticmethod
async def add_points(user_id: int, amount: int, extend_expiration: bool = False, related_id: Optional[int] = None, details: Optional[dict] = None) -> bool:
async def add_points(user_id: int, amount: int, related_id: Optional[int] = None, details: Optional[dict] = None, action: Optional[str] = None) -> bool:
"""
为用户增加积分
@@ -84,21 +68,22 @@ class PointsService:
current_balance = points_account.balance
# 原子性增加积分(可能延期过期时间)
result = await points_dao.add_points_atomic(db, user_id, amount, extend_expiration)
result = await points_dao.add_points_atomic(db, user_id, amount)
if not result:
return False
# 准备日志详情
log_details = details or {}
if extend_expiration:
log_details["expiration_extended"] = True
log_details["extension_days"] = 30
# 记录积分变动日志
new_balance = current_balance + amount
# 根据是否延期过期时间来确定action类型
if not action:
action = POINTS_ACTION_RECHARGE
await points_log_dao.add_log(db, {
"user_id": user_id,
"action": "earn",
"action": action,
"amount": amount,
"balance_after": new_balance,
"related_id": related_id,
@@ -110,7 +95,7 @@ class PointsService:
@staticmethod
async def add_points_from_coupon(user_id: int, amount: int, coupon_id: int) -> bool:
"""
为用户增加积分(来自兑换券)积分有效期从兑换时间开始加30天
为用户增加积分(来自兑换券)
Args:
user_id: 用户ID
@@ -138,15 +123,14 @@ class PointsService:
# 记录积分变动日志包含coupon id和区分action为coupon
log_details = {
"coupon_id": coupon_id,
"expiration_days": 30
"coupon_id": coupon_id
}
# 记录积分变动日志
new_balance = current_balance + amount
await points_log_dao.add_log(db, {
"user_id": user_id,
"action": "coupon", # 区分action为coupon
"action": POINTS_ACTION_COUPON, # 使用充值action类型
"amount": amount,
"balance_after": new_balance,
"related_id": coupon_id, # 关联coupon id
@@ -155,73 +139,10 @@ class PointsService:
return True
@staticmethod
async def deduct_points(user_id: int, amount: int, related_id: Optional[int] = None, details: Optional[dict] = None) -> bool:
"""
扣减用户积分(会检查并清空过期积分)
Args:
user_id: 用户ID
amount: 扣减的积分数量
related_id: 关联ID可选
details: 附加信息(可选)
Returns:
bool: 是否成功
"""
if amount <= 0:
raise ValueError("积分数量必须大于0")
async with async_db_session.begin() as db:
# 获取当前积分余额(清空前)
points_account_before = await points_dao.get_by_user_id(db, user_id)
if not points_account_before:
return False
balance_before = points_account_before.balance
# 检查并清空过期积分
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
# 如果清空了过期积分,记录日志
if expired_cleared and balance_before > 0:
await points_log_dao.add_log(db, {
"user_id": user_id,
"action": "expire_clear",
"amount": balance_before, # 记录清空前的积分数量
"balance_after": 0,
"details": {"message": "过期积分已清空", "cleared_amount": balance_before}
})
# 重新获取账户信息(可能已被清空)
points_account = await points_dao.get_by_user_id(db, user_id)
if not points_account or points_account.balance < amount:
return False
current_balance = points_account.balance
# 原子性扣减积分
result = await points_dao.deduct_points_atomic(db, user_id, amount)
if not result:
return False
# 记录积分变动日志
new_balance = current_balance - amount
await points_log_dao.add_log(db, {
"user_id": user_id,
"action": "spend",
"amount": amount,
"balance_after": new_balance,
"related_id": related_id,
"details": details
})
return True
@staticmethod
async def check_sufficient_points(user_id: int, required_amount: int) -> bool:
"""
检查用户是否有足够的积分
检查用户是否有足够的积分(简化逻辑,移除过期积分检查)
Args:
user_id: 用户ID
@@ -233,25 +154,8 @@ class PointsService:
if required_amount <= 0:
return True
async with async_db_session.begin() as db:
# 获取当前积分余额(清空前
points_account_before = await points_dao.get_by_user_id(db, user_id)
balance_before = points_account_before.balance if points_account_before else 0
# 检查并清空过期积分
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
# 如果清空了过期积分,记录日志
if expired_cleared and balance_before > 0:
await points_log_dao.add_log(db, {
"user_id": user_id,
"action": "expire_clear",
"amount": balance_before, # 记录清空前的积分数量
"balance_after": 0,
"details": {"message": "过期积分已清空", "cleared_amount": balance_before}
})
# 重新获取账户信息(可能已被清空)
async with async_db_session() as db:
# 直接获取用户积分余额(不再检查过期积分
points_account = await points_dao.get_by_user_id(db, user_id)
if not points_account:
return False
@@ -259,9 +163,9 @@ class PointsService:
return points_account.balance >= required_amount
@staticmethod
async def deduct_points_with_db(user_id: int, amount: int, db: AsyncSession, related_id: Optional[int] = None, details: Optional[dict] = None) -> bool:
async def deduct_points_with_db(user_id: int, amount: int, db: AsyncSession, related_id: Optional[int] = None, details: Optional[dict] = None, action: Optional[str] = None) -> bool:
"""
扣减用户积分(会检查并清空过期积分)- 使用提供的数据库连接
扣减用户积分(简化逻辑,移除过期积分检查- 使用提供的数据库连接
Args:
user_id: 用户ID
@@ -269,6 +173,7 @@ class PointsService:
db: 数据库连接
related_id: 关联ID可选
details: 附加信息(可选)
action: 积分操作类型可选默认为POINTS_ACTION_SPEND
Returns:
bool: 是否成功
@@ -276,27 +181,7 @@ class PointsService:
if amount <= 0:
raise ValueError("积分数量必须大于0")
# 获取当前积分余额(清空前
points_account_before = await points_dao.get_by_user_id(db, user_id)
if not points_account_before:
return False
balance_before = points_account_before.balance
# 检查并清空过期积分
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
# 如果清空了过期积分,记录日志
if expired_cleared and balance_before > 0:
await points_log_dao.add_log(db, {
"user_id": user_id,
"action": "expire_clear",
"amount": balance_before, # 记录清空前的积分数量
"balance_after": 0,
"details": {"message": "过期积分已清空", "cleared_amount": balance_before}
})
# 重新获取账户信息(可能已被清空)
# 直接获取用户积分账户(不再检查过期积分
points_account = await points_dao.get_by_user_id(db, user_id)
if not points_account or points_account.balance < amount:
return False
@@ -304,15 +189,19 @@ class PointsService:
current_balance = points_account.balance
# 原子性扣减积分
result = await points_dao.deduct_points_atomic(db, user_id, amount)
result = await points_dao.deduct_balance_atomic(db, user_id, amount)
if not result:
return False
# 记录积分变动日志
new_balance = current_balance - amount
# 如果action参数为空则默认使用POINTS_ACTION_SPEND
if action is None:
action = POINTS_ACTION_SPEND
await points_log_dao.add_log(db, {
"user_id": user_id,
"action": "spend",
"action": action,
"amount": amount,
"balance_after": new_balance,
"related_id": related_id,
@@ -324,58 +213,7 @@ class PointsService:
@staticmethod
async def initialize_user_points(user_id: int, db: AsyncSession = None) -> Points:
"""
为新用户初始化积分账户
"""
if db is not None:
# Use the provided session (for nested transactions)
points_account = await points_dao.get_by_user_id(db, user_id)
if not points_account:
points_account = await points_dao.create_user_points(db, user_id)
return points_account
else:
# Create a new transaction (standalone usage)
async with async_db_session.begin() as db:
points_account = await points_dao.get_by_user_id(db, user_id)
if not points_account:
points_account = await points_dao.create_user_points(db, user_id)
return points_account
@staticmethod
async def is_subscribed_user(user_id: int) -> bool:
"""
检查用户是否为订阅用户(使用显式的订阅标志)
Args:
user_id: 用户ID
Returns:
是否为订阅用户
"""
try:
async with async_db_session() as db:
points_account = await points_dao.get_by_user_id(db, user_id)
if points_account and points_account.is_subscribed:
return True
# 如果没有显式的订阅标志,回退到旧的检查方式
if points_account:
# 检查是否有有效的订阅
if points_account.subscription_type and points_account.subscription_expires_at:
if points_account.subscription_expires_at > datetime.now():
return True
# 检查是否有积分余额
return points_account.balance > 0
return False
except Exception as e:
# 如果出现异常,默认认为不是订阅用户
return False
@staticmethod
async def create_new_user_account(user_id: int, db: AsyncSession = None) -> Points:
"""
为新用户创建账户(包含免费试用)
为新用户初始化积分账户根据新需求直接在balance中增加初始积分
"""
if db is not None:
# Use the provided session (for nested transactions)
@@ -391,39 +229,6 @@ class PointsService:
points_account = await points_dao.create_new_user_account(db, user_id)
return points_account
@staticmethod
async def update_subscription_status(user_id: int) -> bool:
"""
更新用户的订阅状态
Args:
user_id: 用户ID
Returns:
bool: 是否为订阅用户
"""
async with async_db_session.begin() as db:
points_account = await points_dao.get_by_user_id(db, user_id)
if not points_account:
return False
# 检查是否有有效的订阅
is_currently_subscribed = False
if points_account.subscription_type and points_account.subscription_expires_at:
if points_account.subscription_expires_at > datetime.now():
is_currently_subscribed = True
# 检查是否有余额
if points_account.balance > 0:
is_currently_subscribed = True
# 更新订阅标志
if points_account.is_subscribed != is_currently_subscribed:
points_account.is_subscribed = is_currently_subscribed
await points_dao.update(db, points_account.id, {"is_subscribed": is_currently_subscribed})
return is_currently_subscribed
@staticmethod
async def get_user_account_details(user_id: int) -> dict:
"""
@@ -434,44 +239,13 @@ class PointsService:
if not points_account:
return {}
frozen_balance = await points_dao.get_frozen_balance(db, user_id)
available_balance = max(0, points_account.balance - frozen_balance)
# 检查免费试用状态
is_free_trial_active = (
points_account.free_trial_expires_at and
points_account.free_trial_expires_at > datetime.now()
)
return {
"balance": points_account.balance,
"available_balance": available_balance,
"frozen_balance": frozen_balance,
"total_purchased": points_account.total_earned, # Using total_earned as equivalent to total_purchased
"subscription_type": points_account.subscription_type,
"subscription_expires_at": points_account.subscription_expires_at,
"carryover_balance": points_account.carryover_balance,
# 免费试用信息
"free_trial_balance": points_account.free_trial_balance,
"free_trial_expires_at": points_account.free_trial_expires_at,
"free_trial_active": is_free_trial_active,
"free_trial_used": points_account.free_trial_used,
# 计算剩余天数
"free_trial_days_left": (
max(0, (points_account.free_trial_expires_at - datetime.now()).days)
if points_account.free_trial_expires_at else 0
)
}
@staticmethod
async def check_free_trial_valid(user_id: int) -> bool:
"""
检查用户免费试用是否仍然有效
"""
async with async_db_session() as db:
return await points_dao.check_free_trial_valid(db, user_id)
points_service:PointsService = PointsService()

View File

@@ -52,10 +52,8 @@ class WxAuthService:
await wx_user_dao.add(db, user)
await db.flush()
await db.refresh(user)
# initialize user points
# initialize user points with initial gift points (according to new requirements)
await points_service.initialize_user_points(user_id=user.id, db=db)
# initialize user account
await points_service.create_new_user_account(user.id, db=db)
else:
await wx_user_dao.update_session_key(db, user.id, session_key)

View File

@@ -20,7 +20,7 @@ from backend.app.admin.service.file_service import file_service
from backend.app.ai.service.rate_limit_service import rate_limit_service, IMAGE_RECOGNITION_SERVICE
from backend.common.enums import FileType
from backend.common.exception import errors
from backend.common.const import IMAGE_RECOGNITION_COST
from backend.common.const import IMAGE_RECOGNITION_COST, POINTS_ACTION_IMAGE_RECOGNITION
from backend.core.conf import settings
from backend.common.log import log as logger
@@ -287,22 +287,11 @@ class ImageService:
if not dict_level:
dict_level = current_user.dict_level.name
# 优先检查是否订阅用户
is_subscribed = await rate_limit_service.is_subscribed_user(current_user.id)
if is_subscribed:
# 订阅用户检查积分是否足够
if not await points_service.check_sufficient_points(current_user.id, IMAGE_RECOGNITION_COST):
raise errors.ForbiddenError(
msg=f'积分不足,请充值以继续使用'
)
else:
# 非订阅用户检查是否有免费使用次数
usage_allowed = await rate_limit_service.increment_daily_usage(current_user.id, IMAGE_RECOGNITION_SERVICE, DAILY_IMAGE_RECOGNITION_MAX_TIMES)
if not usage_allowed:
raise errors.ForbiddenError(
msg=f'未订阅用户每天只能使用 {DAILY_IMAGE_RECOGNITION_MAX_TIMES} 次图像识别服务,请升级为订阅用户以解除限制'
)
# 检查用户积分是否足够(现在积分没有过期概念)
if not await points_service.check_sufficient_points(current_user.id, IMAGE_RECOGNITION_COST):
raise errors.ForbiddenError(
msg=f'积分不足,请充值以继续使用'
)
# 尝试获取任务槽位
slot_acquired = await rate_limit_service.acquire_task_slot(current_user.id)
@@ -1046,7 +1035,8 @@ class ImageService:
amount=IMAGE_RECOGNITION_COST, # 扣减图片识别费用
db=db,
related_id=image.id,
details={"action": "image_recognition", "task_id": task_id}
details={"task_id": task_id},
action=POINTS_ACTION_IMAGE_RECOGNITION
)
await db.commit()

View File

@@ -47,15 +47,16 @@ class RateLimitService:
@staticmethod
async def is_subscribed_user(user_id: int) -> bool:
"""
检查用户是否为订阅用户(使用显式的订阅标志
检查用户是否为订阅用户(此功能已废弃,根据新需求统一使用积分系统
Args:
user_id: 用户ID
Returns:
是否为订阅用户
是否为订阅用户始终返回False因为订阅概念已废弃
"""
return await PointsService.is_subscribed_user(user_id)
# 根据新需求,订阅概念已废弃,统一使用积分系统
return False
@staticmethod
async def get_user_task_limit(user_id: int) -> int:
@@ -310,41 +311,7 @@ class RateLimitService:
return status
@staticmethod
async def check_daily_usage_limit(user_id: int, service_type: str, max_usage: int = 3) -> tuple[bool, int, int]:
"""
检查用户每日使用次数限制
Args:
user_id: 用户ID
service_type: 服务类型(如"image_recognition", "speech_assessment"
max_usage: 最大使用次数默认3次
Returns:
(是否超过限制, 当前使用次数, 最大使用次数)
"""
# 检查用户是否为订阅用户
is_subscribed = await RateLimitService.is_subscribed_user(user_id)
# 订阅用户不受限制
if is_subscribed:
return False, 0, max_usage
# 获取每日使用次数键
usage_key = RateLimitService._get_daily_usage_key(user_id, service_type)
# 获取当前使用次数
current_usage = await redis_client.get(usage_key)
if current_usage is None:
current_usage = 0
else:
current_usage = int(current_usage)
# 检查是否超过限制
is_limited = current_usage >= max_usage
return is_limited, current_usage, max_usage
@staticmethod
async def increment_daily_usage(user_id: int, service_type: str, max_usage: int = 3) -> bool:
"""
@@ -357,12 +324,7 @@ class RateLimitService:
Returns:
是否成功增加(未超过限制)
"""
# 检查是否超过每日使用限制
is_limited, current_usage, _max_usage = await RateLimitService.check_daily_usage_limit(user_id, service_type, max_usage)
if is_limited:
return False
# 使用新的免费使用次数跟踪系统
await RateLimitService.increment_free_usage_count(user_id, service_type)

View File

@@ -17,6 +17,8 @@ from backend.app.admin.service.audit_log_service import audit_log_service
from backend.app.ai.service.rate_limit_service import rate_limit_service, SPEECH_ASSESSMENT_SERVICE
from backend.database.db import async_db_session
from backend.middleware.tencent_cloud import TencentCloud
from backend.common.const import SPEECH_ASSESSMENT_COST, POINTS_ACTION_SPEECH_ASSESSMENT
from backend.app.admin.service.points_service import points_service
# Import the recording_dao for accessing recording CRUD methods
from backend.app.ai.crud.recording_crud import recording_dao
@@ -337,10 +339,9 @@ class RecordingService:
except Exception as e:
raise RuntimeError(f"Failed to create recording record for file_id {file_id}: {str(e)}")
# 检查每日使用次数限制针对每条recording记录
usage_allowed = await rate_limit_service.increment_daily_usage(user_id, f"{SPEECH_ASSESSMENT_SERVICE}_{recording.id}", 1)
if not usage_allowed:
raise RuntimeError('免费用户每天每条语音只能使用1次语音评测服务请升级为订阅用户以解除限制')
# 检查用户积分是否足够(现在积分没有过期概念
if not await points_service.check_sufficient_points(user_id, SPEECH_ASSESSMENT_COST):
raise RuntimeError('积分不足,请充值以继续使用')
try:
# 调用腾讯云SOE API进行语音评估
@@ -354,6 +355,19 @@ class RecordingService:
if not success:
raise RuntimeError(f"Failed to update recording details for file_id {file_id}")
# 扣减用户积分
async with async_db_session.begin() as db:
points_deducted = await points_service.deduct_points_with_db(
user_id=user_id,
amount=SPEECH_ASSESSMENT_COST,
db=db,
related_id=recording.id,
details={"recording_id": recording.id},
action=POINTS_ACTION_SPEECH_ASSESSMENT
)
if not points_deducted:
logger.warning(f"Failed to deduct points for user {user_id} for speech assessment")
# 计算耗时
duration = time.time() - start_time

View File

@@ -44,14 +44,14 @@ class TestImageServiceRateLimit:
dict_level=DictLevel.LEVEL1
)
@patch('backend.app.ai.service.rate_limit_service.rate_limit_service.increment_daily_usage')
@patch('backend.app.admin.service.points_service.points_service.check_sufficient_points')
@patch('backend.app.admin.service.file_service.file_service.download_file')
@pytest.mark.asyncio
async def test_process_image_from_file_async_daily_usage_limit_exceeded(self, mock_download_file, mock_increment_daily_usage):
"""Test that process_image_from_file_async raises ForbiddenError when daily usage limit is exceeded"""
async def test_process_image_from_file_async_daily_usage_limit_exceeded(self, mock_download_file, mock_check_sufficient_points):
"""Test that process_image_from_file_async raises ForbiddenError when user has insufficient points"""
# Arrange
# Mock rate_limit_service.increment_daily_usage to return False (limit exceeded)
mock_increment_daily_usage.return_value = False
# Mock points_service.check_sufficient_points to return False (insufficient points)
mock_check_sufficient_points.return_value = False
# Mock file_service.download_file to return dummy data
mock_download_file.return_value = (b"dummy_image_data", "test.jpg", "image/jpeg")
@@ -65,25 +65,24 @@ class TestImageServiceRateLimit:
)
# Verify the error message
assert "免费用户每天只能使用3次图像识别服务请升级为订阅用户以解除限制" in str(exc_info.value.msg)
assert "积分不足,请充值以继续使用" in str(exc_info.value.msg)
# Verify that increment_daily_usage was called with correct parameters
mock_increment_daily_usage.assert_awaited_once_with(
# Verify that check_sufficient_points was called with correct parameters
mock_check_sufficient_points.assert_awaited_once_with(
self.user_id,
IMAGE_RECOGNITION_SERVICE,
3 # DAILY_IMAGE_RECOGNITION_MAX_TIMES
IMAGE_RECOGNITION_COST
)
@patch('backend.app.ai.service.rate_limit_service.rate_limit_service.increment_daily_usage')
@patch('backend.app.admin.service.points_service.points_service.check_sufficient_points')
@patch('backend.app.ai.service.rate_limit_service.rate_limit_service.acquire_task_slot')
@patch('backend.app.ai.service.rate_limit_service.rate_limit_service.get_user_task_limit')
@patch('backend.app.admin.service.file_service.file_service.download_file')
@pytest.mark.asyncio
async def test_process_image_from_file_async_task_slot_limit_exceeded(self, mock_download_file, mock_get_user_task_limit, mock_acquire_task_slot, mock_increment_daily_usage):
async def test_process_image_from_file_async_task_slot_limit_exceeded(self, mock_download_file, mock_get_user_task_limit, mock_acquire_task_slot, mock_check_sufficient_points):
"""Test that process_image_from_file_async raises ForbiddenError when task slot limit is exceeded"""
# Arrange
# Mock rate_limit_service.increment_daily_usage to return True (within limit)
mock_increment_daily_usage.return_value = True
# Mock points_service.check_sufficient_points to return True (sufficient points)
mock_check_sufficient_points.return_value = True
# Mock rate_limit_service.acquire_task_slot to return False (no slot available)
mock_acquire_task_slot.return_value = False
@@ -106,11 +105,10 @@ class TestImageServiceRateLimit:
assert "用户同时最多只能运行" in str(exc_info.value.msg)
assert "个任务,请等待现有任务完成后再试" in str(exc_info.value.msg)
# Verify that increment_daily_usage was called
mock_increment_daily_usage.assert_awaited_once_with(
# Verify that check_sufficient_points was called
mock_check_sufficient_points.assert_awaited_once_with(
self.user_id,
IMAGE_RECOGNITION_SERVICE,
3 # DAILY_IMAGE_RECOGNITION_MAX_TIMES
IMAGE_RECOGNITION_COST
)
# Verify that acquire_task_slot was called