Files
backend/backend/app/admin/crud/points_crud.py
2025-11-22 18:50:52 +08:00

104 lines
3.4 KiB
Python

from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from sqlalchemy import select, update, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.admin.model.points import Points, PointsLog
class PointsDao(CRUDPlus[Points]):
async def get_by_user_id(self, db: AsyncSession, user_id: int) -> Optional[Points]:
"""
根据用户ID获取积分账户信息
"""
stmt = select(Points).where(Points.user_id == user_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def create_user_points(self, db: AsyncSession, user_id: int) -> Points:
"""
为用户创建积分账户
"""
points = Points(user_id=user_id)
db.add(points)
await db.flush()
return points
async def add_points_atomic(self, db: AsyncSession, user_id: int, amount: int, extend_expiration: bool = False) -> bool:
"""
原子性增加用户积分
"""
# 先确保用户有积分账户
points_account = await self.get_by_user_id(db, user_id)
if not points_account:
points_account = await self.create_user_points(db, user_id)
# 准备更新值
update_values = {
"balance": Points.balance + amount,
"total_earned": Points.total_earned + amount
}
# 如果需要延期,则更新过期时间
if extend_expiration:
update_values["expired_time"] = datetime.now() + timedelta(days=30)
stmt = update(Points).where(
Points.user_id == user_id
).values(**update_values)
result = await db.execute(stmt)
return result.rowcount > 0
async def deduct_points_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
"""
原子性扣减用户积分(确保不超扣)
"""
stmt = update(Points).where(
Points.user_id == user_id,
Points.balance >= amount
).values(
balance=Points.balance - amount,
total_spent=Points.total_spent + amount
)
result = await db.execute(stmt)
return result.rowcount > 0
async def get_balance(self, db: AsyncSession, user_id: int) -> int:
"""
获取用户积分余额
"""
points_account = await self.get_by_user_id(db, user_id)
if not points_account:
return 0
return points_account.balance
async def check_and_clear_expired_points(self, db: AsyncSession, user_id: int) -> bool:
"""
检查并清空过期积分
"""
stmt = update(Points).where(
Points.user_id == user_id,
Points.expired_time < datetime.now(),
Points.balance > 0
).values(
balance=0,
total_spent=Points.total_spent + Points.balance
)
result = await db.execute(stmt)
return result.rowcount > 0
class PointsLogDao(CRUDPlus[PointsLog]):
async def add_log(self, db: AsyncSession, log_data: Dict[str, Any]) -> PointsLog:
"""
添加积分变动日志
"""
log = PointsLog(**log_data)
db.add(log)
await db.flush()
return log
points_dao = PointsDao(Points)
points_log_dao = PointsLogDao(PointsLog)