This commit is contained in:
felix
2025-11-21 11:28:22 +08:00
parent fd0118e5c1
commit 717d9dca5f
13 changed files with 49 additions and 166 deletions

View File

@@ -71,6 +71,7 @@ class CouponDao(CRUDPlus[Coupon]):
user_id=user_id,
points=coupon.points,
coupon_type=coupon.type,
used_at=datetime.now()
)
db.add(usage)
await db.flush()

View File

@@ -27,7 +27,7 @@ class AuditLog(Base):
dict_level: Mapped[Optional[str]] = mapped_column(String(20), comment="dict level")
api_version: Mapped[Optional[str]] = mapped_column(String(20), comment="API版本")
error_message: Mapped[Optional[str]] = mapped_column(Text, default=None, comment="错误信息")
called_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment="调用时间")
called_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now(), comment="调用时间")
# 索引优化
__table_args__ = (

View File

@@ -34,7 +34,7 @@ class CouponUsage(Base):
coupon_type: Mapped[str] = mapped_column(String(32), nullable=False, comment='兑换券类型')
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='使用者ID')
points: Mapped[int] = mapped_column(BigInteger, nullable=False, comment='兑换积分')
used_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='使用时间')
used_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now(), comment='使用时间')
__table_args__ = (
Index('idx_coupon_usage_user', 'user_id'),

View File

@@ -16,7 +16,7 @@ class Notification(Base):
title: Mapped[str] = mapped_column(String(255), nullable=False, comment='通知标题')
content: Mapped[str] = mapped_column(Text, nullable=False, comment='通知内容')
image_url: Mapped[Optional[str]] = mapped_column(String(512), default=None, comment='图片URL预留')
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='创建时间')
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now(), comment='创建时间')
created_by: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, comment='创建者ID')
is_active: Mapped[bool] = mapped_column(Boolean, default=True, comment='是否激活')
@@ -35,7 +35,7 @@ class UserNotification(Base):
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='用户ID')
is_read: Mapped[bool] = mapped_column(Boolean, default=False, comment='是否已读')
read_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='阅读时间')
received_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='接收时间')
received_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now(), comment='接收时间')
__table_args__ = (
Index('idx_user_notification_user', 'user_id'),

View File

@@ -37,11 +37,10 @@ class PointsLog(Base):
balance_after: Mapped[int] = mapped_column(BigInteger, comment='变动后余额')
related_id: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, comment='关联ID')
details: Mapped[Optional[dict]] = mapped_column(JSONB, default=None, comment='附加信息')
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='创建时间')
# 索引优化
__table_args__ = (
Index('idx_points_log_user_action', 'user_id', 'action'),
Index('idx_points_log_user_time', 'user_id', 'created_at'),
Index('idx_points_log_user_time', 'user_id', 'created_time'),
{'comment': '积分变动日志表'}
)

View File

@@ -1,4 +1,5 @@
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from backend.app.admin.crud.points_crud import points_dao, points_log_dao
from backend.app.admin.model.points import Points
from backend.database.db import async_db_session
@@ -217,14 +218,22 @@ class PointsService:
return True
@staticmethod
async def initialize_user_points(user_id: int) -> Points:
async def initialize_user_points(user_id: int, db: AsyncSession = None) -> Points:
"""
为新用户初始化积分账户
"""
async with async_db_session.begin() as db:
if db is not None:
# Use the provided session (for nested transactions)
points_account = await points_dao.get_by_user_id(db, user_id)
if not points_account:
points_account = await points_dao.create_user_points(db, user_id)
return points_account
else:
# Create a new transaction (standalone usage)
async with async_db_session.begin() as db:
points_account = await points_dao.get_by_user_id(db, user_id)
if not points_account:
points_account = await points_dao.create_user_points(db, user_id)
return points_account
points_service:PointsService = PointsService()

View File

@@ -53,7 +53,7 @@ class WxAuthService:
await db.flush()
await db.refresh(user)
# initialize user points
await points_service.initialize_user_points(user_id=user.id)
await points_service.initialize_user_points(user_id=user.id, db=db)
else:
await wx_user_dao.update_session_key(db, user.id, session_key)

View File

