fix bug
This commit is contained in:
@@ -71,6 +71,7 @@ class CouponDao(CRUDPlus[Coupon]):
|
||||
user_id=user_id,
|
||||
points=coupon.points,
|
||||
coupon_type=coupon.type,
|
||||
used_at=datetime.now()
|
||||
)
|
||||
db.add(usage)
|
||||
await db.flush()
|
||||
|
||||
@@ -27,7 +27,7 @@ class AuditLog(Base):
|
||||
dict_level: Mapped[Optional[str]] = mapped_column(String(20), comment="dict level")
|
||||
api_version: Mapped[Optional[str]] = mapped_column(String(20), comment="API版本")
|
||||
error_message: Mapped[Optional[str]] = mapped_column(Text, default=None, comment="错误信息")
|
||||
called_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment="调用时间")
|
||||
called_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now(), comment="调用时间")
|
||||
|
||||
# 索引优化
|
||||
__table_args__ = (
|
||||
|
||||
@@ -34,7 +34,7 @@ class CouponUsage(Base):
|
||||
coupon_type: Mapped[str] = mapped_column(String(32), nullable=False, comment='兑换券类型')
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='使用者ID')
|
||||
points: Mapped[int] = mapped_column(BigInteger, nullable=False, comment='兑换积分')
|
||||
used_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='使用时间')
|
||||
used_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now(), comment='使用时间')
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_coupon_usage_user', 'user_id'),
|
||||
|
||||
@@ -16,7 +16,7 @@ class Notification(Base):
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False, comment='通知标题')
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False, comment='通知内容')
|
||||
image_url: Mapped[Optional[str]] = mapped_column(String(512), default=None, comment='图片URL(预留)')
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='创建时间')
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now(), comment='创建时间')
|
||||
created_by: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, comment='创建者ID')
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, comment='是否激活')
|
||||
|
||||
@@ -35,7 +35,7 @@ class UserNotification(Base):
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='用户ID')
|
||||
is_read: Mapped[bool] = mapped_column(Boolean, default=False, comment='是否已读')
|
||||
read_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='阅读时间')
|
||||
received_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='接收时间')
|
||||
received_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now(), comment='接收时间')
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_user_notification_user', 'user_id'),
|
||||
|
||||
@@ -37,11 +37,10 @@ class PointsLog(Base):
|
||||
balance_after: Mapped[int] = mapped_column(BigInteger, comment='变动后余额')
|
||||
related_id: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, comment='关联ID')
|
||||
details: Mapped[Optional[dict]] = mapped_column(JSONB, default=None, comment='附加信息')
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='创建时间')
|
||||
|
||||
# 索引优化
|
||||
__table_args__ = (
|
||||
Index('idx_points_log_user_action', 'user_id', 'action'),
|
||||
Index('idx_points_log_user_time', 'user_id', 'created_at'),
|
||||
Index('idx_points_log_user_time', 'user_id', 'created_time'),
|
||||
{'comment': '积分变动日志表'}
|
||||
)
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from backend.app.admin.crud.points_crud import points_dao, points_log_dao
|
||||
from backend.app.admin.model.points import Points
|
||||
from backend.database.db import async_db_session
|
||||
@@ -217,14 +218,22 @@ class PointsService:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def initialize_user_points(user_id: int) -> Points:
|
||||
async def initialize_user_points(user_id: int, db: AsyncSession = None) -> Points:
|
||||
"""
|
||||
为新用户初始化积分账户
|
||||
"""
|
||||
async with async_db_session.begin() as db:
|
||||
if db is not None:
|
||||
# Use the provided session (for nested transactions)
|
||||
points_account = await points_dao.get_by_user_id(db, user_id)
|
||||
if not points_account:
|
||||
points_account = await points_dao.create_user_points(db, user_id)
|
||||
return points_account
|
||||
else:
|
||||
# Create a new transaction (standalone usage)
|
||||
async with async_db_session.begin() as db:
|
||||
points_account = await points_dao.get_by_user_id(db, user_id)
|
||||
if not points_account:
|
||||
points_account = await points_dao.create_user_points(db, user_id)
|
||||
return points_account
|
||||
|
||||
points_service:PointsService = PointsService()
|
||||
@@ -53,7 +53,7 @@ class WxAuthService:
|
||||
await db.flush()
|
||||
await db.refresh(user)
|
||||
# initialize user points
|
||||
await points_service.initialize_user_points(user_id=user.id)
|
||||
await points_service.initialize_user_points(user_id=user.id, db=db)
|
||||
else:
|
||||
await wx_user_dao.update_session_key(db, user.id, session_key)
|
||||
|
||||
|
||||
@@ -77,44 +77,6 @@ class ImageCRUD(CRUDPlus[Image]):
|
||||
result = await db.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def find_similar_images_by_dict_level(self, db: AsyncSession, embedding: List[float], dict_level: str,
|
||||
top_k: int = 3, threshold: float = 0.8) -> List[int]:
|
||||
"""
|
||||
根据向量和词典等级查找相似图片
|
||||
参数:
|
||||
db: 数据库会话
|
||||
embedding: 1024维向量
|
||||
dict_level: 词典等级
|
||||
top_k: 返回最相似的K个结果
|
||||
threshold: 相似度阈值 (0.0-1.0)
|
||||
"""
|
||||
# 确保向量是numpy数组
|
||||
if not isinstance(embedding, np.ndarray):
|
||||
embedding = np.array(embedding)
|
||||
|
||||
# 转换为 NumPy 数组提高性能
|
||||
target_embedding = np.array(embedding, dtype=np.float32)
|
||||
|
||||
# 构建查询
|
||||
cosine_distance_expr = Image.embedding.cosine_distance(target_embedding)
|
||||
similarity_expr = (1 - func.cast(cosine_distance_expr, Float)).label("similarity")
|
||||
# 构建查询
|
||||
stmt = select(
|
||||
Image.id,
|
||||
# Image.info,
|
||||
similarity_expr
|
||||
).where(
|
||||
and_(
|
||||
Image.dict_level == dict_level,
|
||||
similarity_expr >= threshold
|
||||
)
|
||||
).order_by(
|
||||
cosine_distance_expr
|
||||
).limit(top_k)
|
||||
|
||||
results = await db.execute(stmt)
|
||||
id_list: List[int] = results.scalars().all()
|
||||
return id_list
|
||||
|
||||
async def add(self, db: AsyncSession, new_image: Image) -> None:
|
||||
db.add(new_image)
|
||||
@@ -122,39 +84,5 @@ class ImageCRUD(CRUDPlus[Image]):
|
||||
async def update(self, db: AsyncSession, id: int, obj: UpdateImageParam) -> int:
|
||||
return await self.update_model(db, id, obj)
|
||||
|
||||
async def find_similar_image_ids(self, db: AsyncSession, embedding: List[float], top_k: int = 3,
|
||||
threshold: float = 0.8) -> List[int]:
|
||||
"""
|
||||
直接通过向量查找相似图片
|
||||
参数:
|
||||
embedding: 1024维向量
|
||||
top_k: 返回最相似的K个结果
|
||||
threshold: 相似度阈值 (0.0-1.0)
|
||||
"""
|
||||
# 确保向量是numpy数组
|
||||
if not isinstance(embedding, np.ndarray):
|
||||
embedding = np.array(embedding)
|
||||
|
||||
# 转换为 NumPy 数组提高性能
|
||||
target_embedding = np.array(embedding, dtype=np.float32)
|
||||
|
||||
# 构建查询
|
||||
cosine_distance_expr = Image.embedding.cosine_distance(target_embedding)
|
||||
similarity_expr = (1 - func.cast(cosine_distance_expr, Float)).label("similarity")
|
||||
# 构建查询
|
||||
stmt = select(
|
||||
Image.id,
|
||||
# Image.info,
|
||||
similarity_expr
|
||||
).where(
|
||||
similarity_expr >= threshold
|
||||
).order_by(
|
||||
cosine_distance_expr
|
||||
).limit(top_k)
|
||||
|
||||
results = await db.execute(stmt)
|
||||
id_list: List[int] = results.scalars().all()
|
||||
return id_list
|
||||
|
||||
|
||||
image_dao: ImageCRUD = ImageCRUD(Image)
|
||||
|
||||
@@ -18,20 +18,11 @@ class Image(Base):
|
||||
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||
file_id: Mapped[Optional[int]] = mapped_column(BigInteger, ForeignKey('file.id'), nullable=True, comment="关联的文件ID")
|
||||
thumbnail_id: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, nullable=True, comment="缩略图ID")
|
||||
embedding: Mapped[Optional[list[float]]] = mapped_column(Vector(1024), default=None, nullable=True) # 1024 维向量
|
||||
info: Mapped[Optional[ImageMetadata]] = mapped_column(PydanticType(pydantic_type=ImageMetadata), default=None, comment="附加元数据") # 其他可能的字段(根据实际需求添加)
|
||||
details: Mapped[Optional[dict]] = mapped_column(JSONB(astext_type=Text()), default=None, comment="其他信息") # 其他信息
|
||||
|
||||
# 表参数 - 包含所有必要的约束
|
||||
__table_args__ = (
|
||||
# 为 embedding 字段添加 HNSW 索引 (pgvector)
|
||||
Index(
|
||||
'idx_image_embedding_hnsw',
|
||||
embedding,
|
||||
postgresql_using='hnsw',
|
||||
postgresql_with={'m': 16, 'ef_construction': 64},
|
||||
postgresql_ops={'embedding': 'vector_l2_ops'} # 修复在此
|
||||
),
|
||||
# 为 thumbnail_id 添加索引以优化查询
|
||||
Index('idx_image_thumbnail_id', 'thumbnail_id'),
|
||||
)
|
||||
|
||||
@@ -313,7 +313,7 @@ class ImageService:
|
||||
usage_allowed = await rate_limit_service.increment_daily_usage(current_user.id, IMAGE_RECOGNITION_SERVICE, DAILY_IMAGE_RECOGNITION_MAX_TIMES)
|
||||
if not usage_allowed:
|
||||
raise errors.ForbiddenError(
|
||||
msg='免费用户每天只能使用3次图像识别服务,请升级为订阅用户以解除限制'
|
||||
msg=f'未订阅用户每天只能使用 {DAILY_IMAGE_RECOGNITION_MAX_TIMES} 次图像识别服务,请升级为订阅用户以解除限制'
|
||||
)
|
||||
|
||||
# 尝试获取任务槽位
|
||||
@@ -414,6 +414,9 @@ class ImageService:
|
||||
|
||||
@staticmethod
|
||||
async def _process_image_with_limiting(task_id: int, user_id: int) -> None:
|
||||
|
||||
from backend.app.ai.tasks import update_task_status_with_retry
|
||||
|
||||
"""带限流控制的后台处理图片识别任务"""
|
||||
try:
|
||||
# 执行图片处理任务
|
||||
@@ -422,6 +425,12 @@ class ImageService:
|
||||
# 任务完成后释放槽位
|
||||
await rate_limit_service.release_task_slot(user_id)
|
||||
|
||||
async with background_db_session.begin() as db:
|
||||
await update_task_status_with_retry(
|
||||
db, task_id, ImageTaskStatus.COMPLETED
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
@staticmethod
|
||||
async def _process_image(task_id: int) -> None:
|
||||
await ImageService._process_image_recognition(task_id)
|
||||
@@ -1039,7 +1048,7 @@ class ImageService:
|
||||
|
||||
# 更新任务状态为完成
|
||||
await update_task_status_with_retry(
|
||||
db, task_id, ImageTaskStatus.COMPLETED,
|
||||
db, task_id, ImageTaskStatus.PROCESSING,
|
||||
result=transformed_result
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Optional, List
|
||||
|
||||
from soupsieve.util import lower
|
||||
@@ -11,8 +12,13 @@ from backend.app.ai.model.image_text import ImageText
|
||||
from backend.app.ai.schema.image_text import CreateImageTextParam, UpdateImageTextParam, ImageTextInitResponseSchema, \
|
||||
ImageTextAssessmentSchema
|
||||
from backend.app.ai.crud.image_curd import image_dao
|
||||
from backend.common.exception import errors
|
||||
from backend.database.db import async_db_session
|
||||
|
||||
# Add imports for ImageProcessingTask functionality
|
||||
from backend.app.ai.model.image_task import ImageTaskStatus, ImageProcessingTask
|
||||
from backend.app.ai.crud.image_task_crud import image_task_dao
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -104,6 +110,19 @@ class ImageTextService:
|
||||
:return: 图片文本记录列表
|
||||
"""
|
||||
async with async_db_session() as db:
|
||||
# First, check if there's an ImageProcessingTask for this image_id
|
||||
# and wait until it reaches a final state (COMPLETED or FAILED)
|
||||
image_task = await image_task_dao.get_by_image_id(db, image_id)
|
||||
if image_task:
|
||||
while image_task.status not in [ImageTaskStatus.COMPLETED, ImageTaskStatus.FAILED]:
|
||||
# Wait for 1 second before checking again
|
||||
await asyncio.sleep(1)
|
||||
# Refresh the task status
|
||||
await db.refresh(image_task)
|
||||
|
||||
if image_task.status == ImageTaskStatus.FAILED:
|
||||
raise errors.ServerError(msg="Image task failed")
|
||||
|
||||
# 获取图片记录
|
||||
image = await image_dao.get(db, image_id)
|
||||
if not image:
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
兑换券批量生成脚本
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
# 添加项目路径到sys.path
|
||||
sys.path.append('.')
|
||||
|
||||
from backend.app.admin.service.coupon_service import CouponService
|
||||
from backend.app.admin.model.coupon import Coupon
|
||||
|
||||
|
||||
async def generate_coupons(count: int, duration: int, expires_days: int = None) -> List[Coupon]:
|
||||
"""
|
||||
批量生成兑换券
|
||||
|
||||
:param count: 生成数量
|
||||
:param duration: 兑换时长(分钟)
|
||||
:param expires_days: 过期天数(可选)
|
||||
:return: 生成的兑换券列表
|
||||
"""
|
||||
print(f"开始生成 {count} 个兑换券,每个兑换券时长为 {duration} 分钟")
|
||||
if expires_days:
|
||||
print(f"兑换券将在 {expires_days} 天后过期")
|
||||
|
||||
coupons = await CouponService.batch_create_coupons(count, duration, expires_days)
|
||||
|
||||
print(f"成功生成 {len(coupons)} 个兑换券:")
|
||||
for coupon in coupons:
|
||||
print(f" - 兑换码: {coupon.code}, 时长: {coupon.duration} 分钟")
|
||||
|
||||
return coupons
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='批量生成兑换券')
|
||||
parser.add_argument('-c', '--count', type=int, required=True, help='生成数量')
|
||||
parser.add_argument('-d', '--duration', type=int, required=True, help='兑换时长(分钟)')
|
||||
parser.add_argument('-e', '--expires', type=int, help='过期天数(可选)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 验证参数
|
||||
if args.count <= 0 or args.count > 10000:
|
||||
print("错误: 兑换券数量必须在1-10000之间")
|
||||
return 1
|
||||
|
||||
if args.duration <= 0:
|
||||
print("错误: 兑换时长必须大于0")
|
||||
return 1
|
||||
|
||||
if args.expires and args.expires <= 0:
|
||||
print("错误: 过期天数必须大于0")
|
||||
return 1
|
||||
|
||||
# 运行异步函数
|
||||
try:
|
||||
asyncio.run(generate_coupons(args.count, args.duration, args.expires))
|
||||
print("兑换券生成完成!")
|
||||
return 0
|
||||
except Exception as e:
|
||||
print(f"生成兑换券时发生错误: {e}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -33,4 +33,4 @@ async def http_limit_callback(request: Request, response: Response, expire: int)
|
||||
:return:
|
||||
"""
|
||||
expires = ceil(expire / 1000)
|
||||
raise errors.HTTPError(code=429, msg='请求过于频繁,请稍后重试', headers={'Retry-After': str(expires)})
|
||||
raise errors.HTTPError(code=429, msg=f'请求过于频繁,请 {str(expires)} 秒后重试', headers={'Retry-After': str(expires)})
|
||||
|
||||
Reference in New Issue
Block a user