104 lines
3.4 KiB
Python
104 lines
3.4 KiB
Python
from datetime import datetime, timedelta
|
|
from typing import Optional, Dict, Any
|
|
|
|
from sqlalchemy import select, update, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy_crud_plus import CRUDPlus
|
|
from backend.app.admin.model.points import Points, PointsLog
|
|
|
|
|
|
class PointsDao(CRUDPlus[Points]):
|
|
async def get_by_user_id(self, db: AsyncSession, user_id: int) -> Optional[Points]:
|
|
"""
|
|
根据用户ID获取积分账户信息
|
|
"""
|
|
stmt = select(Points).where(Points.user_id == user_id)
|
|
result = await db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def create_user_points(self, db: AsyncSession, user_id: int) -> Points:
|
|
"""
|
|
为用户创建积分账户
|
|
"""
|
|
points = Points(user_id=user_id)
|
|
db.add(points)
|
|
await db.flush()
|
|
return points
|
|
|
|
async def add_points_atomic(self, db: AsyncSession, user_id: int, amount: int, extend_expiration: bool = False) -> bool:
|
|
"""
|
|
原子性增加用户积分
|
|
"""
|
|
# 先确保用户有积分账户
|
|
points_account = await self.get_by_user_id(db, user_id)
|
|
if not points_account:
|
|
points_account = await self.create_user_points(db, user_id)
|
|
|
|
# 准备更新值
|
|
update_values = {
|
|
"balance": Points.balance + amount,
|
|
"total_earned": Points.total_earned + amount
|
|
}
|
|
|
|
# 如果需要延期,则更新过期时间
|
|
if extend_expiration:
|
|
update_values["expired_time"] = datetime.now() + timedelta(days=30)
|
|
|
|
stmt = update(Points).where(
|
|
Points.user_id == user_id
|
|
).values(**update_values)
|
|
result = await db.execute(stmt)
|
|
return result.rowcount > 0
|
|
|
|
async def deduct_points_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
|
|
"""
|
|
原子性扣减用户积分(确保不超扣)
|
|
"""
|
|
stmt = update(Points).where(
|
|
Points.user_id == user_id,
|
|
Points.balance >= amount
|
|
).values(
|
|
balance=Points.balance - amount,
|
|
total_spent=Points.total_spent + amount
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.rowcount > 0
|
|
|
|
async def get_balance(self, db: AsyncSession, user_id: int) -> int:
|
|
"""
|
|
获取用户积分余额
|
|
"""
|
|
points_account = await self.get_by_user_id(db, user_id)
|
|
if not points_account:
|
|
return 0
|
|
return points_account.balance
|
|
|
|
async def check_and_clear_expired_points(self, db: AsyncSession, user_id: int) -> bool:
|
|
"""
|
|
检查并清空过期积分
|
|
"""
|
|
stmt = update(Points).where(
|
|
Points.user_id == user_id,
|
|
Points.expired_time < datetime.now(),
|
|
Points.balance > 0
|
|
).values(
|
|
balance=0,
|
|
total_spent=Points.total_spent + Points.balance
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.rowcount > 0
|
|
|
|
|
|
class PointsLogDao(CRUDPlus[PointsLog]):
|
|
async def add_log(self, db: AsyncSession, log_data: Dict[str, Any]) -> PointsLog:
|
|
"""
|
|
添加积分变动日志
|
|
"""
|
|
log = PointsLog(**log_data)
|
|
db.add(log)
|
|
await db.flush()
|
|
return log
|
|
|
|
|
|
points_dao = PointsDao(Points)
|
|
points_log_dao = PointsLogDao(PointsLog) |