@@ -77,44 +77,6 @@ class ImageCRUD(CRUDPlus[Image]):
result = await db.execute(stmt)
return result.scalars().all()
async def find_similar_images_by_dict_level(self, db: AsyncSession, embedding: List[float], dict_level: str,
top_k: int = 3, threshold: float = 0.8) -> List[int]:
"""
根据向量和词典等级查找相似图片
参数:
db: 数据库会话
embedding: 1024维向量
dict_level: 词典等级
top_k: 返回最相似的K个结果
threshold: 相似度阈值 (0.0-1.0)
"""
# 确保向量是numpy数组
if not isinstance(embedding, np.ndarray):
embedding = np.array(embedding)
# 转换为 NumPy 数组提高性能
target_embedding = np.array(embedding, dtype=np.float32)
# 构建查询
cosine_distance_expr = Image.embedding.cosine_distance(target_embedding)
similarity_expr = (1 - func.cast(cosine_distance_expr, Float)).label("similarity")
# 构建查询
stmt = select(
Image.id,
# Image.info,
similarity_expr
).where(
and_(
Image.dict_level == dict_level,
similarity_expr >= threshold
)
).order_by(
cosine_distance_expr
).limit(top_k)
results = await db.execute(stmt)
id_list: List[int] = results.scalars().all()
return id_list
async def add(self, db: AsyncSession, new_image: Image) -> None:
db.add(new_image)
@@ -122,39 +84,5 @@ class ImageCRUD(CRUDPlus[Image]):
async def update(self, db: AsyncSession, id: int, obj: UpdateImageParam) -> int:
return await self.update_model(db, id, obj)
async def find_similar_image_ids(self, db: AsyncSession, embedding: List[float], top_k: int = 3,
threshold: float = 0.8) -> List[int]:
"""
直接通过向量查找相似图片
参数:
embedding: 1024维向量
top_k: 返回最相似的K个结果
threshold: 相似度阈值 (0.0-1.0)
"""
# 确保向量是numpy数组
if not isinstance(embedding, np.ndarray):
embedding = np.array(embedding)
# 转换为 NumPy 数组提高性能
target_embedding = np.array(embedding, dtype=np.float32)
# 构建查询
cosine_distance_expr = Image.embedding.cosine_distance(target_embedding)
similarity_expr = (1 - func.cast(cosine_distance_expr, Float)).label("similarity")
# 构建查询
stmt = select(
Image.id,
# Image.info,
similarity_expr
).where(
similarity_expr >= threshold
).order_by(
cosine_distance_expr
).limit(top_k)
results = await db.execute(stmt)
id_list: List[int] = results.scalars().all()
return id_list
image_dao: ImageCRUD = ImageCRUD(Image)

View File

@@ -18,20 +18,11 @@ class Image(Base):
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
file_id: Mapped[Optional[int]] = mapped_column(BigInteger, ForeignKey('file.id'), nullable=True, comment="关联的文件ID")
thumbnail_id: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, nullable=True, comment="缩略图ID")
embedding: Mapped[Optional[list[float]]] = mapped_column(Vector(1024), default=None, nullable=True) # 1024 维向量
info: Mapped[Optional[ImageMetadata]] = mapped_column(PydanticType(pydantic_type=ImageMetadata), default=None, comment="附加元数据") # 其他可能的字段(根据实际需求添加)
details: Mapped[Optional[dict]] = mapped_column(JSONB(astext_type=Text()), default=None, comment="其他信息") # 其他信息
# 表参数 - 包含所有必要的约束
__table_args__ = (
# 为 embedding 字段添加 HNSW 索引 (pgvector)
Index(
'idx_image_embedding_hnsw',
embedding,
postgresql_using='hnsw',
postgresql_with={'m': 16, 'ef_construction': 64},
postgresql_ops={'embedding': 'vector_l2_ops'} # 修复在此
),
# 为 thumbnail_id 添加索引以优化查询
Index('idx_image_thumbnail_id', 'thumbnail_id'),
)

View File

