fix code
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
from backend.app.admin.service.coupon_service import CouponService
|
||||
from backend.common.response.response_schema import response_base
|
||||
from backend.common.security.jwt import DependsJwtAuth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class RedeemCouponRequest(BaseModel):
|
||||
code: str = Field(..., min_length=1, max_length=32, description="兑换码")
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
from backend.app.admin.service.coupon_service import CouponService
|
||||
from backend.common.response.response_schema import response_base
|
||||
from backend.common.security.jwt import DependsJwtAuth
|
||||
from backend.core.conf import settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class RedeemCouponRequest(BaseModel):
|
||||
code: str = Field(..., min_length=1, max_length=32, description="兑换码")
|
||||
|
||||
class CreateCouponRequest(BaseModel):
|
||||
duration: int = Field(..., gt=0, description="兑换时长(分钟)")
|
||||
@@ -49,19 +51,30 @@ async def generate_coupons_api(
|
||||
request: Request,
|
||||
create_request: CreateCouponRequest
|
||||
):
|
||||
"""
|
||||
批量生成兑换券(管理员接口)
|
||||
"""
|
||||
# 这里应该添加管理员权限验证
|
||||
# 为简化示例,暂时省略权限验证
|
||||
|
||||
coupons = await CouponService.batch_create_coupons(
|
||||
create_request.count,
|
||||
create_request.duration,
|
||||
create_request.expires_days
|
||||
)
|
||||
|
||||
return response_base.success(data={
|
||||
"count": len(coupons),
|
||||
"message": f"成功生成{len(coupons)}个兑换券"
|
||||
})
|
||||
"""
|
||||
批量生成兑换券(管理员接口)
|
||||
"""
|
||||
# 这里应该添加管理员权限验证
|
||||
# 为简化示例,暂时省略权限验证
|
||||
|
||||
coupons = await CouponService.batch_create_coupons(
|
||||
create_request.count,
|
||||
create_request.duration,
|
||||
create_request.expires_days
|
||||
)
|
||||
|
||||
return response_base.success(data={
|
||||
"count": len(coupons),
|
||||
"message": f"成功生成{len(coupons)}个兑换券"
|
||||
})
|
||||
|
||||
class InitCouponsResponse(BaseModel):
|
||||
count: int = Field(..., description="生成数量")
|
||||
|
||||
@router.get("/init", summary="初始化兑换券")
|
||||
async def init_coupons(request: Request, prefix: str = "VIP", count: int = 10):
|
||||
t = request.query_params.get('t')
|
||||
if not t or t == '' or t != settings.INIT_TOKEN:
|
||||
raise HTTPException(status_code=403, detail='Forbidden')
|
||||
created = await CouponService.init_coupons(prefix, count)
|
||||
return response_base.success(data=InitCouponsResponse(count=created))
|
||||
|
||||
@@ -6,22 +6,22 @@ from backend.app.admin.model.coupon import Coupon, CouponUsage
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class CouponDao(CRUDPlus[Coupon]):
|
||||
|
||||
async def get(self, db: AsyncSession, id: int) -> Optional[Coupon]:
|
||||
"""
|
||||
根据ID获取兑换券
|
||||
"""
|
||||
return await self.select_model(db, id)
|
||||
|
||||
async def get_by_code(self, db: AsyncSession, code: str) -> Optional[Coupon]:
|
||||
"""
|
||||
根据兑换码获取兑换券
|
||||
"""
|
||||
stmt = select(Coupon).where(Coupon.code == code)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
class CouponDao(CRUDPlus[Coupon]):
|
||||
|
||||
async def get(self, db: AsyncSession, id: int) -> Optional[Coupon]:
|
||||
"""
|
||||
根据ID获取兑换券
|
||||
"""
|
||||
return await self.select_model(db, id)
|
||||
|
||||
async def get_by_code(self, db: AsyncSession, code: str) -> Optional[Coupon]:
|
||||
"""
|
||||
根据兑换码获取兑换券
|
||||
"""
|
||||
stmt = select(Coupon).where(Coupon.code == code)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_unused_coupon_by_code(self, db: AsyncSession, code: str) -> Optional[Coupon]:
|
||||
"""
|
||||
根据兑换码获取未使用的兑换券
|
||||
@@ -45,15 +45,21 @@ class CouponDao(CRUDPlus[Coupon]):
|
||||
await db.flush()
|
||||
return coupon
|
||||
|
||||
async def create_coupons(self, db: AsyncSession, coupons_data: List[dict]) -> List[Coupon]:
|
||||
"""
|
||||
批量创建兑换券
|
||||
"""
|
||||
coupons = [Coupon(**data) for data in coupons_data]
|
||||
db.add_all(coupons)
|
||||
await db.flush()
|
||||
return coupons
|
||||
|
||||
async def create_coupons(self, db: AsyncSession, coupons_data: List[dict]) -> List[Coupon]:
|
||||
"""
|
||||
批量创建兑换券
|
||||
"""
|
||||
coupons = [Coupon(**data) for data in coupons_data]
|
||||
db.add_all(coupons)
|
||||
await db.flush()
|
||||
return coupons
|
||||
|
||||
async def list_codes_by_prefix(self, db: AsyncSession, prefix: str) -> List[str]:
|
||||
stmt = select(Coupon.code).where(Coupon.code.like(f"{prefix}%"))
|
||||
result = await db.execute(stmt)
|
||||
rows = result.all()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
async def mark_as_used(self, db: AsyncSession, user_id: int, coupon: Coupon) -> bool:
|
||||
"""
|
||||
标记兑换券为已使用并创建使用记录
|
||||
@@ -110,4 +116,4 @@ class CouponUsageDao(CRUDPlus[CouponUsage]):
|
||||
|
||||
|
||||
coupon_dao = CouponDao(Coupon)
|
||||
coupon_usage_dao = CouponUsageDao(CouponUsage)
|
||||
coupon_usage_dao = CouponUsageDao(CouponUsage)
|
||||
|
||||
@@ -9,22 +9,22 @@ from backend.common.exception import errors
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class CouponService:
|
||||
|
||||
@staticmethod
|
||||
def generate_unique_code(length: int = 6) -> str:
|
||||
"""
|
||||
生成唯一的兑换码
|
||||
"""
|
||||
characters = string.ascii_uppercase + string.digits
|
||||
# 移除容易混淆的字符
|
||||
characters = characters.replace('0', '').replace('O', '').replace('I', '').replace('1')
|
||||
|
||||
while True:
|
||||
code = ''.join(random.choice(characters) for _ in range(length))
|
||||
# 确保生成的兑换码不包含敏感词汇或重复模式
|
||||
if not CouponService._has_sensitive_pattern(code):
|
||||
return code
|
||||
class CouponService:
|
||||
|
||||
@staticmethod
|
||||
def generate_unique_code(length: int = 6) -> str:
|
||||
"""
|
||||
生成唯一的兑换码
|
||||
"""
|
||||
characters = string.ascii_uppercase + string.digits
|
||||
# 移除容易混淆的字符
|
||||
characters = characters.replace('0', '').replace('O', '').replace('I', '').replace('1')
|
||||
|
||||
while True:
|
||||
code = ''.join(random.choice(characters) for _ in range(length))
|
||||
# 确保生成的兑换码不包含敏感词汇或重复模式
|
||||
if not CouponService._has_sensitive_pattern(code):
|
||||
return code
|
||||
|
||||
@staticmethod
|
||||
def _has_sensitive_pattern(code: str) -> bool:
|
||||
@@ -71,39 +71,76 @@ class CouponService:
|
||||
return coupon
|
||||
|
||||
@staticmethod
|
||||
async def batch_create_coupons(count: int, points: int, expires_days: Optional[int] = None) -> List[Coupon]:
|
||||
"""
|
||||
批量创建兑换券
|
||||
"""
|
||||
async with async_db_session.begin() as db:
|
||||
coupons_data = []
|
||||
|
||||
# 生成唯一兑换码列表
|
||||
codes = set()
|
||||
while len(codes) < count:
|
||||
code = CouponService.generate_unique_code()
|
||||
if code not in codes:
|
||||
# 检查数据库中是否已存在该兑换码
|
||||
existing_coupon = await coupon_dao.get_by_code(db, code)
|
||||
if not existing_coupon:
|
||||
codes.add(code)
|
||||
|
||||
# 设置过期时间
|
||||
expires_at = None
|
||||
if expires_days:
|
||||
expires_at = datetime.now() + timedelta(days=expires_days)
|
||||
|
||||
# 准备数据
|
||||
for code in codes:
|
||||
coupons_data.append({
|
||||
'code': code,
|
||||
'points': points,
|
||||
'expires_at': expires_at
|
||||
})
|
||||
|
||||
coupons = await coupon_dao.create_coupons(db, coupons_data)
|
||||
return coupons
|
||||
|
||||
async def batch_create_coupons(count: int, points: int, expires_days: Optional[int] = None) -> List[Coupon]:
|
||||
"""
|
||||
批量创建兑换券
|
||||
"""
|
||||
async with async_db_session.begin() as db:
|
||||
coupons_data = []
|
||||
|
||||
# 生成唯一兑换码列表
|
||||
codes = set()
|
||||
while len(codes) < count:
|
||||
code = CouponService.generate_unique_code()
|
||||
if code not in codes:
|
||||
# 检查数据库中是否已存在该兑换码
|
||||
existing_coupon = await coupon_dao.get_by_code(db, code)
|
||||
if not existing_coupon:
|
||||
codes.add(code)
|
||||
|
||||
# 设置过期时间
|
||||
expires_at = None
|
||||
if expires_days:
|
||||
expires_at = datetime.now() + timedelta(days=expires_days)
|
||||
|
||||
# 准备数据
|
||||
for code in codes:
|
||||
coupons_data.append({
|
||||
'code': code,
|
||||
'points': points,
|
||||
'expires_at': expires_at
|
||||
})
|
||||
|
||||
coupons = await coupon_dao.create_coupons(db, coupons_data)
|
||||
return coupons
|
||||
|
||||
@staticmethod
|
||||
async def init_coupons(prefix: str, count: int) -> int:
|
||||
if not prefix or not any(ch.isalpha() for ch in prefix):
|
||||
raise errors.BadRequestError(msg='前缀至少包含一个字母')
|
||||
prefix = ''.join([ch for ch in prefix.upper() if ch.isalpha()])
|
||||
if len(prefix) not in (3, 4):
|
||||
raise errors.BadRequestError(msg='前缀长度必须为3或4')
|
||||
digits = 6 - len(prefix)
|
||||
max_serial = 10 ** digits - 1
|
||||
async with async_db_session.begin() as db:
|
||||
existing_codes = await coupon_dao.list_codes_by_prefix(db, prefix)
|
||||
current_max = -1
|
||||
for c in existing_codes:
|
||||
if len(c) == 6 and c.startswith(prefix):
|
||||
suffix = c[len(prefix):]
|
||||
if len(suffix) == digits and suffix.isdigit():
|
||||
n = int(suffix)
|
||||
if n > current_max:
|
||||
current_max = n
|
||||
start = current_max + 1
|
||||
if start > max_serial:
|
||||
from backend.common.log import log as logger
|
||||
logger.warning(f"{prefix} 前缀已达到最大序号 {max_serial},不再生成兑换券")
|
||||
return 0
|
||||
to_generate = min(max(0, count), max_serial - start + 1)
|
||||
coupons_data = []
|
||||
for n in range(start, start + to_generate):
|
||||
code = f"{prefix}{str(n).zfill(digits)}"
|
||||
coupons_data.append({
|
||||
'code': code,
|
||||
'type': prefix,
|
||||
'points': 0,
|
||||
'expires_at': None
|
||||
})
|
||||
await coupon_dao.create_coupons(db, coupons_data)
|
||||
return to_generate
|
||||
|
||||
@staticmethod
|
||||
async def redeem_coupon(code: str, user_id: int) -> dict:
|
||||
"""
|
||||
|
||||
@@ -192,9 +192,9 @@ class ImageService:
|
||||
current_user = request.user
|
||||
file_id = int(params.file_id)
|
||||
type = params.type
|
||||
dict_level = params.dict_level.name
|
||||
dict_level = params.dict_level.value
|
||||
if not dict_level:
|
||||
dict_level = current_user.dict_level.name
|
||||
dict_level = current_user.dict_level.value
|
||||
|
||||
# 检查用户积分是否足够(现在积分没有过期概念)
|
||||
if not await points_service.check_sufficient_points(current_user.id, IMAGE_RECOGNITION_COST):
|
||||
|
||||
Reference in New Issue
Block a user