@@ -313,7 +313,7 @@ class ImageService:
usage_allowed = await rate_limit_service.increment_daily_usage(current_user.id, IMAGE_RECOGNITION_SERVICE, DAILY_IMAGE_RECOGNITION_MAX_TIMES)
if not usage_allowed:
raise errors.ForbiddenError(
msg='免费用户每天只能使用3次图像识别服务,请升级为订阅用户以解除限制'
msg=f'未订阅用户每天只能使用 {DAILY_IMAGE_RECOGNITION_MAX_TIMES} 次图像识别服务,请升级为订阅用户以解除限制'
)
# 尝试获取任务槽位
@@ -414,6 +414,9 @@ class ImageService:
@staticmethod
async def _process_image_with_limiting(task_id: int, user_id: int) -> None:
from backend.app.ai.tasks import update_task_status_with_retry
"""带限流控制的后台处理图片识别任务"""
try:
# 执行图片处理任务
@@ -422,6 +425,12 @@ class ImageService:
# 任务完成后释放槽位
await rate_limit_service.release_task_slot(user_id)
async with background_db_session.begin() as db:
await update_task_status_with_retry(
db, task_id, ImageTaskStatus.COMPLETED
)
await db.commit()
@staticmethod
async def _process_image(task_id: int) -> None:
await ImageService._process_image_recognition(task_id)
@@ -1039,7 +1048,7 @@ class ImageService:
# 更新任务状态为完成
await update_task_status_with_retry(
db, task_id, ImageTaskStatus.COMPLETED,
db, task_id, ImageTaskStatus.PROCESSING,
result=transformed_result
)
await db.commit()

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import logging
import asyncio
from typing import Optional, List
from soupsieve.util import lower
@@ -11,8 +12,13 @@ from backend.app.ai.model.image_text import ImageText
from backend.app.ai.schema.image_text import CreateImageTextParam, UpdateImageTextParam, ImageTextInitResponseSchema, \
ImageTextAssessmentSchema
from backend.app.ai.crud.image_curd import image_dao
from backend.common.exception import errors
from backend.database.db import async_db_session
# Add imports for ImageProcessingTask functionality
from backend.app.ai.model.image_task import ImageTaskStatus, ImageProcessingTask
from backend.app.ai.crud.image_task_crud import image_task_dao
logger = logging.getLogger(__name__)
@@ -104,6 +110,19 @@ class ImageTextService:
:return: 图片文本记录列表
"""
async with async_db_session() as db:
# First, check if there's an ImageProcessingTask for this image_id
# and wait until it reaches a final state (COMPLETED or FAILED)
image_task = await image_task_dao.get_by_image_id(db, image_id)
if image_task:
while image_task.status not in [ImageTaskStatus.COMPLETED, ImageTaskStatus.FAILED]:
# Wait for 1 second before checking again
await asyncio.sleep(1)
# Refresh the task status
await db.refresh(image_task)
if image_task.status == ImageTaskStatus.FAILED:
raise errors.ServerError(msg="Image task failed")
# 获取图片记录
image = await image_dao.get(db, image_id)
if not image:

View File

@@ -1,73 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
兑换券批量生成脚本
"""
import argparse
import asyncio
import sys
from typing import List
# 添加项目路径到sys.path
sys.path.append('.')
from backend.app.admin.service.coupon_service import CouponService
from backend.app.admin.model.coupon import Coupon
async def generate_coupons(count: int, duration: int, expires_days: int = None) -> List[Coupon]:
"""
批量生成兑换券
:param count: 生成数量
:param duration: 兑换时长(分钟)
:param expires_days: 过期天数(可选)
:return: 生成的兑换券列表
"""
print(f"开始生成 {count} 个兑换券,每个兑换券时长为 {duration} 分钟")
if expires_days:
print(f"兑换券将在 {expires_days} 天后过期")
coupons = await CouponService.batch_create_coupons(count, duration, expires_days)
print(f"成功生成 {len(coupons)} 个兑换券:")
for coupon in coupons:
print(f" - 兑换码: {coupon.code}, 时长: {coupon.duration} 分钟")
return coupons
def main():
parser = argparse.ArgumentParser(description='批量生成兑换券')
parser.add_argument('-c', '--count', type=int, required=True, help='生成数量')
parser.add_argument('-d', '--duration', type=int, required=True, help='兑换时长(分钟)')
parser.add_argument('-e', '--expires', type=int, help='过期天数(可选)')
args = parser.parse_args()
# 验证参数
if args.count <= 0 or args.count > 10000:
print("错误: 兑换券数量必须在1-10000之间")
return 1
if args.duration <= 0:
print("错误: 兑换时长必须大于0")
return 1
if args.expires and args.expires <= 0:
print("错误: 过期天数必须大于0")
return 1
# 运行异步函数
try:
asyncio.run(generate_coupons(args.count, args.duration, args.expires))
print("兑换券生成完成!")
return 0
except Exception as e:
print(f"生成兑换券时发生错误: {e}")
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -33,4 +33,4 @@ async def http_limit_callback(request: Request, response: Response, expire: int)
:return:
"""
expires = ceil(expire / 1000)
raise errors.HTTPError(code=429, msg='请求过于频繁,请后重试', headers={'Retry-After': str(expires)})
raise errors.HTTPError(code=429, msg=f'请求过于频繁,请 {str(expires)}后重试', headers={'Retry-After': str(expires)})