first commit

This commit is contained in:
felix
2025-10-18 10:54:08 +08:00
commit a35818e359
194 changed files with 20216 additions and 0 deletions

9
.gitignore vendored Executable file
View File

@@ -0,0 +1,9 @@
__pycache__/
.idea/
backend/.env
.venv/
venv/
backend/alembic/versions/
*.log
.ruff_cache/
backend/static

29
.pre-commit-config.yaml Executable file
View File

@@ -0,0 +1,29 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-added-large-files
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.11.4
hooks:
- id: ruff
args:
- '--config'
- '.ruff.toml'
- '--fix'
- '--unsafe-fixes'
- id: ruff-format
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.6.14
hooks:
- id: uv-lock
- id: uv-export
args:
- '-o'
- 'requirements.txt'
- '--no-hashes'

35
.ruff.toml Executable file
View File

@@ -0,0 +1,35 @@
line-length = 120
unsafe-fixes = true
cache-dir = ".ruff_cache"
target-version = "py310"
[lint]
select = [
"E",
"F",
"W505",
"SIM101",
"SIM114",
"PGH004",
"PLE1142",
"RUF100",
"I002",
"F404",
"TC",
"UP007"
]
preview = true
[lint.isort]
lines-between-types = 1
order-by-type = true
[lint.per-file-ignores]
"**/api/v1/*.py" = ["TC"]
"**/model/*.py" = ["TC003"]
"**/model/__init__.py" = ["F401"]
[format]
preview = true
quote-style = "single"
docstring-code-format = true

30
Dockerfile Executable file
View File

@@ -0,0 +1,30 @@
FROM python:3.10-slim
WORKDIR /fsm
COPY . .
RUN sed -i 's/deb.debian.org/mirrors.ustc.edu.cn/g' /etc/apt/sources.list.d/debian.sources \
&& sed -i 's|security.debian.org/debian-security|mirrors.ustc.edu.cn/debian-security|g' /etc/apt/sources.list.d/debian.sources
RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc python3-dev supervisor \
&& rm -rf /var/lib/apt/lists/* \
# 某些包可能存在同步不及时导致安装失败的情况可更改为官方源https://pypi.org/simple
&& pip install --upgrade pip -i https://mirrors.aliyun.com/pypi/simple \
&& pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple \
&& pip install gunicorn wait-for-it -i https://mirrors.aliyun.com/pypi/simple
ENV TZ="Asia/Shanghai"
RUN mkdir -p /var/log/fastapi_server \
&& mkdir -p /var/log/supervisor \
&& mkdir -p /etc/supervisor/conf.d
COPY deploy/supervisor.conf /etc/supervisor/supervisord.conf
COPY deploy/fastapi_server.conf /etc/supervisor/conf.d/
EXPOSE 8001
CMD ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "8001"]

1144
assets/dict/dictionary_parser.py Executable file

File diff suppressed because it is too large Load Diff

14
backend/.env.example Executable file
View File

@@ -0,0 +1,14 @@
# Env: dev、pro
ENVIRONMENT='dev'
# Database
DATABASE_HOST='127.0.0.1'
DATABASE_PORT=3306
DATABASE_USER='root'
DATABASE_PASSWORD='123456'
# Redis
REDIS_HOST='127.0.0.1'
REDIS_PORT=6379
REDIS_PASSWORD=''
REDIS_DATABASE=0
# Token
TOKEN_SECRET_KEY='1VkVF75nsNABBjK_7-qz7GtzNy3AMvktc9TCPwKczCk'

2
backend/__init__.py Executable file
View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

100
backend/alembic.ini Executable file
View File

@@ -0,0 +1,100 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = alembic
# template used to generate migration files
file_template = %%(year)d-%%(month).2d-%%(day).2d_%%(hour).2d-%%(minute).2d_%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python-dateutil library that can be
# installed by adding `alembic[tz]` to the pip requirements
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator"
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. Valid values are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # default: use os.pathsep
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = postgresql+asyncpg://root:root@127.0.0.1:5432/db
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

1
backend/alembic/README Executable file
View File

@@ -0,0 +1 @@
Generic single-database configuration with an async dbapi.

99
backend/alembic/env.py Executable file
View File

@@ -0,0 +1,99 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ruff: noqa: E402
import asyncio
import os
import sys
from logging.config import fileConfig
from sqlalchemy import engine_from_config, MetaData
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import AsyncEngine
from alembic import context
sys.path.append('../')
from backend.core import path_conf
if not os.path.exists(path_conf.ALEMBIC_VERSION_DIR):
os.makedirs(path_conf.ALEMBIC_VERSION_DIR)
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# https://alembic.sqlalchemy.org/en/latest/autogenerate.html#autogenerating-multiple-metadata-collections
from backend.app.admin.model import MappedBase
target_metadata = [
MappedBase.metadata,
]
# other values from the config, defined by the needs of env.py,
from backend.database.db import SQLALCHEMY_DATABASE_URL
config.set_main_option('sqlalchemy.url', SQLALCHEMY_DATABASE_URL)
def run_migrations_offline():
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option('sqlalchemy.url')
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
dialect_opts={'paramstyle': 'named'},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection):
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online():
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = AsyncEngine(
engine_from_config(
config.get_section(config.config_ini_section),
prefix='sqlalchemy.',
poolclass=pool.NullPool,
future=True,
)
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())

0
backend/alembic/hooks.py Executable file
View File

24
backend/alembic/script.py.mako Executable file
View File

@@ -0,0 +1,24 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade():
${upgrades if upgrades else "pass"}
def downgrade():
${downgrades if downgrades else "pass"}

35
backend/app/__init__.py Executable file
View File

@@ -0,0 +1,35 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os.path
from backend.core.path_conf import BASE_PATH
from backend.utils.import_parse import get_model_objects
def get_app_models() -> list[type]:
"""获取 app 所有模型类"""
app_path = os.path.join(BASE_PATH, 'app')
list_dirs = os.listdir(app_path)
apps = []
for d in list_dirs:
if os.path.isdir(os.path.join(app_path, d)) and d != '__pycache__':
apps.append(d)
objs = []
for app in apps:
module_path = f'backend.app.{app}.model'
obj = get_model_objects(module_path)
if obj:
objs.extend(obj)
return objs
# import all app models for auto create db tables
for cls in get_app_models():
class_name = cls.__name__
if class_name not in globals():
globals()[class_name] = cls

2
backend/app/admin/__init__.py Executable file
View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

26
backend/app/admin/api/router.py Executable file
View File

@@ -0,0 +1,26 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from fastapi import APIRouter
from backend.app.admin.api.v1.wx import router as wx_router
from backend.app.admin.api.v1.wxpay_callback import router as wx_pay_router
from backend.app.admin.api.v1.account import router as account_router
from backend.app.admin.api.v1.file import router as file_router
from backend.app.admin.api.v1.dict import router as dict_router
from backend.app.admin.api.v1.audit_log import router as audit_log_router
from backend.app.admin.api.v1.feedback import router as feedback_router
from backend.app.admin.api.v1.coupon import router as coupon_router
from backend.app.admin.api.v1.notification import router as notification_router
from backend.core.conf import settings
v1 = APIRouter(prefix=settings.FASTAPI_API_V1_PATH)
v1.include_router(account_router, prefix='/account', tags=['账户服务'])
v1.include_router(wx_pay_router, prefix="/wxpay", tags=["WeChat Pay Callback"])
v1.include_router(wx_router, prefix='/wx', tags=['微信服务'])
v1.include_router(file_router, prefix='/file', tags=['文件服务'])
v1.include_router(dict_router, prefix='/dict', tags=['字典服务'])
v1.include_router(audit_log_router, prefix='/audit', tags=['审计日志服务'])
v1.include_router(feedback_router, prefix='/feedback', tags=['反馈服务'])
v1.include_router(coupon_router, prefix='/coupon', tags=['兑换券服务'])
v1.include_router(notification_router, prefix='/notification', tags=['消息通知服务'])

View File

@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,54 @@
# routers/account.py
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from backend.app.admin.schema.usage import PurchaseRequest, SubscriptionRequest
from backend.app.admin.service.ad_share_service import AdShareService
from backend.app.admin.service.refund_service import RefundService
from backend.app.admin.service.subscription_service import SubscriptionService
from backend.app.admin.service.usage_service import UsageService
from backend.common.exception import errors
from backend.common.response.response_schema import response_base
from backend.common.security.jwt import DependsJwtAuth
router = APIRouter()
@router.post("/purchase", dependencies=[DependsJwtAuth])
async def purchase_times_api(
purchase_request: PurchaseRequest,
request: Request
):
await UsageService.purchase_times(request.user.id, purchase_request)
return response_base.success(data={"msg": "充值成功"})
@router.post("/subscribe", dependencies=[DependsJwtAuth])
async def subscribe_api(
sub_request: SubscriptionRequest,
request: Request
):
await SubscriptionService.subscribe(request.user.id, sub_request.plan)
return response_base.success(data={"msg": "订阅成功"})
@router.post("/ad", dependencies=[DependsJwtAuth])
async def ad_grant_api(request: Request):
await AdShareService.grant_times_by_ad(request.user.id)
return response_base.success(data={"msg": "已通过广告获得次数"})
@router.post("/share", dependencies=[DependsJwtAuth])
async def share_grant_api(request: Request):
await AdShareService.grant_times_by_share(request.user.id)
return response_base.success(data={"msg": "已通过分享获得次数"})
@router.post("/refund/{order_id}", dependencies=[DependsJwtAuth])
async def apply_refund(
order_id: int,
request: Request,
reason: str = "用户申请退款"
):
"""
申请退款
"""
try:
result = await RefundService.process_refund(request.user.id, order_id, reason)
return response_base.success(data={"msg": "退款申请已提交", "data": result})
except Exception as e:
raise errors.RequestError(msg=str(e))

View File

@@ -0,0 +1,114 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from fastapi import APIRouter, Depends, Request
from fastapi_pagination import Page
from fastapi_pagination.ext.sqlalchemy import paginate
from backend.app.admin.service.audit_log_service import audit_log_service
from backend.app.admin.schema.audit_log import AuditLogHistorySchema, AuditLogStatisticsSchema, DailySummaryPageSchema
from backend.app.admin.tasks import wx_user_index_history
from backend.common.response.response_schema import response_base, ResponseSchemaModel
from backend.common.security.jwt import DependsJwtAuth
router = APIRouter()
@router.get("/history", summary="获取用户识别历史记录", dependencies=[DependsJwtAuth])
async def get_user_recognition_history(
request:Request,
page: int = 1,
size: int = 20
):
"""
通过用户ID查询历史记录
"""
history_items, total = await audit_log_service.get_user_recognition_history(
user_id=request.user.id,
page=page,
size=size
)
# await wx_user_index_history()
# 创建分页结果
result = {
"items": [item.model_dump() for item in history_items],
"total": total,
"page": page,
"size": size
}
return response_base.success(data=result)
@router.get("/statistics", summary="获取用户识别统计信息", dependencies=[DependsJwtAuth])
async def get_user_recognition_statistics(request:Request):
"""
统计用户 recognition 类型的使用记录
返回历史总量和当天总量
"""
total_count, today_count, image_count = await audit_log_service.get_user_recognition_statistics(user_id=request.user.id)
result = AuditLogStatisticsSchema(
total_count=total_count,
today_count=today_count,
image_count=image_count
)
return response_base.success(data=result)
@router.get("/summary", summary="获取用户每日识别汇总记录", dependencies=[DependsJwtAuth])
async def get_user_daily_summary(
request: Request,
page: int = 1,
size: int = 20
):
"""
通过用户ID查询每日识别汇总记录按创建时间降序排列
"""
summary_items, total = await audit_log_service.get_user_daily_summaries(
user_id=request.user.id,
page=page,
size=size
)
# await wx_user_index_history()
# 创建分页结果
result = DailySummaryPageSchema(
items=summary_items,
total=total,
page=page,
size=size
)
return response_base.success(data=result)
@router.get("/today_summary", summary="获取用户今日识别汇总记录", dependencies=[DependsJwtAuth])
async def get_user_today_summary(
request: Request,
page: int = 1,
size: int = 20
):
"""
通过用户ID查询每日识别汇总记录按创建时间降序排列
"""
summary_items, total = await audit_log_service.get_user_today_summaries(
user_id=request.user.id,
page=page,
size=size
)
# await wx_user_index_history()
# 创建分页结果
result = DailySummaryPageSchema(
items=summary_items,
total=total,
page=page,
size=size
)
return response_base.success(data=result)

View File

@@ -0,0 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from fastapi import APIRouter
from backend.app.admin.api.v1.auth.auth import router as auth_router
from backend.app.admin.api.v1.auth.captcha import router as captcha_router
router = APIRouter(prefix='/auth')
router.include_router(auth_router, tags=['授权'])
router.include_router(captcha_router, prefix='/captcha', tags=['验证码'])

View File

@@ -0,0 +1,30 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from fastapi import APIRouter, Depends, Request
from fastapi.security import OAuth2PasswordRequestForm
from backend.app.admin.service.auth_service import auth_service
from backend.common.security.jwt import DependsJwtAuth
from backend.common.response.response_schema import response_base, ResponseModel, ResponseSchemaModel
from backend.app.admin.schema.token import GetSwaggerToken, GetLoginToken
from backend.app.admin.schema.user import AuthLoginParam
router = APIRouter()
@router.post('/login/swagger', summary='swagger 调试专用', description='用于快捷进行 swagger 认证')
async def swagger_login(form_data: OAuth2PasswordRequestForm = Depends()) -> GetSwaggerToken:
token, user = await auth_service.swagger_login(form_data=form_data)
return GetSwaggerToken(access_token=token, user=user) # type: ignore
@router.post('/login', summary='验证码登录')
async def user_login(request: Request, obj: AuthLoginParam) -> ResponseSchemaModel[GetLoginToken]:
data = await auth_service.login(request=request, obj=obj)
return response_base.success(data=data)
@router.post('/logout', summary='用户登出', dependencies=[DependsJwtAuth])
async def user_logout() -> ResponseModel:
return response_base.success()

View File

@@ -0,0 +1,36 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from fast_captcha import img_captcha
from fastapi import APIRouter, Depends, Request
from fastapi_limiter.depends import RateLimiter
from starlette.concurrency import run_in_threadpool
from backend.app.admin.schema.captcha import GetCaptchaDetail
from backend.common.response.response_schema import ResponseSchemaModel, response_base
from backend.core.conf import settings
from backend.database.db import uuid4_str
from backend.database.redis import redis_client
router = APIRouter()
@router.get(
'',
summary='获取登录验证码',
dependencies=[Depends(RateLimiter(times=5, seconds=10))],
)
async def get_captcha(request: Request) -> ResponseSchemaModel[GetCaptchaDetail]:
"""
此接口可能存在性能损耗尽管是异步接口但是验证码生成是IO密集型任务使用线程池尽量减少性能损耗
"""
img_type: str = 'base64'
img, code = await run_in_threadpool(img_captcha, img_byte=img_type)
uuid = uuid4_str()
request.app.state.captcha_uuid = uuid
await redis_client.set(
f'{settings.CAPTCHA_LOGIN_REDIS_PREFIX}:{uuid}',
code,
ex=settings.CAPTCHA_LOGIN_EXPIRE_SECONDS,
)
data = GetCaptchaDetail(image_type=img_type, image=img)
return response_base.success(data=data)

View File

@@ -0,0 +1,67 @@
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel, Field
from typing import List, Optional
from backend.app.admin.service.coupon_service import CouponService
from backend.common.response.response_schema import response_base
from backend.common.security.jwt import DependsJwtAuth
router = APIRouter()
class RedeemCouponRequest(BaseModel):
code: str = Field(..., min_length=1, max_length=32, description="兑换码")
class CreateCouponRequest(BaseModel):
duration: int = Field(..., gt=0, description="兑换时长(分钟)")
count: int = Field(1, ge=1, le=1000, description="生成数量")
expires_days: Optional[int] = Field(None, ge=1, description="过期天数")
class CouponHistoryResponse(BaseModel):
code: str
duration: int
used_at: str
@router.post("/redeem", dependencies=[DependsJwtAuth])
async def redeem_coupon_api(
request: Request,
redeem_request: RedeemCouponRequest
):
"""
兑换兑换券
"""
result = await CouponService.redeem_coupon(redeem_request.code, request.user.id)
return response_base.success(data=result)
@router.get("/history", dependencies=[DependsJwtAuth])
async def get_coupon_history_api(
request: Request,
limit: int = 100
):
"""
获取用户兑换历史
"""
history = await CouponService.get_user_coupon_history(request.user.id, limit)
return response_base.success(data=history)
# 管理员接口,用于批量生成兑换券
@router.post("/generate", dependencies=[DependsJwtAuth])
async def generate_coupons_api(
request: Request,
create_request: CreateCouponRequest
):
"""
批量生成兑换券(管理员接口)
"""
# 这里应该添加管理员权限验证
# 为简化示例,暂时省略权限验证
coupons = await CouponService.batch_create_coupons(
create_request.count,
create_request.duration,
create_request.expires_days
)
return response_base.success(data={
"count": len(coupons),
"message": f"成功生成{len(coupons)}个兑换券"
})

View File

@@ -0,0 +1,71 @@
from fastapi import APIRouter, Query, Path, Response
from fastapi.responses import StreamingResponse
import io
from backend.app.admin.schema.dict import DictWordResponse
from backend.app.admin.service.dict_service import dict_service
from backend.common.response.response_schema import response_base, ResponseSchemaModel
from backend.common.security.jwt import DependsJwtAuth
from backend.middleware import tencent_cloud
router = APIRouter()
@router.get("/word/{word}", summary="根据单词查询字典信息", dependencies=[DependsJwtAuth])
async def get_word_info(
word: str = Path(..., description="要查询的单词", min_length=1, max_length=100)
) -> ResponseSchemaModel[DictWordResponse]:
"""
根据单词查询字典信息
Args:
word: 要查询的单词
Returns:
包含单词详细信息的响应,包括:
- senses: 义项列表
- frequency: 频率信息
- pronunciations: 发音信息
"""
# from backend.middleware.tencent_cloud import TencentCloud
# data = await TencentCloud().text_to_speak(image_id=2088374826979950592,content='green vegetable cut into small pieces and added to food',image_text_id=2088375029640331267,user_id=2083326996703739904)
result = await dict_service.get_word_info(word)
return response_base.success(data=result)
@router.get("/check/{word}", summary="检查单词是否存在", dependencies=[DependsJwtAuth])
async def check_word_exists(
word: str = Path(..., description="要检查的单词", min_length=1, max_length=100)
) -> ResponseSchemaModel[bool]:
"""
检查单词是否在字典中存在
Args:
word: 要检查的单词
Returns:
布尔值,表示单词是否存在
"""
result = await dict_service.check_word_exists(word)
return response_base.success(data=result)
@router.get("/audio/{file_name:path}", summary="播放音频文件", dependencies=[DependsJwtAuth])
async def play_audio(
file_name: str = Path(..., description="音频文件名", min_length=1, max_length=255)
) -> StreamingResponse:
"""
根据文件名播放音频文件
Args:
file_name: 音频文件名
Returns:
音频文件流
"""
audio_data = await dict_service.get_audio_data(file_name)
if not audio_data:
# 返回空的响应或默认音频
return StreamingResponse(io.BytesIO(b""), media_type="audio/mpeg")
return StreamingResponse(io.BytesIO(audio_data), media_type="audio/mpeg")

View File

@@ -0,0 +1,57 @@
from fastapi import APIRouter, Depends, Request
from backend.app.admin.schema.feedback import CreateFeedbackParam, UpdateFeedbackParam, FeedbackInfoSchema
from backend.app.admin.service.feedback_service import feedback_service
from backend.common.response.response_schema import response_base, ResponseModel, ResponseSchemaModel
from backend.common.security.jwt import DependsJwtAuth
router = APIRouter()
@router.post('', summary='创建用户反馈', dependencies=[DependsJwtAuth])
async def create_feedback(request: Request, obj: CreateFeedbackParam) -> ResponseSchemaModel[FeedbackInfoSchema]:
"""创建用户反馈"""
# 从JWT中获取用户ID
user_id = request.user.id
feedback = await feedback_service.create_feedback(user_id, obj)
return response_base.success(data=feedback)
@router.get('/{feedback_id}', summary='获取反馈详情', dependencies=[DependsJwtAuth])
async def get_feedback(feedback_id: int) -> ResponseSchemaModel[FeedbackInfoSchema]:
"""获取反馈详情"""
feedback = await feedback_service.get_feedback(feedback_id)
if not feedback:
return response_base.fail(msg='反馈不存在')
return response_base.success(data=feedback)
@router.get('', summary='获取反馈列表', dependencies=[DependsJwtAuth])
async def get_feedback_list(
user_id: int = None,
status: str = None,
category: str = None,
limit: int = 10,
offset: int = 0
) -> ResponseSchemaModel[list[FeedbackInfoSchema]]:
"""获取反馈列表"""
feedbacks = await feedback_service.get_feedback_list(user_id, status, category, limit, offset)
return response_base.success(data=feedbacks)
@router.put('/{feedback_id}', summary='更新反馈状态', dependencies=[DependsJwtAuth])
async def update_feedback(feedback_id: int, obj: UpdateFeedbackParam) -> ResponseModel:
"""更新反馈状态"""
success = await feedback_service.update_feedback(feedback_id, obj)
if not success:
return response_base.fail(msg='反馈不存在或更新失败')
return response_base.success()
@router.delete('/{feedback_id}', summary='删除反馈', dependencies=[DependsJwtAuth])
async def delete_feedback(feedback_id: int) -> ResponseModel:
"""删除反馈"""
success = await feedback_service.delete_feedback(feedback_id)
if not success:
return response_base.fail(msg='反馈不存在或删除失败')
return response_base.success()

View File

@@ -0,0 +1,41 @@
from fastapi import APIRouter, UploadFile, File, Query, Depends, Response
from backend.app.admin.schema.file import FileUploadResponse
from backend.app.admin.service.file_service import file_service
from backend.common.response.response_schema import response_base, ResponseSchemaModel
from backend.common.security.jwt import DependsJwtAuth
router = APIRouter()
@router.post("/upload", summary="上传文件", dependencies=[DependsJwtAuth])
async def upload_file(
file: UploadFile = File(...),
) -> ResponseSchemaModel[FileUploadResponse]:
"""上传文件"""
result = await file_service.upload_file(file)
return response_base.success(data=result)
@router.get("/{file_id}", summary="下载文件", dependencies=[DependsJwtAuth])
# @router.get("/{file_id}", summary="下载文件")
async def download_file(file_id: int) -> Response:
"""下载文件"""
try:
content, filename, content_type = await file_service.download_file(file_id)
headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Type": content_type
}
return Response(content=content, headers=headers)
except Exception as e:
return Response(content=str(e), status_code=404)
@router.delete("/{file_id}", summary="删除文件", dependencies=[DependsJwtAuth])
async def delete_file(file_id: int) -> ResponseSchemaModel[bool]:
"""删除文件"""
result = await file_service.delete_file(file_id)
if not result:
return await response_base.fail(message="文件不存在或删除失败")
return await response_base.success(data=result)

View File

@@ -0,0 +1,97 @@
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel, Field
from typing import List, Optional
from backend.app.admin.service.notification_service import NotificationService
from backend.common.response.response_schema import response_base
from backend.common.security.jwt import DependsJwtAuth
router = APIRouter()
class CreateNotificationRequest(BaseModel):
title: str = Field(..., min_length=1, max_length=255, description="通知标题")
content: str = Field(..., min_length=1, description="通知内容")
image_url: Optional[str] = Field(None, max_length=512, description="图片URL可选")
class MarkAsReadRequest(BaseModel):
notification_ids: List[int] = Field(..., description="要标记为已读的通知ID列表")
class NotificationResponse(BaseModel):
id: int
notification_id: int
title: str
content: str
image_url: Optional[str]
is_read: Optional[bool]
received_at: str
read_at: Optional[str]
@router.get("/list", dependencies=[DependsJwtAuth])
async def get_notifications_api(
request: Request,
limit: int = 100
):
"""
获取用户消息通知列表
"""
notifications = await NotificationService.get_user_notifications(request.user.id, limit)
return response_base.success(data=notifications)
@router.get("/unread", dependencies=[DependsJwtAuth])
async def get_unread_notifications_api(
request: Request,
limit: int = 100
):
"""
获取用户未读消息通知列表
"""
notifications = await NotificationService.get_unread_notifications(request.user.id, limit)
return response_base.success(data=notifications)
@router.get("/unread/count", dependencies=[DependsJwtAuth])
async def get_unread_count_api(
request: Request
):
"""
获取用户未读消息通知数量
"""
count = await NotificationService.get_unread_count(request.user.id)
return response_base.success(data={"count": count})
@router.post("/{notification_id}/read", dependencies=[DependsJwtAuth])
async def mark_as_read_api(
request: Request,
notification_id: int
):
"""
标记指定通知为已读
"""
success = await NotificationService.mark_notification_as_read(notification_id, request.user.id)
if success:
return response_base.success(data={"msg": "标记成功"})
else:
return response_base.fail(data={"msg": "标记失败"})
# 管理员接口,用于创建通知
@router.post("/create", dependencies=[DependsJwtAuth])
async def create_notification_api(
request: Request,
create_request: CreateNotificationRequest
):
"""
创建消息通知(管理员接口)
"""
# 这里应该添加管理员权限验证
# 为简化示例,暂时省略权限验证
notification = await NotificationService.create_notification(
create_request.title,
create_request.content,
create_request.image_url,
request.user.id
)
return response_base.success(data={
"id": notification.id,
"msg": "通知创建成功"
})

View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Annotated
from fastapi import APIRouter, Query
from backend.common.security.jwt import DependsJwtAuth
from backend.common.pagination import paging_data, DependsPagination, PageData
from backend.common.response.response_schema import response_base, ResponseModel, ResponseSchemaModel
from backend.database.db import CurrentSession
from backend.app.admin.schema.user import (
RegisterUserParam,
GetUserInfoDetail,
ResetPassword,
UpdateUserParam,
AvatarParam,
)
from backend.app.admin.service.wx_user_service import WxUserService
router = APIRouter()
@router.post('/register', summary='用户注册')
async def user_register(obj: RegisterUserParam) -> ResponseModel:
await WxUserService.register(obj=obj)
return response_base.success()
@router.post('/password/reset', summary='密码重置', dependencies=[DependsJwtAuth])
async def password_reset(obj: ResetPassword) -> ResponseModel:
count = await WxUserService.pwd_reset(obj=obj)
if count > 0:
return response_base.success()
return response_base.fail()
@router.get('/{username}', summary='查看用户信息', dependencies=[DependsJwtAuth])
async def get_user(username: str) -> ResponseSchemaModel[GetUserInfoDetail]:
data = await WxUserService.get_userinfo(username=username)
return response_base.success(data=data)
@router.put('/{username}', summary='更新用户信息', dependencies=[DependsJwtAuth])
async def update_userinfo(username: str, obj: UpdateUserParam) -> ResponseModel:
count = await WxUserService.update(username=username, obj=obj)
if count > 0:
return response_base.success()
return response_base.fail()
@router.put('/{username}/avatar', summary='更新头像', dependencies=[DependsJwtAuth])
async def update_avatar(username: str, avatar: AvatarParam) -> ResponseModel:
count = await WxUserService.update_avatar(username=username, avatar=avatar)
if count > 0:
return response_base.success()
return response_base.fail()
@router.get(
'',
summary='(模糊条件)分页获取所有用户',
dependencies=[
DependsJwtAuth,
DependsPagination,
],
)
async def get_all_users(
db: CurrentSession,
username: Annotated[str | None, Query()] = None,
phone: Annotated[str | None, Query()] = None,
status: Annotated[int | None, Query()] = None,
) -> ResponseSchemaModel[PageData[GetUserInfoDetail]]:
user_select = await WxUserService.get_select(username=username, phone=phone, status=status)
page_data = await paging_data(db, user_select)
return response_base.success(data=page_data)
# @router.delete(
# path='/{username}',
# summary='用户注销',
# description='用户注销 != 用户登出,注销之后用户将从数据库删除',
# dependencies=[DependsJwtAuth],
# )
# async def delete_user(current_user: CurrentUser, username: str) -> ResponseModel:
# count = await WxUserService.delete(current_user=current_user, username=username)
# if count > 0:
# return response_base.success()
# return response_base.fail()

95
backend/app/admin/api/v1/wx.py Executable file
View File

@@ -0,0 +1,95 @@
# routers/wx.py
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi_limiter.depends import RateLimiter
from backend.app.admin.schema.token import GetWxLoginToken
from backend.app.admin.schema.wx import WxLoginRequest, TokenResponse, UserInfo, UpdateUserSettingsRequest, GetUserSettingsResponse, DictLevel
from backend.app.admin.service.wx_service import wx_service
from backend.common.response.response_schema import response_base, ResponseSchemaModel
from backend.core.wx_integration import verify_wx_code
router = APIRouter()
@router.post("/login",
summary="微信登录",
dependencies=[Depends(RateLimiter(times=5, minutes=1))])
async def wechat_login(
request: Request, response: Response,
wx_request: WxLoginRequest
) -> ResponseSchemaModel[GetWxLoginToken]:
"""
微信小程序登录接口
- **code**: 微信小程序前端获取的临时code
- **encrypted_data** (可选): 加密的用户信息
- **iv** (可选): 加密算法的初始向量
"""
# 验证微信code并获取用户信息
wx_result = await verify_wx_code(wx_request.code)
if not wx_result:
raise HTTPException(status_code=401, detail="微信认证失败")
# 处理用户登录逻辑
result = await wx_service.login(
request=request, response=response,
openid=wx_result.get("openid"),
session_key=wx_result.get("session_key"),
encrypted_data=wx_request.encrypted_data,
iv=wx_request.iv
)
return response_base.success(data=result)
# @router.put("/settings", summary="更新用户设置", dependencies=[DependsJwtAuth])
# async def update_user_settings(
# request: Request,
# settings: UpdateUserSettingsRequest
# ) -> ResponseSchemaModel[None]:
# """
# 更新用户设置
# """
#
# # 更新用户设置
# await wx_service.update_user_settings(
# user_id=request.user.id,
# dict_level=settings.dict_level
# )
#
# return response_base.success()
#
#
# @router.get("/settings", summary="获取用户设置", dependencies=[DependsJwtAuth])
# async def get_user_settings(
# request: Request
# ) -> ResponseSchemaModel[GetUserSettingsResponse]:
# """
# 获取用户设置
# """
# # 从请求中获取用户ID实际项目中应该从JWT token中获取
# user_id = getattr(request.state, 'user_id', None)
# if not user_id:
# raise HTTPException(status_code=401, detail="未授权访问")
#
# # 获取用户信息
# async with async_db_session() as db:
# user = await wx_user_dao.get(db, user_id)
# if not user:
# raise HTTPException(status_code=404, detail="用户不存在")
#
# # 提取设置信息
# dict_level = None
# if user.profile and isinstance(user.profile, dict):
# dict_level_value = user.profile.get("dict_level")
# if dict_level_value:
# # 将字符串值转换为枚举值
# try:
# dict_level = DictLevel(dict_level_value)
# except ValueError:
# pass # 无效值保持为None
#
# response_data = GetUserSettingsResponse(
# dict_level=dict_level
# )
#
# return response_base.success(data=response_data)

View File

@@ -0,0 +1,150 @@
from fastapi import APIRouter, Request, BackgroundTasks
from wechatpy.exceptions import InvalidSignatureException
from backend.app.admin.crud.order_crud import order_dao
from backend.app.admin.crud.user_account_crud import user_account_dao
from backend.app.admin.model import Order, UserAccount
from backend.app.admin.service.refund_service import RefundService
from backend.app.admin.service.subscription_service import SUBSCRIPTION_PLANS
from backend.database.db import async_db_session
from backend.common.log import log as logger
from backend.app.admin.crud.usage_log_crud import usage_log_dao
from sqlalchemy import select, update
from datetime import datetime
import hashlib
from backend.utils.wx_pay import wx_pay_utils
router = APIRouter()
def verify_wxpay_signature(data: dict, api_key: str) -> bool:
"""
验证微信支付回调签名
"""
try:
# 获取签名
sign = data.pop('sign', None)
if not sign:
return False
# 重新计算签名
sorted_keys = sorted(data.keys())
stringA = '&'.join(f"{key}={data[key]}" for key in sorted_keys if data[key])
stringSignTemp = f"{stringA}&key={api_key}"
calculated_sign = hashlib.md5(stringSignTemp.encode('utf-8')).hexdigest().upper()
return sign == calculated_sign
except Exception as e:
logger.error(f"签名验证异常: {str(e)}")
return False
@router.post("/notify")
async def wxpay_notify(request: Request):
"""
微信支付异步通知处理(增强安全性和幂等性)
"""
try:
# 读取原始数据
body = await request.body()
body_str = body.decode('utf-8')
# 解析微信回调数据
result = wx_pay_utils.parse_payment_result(body)
# 验证签名(增强安全性)
from backend.core.conf import settings
if not verify_wxpay_signature(result.copy(), settings.WECHAT_PAY_API_KEY):
logger.warning("微信支付回调签名验证失败")
return {"return_code": "FAIL", "return_msg": "签名验证失败"}
if result['return_code'] == 'SUCCESS' and result['result_code'] == 'SUCCESS':
out_trade_no = result['out_trade_no']
transaction_id = result['transaction_id']
async with async_db_session.begin() as db:
# 使用SELECT FOR UPDATE确保并发安全
stmt = select(Order).where(Order.id == int(out_trade_no)).with_for_update()
order_result = await db.execute(stmt)
order = order_result.scalar_one_or_none()
if not order:
logger.warning(f"订单不存在: {out_trade_no}")
return {"return_code": "SUCCESS"}
# 幂等性检查
if order.processed_at is not None:
logger.info(f"订单已处理过: {out_trade_no}")
return {"return_code": "SUCCESS"}
if order.status != 'pending':
logger.warning(f"订单状态异常: {out_trade_no}, status: {order.status}")
return {"return_code": "SUCCESS"}
# 更新订单状态
order.status = 'completed'
order.transaction_id = transaction_id
order.processed_at = datetime.now()
await order_dao.update(db, order.id, order)
# 如果是订阅订单,更新用户订阅信息
if order.order_type == 'subscription':
# 使用SELECT FOR UPDATE锁定用户账户
account_stmt = select(UserAccount).where(UserAccount.user_id == order.user_id).with_for_update()
account_result = await db.execute(account_stmt)
user_account = account_result.scalar_one_or_none()
if user_account:
plan = None
for key, value in SUBSCRIPTION_PLANS.items():
if value['price'] == order.amount_cents:
plan = key
break
if plan:
# 处理未用完次数(仅累计一个月)
new_balance = user_account.balance + user_account.carryover_balance
carryover = 0
# 如果是续费且当期次数未用完,则累计到下一期
if (user_account.subscription_type and
user_account.subscription_expires_at and
user_account.subscription_expires_at > datetime.now()):
# 计算当期剩余次数
remaining = max(0, user_account.balance)
carryover = min(remaining, SUBSCRIPTION_PLANS[plan]["times"])
# 更新订阅信息
user_account.subscription_type = plan
user_account.subscription_expires_at = datetime.now() + SUBSCRIPTION_PLANS[plan]["duration"]
user_account.balance = new_balance + SUBSCRIPTION_PLANS[plan]["times"]
user_account.carryover_balance = carryover
await user_account_dao.update(db, user_account.id, user_account)
# 记录使用日志
account = await user_account_dao.get_by_user_id(db, order.user_id)
await usage_log_dao.add(db, {
"user_id": order.user_id,
"action": "purchase" if order.order_type == "purchase" else "renewal",
"amount": order.amount_times,
"balance_after": account.balance if account else 0,
"related_id": order.id,
"details": {"transaction_id": transaction_id}
})
return {"return_code": "SUCCESS"}
else:
logger.error(f"微信支付回调失败: {result}")
return {"return_code": "FAIL", "return_msg": "处理失败"}
except Exception as e:
logger.error(f"微信支付回调处理异常: {str(e)}")
return {"return_code": "FAIL", "return_msg": "服务器异常"}
@router.post("/refund/notify")
async def wxpay_refund_notify(request: Request):
"""
微信退款异步通知处理
"""
return await RefundService.handle_refund_notify(request)

View File

@@ -0,0 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from backend.app.admin.crud.file_crud import file_dao
from backend.app.admin.crud.daily_summary_crud import daily_summary_dao
from backend.app.admin.crud.points_crud import points_dao, points_log_dao

View File

@@ -0,0 +1,204 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime, timedelta, date
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.admin.model.audit_log import AuditLog, DailySummary
from backend.app.admin.schema.audit_log import CreateAuditLogParam, AuditLogStatisticsSchema, CreateDailySummaryParam
from backend.app.ai import Image
class CRUDAuditLog(CRUDPlus[AuditLog]):
async def create(self, db: AsyncSession, obj: CreateAuditLogParam) -> None:
"""
创建操作日志
:param db: 数据库会话
:param obj: 创建操作日志参数
:return:
"""
await self.create_model(db, obj)
async def get_user_recognition_history(self, db: AsyncSession, user_id: int, page: int = 1, size: int = 20):
"""
通过用户ID查询历史记录联合image表查找api_type='recognition'的记录
一个image可能有多个audit_log记录找出其中called_at最早的
支持分页查询
返回的数据有image.thumbnail_id, image.created_time, audit_log.dict_level
:param db: 数据库会话
:param user_id: 用户ID
:param page: 页码
:param size: 每页数量
:return: 查询结果和总数
"""
# 子查询找出每个image_id对应的最早called_at记录
subquery = (
select(
AuditLog.image_id,
func.min(AuditLog.called_at).label('earliest_called_at')
)
.where(AuditLog.user_id == user_id, AuditLog.api_type == 'recognition')
.group_by(AuditLog.image_id)
.subquery()
)
# 主查询关联image表和audit_log表获取所需字段
stmt = (
select(
Image.id,
Image.thumbnail_id,
Image.file_id,
Image.created_time,
AuditLog.dict_level
)
.join(Image, AuditLog.image_id == Image.id)
.join(
subquery,
(AuditLog.image_id == subquery.c.image_id) &
(AuditLog.called_at == subquery.c.earliest_called_at)
)
.where(AuditLog.user_id == user_id, AuditLog.api_type == 'recognition')
.order_by(AuditLog.called_at.desc(), AuditLog.id.desc())
.offset((page - 1) * size)
.limit(size)
)
result = await db.execute(stmt)
items = result.fetchall()
# 获取总数
count_stmt = (
select(func.count(func.distinct(AuditLog.image_id)))
.where(AuditLog.user_id == user_id, AuditLog.api_type == 'recognition')
)
total_result = await db.execute(count_stmt)
total = total_result.scalar()
return items, total
async def get_user_today_recognition_history(self, db: AsyncSession, user_id: int, page: int = 1, size: int = 20):
"""
通过用户ID查询当天的识别记录联合image表查找api_type='recognition'的记录
一个image可能有多个audit_log记录找出其中called_at最早的
支持分页查询
返回的数据有image.thumbnail_id, image.created_time, audit_log.dict_level
:param db: 数据库会话
:param user_id: 用户ID
:param page: 页码
:param size: 每页数量
:return: 查询结果和总数
"""
# 获取当天的开始时间
today = date.today()
today_start = datetime(today.year, today.month, today.day)
# 子查询找出每个image_id对应的最早called_at记录仅当天
subquery = (
select(
AuditLog.image_id,
func.min(AuditLog.called_at).label('earliest_called_at')
)
.where(
AuditLog.user_id == user_id,
AuditLog.api_type == 'recognition',
AuditLog.called_at >= today_start
)
.group_by(AuditLog.image_id)
.subquery()
)
# 主查询关联image表和audit_log表获取所需字段仅当天
stmt = (
select(
Image.id,
Image.thumbnail_id,
Image.file_id,
Image.created_time,
AuditLog.dict_level
)
.join(Image, AuditLog.image_id == Image.id)
.join(
subquery,
(AuditLog.image_id == subquery.c.image_id) &
(AuditLog.called_at == subquery.c.earliest_called_at)
)
.where(
AuditLog.user_id == user_id,
AuditLog.api_type == 'recognition',
AuditLog.called_at >= today_start
)
.order_by(AuditLog.called_at.desc(), AuditLog.id.desc())
.offset((page - 1) * size)
.limit(size)
)
result = await db.execute(stmt)
items = result.fetchall()
# 获取当天总数
count_stmt = (
select(func.count(func.distinct(AuditLog.image_id)))
.where(
AuditLog.user_id == user_id,
AuditLog.api_type == 'recognition',
AuditLog.called_at >= today_start
)
)
total_result = await db.execute(count_stmt)
total = total_result.scalar()
return items, total
async def get_user_recognition_statistics(self, db: AsyncSession, user_id: int):
"""
统计用户 recognition 类型的使用记录
返回历史总量和当天总量
:param db: 数据库会话
:param user_id: 用户ID
:return: (历史总量, 当天总量)
"""
# 获取历史总量
total_stmt = (
select(func.count())
.where(AuditLog.user_id == user_id, AuditLog.api_type == 'recognition')
)
total_result = await db.execute(total_stmt)
total_count = total_result.scalar()
# 获取当天总量
today = date.today()
today_start = datetime(today.year, today.month, today.day)
today_stmt = (
select(func.count())
.where(
AuditLog.user_id == user_id,
AuditLog.api_type == 'recognition',
AuditLog.called_at >= today_start
)
)
today_result = await db.execute(today_stmt)
today_count = today_result.scalar()
# 获取总数
count_stmt = (
select(func.count(func.distinct(AuditLog.image_id)))
.where(AuditLog.user_id == user_id, AuditLog.api_type == 'recognition')
)
count_stmt = (
select(func.count(func.distinct(AuditLog.image_id)))
.where(AuditLog.user_id == user_id, AuditLog.api_type == 'recognition')
)
count_result = await db.execute(count_stmt)
image_count = count_result.scalar()
return total_count, today_count, image_count
audit_log_dao: CRUDAuditLog = CRUDAuditLog(AuditLog)

View File

@@ -0,0 +1,105 @@
from typing import Optional, List
from sqlalchemy import select, and_, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.admin.model.coupon import Coupon, CouponUsage
from datetime import datetime
class CouponDao(CRUDPlus[Coupon]):
async def get_by_code(self, db: AsyncSession, code: str) -> Optional[Coupon]:
"""
根据兑换码获取兑换券
"""
stmt = select(Coupon).where(Coupon.code == code)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_unused_coupon_by_code(self, db: AsyncSession, code: str) -> Optional[Coupon]:
"""
根据兑换码获取未使用的兑换券
"""
stmt = select(Coupon).where(
and_(
Coupon.code == code,
Coupon.is_used == False,
(Coupon.expires_at.is_(None)) | (Coupon.expires_at > datetime.now())
)
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def create_coupon(self, db: AsyncSession, coupon_data: dict) -> Coupon:
"""
创建兑换券
"""
coupon = Coupon(**coupon_data)
db.add(coupon)
await db.flush()
return coupon
async def create_coupons(self, db: AsyncSession, coupons_data: List[dict]) -> List[Coupon]:
"""
批量创建兑换券
"""
coupons = [Coupon(**data) for data in coupons_data]
db.add_all(coupons)
await db.flush()
return coupons
async def mark_as_used(self, db: AsyncSession, coupon_id: int, user_id: int, duration: int) -> bool:
"""
标记兑换券为已使用并创建使用记录
"""
# 更新兑换券状态
stmt = update(Coupon).where(Coupon.id == coupon_id).values(is_used=True)
result = await db.execute(stmt)
if result.rowcount == 0:
return False
# 创建使用记录
usage = CouponUsage(
coupon_id=coupon_id,
user_id=user_id,
duration=duration
)
db.add(usage)
await db.flush()
return True
async def get_user_coupons(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[CouponUsage]:
"""
获取用户使用的兑换券记录
"""
stmt = select(CouponUsage).where(
CouponUsage.user_id == user_id
).order_by(CouponUsage.used_at.desc()).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
class CouponUsageDao(CRUDPlus[CouponUsage]):
async def get_by_user_id(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[CouponUsage]:
"""
根据用户ID获取兑换记录
"""
stmt = select(CouponUsage).where(
CouponUsage.user_id == user_id
).order_by(CouponUsage.used_at.desc()).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def get_by_coupon_id(self, db: AsyncSession, coupon_id: int) -> Optional[CouponUsage]:
"""
根据兑换券ID获取使用记录
"""
stmt = select(CouponUsage).where(CouponUsage.coupon_id == coupon_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
coupon_dao = CouponDao(Coupon)
coupon_usage_dao = CouponUsageDao(CouponUsage)

View File

@@ -0,0 +1,119 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Sequence
from sqlalchemy import Select, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.admin.model import DataRule, DataScope
from backend.app.admin.schema.data_scope import CreateDataScopeParam, UpdateDataScopeParam, UpdateDataScopeRuleParam
class CRUDDataScope(CRUDPlus[DataScope]):
"""数据范围数据库操作类"""
async def get(self, db: AsyncSession, pk: int) -> DataScope | None:
"""
获取数据范围详情
:param db: 数据库会话
:param pk: 范围 ID
:return:
"""
return await self.select_model(db, pk)
async def get_by_name(self, db: AsyncSession, name: str) -> DataScope | None:
"""
通过名称获取数据范围
:param db: 数据库会话
:param name: 范围名称
:return:
"""
return await self.select_model_by_column(db, name=name)
async def get_with_relation(self, db: AsyncSession, pk: int) -> DataScope:
"""
获取数据范围关联数据
:param db: 数据库会话
:param pk: 范围 ID
:return:
"""
return await self.select_model(db, pk, load_strategies=['rules'])
async def get_all(self, db: AsyncSession) -> Sequence[DataScope]:
"""
获取所有数据范围
:param db: 数据库会话
:return:
"""
return await self.select_models(db)
async def get_list(self, name: str | None, status: int | None) -> Select:
"""
获取数据范围列表
:param name: 范围名称
:param status: 范围状态
:return:
"""
filters = {}
if name is not None:
filters['name__like'] = f'%{name}%'
if status is not None:
filters['status'] = status
return await self.select_order('id', load_strategies={'rules': 'noload', 'roles': 'noload'}, **filters)
async def create(self, db: AsyncSession, obj: CreateDataScopeParam) -> None:
"""
创建数据范围
:param db: 数据库会话
:param obj: 创建数据范围参数
:return:
"""
await self.create_model(db, obj)
async def update(self, db: AsyncSession, pk: int, obj: UpdateDataScopeParam) -> int:
"""
更新数据范围
:param db: 数据库会话
:param pk: 范围 ID
:param obj: 更新数据范围参数
:return:
"""
return await self.update_model(db, pk, obj)
async def update_rules(self, db: AsyncSession, pk: int, rule_ids: UpdateDataScopeRuleParam) -> int:
"""
更新数据范围规则
:param db: 数据库会话
:param pk: 范围 ID
:param rule_ids: 数据规则 ID 列表
:return:
"""
current_data_scope = await self.get_with_relation(db, pk)
stmt = select(DataRule).where(DataRule.id.in_(rule_ids.rules))
rules = await db.execute(stmt)
current_data_scope.rules = rules.scalars().all()
return len(current_data_scope.rules)
async def delete(self, db: AsyncSession, pks: list[int]) -> int:
"""
批量删除数据范围
:param db: 数据库会话
:param pks: 范围 ID 列表
:return:
"""
return await self.delete_model_by_column(db, allow_multiple=True, id__in=pks)
data_scope_dao: CRUDDataScope = CRUDDataScope(DataScope)

View File

@@ -0,0 +1,46 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from sqlalchemy import select, func, desc
from datetime import datetime, date
from backend.app.admin.model.audit_log import DailySummary
from backend.app.admin.schema.audit_log import DailySummarySchema
class DailySummaryCRUD(CRUDPlus[DailySummary]):
""" Daily Summary CRUD """
async def get_user_daily_summaries(self, db: AsyncSession, user_id: int, page: int = 1, size: int = 20):
"""
获取用户的每日汇总记录,按创建时间降序排列
:param db: 数据库会话
:param user_id: 用户ID
:param page: 页码
:param size: 每页数量
:return: 查询结果和总数
"""
# 主查询:获取用户每日汇总记录
stmt = (
select(DailySummary)
.where(DailySummary.user_id == user_id)
.order_by(desc(DailySummary.created_time))
.offset((page - 1) * size)
.limit(size)
)
result = await db.execute(stmt)
items = result.scalars().all()
# 获取总数
count_stmt = (
select(func.count())
.where(DailySummary.user_id == user_id)
)
total_result = await db.execute(count_stmt)
total = total_result.scalar()
return items, total
daily_summary_dao: DailySummaryCRUD = DailySummaryCRUD(DailySummary)

View File

@@ -0,0 +1,42 @@
from typing import Optional, List
from sqlalchemy import select, and_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from backend.app.admin.model.dict import DictionaryEntry, DictionaryMedia
class DictCRUD:
async def get_by_word(self, db: AsyncSession, word: str) -> Optional[DictionaryEntry]:
"""根据单词查询字典条目"""
query = select(DictionaryEntry).where(DictionaryEntry.word == word)
result = await db.execute(query)
return result.scalars().first()
async def get_by_id(self, db: AsyncSession, entry_id: int) -> Optional[DictionaryEntry]:
"""根据ID获取字典条目"""
return await db.get(DictionaryEntry, entry_id)
async def get_media_by_dict_id(self, db: AsyncSession, dict_id: int) -> List[DictionaryMedia]:
"""根据字典条目ID获取相关媒体文件"""
query = select(DictionaryMedia).where(DictionaryMedia.dict_id == dict_id)
result = await db.execute(query)
return result.scalars().all()
async def get_media_by_filename(self, db: AsyncSession, filename: str) -> Optional[DictionaryMedia]:
"""根据文件名获取媒体文件"""
query = select(DictionaryMedia).where(DictionaryMedia.file_name == filename)
result = await db.execute(query)
return result.scalars().first()
async def search_words(self, db: AsyncSession, word_pattern: str, limit: int = 10) -> List[DictionaryEntry]:
"""模糊搜索单词"""
query = select(DictionaryEntry).where(
DictionaryEntry.word.ilike(f'%{word_pattern}%')
).limit(limit)
result = await db.execute(query)
return result.scalars().all()
dict_dao = DictCRUD()

View File

@@ -0,0 +1,119 @@
from typing import Optional, List
from sqlalchemy import select, func, and_
from sqlalchemy.ext.asyncio import AsyncSession
from backend.app.admin.model.feedback import Feedback
from backend.app.admin.schema.feedback import CreateFeedbackParam, UpdateFeedbackParam
class FeedbackCRUD:
async def get(self, db: AsyncSession, feedback_id: int) -> Optional[Feedback]:
"""根据ID获取反馈"""
return await db.get(Feedback, feedback_id)
async def get_list(
self,
db: AsyncSession,
user_id: Optional[int] = None,
status: Optional[str] = None,
category: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None
) -> List[Feedback]:
"""获取反馈列表"""
query = select(Feedback).order_by(Feedback.created_at.desc())
# 添加过滤条件
filters = []
if user_id is not None:
filters.append(Feedback.user_id == user_id)
if status is not None:
filters.append(Feedback.status == status)
if category is not None:
filters.append(Feedback.category == category)
if filters:
query = query.where(and_(*filters))
# 添加分页
if limit is not None:
query = query.limit(limit)
if offset is not None:
query = query.offset(offset)
result = await db.execute(query)
return result.scalars().all()
async def create(self, db: AsyncSession, user_id: int, obj_in: CreateFeedbackParam) -> Feedback:
"""创建反馈"""
db_obj = Feedback(
user_id=user_id,
content=obj_in.content,
contact_info=obj_in.contact_info,
category=obj_in.category.value if obj_in.category else None,
metadata_info=obj_in.metadata_info
)
db.add(db_obj)
await db.flush()
return db_obj
async def update(self, db: AsyncSession, feedback_id: int, obj_in: UpdateFeedbackParam) -> int:
"""更新反馈"""
query = select(Feedback).where(Feedback.id == feedback_id)
result = await db.execute(query)
db_obj = result.scalars().first()
if db_obj:
update_data = obj_in.model_dump(exclude_unset=True)
# 处理枚举类型
if 'category' in update_data and update_data['category']:
update_data['category'] = update_data['category'].value
if 'status' in update_data and update_data['status']:
update_data['status'] = update_data['status'].value
for field, value in update_data.items():
setattr(db_obj, field, value)
await db.flush()
return 1
return 0
async def delete(self, db: AsyncSession, feedback_id: int) -> int:
"""删除反馈"""
query = select(Feedback).where(Feedback.id == feedback_id)
result = await db.execute(query)
db_obj = result.scalars().first()
if db_obj:
await db.delete(db_obj)
await db.flush()
return 1
return 0
async def count(
self,
db: AsyncSession,
user_id: Optional[int] = None,
status: Optional[str] = None,
category: Optional[str] = None
) -> int:
"""统计反馈数量"""
query = select(func.count(Feedback.id))
# 添加过滤条件
filters = []
if user_id is not None:
filters.append(Feedback.user_id == user_id)
if status is not None:
filters.append(Feedback.status == status)
if category is not None:
filters.append(Feedback.category == category)
if filters:
query = query.where(and_(*filters))
result = await db.execute(query)
return result.scalar_one()
feedback_dao = FeedbackCRUD()

View File

@@ -0,0 +1,60 @@
from typing import Optional
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from backend.app.admin.model.file import File
from backend.app.admin.schema.file import AddFileParam, UpdateFileParam
class FileCRUD:
async def get(self, db: AsyncSession, file_id: int) -> Optional[File]:
"""根据ID获取文件"""
return await db.get(File, file_id)
async def get_by_hash(self, db: AsyncSession, file_hash: str) -> Optional[File]:
"""根据哈希值获取文件"""
query = select(File).where(File.file_hash == file_hash)
result = await db.execute(query)
return result.scalars().first()
async def create(self, db: AsyncSession, obj_in: AddFileParam) -> File:
"""创建文件记录"""
db_obj = File(**obj_in.model_dump())
db.add(db_obj)
await db.flush()
return db_obj
async def update(self, db: AsyncSession, file_id: int, obj_in: UpdateFileParam) -> int:
"""更新文件记录"""
query = select(File).where(File.id == file_id)
result = await db.execute(query)
db_obj = result.scalars().first()
if db_obj:
for field, value in obj_in.model_dump(exclude_unset=True).items():
setattr(db_obj, field, value)
await db.flush()
return 1
return 0
async def delete(self, db: AsyncSession, file_id: int) -> int:
"""删除文件记录"""
query = select(File).where(File.id == file_id)
result = await db.execute(query)
db_obj = result.scalars().first()
if db_obj:
await db.delete(db_obj)
await db.flush()
return 1
return 0
async def count(self, db: AsyncSession) -> int:
"""统计文件数量"""
query = select(func.count(File.id))
result = await db.execute(query)
return result.scalar_one()
file_dao = FileCRUD()

View File

@@ -0,0 +1,38 @@
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.admin.model.order import FreezeLog
class FreezeLogDao(CRUDPlus[FreezeLog]):
async def get_by_id(self, db: AsyncSession, freeze_id: int) -> Optional[FreezeLog]:
"""
根据ID获取冻结记录
"""
stmt = select(FreezeLog).where(FreezeLog.id == freeze_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_by_order_id(self, db: AsyncSession, order_id: int) -> Optional[FreezeLog]:
"""
根据订单ID获取冻结记录
"""
stmt = select(FreezeLog).where(FreezeLog.order_id == order_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_pending_by_user(self, db: AsyncSession, user_id: int) -> list[FreezeLog]:
"""
获取用户所有待处理的冻结记录
"""
stmt = select(FreezeLog).where(
FreezeLog.user_id == user_id,
FreezeLog.status == "pending"
)
result = await db.execute(stmt)
return result.scalars().all()
freeze_log_dao = FreezeLogDao(FreezeLog)

View File

@@ -0,0 +1,135 @@
from typing import Optional, List
from sqlalchemy import select, and_, update, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.admin.model.notification import Notification, UserNotification
from datetime import datetime
class NotificationDao(CRUDPlus[Notification]):
async def create_notification(self, db: AsyncSession, notification_data: dict) -> Notification:
"""
创建消息通知
"""
notification = Notification(**notification_data)
db.add(notification)
await db.flush()
return notification
async def get_active_notifications(self, db: AsyncSession, limit: int = 100) -> List[Notification]:
"""
获取激活的通知列表
"""
stmt = select(Notification).where(
Notification.is_active == True
).order_by(Notification.created_at.desc()).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def get_notification_by_id(self, db: AsyncSession, notification_id: int) -> Optional[Notification]:
"""
根据ID获取通知
"""
stmt = select(Notification).where(Notification.id == notification_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
class UserNotificationDao(CRUDPlus[UserNotification]):
async def create_user_notification(self, db: AsyncSession, user_notification_data: dict) -> UserNotification:
"""
创建用户通知关联
"""
user_notification = UserNotification(**user_notification_data)
db.add(user_notification)
await db.flush()
return user_notification
async def create_user_notifications(self, db: AsyncSession, user_notifications_data: List[dict]) -> List[UserNotification]:
"""
批量创建用户通知关联
"""
user_notifications = [UserNotification(**data) for data in user_notifications_data]
db.add_all(user_notifications)
await db.flush()
return user_notifications
async def get_user_notifications(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[UserNotification]:
"""
获取用户的通知列表
"""
stmt = select(UserNotification).where(
UserNotification.user_id == user_id
).order_by(UserNotification.received_at.desc()).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def get_unread_notifications(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[UserNotification]:
"""
获取用户未读通知列表
"""
stmt = select(UserNotification).where(
and_(
UserNotification.user_id == user_id,
UserNotification.is_read == False
)
).order_by(UserNotification.received_at.desc()).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def mark_as_read(self, db: AsyncSession, user_notification_id: int) -> bool:
"""
标记通知为已读
"""
stmt = update(UserNotification).where(
UserNotification.id == user_notification_id
).values(
is_read=True,
read_at=datetime.now()
)
result = await db.execute(stmt)
return result.rowcount > 0
async def mark_multiple_as_read(self, db: AsyncSession, user_id: int, notification_ids: List[int]) -> int:
"""
批量标记通知为已读
"""
stmt = update(UserNotification).where(
and_(
UserNotification.user_id == user_id,
UserNotification.notification_id.in_(notification_ids),
UserNotification.is_read == False
)
).values(
is_read=True,
read_at=datetime.now()
)
result = await db.execute(stmt)
return result.rowcount
async def get_unread_count(self, db: AsyncSession, user_id: int) -> int:
"""
获取用户未读通知数量
"""
stmt = select(func.count(UserNotification.id)).where(
and_(
UserNotification.user_id == user_id,
UserNotification.is_read == False
)
)
result = await db.execute(stmt)
return result.scalar() or 0
async def get_user_notification_by_id(self, db: AsyncSession, user_notification_id: int) -> Optional[UserNotification]:
"""
根据ID获取用户通知关联记录
"""
stmt = select(UserNotification).where(UserNotification.id == user_notification_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
notification_dao = NotificationDao(Notification)
user_notification_dao = UserNotificationDao(UserNotification)

View File

@@ -0,0 +1,94 @@
from typing import Optional, List
from sqlalchemy import select, and_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.admin.model.order import Order
class OrderDao(CRUDPlus[Order]):
async def get_by_id(self, db: AsyncSession, order_id: int) -> Optional[Order]:
"""
根据ID获取订单
"""
stmt = select(Order).where(Order.id == order_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_by_user_id(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[Order]:
"""
根据用户ID获取订单列表
"""
stmt = select(Order).where(
Order.user_id == user_id
).order_by(Order.created_at.desc()).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def get_pending_orders(self, db: AsyncSession, user_id: int) -> List[Order]:
"""
获取用户所有待处理订单
"""
stmt = select(Order).where(
and_(
Order.user_id == user_id,
Order.status == "pending"
)
)
result = await db.execute(stmt)
return result.scalars().all()
async def get_completed_orders(self, db: AsyncSession, user_id: int, limit: int = 50) -> List[Order]:
"""
获取用户已完成的订单
"""
stmt = select(Order).where(
and_(
Order.user_id == user_id,
Order.status == "completed"
)
).order_by(Order.created_at.desc()).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def get_order_by_payment_id(self, db: AsyncSession, payment_id: str) -> Optional[Order]:
"""
根据支付ID获取订单
"""
stmt = select(Order).where(Order.payment_id == payment_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_order_by_transaction_id(self, db: AsyncSession, transaction_id: str) -> Optional[Order]:
"""
根据交易ID获取订单
"""
stmt = select(Order).where(Order.transaction_id == transaction_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def update_order_status(self, db: AsyncSession, order_id: int, status: str) -> bool:
"""
更新订单状态
"""
from sqlalchemy import update
stmt = update(Order).where(Order.id == order_id).values(status=status)
result = await db.execute(stmt)
return result.rowcount > 0
async def get_subscription_orders(self, db: AsyncSession, user_id: int) -> List[Order]:
"""
获取用户所有订阅订单
"""
stmt = select(Order).where(
and_(
Order.user_id == user_id,
Order.order_type == "subscription"
)
).order_by(Order.created_at.desc())
result = await db.execute(stmt)
return result.scalars().all()
order_dao = OrderDao(Order)

View File

@@ -0,0 +1,103 @@
from datetime import datetime
from typing import Optional, Dict, Any
from sqlalchemy import select, update, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.admin.model.points import Points, PointsLog
class PointsDao(CRUDPlus[Points]):
async def get_by_user_id(self, db: AsyncSession, user_id: int) -> Optional[Points]:
"""
根据用户ID获取积分账户信息
"""
stmt = select(Points).where(Points.user_id == user_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def create_user_points(self, db: AsyncSession, user_id: int) -> Points:
"""
为用户创建积分账户
"""
points = Points(user_id=user_id)
db.add(points)
await db.flush()
return points
async def add_points_atomic(self, db: AsyncSession, user_id: int, amount: int, extend_expiration: bool = False) -> bool:
"""
原子性增加用户积分
"""
# 先确保用户有积分账户
points_account = await self.get_by_user_id(db, user_id)
if not points_account:
points_account = await self.create_user_points(db, user_id)
# 准备更新值
update_values = {
"balance": Points.balance + amount,
"total_earned": Points.total_earned + amount
}
# 如果需要延期,则更新过期时间
if extend_expiration:
update_values["expired_time"] = datetime.now() + timedelta(days=30)
stmt = update(Points).where(
Points.user_id == user_id
).values(**update_values)
result = await db.execute(stmt)
return result.rowcount > 0
async def deduct_points_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
"""
原子性扣减用户积分(确保不超扣)
"""
stmt = update(Points).where(
Points.user_id == user_id,
Points.balance >= amount
).values(
balance=Points.balance - amount,
total_spent=Points.total_spent + amount
)
result = await db.execute(stmt)
return result.rowcount > 0
async def get_balance(self, db: AsyncSession, user_id: int) -> int:
"""
获取用户积分余额
"""
points_account = await self.get_by_user_id(db, user_id)
if not points_account:
return 0
return points_account.balance
async def check_and_clear_expired_points(self, db: AsyncSession, user_id: int) -> bool:
"""
检查并清空过期积分
"""
stmt = update(Points).where(
Points.user_id == user_id,
Points.expired_time < datetime.now(),
Points.balance > 0
).values(
balance=0,
total_spent=Points.total_spent + Points.balance
)
result = await db.execute(stmt)
return result.rowcount > 0
class PointsLogDao(CRUDPlus[PointsLog]):
async def add_log(self, db: AsyncSession, log_data: Dict[str, Any]) -> PointsLog:
"""
添加积分变动日志
"""
log = PointsLog(**log_data)
db.add(log)
await db.flush()
return log
points_dao = PointsDao(Points)
points_log_dao = PointsLogDao(PointsLog)

View File

@@ -0,0 +1,61 @@
from typing import Optional, List, Dict, Any
from sqlalchemy import select, and_
from sqlalchemy.ext.asyncio import AsyncSession
from backend.app.admin.model.order import UsageLog
from sqlalchemy_crud_plus import CRUDPlus
class UsageLogDao(CRUDPlus[UsageLog]):
async def get_by_id(self, db: AsyncSession, log_id: int) -> Optional[UsageLog]:
"""
根据ID获取使用日志
"""
stmt = select(UsageLog).where(UsageLog.id == log_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_by_user_id(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[UsageLog]:
"""
根据用户ID获取使用日志列表
"""
stmt = select(UsageLog).where(
UsageLog.user_id == user_id
).order_by(UsageLog.created_at.desc()).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def get_by_action(self, db: AsyncSession, user_id: int, action: str, limit: int = 50) -> List[UsageLog]:
"""
根据动作类型获取使用日志
"""
stmt = select(UsageLog).where(
and_(
UsageLog.user_id == user_id,
UsageLog.action == action
)
).order_by(UsageLog.created_at.desc()).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def get_balance_history(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[UsageLog]:
"""
获取用户余额变动历史
"""
stmt = select(UsageLog).where(
UsageLog.user_id == user_id
).order_by(UsageLog.created_at.desc()).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
async def add_log(self, db: AsyncSession, log_data: Dict[str, Any]) -> UsageLog:
"""
添加使用日志
"""
log = UsageLog(**log_data)
db.add(log)
await db.flush()
return log
usage_log_dao = UsageLogDao(UsageLog)

View File

@@ -0,0 +1,105 @@
from datetime import datetime, timedelta
from typing import Optional
from sqlalchemy import select, update, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.admin.model.order import UserAccount, FreezeLog
class UserAccountDao(CRUDPlus[UserAccount]):
async def get_by_user_id(self, db: AsyncSession, user_id: int) -> Optional[UserAccount]:
"""
根据用户ID获取账户信息
"""
stmt = select(UserAccount).where(UserAccount.user_id == user_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_by_id(self, db: AsyncSession, account_id: int) -> Optional[UserAccount]:
"""
根据账户ID获取账户信息
"""
stmt = select(UserAccount).where(UserAccount.id == account_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def create_new_user_account(self, db: AsyncSession, user_id: int) -> UserAccount:
"""
为新用户创建账户(包含免费试用)
"""
# 设置免费试用期3天
trial_expires_at = datetime.now() + timedelta(days=3)
account = UserAccount(
user_id=user_id,
balance=30, # 初始30次免费次数
free_trial_balance=30,
free_trial_expires_at=trial_expires_at,
free_trial_used=True # 标记为已使用因为已经给了30次
)
db.add(account)
await db.flush()
return account
async def update_balance_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
"""
原子性更新用户余额
"""
stmt = update(UserAccount).where(
UserAccount.user_id == user_id
).values(
balance=UserAccount.balance + amount
)
result = await db.execute(stmt)
return result.rowcount > 0
async def deduct_balance_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
"""
原子性扣减用户余额(确保不超扣)
"""
stmt = update(UserAccount).where(
UserAccount.user_id == user_id,
UserAccount.balance >= amount
).values(
balance=UserAccount.balance - amount
)
result = await db.execute(stmt)
return result.rowcount > 0
async def get_frozen_balance(self, db: AsyncSession, user_id: int) -> int:
"""
获取用户被冻结的次数
"""
stmt = select(func.sum(FreezeLog.amount)).where(
FreezeLog.user_id == user_id,
FreezeLog.status == "pending"
)
result = await db.execute(stmt)
return result.scalar() or 0
async def get_available_balance(self, db: AsyncSession, user_id: int) -> int:
"""
获取用户可用余额(总余额减去冻结余额)
"""
account = await self.get_by_user_id(db, user_id)
if not account:
return 0
frozen_balance = await self.get_frozen_balance(db, user_id)
return max(0, account.balance - frozen_balance)
async def check_free_trial_valid(self, db: AsyncSession, user_id: int) -> bool:
"""
检查用户免费试用是否仍然有效
"""
account = await self.get_by_user_id(db, user_id)
if not account or not account.free_trial_expires_at:
return False
return (account.free_trial_expires_at > datetime.now() and
account.free_trial_balance > 0)
user_account_dao = UserAccountDao(UserAccount)

View File

@@ -0,0 +1,201 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime
import bcrypt
from sqlalchemy import select, update, desc, and_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import Select
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.admin.model import WxUser
from backend.app.admin.schema.user import RegisterUserParam, UpdateUserParam, AvatarParam
from backend.common.security.jwt import get_hash_password
from backend.utils.wx_pay import wx_pay_utils
class WxUserCRUD(CRUDPlus[WxUser]):
async def get(self, db: AsyncSession, user_id: int) -> WxUser | None:
"""
获取用户
:param db:
:param user_id:
:return:
"""
return await self.select_model(db, user_id)
async def get_by_username(self, db: AsyncSession, username: str) -> WxUser | None:
"""
通过 username 获取用户
:param db:
:param username:
:return:
"""
return await self.select_model_by_column(db, username=username)
async def get_by_openid(self, db: AsyncSession, openid: str) -> WxUser | None:
"""
通过 username 获取用户
:param db:
:param openid:
:return:
"""
return await self.select_model_by_column(db, openid=openid)
async def add(self, db: AsyncSession, new_user: WxUser) -> None:
"""
通过 openid 添加用户
:param db:
:param openid:
:return:
"""
db.add(new_user)
async def update_session_key(self, db: AsyncSession, input_user: int, session_key: str) -> int:
"""
更新用户头像
:param db:
:param input_user:
:param avatar:
:return:
"""
return await self.update_model(db, input_user, {'session_key': session_key})
async def update_user_profile(self, db: AsyncSession, user_id: int, profile: dict) -> int:
"""
更新用户资料并自动更新updated_time字段
:param db: 数据库会话
:param user_id: 用户ID
:param profile: 用户资料
:return: 更新的记录数
"""
return await self.update_model(db, user_id, {'profile': profile})
async def update_login_time(self, db: AsyncSession, username: str, login_time: datetime) -> int:
user = await db.execute(
update(self.model).where(self.model.username == username).values(last_login_time=login_time)
)
return user.rowcount
async def create(self, db: AsyncSession, obj: RegisterUserParam) -> None:
"""
创建用户
:param db:
:param obj:
:return:
"""
salt = bcrypt.gensalt()
obj.password = get_hash_password(obj.password, salt)
dict_obj = obj.model_dump()
dict_obj.update({'salt': salt})
new_user = self.model(**dict_obj)
db.add(new_user)
async def update_userinfo(self, db: AsyncSession, input_user: int, obj: UpdateUserParam) -> int:
"""
更新用户信息
:param db:
:param input_user:
:param obj:
:return:
"""
return await self.update_model(db, input_user, obj)
async def update_avatar(self, db: AsyncSession, input_user: int, avatar: AvatarParam) -> int:
"""
更新用户头像
:param db:
:param input_user:
:param avatar:
:return:
"""
return await self.update_model(db, input_user, {'avatar': avatar.url})
async def delete(self, db: AsyncSession, user_id: int) -> int:
"""
删除用户
:param db:
:param user_id:
:return:
"""
return await self.delete_model(db, user_id)
async def check_email(self, db: AsyncSession, email: str) -> WxUser:
"""
检查邮箱是否存在
:param db:
:param email:
:return:
"""
return await self.select_model_by_column(db, email=email)
async def reset_password(self, db: AsyncSession, pk: int, new_pwd: str) -> int:
"""
重置用户密码
:param db:
:param pk:
:param new_pwd:
:return:
"""
return await self.update_model(db, pk, {'password': new_pwd})
async def get_list(self, username: str = None, phone: str = None, status: int = None) -> Select:
"""
获取用户列表
:param username:
:param phone:
:param status:
:return:
"""
stmt = select(self.model).order_by(desc(self.model.join_time))
filters = []
if username:
filters.append(self.model.username.like(f'%{username}%'))
if phone:
filters.append(self.model.phone.like(f'%{phone}%'))
if status is not None:
filters.append(self.model.status == status)
if filters:
stmt = stmt.where(and_(*filters))
return stmt
async def get_with_relation(
self, db: AsyncSession, *, user_id: int | None = None, openid: str | None = None
) -> WxUser | None:
"""
获取用户关联信息
:param db: 数据库会话
:param user_id: 用户 ID
:param username: 用户名
:return:
"""
filters = {}
if user_id:
filters['id'] = user_id
if openid:
filters['openid'] = openid
return await self.select_model_by_column(
db,
**filters,
)
wx_user_dao: WxUserCRUD = WxUserCRUD(WxUser)

View File

@@ -0,0 +1,12 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from backend.common.model import MappedBase # noqa: I
from backend.app.admin.model.wx_user import WxUser
from backend.app.admin.model.audit_log import AuditLog, DailySummary
from backend.app.admin.model.file import File
from backend.app.admin.model.dict import DictionaryEntry, DictionaryMedia
from backend.app.admin.model.order import Order, UserAccount, FreezeLog, UsageLog
from backend.app.admin.model.coupon import Coupon, CouponUsage
from backend.app.admin.model.notification import Notification, UserNotification
from backend.app.admin.model.points import Points, PointsLog
from backend.app.ai.model import Image

View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional, List
from sqlalchemy import Integer, BigInteger, Text, String, Numeric, Float, DateTime, ForeignKey, Index
from sqlalchemy.dialects.postgresql import JSONB, ARRAY
from sqlalchemy.orm import Mapped, mapped_column
from backend.common.model import snowflake_id_key, Base
class AuditLog(Base):
__tablename__ = 'audit_log'
id: Mapped[snowflake_id_key] = mapped_column(init=False, primary_key=True)
api_type: Mapped[str] = mapped_column(String(20), nullable=False, comment="API类型: recognition embedding assessment")
model_name: Mapped[str] = mapped_column(String(50), nullable=False, comment="模型名称")
request_data: Mapped[Optional[dict]] = mapped_column(JSONB, comment="请求数据")
response_data: Mapped[Optional[dict]] = mapped_column(JSONB, comment="响应数据")
token_usage: Mapped[Optional[dict]] = mapped_column(JSONB, comment="消耗的token数量")
cost: Mapped[Optional[float]] = mapped_column(Numeric(10, 5), comment="API调用成本")
duration: Mapped[Optional[float]] = mapped_column(Float, comment="调用耗时(秒)")
status_code: Mapped[Optional[int]] = mapped_column(Integer, comment="HTTP状态码")
image_id: Mapped[Optional[int]] = mapped_column(BigInteger, ForeignKey('image.id'), comment="关联的图片ID")
user_id: Mapped[Optional[int]] = mapped_column(BigInteger, ForeignKey('wx_user.id'), comment="调用用户ID")
dict_level: Mapped[Optional[str]] = mapped_column(String(20), comment="dict level")
api_version: Mapped[Optional[str]] = mapped_column(String(20), comment="API版本")
error_message: Mapped[Optional[str]] = mapped_column(Text, default=None, comment="错误信息")
called_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment="调用时间")
# 索引优化
__table_args__ = (
Index('idx_audit_logs_image_id', 'image_id'),
# 为用户历史记录查询优化的索引
Index('idx_audit_log_user_api_called', 'user_id', 'api_type', 'called_at'),
)
class DailySummary(Base):
__tablename__ = 'daily_summary'
id: Mapped[snowflake_id_key] = mapped_column(init=False, primary_key=True)
user_id: Mapped[Optional[int]] = mapped_column(BigInteger, ForeignKey('wx_user.id'), comment="调用用户ID")
image_ids: Mapped[List[str]] = mapped_column(ARRAY(Text), default=None, comment="图片ID列表")
thumbnail_ids: Mapped[List[str]] = mapped_column(ARRAY(Text), default=None, comment="图片缩略图列表")
summary_time: Mapped[datetime] = mapped_column(DateTime, default=None, comment="总结的时间")
# 索引优化
__table_args__ = (
# 为用户历史记录查询优化的索引
Index('idx_daily_summary_api_called', 'user_id', 'summary_time'),
)

View File

@@ -0,0 +1,43 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from sqlalchemy import String, Column, BigInteger, ForeignKey, Boolean, DateTime, Index, Text
from sqlalchemy.orm import Mapped, mapped_column
from backend.common.model import snowflake_id_key, Base
class Coupon(Base):
__tablename__ = 'coupon'
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
code: Mapped[str] = mapped_column(String(32), unique=True, nullable=False, comment='兑换码')
duration: Mapped[int] = mapped_column(BigInteger, nullable=False, comment='兑换时长(分钟)')
is_used: Mapped[bool] = mapped_column(Boolean, default=False, comment='是否已使用')
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='创建时间')
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='过期时间')
created_by: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, comment='创建者ID')
__table_args__ = (
Index('idx_coupon_code', 'code'),
Index('idx_coupon_is_used', 'is_used'),
{'comment': '兑换券表'}
)
class CouponUsage(Base):
__tablename__ = 'coupon_usage'
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
coupon_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('coupon.id'), nullable=False, comment='兑换券ID')
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='使用者ID')
duration: Mapped[int] = mapped_column(BigInteger, nullable=False, comment='兑换时长(分钟)')
used_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='使用时间')
__table_args__ = (
Index('idx_coupon_usage_user', 'user_id'),
Index('idx_coupon_usage_coupon', 'coupon_id'),
{'comment': '兑换券使用记录表'}
)

42
backend/app/admin/model/dict.py Executable file
View File

@@ -0,0 +1,42 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional
from sqlalchemy import String, Column, LargeBinary, ForeignKey, BigInteger, Index, func, JSON, Text, Numeric
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from backend.app.admin.schema.dict import WordMetaData
from backend.app.admin.schema.pydantic_type import PydanticType
from backend.common.model import snowflake_id_key, DataClassBase
class DictionaryEntry(DataClassBase):
__tablename__ = 'dict_entry'
id: Mapped[int] = mapped_column(primary_key=True, init=True, autoincrement=True)
word: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
definition: Mapped[Optional[str]] = mapped_column(Text, default=None)
details: Mapped[Optional[WordMetaData]] = mapped_column(PydanticType(pydantic_type=WordMetaData), default=None) # 其他可能的字段(根据实际需求添加)
__table_args__ = (
Index('idx_dict_word', word),
)
class DictionaryMedia(DataClassBase):
__tablename__ = 'dict_media'
id: Mapped[int] = mapped_column(primary_key=True, init=True, autoincrement=True)
file_name: Mapped[str] = mapped_column(String(255), nullable=False)
file_type: Mapped[str] = mapped_column(String(50), nullable=False) # 'audio', 'image'
dict_id: Mapped[Optional[int]] = mapped_column(BigInteger, ForeignKey("dict_entry.id"), default=None)
file_data: Mapped[Optional[bytes]] = mapped_column(LargeBinary, default=None)
file_hash: Mapped[Optional[str]] = mapped_column(String(64), default=None)
details: Mapped[Optional[dict]] = mapped_column(JSONB(astext_type=Text()), default=None, comment="其他信息") # 其他信息
__table_args__ = (
Index('idx_media_filename', file_name),
Index('idx_media_dict_id', dict_id),
)

View File

@@ -0,0 +1,29 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional, List
from sqlalchemy import String, Text, DateTime, ForeignKey, Index, BigInteger
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship
from backend.common.model import snowflake_id_key, Base
class Feedback(Base):
__tablename__ = 'feedback'
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='用户ID')
content: Mapped[str] = mapped_column(Text, nullable=False, comment='反馈内容')
contact_info: Mapped[Optional[str]] = mapped_column(String(255), nullable=True, comment='联系方式')
category: Mapped[Optional[str]] = mapped_column(String(50), nullable=True, comment='反馈分类')
status: Mapped[str] = mapped_column(String(20), default='pending', comment='处理状态: pending, processing, resolved')
# 索引优化
__table_args__ = (
Index('idx_feedback_user_id', 'user_id'),
Index('idx_feedback_status', 'status'),
Index('idx_feedback_category', 'category'),
Index('idx_feedback_created_at', 'created_time'),
)

28
backend/app/admin/model/file.py Executable file
View File

@@ -0,0 +1,28 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional
from sqlalchemy import BigInteger, Text, String, Index, DateTime, LargeBinary
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import mapped_column, Mapped
from backend.common.model import snowflake_id_key, Base
class File(Base):
__tablename__ = 'file'
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
file_hash: Mapped[str] = mapped_column(String(64), index=True, nullable=False) # SHA256哈希
file_name: Mapped[str] = mapped_column(String(255), nullable=False) # 原始文件名
content_type: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) # MIME类型
file_size: Mapped[int] = mapped_column(BigInteger, nullable=False) # 文件大小(字节)
storage_path: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 存储路径(非数据库存储时使用)
file_data: Mapped[Optional[bytes]] = mapped_column(LargeBinary, default=None, nullable=True) # 文件二进制数据(数据库存储时使用)
storage_type: Mapped[str] = mapped_column(String(20), nullable=False, default='database') # 存储类型: database, local, s3
metadata_info: Mapped[Optional[dict]] = mapped_column(JSONB(astext_type=Text()), default=None, comment="元数据信息")
# 表参数 - 包含所有必要的约束
__table_args__ = (
Index('idx_file_hash', file_hash),
)

View File

@@ -0,0 +1,45 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from sqlalchemy import String, Column, BigInteger, ForeignKey, Boolean, DateTime, Index, Text
from sqlalchemy.orm import Mapped, mapped_column
from backend.common.model import snowflake_id_key, Base
class Notification(Base):
__tablename__ = 'notification'
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
title: Mapped[str] = mapped_column(String(255), nullable=False, comment='通知标题')
content: Mapped[str] = mapped_column(Text, nullable=False, comment='通知内容')
image_url: Mapped[Optional[str]] = mapped_column(String(512), default=None, comment='图片URL预留')
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='创建时间')
created_by: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, comment='创建者ID')
is_active: Mapped[bool] = mapped_column(Boolean, default=True, comment='是否激活')
__table_args__ = (
Index('idx_notification_created', 'created_at'),
Index('idx_notification_active', 'is_active'),
{'comment': '消息通知表'}
)
class UserNotification(Base):
__tablename__ = 'user_notification'
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
notification_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('notification.id'), nullable=False, comment='通知ID')
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='用户ID')
is_read: Mapped[bool] = mapped_column(Boolean, default=False, comment='是否已读')
read_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='阅读时间')
received_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='接收时间')
__table_args__ = (
Index('idx_user_notification_user', 'user_id'),
Index('idx_user_notification_notification', 'notification_id'),
Index('idx_user_notification_read', 'is_read'),
{'comment': '用户通知关联表'}
)

View File

@@ -0,0 +1,79 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from sqlalchemy import String, Column, BigInteger, ForeignKey, Boolean, DateTime, Index, func, JSON, Text, Numeric
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from backend.common.model import snowflake_id_key, Base
class Order(Base):
__tablename__ = 'order'
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False)
order_type: Mapped[str] = mapped_column(String(32), comment='类型purchase/subscription/extra')
payment_id: Mapped[Optional[str]] = mapped_column(String(64), comment='微信支付ID')
transaction_id: Mapped[Optional[str]] = mapped_column(String(64), comment='微信交易号')
amount_cents: Mapped[int] = mapped_column(BigInteger, comment='金额(分)')
amount_times: Mapped[int] = mapped_column(BigInteger, comment='实际获得次数')
status: Mapped[str] = mapped_column(String(16), default='pending', comment='订单状态')
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='过期时间(用于订阅)')
processed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='处理时间(用于幂等性)')
__table_args__ = (
Index('idx_order_user_status', 'user_id', 'status'),
Index('idx_order_processed', 'processed_at'),
)
class UserAccount(Base):
__tablename__ = 'user_account'
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), unique=True, nullable=False, comment='关联的用户ID')
balance: Mapped[int] = mapped_column(BigInteger, default=0, comment='当前可用次数')
total_purchased: Mapped[int] = mapped_column(BigInteger, default=0, comment='累计购买次数')
subscription_type: Mapped[Optional[str]] = mapped_column(String(32), default=None, comment='订阅类型monthly/quarterly/half_yearly/yearly')
subscription_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='订阅到期时间')
carryover_balance: Mapped[int] = mapped_column(BigInteger, default=0, comment='上期未使用的次数')
# 新用户免费次数相关
free_trial_balance: Mapped[int] = mapped_column(BigInteger, default=30, comment='新用户免费试用次数')
free_trial_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='免费试用期结束时间')
free_trial_used: Mapped[bool] = mapped_column(Boolean, default=False, comment='是否已使用免费试用')
class FreezeLog(Base):
__tablename__ = 'freeze_log'
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False)
order_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('order.id'), nullable=False)
amount: Mapped[int] = mapped_column(BigInteger, comment='冻结次数')
reason: Mapped[Optional[str]] = mapped_column(Text, comment='冻结原因')
status: Mapped[str] = mapped_column(String(16), default='pending', comment='状态pending/confirmed/cancelled')
__table_args__ = (
Index('idx_freeze_user_status', 'user_id', 'status'),
{'comment': '次数冻结记录表'}
)
class UsageLog(Base):
__tablename__ = 'usage_log'
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='用户ID')
action: Mapped[str] = mapped_column(String(32), comment='动作purchase/renewal/use/carryover/share/ad/freeze/unfreeze/refund')
amount: Mapped[int] = mapped_column(BigInteger, comment='变动数量')
balance_after: Mapped[int] = mapped_column(BigInteger, comment='变动后余额')
related_id: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, comment='关联ID如订单ID、冻结记录ID')
details: Mapped[Optional[dict]] = mapped_column(JSONB, default=None, comment='附加信息')
__table_args__ = (
Index('idx_usage_user_action', 'user_id', 'action'),
Index('idx_usage_user_time', 'user_id', 'created_time'),
Index('idx_usage_action_time', 'action', 'created_time'),
{'comment': '使用日志表'}
)

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime, timedelta
from typing import Optional
from sqlalchemy import String, Column, BigInteger, ForeignKey, DateTime, Index, Text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from backend.common.model import snowflake_id_key, Base
class Points(Base):
__tablename__ = 'points'
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), unique=True, nullable=False, comment='关联的用户ID')
balance: Mapped[int] = mapped_column(BigInteger, default=0, comment='当前积分余额')
total_earned: Mapped[int] = mapped_column(BigInteger, default=0, comment='累计获得积分')
total_spent: Mapped[int] = mapped_column(BigInteger, default=0, comment='累计消费积分')
expired_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now() + timedelta(days=30), comment="过期时间")
# 索引优化
__table_args__ = (
Index('idx_points_user', 'user_id'),
{'comment': '用户积分表'}
)
class PointsLog(Base):
__tablename__ = 'points_log'
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='用户ID')
action: Mapped[str] = mapped_column(String(32), comment='动作earn/spend')
amount: Mapped[int] = mapped_column(BigInteger, comment='变动数量')
balance_after: Mapped[int] = mapped_column(BigInteger, comment='变动后余额')
related_id: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, comment='关联ID')
details: Mapped[Optional[dict]] = mapped_column(JSONB, default=None, comment='附加信息')
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='创建时间')
# 索引优化
__table_args__ = (
Index('idx_points_log_user_action', 'user_id', 'action'),
Index('idx_points_log_user_time', 'user_id', 'created_at'),
{'comment': '积分变动日志表'}
)

View File

@@ -0,0 +1,36 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional
from sqlalchemy import String, Column, BigInteger, SmallInteger, Boolean, DateTime, Index, func, JSON, Text, Numeric
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from backend.common.model import snowflake_id_key, Base
class WxUser(Base):
__tablename__ = 'wx_user'
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
openid: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, comment='微信OpenID')
session_key: Mapped[str] = mapped_column(String(128), nullable=False, comment='会话密钥')
unionid: Mapped[Optional[str]] = mapped_column(String(64), default=None, index=True, comment='微信UnionID')
mobile: Mapped[Optional[str]] = mapped_column(String(15), default=None, index=True, comment='加密手机号')
profile: Mapped[Optional[dict]] = mapped_column(JSONB(astext_type=Text()), default=None, comment='用户资料JSON')
# class WxPayment(Base):
# __tablename__ = 'wx_payments'
#
# id: Mapped[snowflake_id_key] = mapped_column(init=False, primary_key=True)
# user_id: Mapped[int] = mapped_column(BigInteger, index=True, nullable=False)
# prepay_id: Mapped[str] = mapped_column(String(64), nullable=False, comment='预支付ID')
# transaction_id: Mapped[Optional[str]] = mapped_column(String(32), comment='微信支付单号')
# amount: Mapped[float] = mapped_column(Numeric, nullable=False, comment='分单位金额')
# status: Mapped[str] = mapped_column(String(16), default='pending', comment='支付状态')
#
# __table_args__ = (
# Index('idx_payment_user', 'user_id', 'status'),
# {'comment': '支付记录表'}
# )

View File

@@ -0,0 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from backend.app.admin.schema.file import FileSchemaBase, AddFileParam, UpdateFileParam, FileInfoSchema, FileUploadResponse
from backend.app.admin.schema.audit_log import DailySummarySchema, DailySummaryPageSchema
from backend.app.admin.schema.points import PointsSchema, PointsLogSchema, AddPointsRequest, DeductPointsRequest

View File

@@ -0,0 +1,69 @@
from datetime import datetime
from typing import Optional, List
from pydantic import ConfigDict, Field
from backend.common.schema import SchemaBase
class AuditLogSchemaBase(SchemaBase):
""" Audit Log Schema """
api_type: str = Field(description="API类型: recognition/embedding/assessment")
model_name: str = Field(description="模型名称")
request_data: Optional[dict] = Field(None, description="请求数据")
response_data: Optional[dict] = Field(None, description="响应数据")
token_usage: Optional[dict] = Field(0, description="消耗的token数量")
cost: Optional[float] = Field(0.0, description="API调用成本")
duration: float = Field(description="调用耗时(秒)")
status_code: int = Field(description="HTTP状态码")
error_message: Optional[str] = Field("", description="错误信息")
called_at: Optional[datetime] = Field(None, description="调用时间")
image_id: int = Field(description="关联的图片ID")
user_id: int = Field(description="调用用户ID")
api_version: str = Field(description="API版本")
dict_level: Optional[str] = Field(None, description="词典等级")
class CreateAuditLogParam(AuditLogSchemaBase):
"""创建操作日志参数"""
class AuditLogHistorySchema(SchemaBase):
""" Audit Log History Schema for user history records """
image_id: Optional[str] = Field(None, description="图ID")
file_id: Optional[str] = Field(None, description="原图ID")
thumbnail_id: Optional[str] = Field(None, description="缩略图ID")
created_time: Optional[str] = Field(description="图片创建时间")
dict_level: Optional[str] = Field(None, description="词典等级")
class AuditLogStatisticsSchema(SchemaBase):
""" Audit Log Statistics Schema """
total_count: int = Field(description="历史总量")
today_count: int = Field(description="当天总量")
image_count: int = Field(description="图片总量")
class CreateDailySummaryParam(SchemaBase):
"""创建每日总结参数"""
user_id: int = Field(description="调用用户ID")
image_ids: Optional[List[str]] = Field(None, description="图ID")
thumbnail_ids: Optional[List[str]] = Field(None, description="缩略图ID")
summary_time: Optional[datetime] = Field(None, description="调用时间")
class DailySummarySchema(SchemaBase):
""" Daily Summary Schema """
# id: int = Field(description="记录ID")
# user_id: int = Field(description="用户ID")
image_ids: List[str] = Field(description="图片ID列表")
thumbnail_ids: List[str] = Field(description="图片缩略图列表")
summary_time: str = Field(description="创建时间")
class DailySummaryPageSchema(SchemaBase):
""" Daily Summary Page Schema """
items: List[DailySummarySchema] = Field(description="每日汇总记录列表")
total: int = Field(description="总记录数")
page: int = Field(description="当前页码")
size: int = Field(description="每页记录数")

View File

@@ -0,0 +1,10 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pydantic import Field
from backend.common.schema import SchemaBase
class GetCaptchaDetail(SchemaBase):
image_type: str = Field(description='图片类型')
image: str = Field(description='图片内容')

198
backend/app/admin/schema/dict.py Executable file
View File

@@ -0,0 +1,198 @@
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any, List, Union
# --- 辅助模型 ---
class CrossReference(BaseModel):
# 交叉引用的文本,通常是另一个词条名
text: Optional[str] = None
sense_id: Optional[str] = None
# 链接到另一个词条的 href
entry_href: Optional[str] = None # 例如: "entry://cooking-apple"
# 如果是图片相关的交叉引用,可能包含图片信息
# 图片的 ID (来自 a 标签的 showid 属性)
show_id: Optional[str] = None # 例如: "img3241"
# image_filename: 指向 dict_media 表中的 file_name
image_filename: Optional[str] = None # 例如: "img3241_ldoce4188jpg" (由 showid 和 base64 属性组合)
# 图片的标题/描述 (来自 a 标签的 title 属性)
image_title: Optional[str] = None # 例如: "apple from LDOCE 4"
# LDOCE 版本信息 (如果适用)
ldoce_version: Optional[str] = None # 例如: "LDOCEVERSION_5", "LDOCEVERSION_new"
class FamilyItem(BaseModel):
text: Optional[str] = None
# 链接词用 href 存储,非链接词 href 为 None
href: Optional[str] = None # 例如: "bword://underworld"
class WordFamily(BaseModel):
pos: Optional[str] = None # 词性,如 "noun", "adjective"
items: List[FamilyItem] = [] # 该词性下的相关词项
class Pronunciation(BaseModel):
# IPA 音标
uk_ipa: Optional[str] = None # 例如: "ˈæpəl"
us_ipa: Optional[str] = None # 例如: "ˈæpəl" 或 "wɜːrld"
# 音频文件路径 (相对或绝对)
uk_audio: Optional[str] = None # 例如: "/media/english/breProns/brelasdeapple.mp3"
us_audio: Optional[str] = None # 例如: "/media/english/ameProns/apple1.mp3"
# title
uk_audio_title: Optional[str] = None # 例如: "Play British pronunciation of plane"
us_audio_title: Optional[str] = None # 例如: "Play American pronunciation of plane"
class Frequency(BaseModel):
level: Optional[str] = None # 例如: "Core vocabulary: High-frequency"
spoken: Optional[str] = None # 例如: "Top 2000 spoken words"
written: Optional[str] = None # 例如: "Top 3000 written words"
level_tag: Optional[str] = None # 例如: "●●●"
spoken_tag: Optional[str] = None # 例如: "S2"
written_tag: Optional[str] = None # 例如: "W2"
# --- 核心模型 ---
class Example(BaseModel):
# 英文例句 (包含可能的高亮或链接,但这里简化为纯文本)
en: Optional[str] = None
# 中文翻译
cn: Optional[str] = None
# 例句中突出的搭配 (如 "a world of difference")
collocation: Optional[str] = None
# 例句中链接到的其他词条 (如 "wanting", "loving")
related_words_in_example: Optional[List[str]] = Field(default_factory=list)
# 例句音频文件路径
audio: Optional[str] = None # 例如: "/media/english/exaProns/p008-000499910.mp3"
class Definition(BaseModel):
# 英文定义
en: Optional[str] = None
# 中文定义
cn: Optional[str] = None
# 定义中链接到的其他词条 (来自 defRef 标签)
related_words: Optional[List[str]] = Field(default_factory=list)
class Sense(BaseModel):
# Sense 的唯一标识符 (来自 HTML id)
id: Optional[str] = None # 例如: "apple__1", "world__3"
# Sense 编号 (如果存在)
number: Optional[str] = None # 例如: "1", "2", "a)"
# Signpost (英文标签,如 "OUR PLANET/EVERYONE ON IT")
signpost_en: Optional[str] = None
# Signpost 中文翻译
signpost_cn: Optional[str] = None
ref_hwd: Optional[str] = None
# 语法信息 (如果该 Sense 特有)
grammar: Optional[str] = None # 可以是字符串或更复杂的结构,视情况而定
# 定义 (可能有中英对照)
definitions: Optional[List[Definition]] = Field(default_factory=list)
# 例子 (属于该 Sense)
examples: Optional[List[Example]] = Field(default_factory=list)
# 图片信息 (如果该 Sense 有图)
image: Optional[str] = None # 图片文件名或路径, 例如: "apple.jpg"
# Sense 内的交叉引用 (来自 Crossref 标签)
cross_references: Optional[List[CrossReference]] = Field(default_factory=list)
# 可数性
countability: Optional[List[str]] = Field(default_factory=list)
class Topic(BaseModel):
# Topic 名称
name: Optional[str] = None # 例如: "Food, dish", "Astronomy"
# 链接到 Topic 的 href
href: Optional[str] = None # 例如: "entry://Food, dish-topic food"
class DictEntry(BaseModel):
# 词条名 (标准化形式)
headword: Optional[str] = None # 例如: "apple", "world"
# 同形异义词编号 (如果存在)
homograph_number: Optional[int] = None # 例如: 1, 2 (对应 HOMNUM)
# 词性 (主要词性,或第一个 Sense 的词性)
part_of_speech: Optional[str] = None # 例如: "noun", "verb"
transitive: Optional[List[str]] = Field(default_factory=list) # 例如: "transitive"
# 发音
pronunciations: Optional[Pronunciation] = None
# 频率信息
frequency: Optional[Frequency] = None
# 相关话题 (Topics)
topics: Optional[List[Topic]] = Field(default_factory=list)
# 词族信息
word_family: Optional[List[WordFamily]] = None
# 词条级别的语法信息 (如果适用,不常见)
entry_grammar: Optional[str] = None
# 所有义项 (Sense)
senses: Optional[List[Sense]] = Field(default_factory=list)
class EtymologyItem(BaseModel):
language: Optional[str] = None
origin: Optional[str] = None
class Etymology(BaseModel):
intro: Optional[str] = None
headword: Optional[str] = None
hom_num: Optional[str] = None
item: Optional[List[EtymologyItem]] = None
class WordMetaData(BaseModel):
ref_link: Optional[List[str]] = None
dict_list: List[DictEntry] = Field(default_factory=list)
etymology: Optional[Etymology] = None
# 来源词典信息 (可选)
source_dict: Optional[str] = "Longman Dictionary of Contemporary English 5++" # 可根据需要调整
# --- 简化的查询响应模型 ---
class SimpleCrossReference(BaseModel):
"""简化的交叉引用模型"""
text: Optional[str] = None
show_id: Optional[str] = None
sense_id: Optional[str] = None
image_filename: Optional[str] = None
class SimpleExample(BaseModel):
"""简化的例句模型"""
cn: Optional[str] = None
en: Optional[str] = None
audio: Optional[str] = None
class SimpleDefinition(BaseModel):
"""简化的定义模型"""
cn: Optional[str] = None
en: Optional[str] = None
related_words: Optional[List[str]] = Field(default_factory=list)
class SimpleSense(BaseModel):
"""简化的义项模型"""
id: Optional[str] = None
image: Optional[str] = None
number: Optional[str] = None
grammar: Optional[str] = None
ref_hwd: Optional[str] = None
examples: Optional[List[SimpleExample]] = Field(default_factory=list)
definitions: Optional[List[SimpleDefinition]] = Field(default_factory=list)
signpost_cn: Optional[str] = None
signpost_en: Optional[str] = None
cross_references: Optional[List[SimpleCrossReference]] = Field(default_factory=list)
countability: Optional[List[str]] = Field(default_factory=list)
class SimpleFrequency(BaseModel):
"""简化的频率信息模型"""
level_tag: Optional[str] = None
spoken_tag: Optional[str] = None
written_tag: Optional[str] = None
class SimplePronunciation(BaseModel):
"""简化的发音模型"""
uk_ipa: Optional[str] = None
us_ipa: Optional[str] = None
uk_audio: Optional[str] = None
us_audio: Optional[str] = None
class SimpleDictEntry(BaseModel):
"""字典单词查询响应模型(缩减版 DictEntry"""
part_of_speech: Optional[str] = None # 例如: "noun", "verb"
transitive: Optional[List[str]] = Field(default_factory=list)
senses: Optional[List[SimpleSense]] = Field(default_factory=list)
frequency: Optional[SimpleFrequency] = None
pronunciations: Optional[SimplePronunciation] = None
class DictWordResponse(BaseModel):
dict_list: List[SimpleDictEntry] = Field(default_factory=list)
etymology: Optional[Etymology] = None

View File

@@ -0,0 +1,57 @@
from datetime import datetime
from enum import Enum
from typing import Optional, List
from pydantic import BaseModel
from backend.common.schema import SchemaBase
class FeedbackStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
RESOLVED = "resolved"
class FeedbackCategory(str, Enum):
BUG = "bug"
FEATURE = "feature"
COMPLIMENT = "compliment"
OTHER = "other"
class FeedbackSchemaBase(SchemaBase):
content: str
contact_info: Optional[str] = None
category: Optional[FeedbackCategory] = None
metadata_info: Optional[dict] = None
class CreateFeedbackParam(FeedbackSchemaBase):
"""创建反馈参数"""
pass
class UpdateFeedbackParam(FeedbackSchemaBase):
"""更新反馈参数"""
status: Optional[FeedbackStatus] = None
class FeedbackInfoSchema(FeedbackSchemaBase):
"""反馈信息Schema"""
id: int
user_id: int
status: FeedbackStatus
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class FeedbackUserInfoSchema(FeedbackInfoSchema):
"""包含用户信息的反馈Schema"""
user: Optional[dict] = None # 简化的用户信息
class Config:
from_attributes = True

View File

@@ -0,0 +1,52 @@
from pydantic import BaseModel
from typing import Optional, Dict, Any
from datetime import datetime
from sqlalchemy.sql.sqltypes import BigInteger
from backend.common.schema import SchemaBase
class FileMetadata(BaseModel):
"""文件元数据结构"""
file_name: Optional[str] = None
content_type: Optional[str] = None
file_size: int # 文件大小(字节)
user_agent: Optional[str] = None # 上传客户端信息
extra: Optional[Dict[str, Any]] = None # 其他自定义元数据
class FileSchemaBase(SchemaBase):
file_hash: str
file_name: str
content_type: Optional[str] = None
file_size: int
storage_type: str = "database"
storage_path: Optional[str] = None
metadata_info: Optional[FileMetadata] = None
class AddFileParam(FileSchemaBase):
"""添加文件参数"""
pass
class UpdateFileParam(FileSchemaBase):
"""更新文件参数"""
pass
class FileInfoSchema(FileSchemaBase):
"""文件信息Schema"""
id: int
created_at: datetime
updated_at: datetime
class FileUploadResponse(SchemaBase):
"""文件上传响应"""
id: str
file_hash: str
file_name: str
content_type: Optional[str] = None
file_size: int

View File

@@ -0,0 +1,45 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
class PointsSchema(BaseModel):
"""积分账户信息"""
id: int
user_id: int
balance: int = Field(default=0, description="当前积分余额")
total_earned: int = Field(default=0, description="累计获得积分")
total_spent: int = Field(default=0, description="累计消费积分")
expired_time: datetime = Field(default_factory=datetime.now, description="过期时间")
class PointsLogSchema(BaseModel):
"""积分变动日志"""
id: int
user_id: int
action: str = Field(description="动作earn/spend")
amount: int = Field(description="变动数量")
balance_after: int = Field(description="变动后余额")
related_id: Optional[int] = Field(default=None, description="关联ID")
details: Optional[dict] = Field(default=None, description="附加信息")
created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
class AddPointsRequest(BaseModel):
"""增加积分请求"""
user_id: int = Field(description="用户ID")
amount: int = Field(gt=0, description="增加的积分数量")
extend_expiration: bool = Field(default=False, description="是否自动延期过期时间")
related_id: Optional[int] = Field(default=None, description="关联ID")
details: Optional[dict] = Field(default=None, description="附加信息")
class DeductPointsRequest(BaseModel):
"""扣减积分请求"""
user_id: int = Field(description="用户ID")
amount: int = Field(gt=0, description="扣减的积分数量")
related_id: Optional[int] = Field(default=None, description="关联ID")
details: Optional[dict] = Field(default=None, description="附加信息")

View File

@@ -0,0 +1,25 @@
from sqlalchemy import Column, BigInteger, String, Text
from sqlalchemy.dialects.postgresql import JSONB
from pgvector.sqlalchemy import Vector
from sqlalchemy.types import TypeDecorator
from backend.utils.json_encoder import jsonable_encoder
class PydanticType(TypeDecorator):
"""处理 Pydantic 模型的 SQLAlchemy 自定义类型"""
impl = JSONB
def __init__(self, pydantic_type=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pydantic_type = pydantic_type
def process_bind_param(self, value, dialect):
if value is None:
return None
return jsonable_encoder(value)
def process_result_value(self, value, dialect):
if value is None or self.pydantic_type is None:
return value
return self.pydantic_type(**value)

View File

@@ -0,0 +1,25 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pydantic import Field
from backend.common.schema import SchemaBase
class QwenParamBase(SchemaBase):
""" Qwen Base Params """
user_id: int = Field(description="user id")
image_id: int = Field(description="image id")
data: str = Field(description='Base64')
file_name: str = Field(description='文件名')
format: str = Field(description='图片后缀')
dict_level: str = Field(description='dict level')
class QwenEmbedImageParams(QwenParamBase):
""" Embedding Image Params """
class QwenRecognizeImageParams(QwenParamBase):
""" Recognize image Params """
type: str = Field(description='识别类型')
exclude_words: list[str] = Field(description='exclude words')

View File

@@ -0,0 +1,57 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from pydantic import Field
from backend.app.admin.schema.user import GetUserInfoDetail
from backend.common.enums import StatusType
from backend.common.schema import SchemaBase
class GetSwaggerToken(SchemaBase):
"""Swagger 认证令牌"""
access_token: str = Field(description='访问令牌')
token_type: str = Field('Bearer', description='令牌类型')
user: GetUserInfoDetail = Field(description='用户信息')
class AccessTokenBase(SchemaBase):
"""访问令牌基础模型"""
access_token: str = Field(description='访问令牌')
access_token_expire_time: datetime = Field(description='令牌过期时间')
session_uuid: str = Field(description='会话 UUID')
class GetNewToken(AccessTokenBase):
"""获取新令牌"""
class GetLoginToken(AccessTokenBase):
"""获取登录令牌"""
user: GetUserInfoDetail = Field(description='用户信息')
class GetWxLoginToken(AccessTokenBase):
"""微信登录令牌"""
dict_level: Optional[str] = Field(None, description="词典等级")
class GetTokenDetail(SchemaBase):
"""令牌详情"""
id: int = Field(description='用户 ID')
session_uuid: str = Field(description='会话 UUID')
username: str = Field(description='用户名')
nickname: str = Field(description='昵称')
ip: str = Field(description='IP 地址')
os: str = Field(description='操作系统')
browser: str = Field(description='浏览器')
device: str = Field(description='设备')
status: StatusType = Field(description='状态')
last_login_time: str = Field(description='最后登录时间')
expire_time: datetime = Field(description='过期时间')

View File

@@ -0,0 +1,12 @@
from pydantic import BaseModel, Field
from typing import Optional
class PurchaseRequest(BaseModel):
amount_cents: int = Field(..., ge=100, le=10000000, description="充值金额1元=100分")
class SubscriptionRequest(BaseModel):
plan: str = Field(..., pattern=r'^(monthly|quarterly|half_yearly|yearly)$', description="订阅计划")
class RefundRequest(BaseModel):
order_id: int = Field(..., gt=0, description="订单ID")
reason: Optional[str] = Field(None, max_length=200, description="退款原因")

View File

@@ -0,0 +1,48 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime
from pydantic import Field, EmailStr, ConfigDict, HttpUrl
from backend.common.schema import SchemaBase, CustomPhoneNumber
class AuthSchemaBase(SchemaBase):
username: str = Field(description='用户名')
password: str = Field(description='密码')
class AuthLoginParam(AuthSchemaBase):
captcha: str = Field(description='验证码')
class RegisterUserParam(AuthSchemaBase):
email: EmailStr = Field(examples=['user@example.com'], description='邮箱')
class UpdateUserParam(SchemaBase):
username: str = Field(description='用户名')
email: EmailStr = Field(examples=['user@example.com'], description='邮箱')
phone: CustomPhoneNumber | None = Field(None, description='手机号')
class AvatarParam(SchemaBase):
url: HttpUrl = Field(..., description='头像 http 地址')
class GetUserInfoDetail(UpdateUserParam):
model_config = ConfigDict(from_attributes=True)
id: int = Field(description='用户 ID')
uuid: str = Field(description='用户 UUID')
avatar: str | None = Field(None, description='头像')
status: int = Field(description='状态')
is_superuser: bool = Field(description='是否超级管理员')
join_time: datetime = Field(description='加入时间')
last_login_time: datetime | None = Field(None, description='最后登录时间')
class ResetPassword(SchemaBase):
username: str = Field(description='用户名')
old_password: str = Field(description='旧密码')
new_password: str = Field(description='新密码')
confirm_password: str = Field(description='确认密码')

70
backend/app/admin/schema/wx.py Executable file
View File

@@ -0,0 +1,70 @@
# schemas.py
from enum import Enum
from pydantic import BaseModel, Field
from datetime import datetime
from typing import Optional
from backend.common.schema import SchemaBase
class WxUserBase(SchemaBase):
id: int = Field(description='User ID')
openid: str = Field(description="微信 user openid")
class GetWxUserInfoDetail(WxUserBase):
"""用户信息详情"""
class GetWxUserInfoWithRelationDetail(GetWxUserInfoDetail):
"""用户信息关联详情"""
class WxLoginRequest(BaseModel):
code: str = Field(..., description="微信登录code")
appid: Optional[str] = Field(None, description="微信 appid")
encrypted_data: Optional[str] = Field(None, description="加密的用户数据")
iv: Optional[str] = Field(None, description="加密算法的初始向量")
class TokenResponse(BaseModel):
access_token: str = Field(..., description="访问令牌")
refresh_token: str = Field(..., description="刷新令牌")
token_type: str = Field("bearer", description="令牌类型")
expires_in: int = Field(..., description="过期时间(秒)")
class UserBase(BaseModel):
id: int = Field(..., description="用户ID")
openid: str = Field(..., description="微信OpenID")
class UserInfo(UserBase):
mobile: Optional[str] = Field(None, description="手机号")
profile: Optional[dict] = Field({}, description="用户资料")
created_at: datetime = Field(..., description="创建时间")
class UserAuth(BaseModel):
access_token: str
refresh_token: str
expires_at: datetime
class DictLevel(str, Enum):
LEVEL1 = "LEVEL1" # "小学"
LEVEL2 = "LEVEL2" # "初高中"
LEVEL3 = "LEVEL3" # "四六级"
class UserSettings(BaseModel):
dict_level: Optional[DictLevel] = Field(None, description="词典等级")
class UpdateUserSettingsRequest(UserSettings):
pass
class GetUserSettingsResponse(UserSettings):
pass

View File

@@ -0,0 +1,4 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from backend.app.admin.service.points_service import PointsService

View File

@@ -0,0 +1,247 @@
from backend.app.admin.crud.user_account_crud import user_account_dao
from backend.app.admin.crud.usage_log_crud import usage_log_dao
from backend.database.db import async_db_session
from backend.common.exception import errors
from datetime import datetime, timedelta
from backend.common.log import log as logger
# 导入 Redis 客户端
from backend.database.redis import redis_client
class AdShareService:
# 每日限制配置
DAILY_AD_LIMIT = 5 # 每日广告观看限制
DAILY_SHARE_LIMIT = 3 # 每日分享限制
AD_REWARD_TIMES = 3 # 每次广告奖励次数
SHARE_REWARD_TIMES = 3 # 每次分享奖励次数
@staticmethod
def _get_redis_key(user_id: int, action_type: str, date_str: str = None) -> str:
"""
生成 Redis key
Args:
user_id: 用户ID
action_type: 动作类型 (ad/share)
date_str: 日期字符串 (YYYY-MM-DD)
Returns:
str: Redis key
"""
if date_str is None:
date_str = datetime.now().strftime('%Y-%m-%d')
return f"user:{user_id}:{action_type}:count:{date_str}"
@staticmethod
async def _check_daily_limit(user_id: int, action_type: str, limit: int) -> bool:
"""
检查每日限制
Args:
user_id: 用户ID
action_type: 动作类型
limit: 限制次数
Returns:
bool: 是否超过限制
"""
try:
# 确保 Redis 连接
await redis_client.ping()
# 获取今天的计数
key = AdShareService._get_redis_key(user_id, action_type)
current_count = await redis_client.get(key)
if current_count is None:
current_count = 0
else:
current_count = int(current_count)
return current_count >= limit
except Exception as e:
logger.error(f"检查每日限制失败: {str(e)}")
# Redis 异常时允许继续操作(避免服务中断)
return False
@staticmethod
async def _increment_daily_count(user_id: int, action_type: str) -> int:
"""
增加每日计数
Args:
user_id: 用户ID
action_type: 动作类型
Returns:
int: 当前计数
"""
try:
# 确保 Redis 连接
await redis_client.ping()
key = AdShareService._get_redis_key(user_id, action_type)
# 增加计数
current_count = await redis_client.incr(key)
# 设置过期时间(第二天自动清除)
tomorrow = datetime.now() + timedelta(days=1)
tomorrow_midnight = tomorrow.replace(hour=0, minute=0, second=0, microsecond=0)
expire_seconds = int((tomorrow_midnight - datetime.now()).total_seconds())
if expire_seconds > 0:
await redis_client.expire(key, expire_seconds)
return current_count
except Exception as e:
logger.error(f"增加每日计数失败: {str(e)}")
raise errors.ServerError(msg="系统繁忙,请稍后重试")
@staticmethod
async def _get_daily_count(user_id: int, action_type: str) -> int:
"""
获取今日计数
Args:
user_id: 用户ID
action_type: 动作类型
Returns:
int: 当前计数
"""
try:
# 确保 Redis 连接
await redis_client.ping()
key = AdShareService._get_redis_key(user_id, action_type)
count = await redis_client.get(key)
return int(count) if count is not None else 0
except Exception as e:
logger.error(f"获取每日计数失败: {str(e)}")
return 0
@staticmethod
async def grant_times_by_ad(user_id: int):
"""
通过观看广告获得次数
Args:
user_id: 用户ID
"""
# 检查每日限制
if await AdShareService._check_daily_limit(user_id, "ad", AdShareService.DAILY_AD_LIMIT):
raise errors.ForbiddenError(msg=f"今日广告观看次数已达上限({AdShareService.DAILY_AD_LIMIT}次)")
async with async_db_session.begin() as db:
try:
# 增加计数
current_count = await AdShareService._increment_daily_count(user_id, "ad")
# 增加用户余额
result = await user_account_dao.update_balance_atomic(db, user_id, AdShareService.AD_REWARD_TIMES)
if not result:
# 回滚计数
key = AdShareService._get_redis_key(user_id, "ad")
await redis_client.decr(key)
raise errors.ServerError(msg="账户更新失败")
# 记录使用日志
account = await user_account_dao.get_by_user_id(db, user_id)
await usage_log_dao.add(db, {
"user_id": user_id,
"action": "ad",
"amount": AdShareService.AD_REWARD_TIMES,
"balance_after": account.balance if account else AdShareService.AD_REWARD_TIMES,
"metadata_": {
"daily_count": current_count,
"max_limit": AdShareService.DAILY_AD_LIMIT
}
})
except Exception as e:
if not isinstance(e, errors.ForbiddenError):
logger.error(f"广告奖励处理失败: {str(e)}")
raise
@staticmethod
async def grant_times_by_share(user_id: int):
"""
通过分享获得次数
Args:
user_id: 用户ID
"""
# 检查每日限制
if await AdShareService._check_daily_limit(user_id, "share", AdShareService.DAILY_SHARE_LIMIT):
raise errors.ForbiddenError(msg=f"今日分享次数已达上限({AdShareService.DAILY_SHARE_LIMIT}次)")
async with async_db_session.begin() as db:
try:
# 增加计数
current_count = await AdShareService._increment_daily_count(user_id, "share")
# 增加用户余额
result = await user_account_dao.update_balance_atomic(db, user_id, AdShareService.SHARE_REWARD_TIMES)
if not result:
# 回滚计数
key = AdShareService._get_redis_key(user_id, "share")
await redis_client.decr(key)
raise errors.ServerError(msg="账户更新失败")
# 记录使用日志
account = await user_account_dao.get_by_user_id(db, user_id)
await usage_log_dao.add(db, {
"user_id": user_id,
"action": "share",
"amount": AdShareService.SHARE_REWARD_TIMES,
"balance_after": account.balance if account else AdShareService.SHARE_REWARD_TIMES,
"metadata_": {
"daily_count": current_count,
"max_limit": AdShareService.DAILY_SHARE_LIMIT
}
})
except Exception as e:
if not isinstance(e, errors.ForbiddenError):
logger.error(f"分享奖励处理失败: {str(e)}")
raise
@staticmethod
async def get_daily_stats(user_id: int) -> dict:
"""
获取用户今日统计信息
Args:
user_id: 用户ID
Returns:
dict: 统计信息
"""
try:
ad_count = await AdShareService._get_daily_count(user_id, "ad")
share_count = await AdShareService._get_daily_count(user_id, "share")
return {
"ad_count": ad_count,
"ad_limit": AdShareService.DAILY_AD_LIMIT,
"share_count": share_count,
"share_limit": AdShareService.DAILY_SHARE_LIMIT,
"can_watch_ad": ad_count < AdShareService.DAILY_AD_LIMIT,
"can_share": share_count < AdShareService.DAILY_SHARE_LIMIT
}
except Exception as e:
logger.error(f"获取每日统计失败: {str(e)}")
return {
"ad_count": 0,
"ad_limit": AdShareService.DAILY_AD_LIMIT,
"share_count": 0,
"share_limit": AdShareService.DAILY_SHARE_LIMIT,
"can_watch_ad": True,
"can_share": True
}

View File

@@ -0,0 +1,146 @@
from backend.app.admin.crud.audit_log_crud import audit_log_dao
from backend.app.admin.crud.daily_summary_crud import daily_summary_dao
from backend.app.admin.schema.audit_log import CreateAuditLogParam, AuditLogHistorySchema, DailySummarySchema
from backend.database.db import async_db_session
from typing import List, Tuple
from datetime import datetime, date
from collections import defaultdict
class AuditLogService:
""" Audit Log Service """
@staticmethod
async def create(*, obj: CreateAuditLogParam) -> None:
"""
创建操作日志
:param obj: 操作日志创建参数
:return:
"""
async with async_db_session.begin() as db:
await audit_log_dao.create(db, obj)
@staticmethod
async def get_user_recognition_history(
*,
user_id: int,
page: int = 1,
size: int = 20
) -> Tuple[List[AuditLogHistorySchema], int]:
"""
通过用户ID查询历史记录
:param user_id: 用户ID
:param page: 页码
:param size: 每页数量
:return: 历史记录列表和总数
"""
async with async_db_session() as db:
items, total = await audit_log_dao.get_user_recognition_history(db, user_id, page, size)
# 转换为 schema 对象
history_items = [
AuditLogHistorySchema(
image_id=str(item.id),
file_id=str(item.file_id),
thumbnail_id=str(item.thumbnail_id),
created_time=item.created_time.strftime("%Y-%m-%d %H:%M:%S") if item.created_time else None,
dict_level=item.dict_level
)
for item in items
]
return history_items, total
@staticmethod
async def get_user_recognition_statistics(*, user_id: int) -> Tuple[int, int, int]:
"""
统计用户 recognition 类型的使用记录
返回历史总量和当天总量
:param user_id: 用户ID
:return: (历史总量, 当天总量)
"""
async with async_db_session() as db:
total_count, today_count, image_count = await audit_log_dao.get_user_recognition_statistics(db, user_id)
return total_count, today_count, image_count
@staticmethod
async def get_user_daily_summaries(
*,
user_id: int,
page: int = 1,
size: int = 20
) -> Tuple[List[DailySummarySchema], int]:
"""
通过用户ID查询每日识别汇总记录按创建时间降序排列
:param user_id: 用户ID
:param page: 页码
:param size: 每页数量
:return: 每日汇总记录列表和总数
"""
async with async_db_session() as db:
items, total = await daily_summary_dao.get_user_daily_summaries(db, user_id, page, size)
# 转换为 schema 对象
summary_items = [
DailySummarySchema(
image_ids=item.image_ids,
thumbnail_ids=item.thumbnail_ids,
summary_time=str(item.summary_time.date())
)
for item in items
]
return summary_items, total
@staticmethod
async def get_user_today_summaries(
*,
user_id: int,
page: int = 1,
size: int = 20
) -> Tuple[List[DailySummarySchema], int]:
"""
获取用户当天的识别记录摘要
查询当天时间内AuditLog.api_type == 'recognition' 的所有记录,
获取相关的图片和缩略图,构成返回结构中的数据
:param user_id: 用户ID
:param page: 页码
:param size: 每页数量
:return: 当天识别记录摘要列表和总数
"""
async with async_db_session() as db:
# 获取当天的识别记录
items, total = await audit_log_dao.get_user_today_recognition_history(db, user_id, page, size)
# 如果没有记录,返回空列表
if not items:
return [], total
# 提取图片ID和缩略图ID
image_ids = []
thumbnail_ids = []
for item in items:
if item.id:
image_ids.append(str(item.id))
if item.thumbnail_id:
thumbnail_ids.append(str(item.thumbnail_id))
# 构建返回数据
# 由于是当天的数据,所有记录都属于同一天
today = date.today()
summary_item = DailySummarySchema(
image_ids=image_ids,
thumbnail_ids=thumbnail_ids,
summary_time=str(today)
)
return [summary_item], total
audit_log_service: AuditLogService = AuditLogService()

View File

@@ -0,0 +1,171 @@
# audit_service.py
from sqlalchemy import func, extract, case, Integer
from sqlalchemy.sql import label
from datetime import datetime, timedelta
from typing import List, Dict, Any
from backend.app.admin.model.audit_log import AuditLog
from backend.core.database import SessionLocal
from backend.common.log import log as logger
class AuditService:
@staticmethod
def get_audit_logs(page: int = 1, page_size: int = 20,
filters: Dict[str, Any] = None) -> Dict[str, Any]:
"""获取审计日志(分页)"""
with SessionLocal() as db:
query = db.query(AuditLog)
# 应用过滤条件
if filters:
if filters.get("api_type"):
query = query.filter(AuditLog.api_type == filters["api_type"])
if filters.get("model_name"):
query = query.filter(AuditLog.model_name == filters["model_name"])
if filters.get("status_code"):
query = query.filter(AuditLog.status_code == filters["status_code"])
if filters.get("start_date"):
query = query.filter(AuditLog.called_at >= filters["start_date"])
if filters.get("end_date"):
query = query.filter(AuditLog.called_at <= filters["end_date"])
if filters.get("user_id"):
query = query.filter(AuditLog.user_id == filters["user_id"])
# 分页处理
total = query.count()
logs = query.order_by(AuditLog.called_at.desc()) \
.offset((page - 1) * page_size) \
.limit(page_size) \
.all()
return {
"total": total,
"page": page,
"page_size": page_size,
"results": [log.to_dict() for log in logs]
}
@staticmethod
def get_usage_statistics(time_range: str = "daily") -> Dict[str, Any]:
"""获取API使用统计"""
with SessionLocal() as db:
# 根据时间范围确定分组方式
if time_range == "hourly":
time_group = func.date_trunc('hour', AuditLog.called_at)
elif time_range == "weekly":
time_group = func.date_trunc('week', AuditLog.called_at)
elif time_range == "monthly":
time_group = func.date_trunc('month', AuditLog.called_at)
else: # daily
time_group = func.date_trunc('day', AuditLog.called_at)
# 基本统计
stats = db.query(
time_group.label("time_period"),
AuditLog.api_type,
AuditLog.model_name,
func.count().label("request_count"),
func.sum(AuditLog.cost).label("total_cost"),
func.avg(AuditLog.duration).label("avg_duration"),
func.sum(
case(
(AuditLog.status_code >= 200, 1),
else_=0
)
).label("success_count"),
func.sum(
case(
(AuditLog.status_code >= 400, 1),
else_=0
)
).label("error_count")
).group_by(
"time_period",
AuditLog.api_type,
AuditLog.model_name
).order_by(
"time_period"
).all()
# Token使用统计
token_stats = db.query(
time_group.label("time_period"),
AuditLog.api_type,
AuditLog.model_name,
func.sum(AuditLog.token_usage['input_tokens'].astext.cast(Integer)).label("input_tokens"),
func.sum(AuditLog.token_usage['output_tokens'].astext.cast(Integer)).label("output_tokens"),
func.sum(AuditLog.token_usage['total_tokens'].astext.cast(Integer)).label("total_tokens")
).filter(
AuditLog.token_usage != None
).group_by(
"time_period",
AuditLog.api_type,
AuditLog.model_name
).order_by(
"time_period"
).all()
# 转换结果
return {
"usage_stats": [
{
"time_period": s.time_period,
"api_type": s.api_type,
"model_name": s.model_name,
"request_count": s.request_count,
"total_cost": float(s.total_cost) if s.total_cost else 0.0,
"avg_duration": float(s.avg_duration) if s.avg_duration else 0.0,
"success_rate": s.success_count / s.request_count if s.request_count else 0.0,
"error_rate": s.error_count / s.request_count if s.request_count else 0.0
} for s in stats
],
"token_stats": [
{
"time_period": t.time_period,
"api_type": t.api_type,
"model_name": t.model_name,
"input_tokens": t.input_tokens,
"output_tokens": t.output_tokens,
"total_tokens": t.total_tokens
} for t in token_stats
]
}
@staticmethod
def get_cost_forecast() -> Dict[str, Any]:
"""预测未来成本"""
with SessionLocal() as db:
# 获取最近30天的成本数据
end_date = datetime.utcnow()
start_date = end_date - timedelta(days=30)
daily_costs = db.query(
func.date_trunc('day', AuditLog.called_at).label("day"),
func.sum(AuditLog.cost).label("daily_cost")
).filter(
AuditLog.called_at >= start_date,
AuditLog.called_at <= end_date
).group_by("day").order_by("day").all()
# 简单预测模型 (移动平均)
costs = [float(d.daily_cost) if d.daily_cost else 0.0 for d in daily_costs]
avg_cost = sum(costs) / len(costs) if costs else 0.0
# 生成预测 (未来7天)
forecast = []
for i in range(1, 8):
forecast_date = end_date + timedelta(days=i)
forecast.append({
"date": forecast_date.date(),
"predicted_cost": avg_cost
})
return {
"historical": [
{"date": d.day.date(), "cost": float(d.daily_cost) if d.daily_cost else 0.0}
for d in daily_costs
],
"forecast": forecast
}

View File

@@ -0,0 +1,57 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from fastapi import Request
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from backend.app.admin.crud.wx_user_crud import wx_user_dao
from backend.app.admin.model import WxUser
from backend.app.admin.schema.token import GetLoginToken
from backend.app.admin.schema.user import AuthLoginParam
from backend.common.exception import errors
from backend.common.response.response_code import CustomErrorCode
from backend.common.security.jwt import password_verify, create_access_token
from backend.core.conf import settings
from backend.database.db import async_db_session
from backend.database.redis import redis_client
from backend.utils.timezone import timezone
class AuthService:
@staticmethod
async def user_verify(db: AsyncSession, username: str, password: str) -> WxUser:
user = await wx_user_dao.get_by_username(db, username)
if not user:
raise errors.NotFoundError(msg='用户名或密码有误')
elif not password_verify(password, user.password):
raise errors.AuthorizationError(msg='用户名或密码有误')
elif not user.status:
raise errors.AuthorizationError(msg='用户已被锁定, 请联系统管理员')
return user
async def swagger_login(self, *, form_data: OAuth2PasswordRequestForm) -> tuple[str, WxUser]:
async with async_db_session() as db:
user = await self.user_verify(db, form_data.username, form_data.password)
await wx_user_dao.update_login_time(db, user.username, login_time=timezone.now())
token = create_access_token(str(user.id))
return token, user
async def login(self, *, request: Request, obj: AuthLoginParam) -> GetLoginToken:
async with async_db_session() as db:
user = await self.user_verify(db, obj.username, obj.password)
try:
captcha_uuid = request.app.state.captcha_uuid
redis_code = await redis_client.get(f'{settings.CAPTCHA_LOGIN_REDIS_PREFIX}:{captcha_uuid}')
if not redis_code:
raise errors.ForbiddenError(msg='验证码失效,请重新获取')
except AttributeError:
raise errors.ForbiddenError(msg='验证码失效,请重新获取')
if redis_code.lower() != obj.captcha.lower():
raise errors.CustomError(error=CustomErrorCode.CAPTCHA_ERROR)
await wx_user_dao.update_login_time(db, user.username, login_time=timezone.now())
token = create_access_token(str(user.id))
data = GetLoginToken(access_token=token, user=user)
return data
auth_service: AuthService = AuthService()

View File

@@ -0,0 +1,148 @@
import random
import string
from typing import List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from backend.app.admin.crud.coupon_crud import coupon_dao, coupon_usage_dao
from backend.app.admin.model.coupon import Coupon
from backend.database.db import async_db_session
from backend.common.exception import errors
from datetime import datetime, timedelta
class CouponService:
@staticmethod
def generate_unique_code(length: int = 6) -> str:
"""
生成唯一的兑换码
"""
characters = string.ascii_uppercase + string.digits
# 移除容易混淆的字符
characters = characters.replace('0', '').replace('O', '').replace('I', '').replace('1')
while True:
code = ''.join(random.choice(characters) for _ in range(length))
# 确保生成的兑换码不包含敏感词汇或重复模式
if not CouponService._has_sensitive_pattern(code):
return code
@staticmethod
def _has_sensitive_pattern(code: str) -> bool:
"""
检查兑换码是否包含敏感模式
"""
# 简单的敏感词检查
sensitive_words = ['ADMIN', 'ROOT', 'USER', 'TEST']
for word in sensitive_words:
if word in code:
return True
# 检查是否为重复字符
if len(set(code)) == 1:
return True
return False
@staticmethod
async def create_coupon(duration: int, expires_days: Optional[int] = None) -> Coupon:
"""
创建单个兑换券
"""
async with async_db_session.begin() as db:
# 生成唯一兑换码
code = CouponService.generate_unique_code()
# 确保兑换码唯一性
while await coupon_dao.get_by_code(db, code):
code = CouponService.generate_unique_code()
# 设置过期时间
expires_at = None
if expires_days:
expires_at = datetime.now() + timedelta(days=expires_days)
coupon_data = {
'code': code,
'duration': duration,
'expires_at': expires_at
}
coupon = await coupon_dao.create_coupon(db, coupon_data)
return coupon
@staticmethod
async def batch_create_coupons(count: int, duration: int, expires_days: Optional[int] = None) -> List[Coupon]:
"""
批量创建兑换券
"""
async with async_db_session.begin() as db:
coupons_data = []
# 生成唯一兑换码列表
codes = set()
while len(codes) < count:
code = CouponService.generate_unique_code()
if code not in codes:
# 检查数据库中是否已存在该兑换码
existing_coupon = await coupon_dao.get_by_code(db, code)
if not existing_coupon:
codes.add(code)
# 设置过期时间
expires_at = None
if expires_days:
expires_at = datetime.now() + timedelta(days=expires_days)
# 准备数据
for code in codes:
coupons_data.append({
'code': code,
'duration': duration,
'expires_at': expires_at
})
coupons = await coupon_dao.create_coupons(db, coupons_data)
return coupons
@staticmethod
async def redeem_coupon(code: str, user_id: int) -> dict:
"""
兑换兑换券
"""
async with async_db_session.begin() as db:
# 获取未使用的兑换券
coupon = await coupon_dao.get_unused_coupon_by_code(db, code)
if not coupon:
raise errors.RequestError(msg='兑换码无效或已过期')
# 标记为已使用并创建使用记录
success = await coupon_dao.mark_as_used(db, coupon.id, user_id, coupon.duration)
if not success:
raise errors.ServerError(msg='兑换失败,请稍后重试')
return {
'code': coupon.code,
'duration': coupon.duration,
'used_at': datetime.now()
}
@staticmethod
async def get_user_coupon_history(user_id: int, limit: int = 100) -> List[dict]:
"""
获取用户兑换历史
"""
async with async_db_session() as db:
usages = await coupon_usage_dao.get_by_user_id(db, user_id, limit)
history = []
for usage in usages:
# 获取兑换券信息
coupon = await coupon_dao.get(db, usage.coupon_id)
if coupon:
history.append({
'code': coupon.code,
'duration': usage.duration,
'used_at': usage.used_at
})
return history

View File

@@ -0,0 +1,262 @@
import json
from typing import Optional
from backend.app.admin.crud.dict_crud import dict_dao
from backend.app.admin.model.dict import DictionaryEntry
from backend.app.admin.schema.dict import (
DictWordResponse, SimpleDictEntry, SimpleSense, SimpleDefinition, SimpleExample,
SimpleCrossReference, SimpleFrequency, SimplePronunciation, WordMetaData
)
from backend.common.exception import errors
from backend.database.db import async_db_session
from backend.database.redis import redis_client
class DictService:
# Redis缓存键前缀
WORD_LINK_CACHE_PREFIX = "dict:word_link:"
# 缓存过期时间24小时
CACHE_EXPIRE_TIME = 24 * 60 * 60
@staticmethod
def _convert_to_simple_response(entry: DictionaryEntry) -> DictWordResponse:
"""将完整的字典条目转换为简化的响应格式"""
if not entry.details:
# 如果没有详细信息,返回空的响应
return DictWordResponse()
details: WordMetaData = entry.details
# 转换字典条目列表
simple_dict_list = []
if details.dict_list:
for dict_entry in details.dict_list:
# 转换义项信息
simple_senses = []
if dict_entry.senses:
for sense in dict_entry.senses:
# 转换定义
simple_definitions = []
if sense.definitions:
for definition in sense.definitions:
simple_definitions.append(SimpleDefinition(
cn=definition.cn,
en=definition.en,
related_words=definition.related_words or []
))
# 转换例句
simple_examples = []
if sense.examples:
for example in sense.examples:
simple_examples.append(SimpleExample(
cn=example.cn,
en=example.en,
audio=example.audio
))
# 转换交叉引用
simple_cross_refs = []
if sense.cross_references:
for cross_ref in sense.cross_references:
simple_cross_refs.append(SimpleCrossReference(
text=cross_ref.text,
show_id=cross_ref.show_id,
sense_id=cross_ref.sense_id,
image_filename=cross_ref.image_filename
))
# 创建简化的义项
simple_sense = SimpleSense(
id=sense.id,
image=sense.image,
number=sense.number,
grammar=sense.grammar,
ref_hwd=sense.ref_hwd,
examples=simple_examples,
definitions=simple_definitions,
signpost_cn=sense.signpost_cn,
signpost_en=sense.signpost_en,
cross_references=simple_cross_refs,
countability = sense.countability or [],
)
simple_senses.append(simple_sense)
# 转换频率信息
simple_frequency = None
if dict_entry.frequency:
simple_frequency = SimpleFrequency(
level_tag=dict_entry.frequency.level_tag,
spoken_tag=dict_entry.frequency.spoken_tag,
written_tag=dict_entry.frequency.written_tag
)
# 转换发音信息
simple_pronunciations = None
if dict_entry.pronunciations:
simple_pronunciations = SimplePronunciation(
uk_ipa=dict_entry.pronunciations.uk_ipa,
us_ipa=dict_entry.pronunciations.us_ipa,
uk_audio=dict_entry.pronunciations.uk_audio,
us_audio=dict_entry.pronunciations.us_audio
)
# 创建简化的字典条目
simple_dict_entry = SimpleDictEntry(
part_of_speech=dict_entry.part_of_speech,
transitive=dict_entry.transitive,
senses=simple_senses,
frequency=simple_frequency,
pronunciations=simple_pronunciations
)
simple_dict_list.append(simple_dict_entry)
return DictWordResponse(
dict_list=simple_dict_list,
etymology=details.etymology
)
@staticmethod
async def _get_linked_word_from_cache(word: str) -> Optional[str]:
"""
从Redis缓存中获取单词的链接单词
返回None表示没有链接单词返回具体单词表示有链接单词
"""
cache_key = f"{DictService.WORD_LINK_CACHE_PREFIX}{word.lower()}"
try:
# 尝试从Redis获取缓存
cached_result = await redis_client.get(cache_key)
if cached_result is not None:
# 如果缓存的是"None"字符串,表示没有链接单词
if cached_result == "None":
return word
# 否则返回缓存的链接单词
return cached_result
return None
except Exception as e:
# 如果Redis出错不影响主流程
return None
@staticmethod
async def _set_linked_word_to_cache(word: str, linked_word: Optional[str]) -> None:
"""
将单词的链接单词存入Redis缓存
linked_word为None表示没有链接单词
"""
cache_key = f"{DictService.WORD_LINK_CACHE_PREFIX}{word.lower()}"
try:
# 如果没有链接单词,存储"None"字符串
value = linked_word if linked_word is not None else "None"
await redis_client.setex(cache_key, DictService.CACHE_EXPIRE_TIME, value)
except Exception as e:
# 如果Redis出错不影响主流程
pass
@staticmethod
async def _get_linked_word_from_db(word: str) -> Optional[str]:
"""
从数据库中获取单词的链接单词
"""
try:
async with async_db_session() as db:
entry = await dict_dao.get_by_word(db, word.strip().lower())
if not entry:
return None
# 检查 details 字段内的 json
if (entry.details and
not entry.details.dict_list and
entry.details.ref_link):
# dict_list 为空数组,检查 ref_link 的内容
ref_links = entry.details.ref_link
if ref_links:
# 获取第一个链接
first_link = ref_links[0]
# 检查是否包含 "LINK=" 前缀
if isinstance(first_link, str) and first_link.startswith("LINK="):
# 截取 "LINK=" 后的单词
referenced_word = first_link[5:] # 去掉 "LINK=" 前缀
if referenced_word:
return referenced_word
return word
except Exception as e:
# 如果出现任何错误返回None表示未找到链接单词
return None
@staticmethod
async def get_linked_word(word: str) -> str:
"""
获取单词的链接单词(如果有)
首先检查Redis缓存如果没有则查询数据库并更新缓存
"""
if not word or not word.strip():
return word
word = word.strip().lower()
# 首先尝试从缓存获取
linked_word = await DictService._get_linked_word_from_cache(word)
if linked_word is not None:
return linked_word
# 缓存未命中,从数据库查询
linked_word = await DictService._get_linked_word_from_db(word)
# 更新缓存
await DictService._set_linked_word_to_cache(word, linked_word)
# 返回结果,如果没有链接单词则返回原单词
return linked_word if linked_word is not None else word
@staticmethod
async def get_word_info(word: str) -> DictWordResponse:
"""根据单词获取字典信息"""
if not word or not word.strip():
raise errors.ForbiddenError(msg="单词不能为空")
word = word.strip().lower()
async with async_db_session() as db:
entry = await dict_dao.get_by_word(db, word)
if not entry:
raise errors.NotFoundError(msg=f"未找到单词 '{word}' 的释义")
# 使用新的缓存机制获取链接单词
linked_word = await DictService.get_linked_word(word)
if linked_word != word:
# 如果找到了引用的单词,返回其信息
referenced_entry = await dict_dao.get_by_word(db, linked_word)
if referenced_entry:
return DictService._convert_to_simple_response(referenced_entry)
return DictService._convert_to_simple_response(entry)
@staticmethod
async def check_word_exists(word: str) -> bool:
"""检查单词是否存在"""
if not word or not word.strip():
return False
word = word.strip().lower()
async with async_db_session() as db:
entry = await dict_dao.get_by_word(db, word)
return entry is not None
@staticmethod
async def get_audio_data(file_name: str) -> Optional[bytes]:
"""根据文件名获取音频数据"""
if not file_name or not file_name.strip():
return None
file_name = file_name.strip()
async with async_db_session() as db:
media = await dict_dao.get_media_by_filename(db, file_name)
if not media or media.file_type != 'audio':
return None
return media.file_data
dict_service = DictService()

View File

@@ -0,0 +1,66 @@
from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from backend.app.admin.crud.feedback_crud import feedback_dao
from backend.app.admin.model.feedback import Feedback
from backend.app.admin.schema.feedback import CreateFeedbackParam, UpdateFeedbackParam, FeedbackInfoSchema
from backend.database.db import async_db_session
class FeedbackService:
@staticmethod
async def create_feedback(user_id: int, obj_in: CreateFeedbackParam) -> FeedbackInfoSchema:
"""创建反馈"""
async with async_db_session.begin() as db:
feedback = await feedback_dao.create(db, user_id, obj_in)
return FeedbackInfoSchema.model_validate(feedback)
@staticmethod
async def get_feedback(feedback_id: int) -> Optional[FeedbackInfoSchema]:
"""获取反馈详情"""
async with async_db_session() as db:
feedback = await feedback_dao.get(db, feedback_id)
if feedback:
return FeedbackInfoSchema.model_validate(feedback)
return None
@staticmethod
async def get_feedback_list(
user_id: Optional[int] = None,
status: Optional[str] = None,
category: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None
) -> List[FeedbackInfoSchema]:
"""获取反馈列表"""
async with async_db_session() as db:
feedbacks = await feedback_dao.get_list(db, user_id, status, category, limit, offset)
return [FeedbackInfoSchema.model_validate(feedback) for feedback in feedbacks]
@staticmethod
async def update_feedback(feedback_id: int, obj_in: UpdateFeedbackParam) -> bool:
"""更新反馈"""
async with async_db_session.begin() as db:
count = await feedback_dao.update(db, feedback_id, obj_in)
return count > 0
@staticmethod
async def delete_feedback(feedback_id: int) -> bool:
"""删除反馈"""
async with async_db_session.begin() as db:
count = await feedback_dao.delete(db, feedback_id)
return count > 0
@staticmethod
async def count_feedbacks(
user_id: Optional[int] = None,
status: Optional[str] = None,
category: Optional[str] = None
) -> int:
"""统计反馈数量"""
async with async_db_session() as db:
return await feedback_dao.count(db, user_id, status, category)
feedback_service = FeedbackService()

View File

@@ -0,0 +1,419 @@
import io
import imghdr
from datetime import datetime
from typing import Optional, Dict, Any
from fastapi import UploadFile
from PIL import Image as PILImage, ExifTags
from backend.app.admin.crud.file_crud import file_dao
from backend.app.admin.model.file import File
from backend.app.admin.schema.file import AddFileParam, FileUploadResponse, UpdateFileParam, FileMetadata
from backend.app.ai.schema.image import ColorMode, ImageMetadata, ImageFormat
from backend.app.admin.service.file_storage import get_storage_provider, calculate_file_hash
from backend.common.exception import errors
from backend.core.conf import settings
from backend.database.db import async_db_session
class FileService:
@staticmethod
def is_image_file(content_type: str, file_content: bytes, file_name: str) -> bool:
"""判断是否为图片文件"""
# 首先检查文件扩展名是否在允许的图片类型列表中
if file_name:
file_ext = file_name.split('.')[-1].lower()
if file_ext in settings.UPLOAD_IMAGE_EXT_INCLUDE:
return True
# 然后检查content_type
if content_type and content_type.startswith('image/'):
return True
# 最后通过文件内容检测
image_format = imghdr.what(None, h=file_content)
return image_format is not None
@staticmethod
def validate_image_file(file_name: str) -> None:
"""验证图片文件类型是否被允许"""
if not file_name:
return
file_ext = file_name.split('.')[-1].lower()
if file_ext and file_ext not in settings.UPLOAD_IMAGE_EXT_INCLUDE:
raise errors.ForbiddenError(msg=f'[{file_ext}] 此图片格式暂不支持')
@staticmethod
async def upload_file(
file: UploadFile,
metadata: Optional[dict] = None
) -> FileUploadResponse:
"""上传文件"""
# 读取文件内容
content = await file.read()
await file.seek(0) # 重置文件指针
storage_type = settings.DEFAULT_STORAGE_TYPE
# 计算文件哈希
file_hash = calculate_file_hash(content)
# 检查文件是否已存在
async with async_db_session() as db:
existing_file = await file_dao.get_by_hash(db, file_hash)
if existing_file:
return FileUploadResponse(
id=str(existing_file.id),
file_hash=existing_file.file_hash,
file_name=existing_file.file_name,
content_type=existing_file.content_type,
file_size=existing_file.file_size
)
# 获取存储提供者
storage_provider = get_storage_provider(storage_type)
# 保存文件到存储
storage_path = None
file_data = None
if storage_type == "database":
# 数据库存储,将文件数据保存到数据库
file_data = content
else:
# 其他存储方式,保存文件并记录路径
storage_path = await storage_provider.save(
file_id=0, # 临时ID后续会更新
content=content,
file_name=file.filename or "unnamed"
)
# 创建文件记录
async with async_db_session.begin() as db:
# 创建文件元数据
file_metadata_dict = {
"file_name": file.filename,
"content_type": file.content_type,
"file_size": len(content),
"created_at": datetime.now(),
"updated_at": datetime.now(),
"extra": metadata
}
# 验证图片文件类型
if file.filename:
FileService.validate_image_file(file.filename)
# 如果是图片文件,提取图片元数据
if FileService.is_image_file(file.content_type or "", content, file.filename or ""):
try:
additional_info = {
"file_name": file.filename,
"content_type": file.content_type,
"file_size": len(content),
}
image_metadata = file_service.extract_image_metadata(content, additional_info)
file_metadata_dict["image_info"] = image_metadata.dict()
except Exception as e:
# 如果提取图片元数据失败,记录错误但不中断上传
file_metadata_dict["extra"] = {
**(metadata or {}),
"image_metadata_error": str(e)
}
file_metadata = FileMetadata(**file_metadata_dict)
# 创建文件参数
file_params = AddFileParam(
file_hash=file_hash,
file_name=file.filename or "unnamed",
content_type=file.content_type,
file_size=len(content),
storage_type=storage_type,
storage_path=storage_path,
metadata_info=file_metadata
)
# 保存到数据库
db_file = await file_dao.create(db, file_params)
# 如果是本地存储或其他存储需要更新文件ID
if storage_type != "database":
# 重新保存文件使用真实的文件ID
storage_path = await storage_provider.save(
file_id=db_file.id,
content=content,
file_name=file.filename or "unnamed"
)
# 更新数据库中的存储路径
update_params = UpdateFileParam(
storage_path=storage_path
)
await file_dao.update(db, db_file.id, update_params)
db_file.storage_path = storage_path
# 如果是数据库存储,更新文件数据
if storage_type == "database":
db_file.file_data = content
await db.flush() # 确保将file_data保存到数据库
return FileUploadResponse(
id=str(db_file.id),
file_hash=db_file.file_hash,
file_name=db_file.file_name,
content_type=db_file.content_type,
file_size=db_file.file_size
)
@staticmethod
async def upload_file_with_content_type(
file,
content_type: str,
metadata: Optional[dict] = None
) -> FileUploadResponse:
"""上传文件并指定content_type"""
# 读取文件内容
content = await file.read()
await file.seek(0) # 重置文件指针
storage_type = settings.DEFAULT_STORAGE_TYPE
# 计算文件哈希
file_hash = calculate_file_hash(content)
# 检查文件是否已存在
async with async_db_session() as db:
existing_file = await file_dao.get_by_hash(db, file_hash)
if existing_file:
return FileUploadResponse(
id=str(existing_file.id),
file_hash=existing_file.file_hash,
file_name=existing_file.file_name,
content_type=existing_file.content_type,
file_size=existing_file.file_size
)
# 获取存储提供者
storage_provider = get_storage_provider(storage_type)
# 保存文件到存储
storage_path = None
file_data = None
if storage_type == "database":
# 数据库存储,将文件数据保存到数据库
file_data = content
else:
# 其他存储方式,保存文件并记录路径
storage_path = await storage_provider.save(
file_id=0, # 临时ID后续会更新
content=content,
file_name=getattr(file, 'filename', 'unnamed') or "unnamed"
)
# 创建文件记录
async with async_db_session.begin() as db:
# 创建文件元数据
file_metadata_dict = {
"file_name": getattr(file, 'filename', 'unnamed'),
"content_type": content_type,
"file_size": len(content),
"created_at": datetime.now(),
"updated_at": datetime.now(),
"extra": metadata
}
# 验证图片文件类型
file_name = getattr(file, 'filename', 'unnamed')
if file_name:
FileService.validate_image_file(file_name)
# 如果是图片文件,提取图片元数据
if FileService.is_image_file(content_type or "", content, file_name or ""):
try:
additional_info = {
"file_name": file_name,
"content_type": content_type,
"file_size": len(content),
}
image_metadata = file_service.extract_image_metadata(content, additional_info)
file_metadata_dict["image_info"] = image_metadata.dict()
except Exception as e:
# 如果提取图片元数据失败,记录错误但不中断上传
file_metadata_dict["extra"] = {
**(metadata or {}),
"image_metadata_error": str(e)
}
file_metadata = FileMetadata(**file_metadata_dict)
# 创建文件参数
file_params = AddFileParam(
file_hash=file_hash,
file_name=file_name or "unnamed",
content_type=content_type,
file_size=len(content),
storage_type=storage_type,
storage_path=storage_path,
metadata_info=file_metadata
)
# 保存到数据库
db_file = await file_dao.create(db, file_params)
# 如果是本地存储或其他存储需要更新文件ID
if storage_type != "database":
# 重新保存文件使用真实的文件ID
storage_path = await storage_provider.save(
file_id=db_file.id,
content=content,
file_name=file_name or "unnamed"
)
# 更新数据库中的存储路径
update_params = UpdateFileParam(
storage_path=storage_path
)
await file_dao.update(db, db_file.id, update_params)
db_file.storage_path = storage_path
# 如果是数据库存储,更新文件数据
if storage_type == "database":
db_file.file_data = content
await db.flush() # 确保将file_data保存到数据库
return FileUploadResponse(
id=str(db_file.id),
file_hash=db_file.file_hash,
file_name=db_file.file_name,
content_type=db_file.content_type,
file_size=db_file.file_size
)
@staticmethod
async def get_file(file_id: int) -> Optional[File]:
"""获取文件信息"""
async with async_db_session() as db:
return await file_dao.get(db, file_id)
@staticmethod
async def download_file(file_id: int) -> tuple[bytes, str, str]:
"""下载文件"""
async with async_db_session() as db:
db_file = await file_dao.get(db, file_id)
if not db_file:
raise errors.NotFoundError(msg="文件不存在")
content = b""
storage_provider = get_storage_provider(db_file.storage_type)
if db_file.storage_type == "database":
# 从数据库获取文件数据
content = db_file.file_data or b""
else:
# 从存储中读取文件
content = await storage_provider.read(file_id, db_file.storage_path or "")
return content, db_file.file_name, db_file.content_type or "application/octet-stream"
@staticmethod
async def delete_file(file_id: int) -> bool:
"""删除文件"""
async with async_db_session.begin() as db:
db_file = await file_dao.get(db, file_id)
if not db_file:
return False
# 删除存储中的文件
if db_file.storage_type != "database":
storage_provider = get_storage_provider(db_file.storage_type)
await storage_provider.delete(file_id, db_file.storage_path or "")
# 删除数据库记录
result = await file_dao.delete(db, file_id)
return result > 0
@staticmethod
async def get_file_by_hash(db, file_hash: str) -> Optional[File]:
"""通过哈希值获取文件"""
return await file_dao.get_by_hash(db, file_hash)
@staticmethod
def detect_image_format(image_bytes: bytes) -> ImageFormat:
"""通过二进制数据检测图片格式"""
# 使用imghdr识别基础格式
format_str = imghdr.what(None, h=image_bytes)
# 映射到枚举类型
format_mapping = {
'jpeg': ImageFormat.JPEG,
'jpg': ImageFormat.JPEG,
'png': ImageFormat.PNG,
'gif': ImageFormat.GIF,
'bmp': ImageFormat.BMP,
'webp': ImageFormat.WEBP,
'tiff': ImageFormat.TIFF,
'svg': ImageFormat.SVG
}
return format_mapping.get(format_str, ImageFormat.UNKNOWN)
@staticmethod
def extract_image_metadata(image_bytes: bytes, additional_info: Dict[str, Any] = None) -> ImageMetadata:
"""从图片二进制数据中提取元数据"""
try:
with PILImage.open(io.BytesIO(image_bytes)) as img:
# 获取基础信息
width, height = img.size
color_mode = ColorMode(img.mode) if img.mode in ColorMode.__members__.values() else ColorMode.UNKNOWN
# 获取EXIF数据
exif_data = {}
if hasattr(img, '_getexif') and img._getexif():
for tag, value in img._getexif().items():
decoded_tag = ExifTags.TAGS.get(tag, tag)
# 特殊处理日期时间
if decoded_tag in ['DateTime', 'DateTimeOriginal', 'DateTimeDigitized']:
try:
value = datetime.strptime(value, "%Y:%m:%d %H:%M:%S").isoformat()
except:
pass
exif_data[decoded_tag] = value
# 获取颜色通道数
channels = len(img.getbands())
# 尝试获取DPI
dpi = img.info.get('dpi')
# 创建元数据对象
metadata = ImageMetadata(
format=file_service.detect_image_format(image_bytes),
width=width,
height=height,
color_mode=color_mode,
file_size=len(image_bytes),
channels=channels,
dpi=dpi,
exif=exif_data
)
# 添加额外信息
if additional_info:
for key, value in additional_info.items():
if hasattr(metadata, key):
setattr(metadata, key, value)
return metadata
except Exception as e:
# 无法解析图片时返回基础元数据
return ImageMetadata(
format=file_service.detect_image_format(image_bytes),
width=0,
height=0,
color_mode=ColorMode.UNKNOWN,
file_size=len(image_bytes),
error=f"Metadata extraction failed: {str(e)}"
)
file_service = FileService()

View File

@@ -0,0 +1,114 @@
import hashlib
import os
from abc import ABC, abstractmethod
from typing import Optional
import aiofiles
from fastapi import UploadFile
from backend.core.conf import settings
from backend.core.path_conf import UPLOAD_DIR
class StorageProvider(ABC):
"""存储提供者抽象基类"""
@abstractmethod
async def save(self, file_id: int, content: bytes, file_name: str) -> str:
"""保存文件"""
pass
@abstractmethod
async def read(self, file_id: int, storage_path: str) -> bytes:
"""读取文件"""
pass
@abstractmethod
async def delete(self, file_id: int, storage_path: str) -> bool:
"""删除文件"""
pass
class DatabaseStorage(StorageProvider):
"""数据库存储提供者"""
async def save(self, file_id: int, content: bytes, file_name: str) -> str:
"""数据库存储不需要实际保存文件,直接返回空字符串"""
return ""
async def read(self, file_id: int, storage_path: str) -> bytes:
"""数据库存储不需要读取文件"""
return b""
async def delete(self, file_id: int, storage_path: str) -> bool:
"""数据库存储不需要删除文件"""
return True
class LocalStorage(StorageProvider):
"""本地文件系统存储提供者"""
def __init__(self, base_path: str = settings.STORAGE_PATH):
self.base_path = base_path
def _get_path(self, file_id: int, file_name: str) -> str:
"""构建文件路径"""
# 使用文件ID作为目录名避免单个目录下文件过多
dir_name = str(file_id // 1000)
file_dir = os.path.join(self.base_path, dir_name)
os.makedirs(file_dir, exist_ok=True)
return os.path.join(file_dir, f"{file_id}_{file_name}")
async def save(self, file_id: int, content: bytes, file_name: str) -> str:
"""保存文件到本地"""
path = self._get_path(file_id, file_name)
async with aiofiles.open(path, 'wb') as f:
await f.write(content)
return path
async def read(self, file_id: int, storage_path: str) -> bytes:
"""从本地读取文件"""
async with aiofiles.open(storage_path, 'rb') as f:
return await f.read()
async def delete(self, file_id: int, storage_path: str) -> bool:
"""从本地删除文件"""
try:
os.remove(storage_path)
return True
except:
return False
class S3Storage(StorageProvider):
"""AWS S3存储提供者占位实现"""
async def save(self, file_id: int, content: bytes, file_name: str) -> str:
"""保存文件到S3"""
# 实际实现需要使用boto3等库连接到S3
# 这里仅作为示例展示接口
return f"s3://bucket-name/{file_id}/{file_name}"
async def read(self, file_id: int, storage_path: str) -> bytes:
"""从S3读取文件"""
# 实际实现需要使用boto3等库连接到S3
return b""
async def delete(self, file_id: int, storage_path: str) -> bool:
"""从S3删除文件"""
# 实际实现需要使用boto3等库连接到S3
return True
def get_storage_provider(provider_type: str) -> StorageProvider:
"""根据配置获取存储提供者"""
if provider_type == "local":
return LocalStorage()
elif provider_type == "s3":
return S3Storage()
else: # 默认使用数据库存储
return DatabaseStorage()
def calculate_file_hash(content: bytes) -> str:
"""计算文件的SHA256哈希值"""
return hashlib.sha256(content).hexdigest()

View File

@@ -0,0 +1,146 @@
from typing import List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from backend.app.admin.crud.notification_crud import notification_dao, user_notification_dao
from backend.app.admin.model.notification import Notification, UserNotification
from backend.database.db import async_db_session
from backend.common.exception import errors
from datetime import datetime
class NotificationService:
@staticmethod
async def create_notification(title: str, content: str, image_url: Optional[str] = None,
created_by: Optional[int] = None) -> Notification:
"""
创建消息通知
"""
async with async_db_session.begin() as db:
notification_data = {
'title': title,
'content': content,
'image_url': image_url,
'created_by': created_by
}
notification = await notification_dao.create_notification(db, notification_data)
return notification
@staticmethod
async def send_notification_to_user(notification_id: int, user_id: int) -> UserNotification:
"""
发送通知给指定用户
"""
async with async_db_session.begin() as db:
# 检查通知是否存在
notification = await notification_dao.get(db, notification_id)
if not notification:
raise errors.RequestError(msg='通知不存在')
# 创建用户通知关联
user_notification_data = {
'notification_id': notification_id,
'user_id': user_id
}
user_notification = await user_notification_dao.create_user_notification(db, user_notification_data)
return user_notification
@staticmethod
async def send_notification_to_users(notification_id: int, user_ids: List[int]) -> List[UserNotification]:
"""
批量发送通知给多个用户
"""
async with async_db_session.begin() as db:
# 检查通知是否存在
notification = await notification_dao.get(db, notification_id)
if not notification:
raise errors.RequestError(msg='通知不存在')
# 创建用户通知关联列表
user_notifications_data = [
{
'notification_id': notification_id,
'user_id': user_id
}
for user_id in user_ids
]
user_notifications = await user_notification_dao.create_user_notifications(db, user_notifications_data)
return user_notifications
@staticmethod
async def get_user_notifications(user_id: int, limit: int = 100) -> List[dict]:
"""
获取用户的通知列表(包含通知详情)
"""
async with async_db_session() as db:
user_notifications = await user_notification_dao.get_user_notifications(db, user_id, limit)
notifications = []
for user_notification in user_notifications:
# 获取通知详情
notification = await notification_dao.get(db, user_notification.notification_id)
if notification:
notifications.append({
'id': user_notification.id,
'notification_id': notification.id,
'title': notification.title,
'content': notification.content,
'image_url': notification.image_url,
'is_read': user_notification.is_read,
'received_at': user_notification.received_at,
'read_at': user_notification.read_at
})
return notifications
@staticmethod
async def get_unread_notifications(user_id: int, limit: int = 100) -> List[dict]:
"""
获取用户未读通知列表(包含通知详情)
"""
async with async_db_session() as db:
user_notifications = await user_notification_dao.get_unread_notifications(db, user_id, limit)
notifications = []
for user_notification in user_notifications:
# 获取通知详情
notification = await notification_dao.get(db, user_notification.notification_id)
if notification:
notifications.append({
'id': user_notification.id,
'notification_id': notification.id,
'title': notification.title,
'content': notification.content,
'image_url': notification.image_url,
'received_at': user_notification.received_at
})
return notifications
@staticmethod
async def mark_notification_as_read(user_notification_id: int, user_id: int) -> bool:
"""
标记通知为已读
"""
async with async_db_session.begin() as db:
# 验证用户权限
user_notification = await user_notification_dao.get_user_notification_by_id(db, user_notification_id)
if not user_notification:
raise errors.RequestError(msg='通知不存在')
if user_notification.user_id != user_id:
raise errors.ForbiddenError(msg='无权限操作此通知')
success = await user_notification_dao.mark_as_read(db, user_notification_id)
return success
@staticmethod
async def get_unread_count(user_id: int) -> int:
"""
获取用户未读通知数量
"""
async with async_db_session() as db:
count = await user_notification_dao.get_unread_count(db, user_id)
return count

View File

@@ -0,0 +1,180 @@
from typing import Optional
from backend.app.admin.crud.points_crud import points_dao, points_log_dao
from backend.app.admin.model.points import Points
from backend.database.db import async_db_session
class PointsService:
@staticmethod
async def get_user_points(user_id: int) -> Optional[Points]:
"""
获取用户积分账户信息(会检查并清空过期积分)
"""
async with async_db_session.begin() as db:
# 获取当前积分余额(清空前)
points_account_before = await points_dao.get_by_user_id(db, user_id)
balance_before = points_account_before.balance if points_account_before else 0
# 检查并清空过期积分
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
# 如果清空了过期积分,记录日志
if expired_cleared and balance_before > 0:
await points_log_dao.add_log(db, {
"user_id": user_id,
"action": "expire_clear",
"amount": balance_before, # 记录清空前的积分数量
"balance_after": 0,
"details": {"message": "过期积分已清空", "cleared_amount": balance_before}
})
return await points_dao.get_by_user_id(db, user_id)
@staticmethod
async def get_user_balance(user_id: int) -> int:
"""
获取用户积分余额(会检查并清空过期积分)
"""
async with async_db_session.begin() as db:
# 获取当前积分余额(清空前)
points_account_before = await points_dao.get_by_user_id(db, user_id)
balance_before = points_account_before.balance if points_account_before else 0
# 检查并清空过期积分
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
# 如果清空了过期积分,记录日志
if expired_cleared and balance_before > 0:
await points_log_dao.add_log(db, {
"user_id": user_id,
"action": "expire_clear",
"balance_before": balance_before,
"balance_after": 0,
"details": {"message": "过期积分已清空", "cleared_amount": balance_before}
})
return await points_dao.get_balance(db, user_id)
@staticmethod
async def add_points(user_id: int, amount: int, extend_expiration: bool = False, related_id: Optional[int] = None, details: Optional[dict] = None) -> bool:
"""
为用户增加积分
Args:
user_id: 用户ID
amount: 增加的积分数量
extend_expiration: 是否自动延期过期时间
related_id: 关联ID可选
details: 附加信息(可选)
Returns:
bool: 是否成功
"""
if amount <= 0:
raise ValueError("积分数量必须大于0")
async with async_db_session.begin() as db:
# 获取当前余额以记录日志
points_account = await points_dao.get_by_user_id(db, user_id)
if not points_account:
points_account = await points_dao.create_user_points(db, user_id)
current_balance = points_account.balance
# 原子性增加积分(可能延期过期时间)
result = await points_dao.add_points_atomic(db, user_id, amount, extend_expiration)
if not result:
return False
# 准备日志详情
log_details = details or {}
if extend_expiration:
log_details["expiration_extended"] = True
log_details["extension_days"] = 30
# 记录积分变动日志
new_balance = current_balance + amount
await points_log_dao.add_log(db, {
"user_id": user_id,
"action": "earn",
"amount": amount,
"balance_after": new_balance,
"related_id": related_id,
"details": log_details
})
return True
@staticmethod
async def deduct_points(user_id: int, amount: int, related_id: Optional[int] = None, details: Optional[dict] = None) -> bool:
"""
扣减用户积分(会检查并清空过期积分)
Args:
user_id: 用户ID
amount: 扣减的积分数量
related_id: 关联ID可选
details: 附加信息(可选)
Returns:
bool: 是否成功
"""
if amount <= 0:
raise ValueError("积分数量必须大于0")
async with async_db_session.begin() as db:
# 获取当前积分余额(清空前)
points_account_before = await points_dao.get_by_user_id(db, user_id)
if not points_account_before:
return False
balance_before = points_account_before.balance
# 检查并清空过期积分
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
# 如果清空了过期积分,记录日志
if expired_cleared and balance_before > 0:
await points_log_dao.add_log(db, {
"user_id": user_id,
"action": "expire_clear",
"amount": balance_before, # 记录清空前的积分数量
"balance_after": 0,
"details": {"message": "过期积分已清空", "cleared_amount": balance_before}
})
# 重新获取账户信息(可能已被清空)
points_account = await points_dao.get_by_user_id(db, user_id)
if not points_account or points_account.balance < amount:
return False
current_balance = points_account.balance
# 原子性扣减积分
result = await points_dao.deduct_points_atomic(db, user_id, amount)
if not result:
return False
# 记录积分变动日志
new_balance = current_balance - amount
await points_log_dao.add_log(db, {
"user_id": user_id,
"action": "spend",
"amount": amount,
"balance_after": new_balance,
"related_id": related_id,
"details": details
})
return True
@staticmethod
async def initialize_user_points(user_id: int) -> Points:
"""
为新用户初始化积分账户
"""
async with async_db_session.begin() as db:
points_account = await points_dao.get_by_user_id(db, user_id)
if not points_account:
points_account = await points_dao.create_user_points(db, user_id)
return points_account

View File

@@ -0,0 +1,233 @@
from backend.app.admin.crud.freeze_log_crud import freeze_log_dao
from backend.app.admin.crud.user_account_crud import user_account_dao
from backend.app.admin.crud.order_crud import order_dao
from backend.app.admin.crud.usage_log_crud import usage_log_dao
from backend.app.admin.model.order import FreezeLog
from backend.app.admin.model import UserAccount
from backend.database.db import async_db_session
from wechatpy.pay import WeChatPay
from backend.core.conf import settings
from backend.common.log import log as logger
from backend.common.exception import errors
from sqlalchemy import select, update
from datetime import datetime, timedelta
from fastapi import Request
from backend.utils.wx_pay import wx_pay_utils
wxpay = WeChatPay(
appid=settings.WX_APPID,
api_key=settings.WX_SECRET,
mch_id=settings.WX_MCH_ID,
mch_cert=settings.WX_PAY_CERT_PATH,
mch_key=settings.WX_PAY_KEY_PATH
)
class RefundService:
@staticmethod
async def freeze_times_for_refund(user_id: int, order_id: int, reason: str = None):
"""
为退款冻结次数(增强幂等性和安全性)
"""
async with async_db_session.begin() as db:
# 检查是否已有pending的退款申请幂等性
existing_stmt = select(FreezeLog).where(
FreezeLog.order_id == order_id,
FreezeLog.status == "pending"
)
existing_result = await db.execute(existing_stmt)
if existing_result.scalar_one_or_none():
raise errors.RequestError(msg="该订单已有待处理的退款申请")
order = await order_dao.get_by_id(db, order_id)
if not order:
raise errors.NotFoundError(msg="订单不存在")
if order.status != 'completed':
raise errors.RequestError(msg="订单状态异常,无法退款")
if order.user_id != user_id:
raise errors.ForbiddenError(msg="无权操作该订单")
# 检查退款时效7天内
if datetime.now() - order.created_at > timedelta(days=7):
raise errors.RequestError(msg="超过退款时效7天")
# 原子性冻结次数
result = await db.execute(
update(UserAccount)
.where(UserAccount.user_id == user_id)
.where(UserAccount.balance >= order.amount_times)
.values(balance=UserAccount.balance - order.amount_times)
)
if result.rowcount == 0:
raise errors.ForbiddenError(msg="余额不足或并发冲突")
# 创建冻结记录
freeze_log = FreezeLog(
user_id=user_id,
order_id=order_id,
amount=order.amount_times,
reason=reason or "用户申请退款",
status="pending"
)
await freeze_log_dao.add(db, freeze_log)
# 记录使用日志
account = await user_account_dao.get_by_user_id(db, user_id)
await usage_log_dao.add(db, {
"user_id": user_id,
"action": "freeze",
"amount": -order.amount_times,
"balance_after": account.balance,
"related_id": freeze_log.id,
"details": {"order_id": order_id, "reason": reason}
})
return freeze_log
@staticmethod
async def process_refund(user_id: int, order_id: int, refund_desc: str = "用户申请退款"):
"""
处理微信退款(增强安全性和幂等性)
"""
async with async_db_session() as db:
# 先冻结次数
freeze_log = await RefundService.freeze_times_for_refund(user_id, order_id, refund_desc)
try:
order = await order_dao.get_by_id(db, order_id)
# 调用微信退款接口
refund_result = wxpay.refund.apply(
out_trade_no=str(order_id),
out_refund_no=str(freeze_log.id), # 使用冻结记录ID作为退款单号
total_fee=order.amount_cents,
refund_fee=order.amount_cents,
refund_desc=refund_desc
)
if refund_result['return_code'] == 'SUCCESS' and refund_result['result_code'] == 'SUCCESS':
# 更新冻结记录状态
async with async_db_session.begin() as update_db:
freeze_log.status = "confirmed"
await freeze_log_dao.update(update_db, freeze_log.id, freeze_log)
return {"status": "success", "refund_id": refund_result.get('refund_id')}
else:
# 退款失败,取消冻结
await RefundService.cancel_freeze(freeze_log.id)
raise errors.ServerError(
msg=f"微信退款失败: {refund_result.get('err_code_des', '未知错误')}")
except Exception as e:
# 发生异常时也要取消冻结
await RefundService.cancel_freeze(freeze_log.id)
raise e
@staticmethod
async def cancel_freeze(freeze_id: int):
"""
取消冻结(退款失败时调用)
"""
async with async_db_session.begin() as db:
freeze_log = await freeze_log_dao.get_by_id(db, freeze_id)
if not freeze_log or freeze_log.status != "pending":
return
freeze_log.status = "cancelled"
await freeze_log_dao.update(db, freeze_log.id, freeze_log)
# 原子性恢复用户余额
result = await db.execute(
update(UserAccount)
.where(UserAccount.user_id == freeze_log.user_id)
.values(balance=UserAccount.balance + freeze_log.amount)
)
if result.rowcount == 0:
logger.error(f"恢复冻结次数失败: freeze_id={freeze_id}")
return
# 记录使用日志
account = await user_account_dao.get_by_user_id(db, freeze_log.user_id)
await usage_log_dao.add(db, {
"user_id": freeze_log.user_id,
"action": "unfreeze",
"amount": freeze_log.amount,
"balance_after": account.balance,
"related_id": freeze_log.id,
"details": {"reason": "退款失败,恢复次数"}
})
@staticmethod
async def handle_refund_notify(request: Request):
"""
处理微信退款回调(增强安全性和幂等性)
"""
try:
body = await request.body()
result = wx_pay_utils.parse_payment_result(body)
# 验证签名
if not wx_pay_utils.verify_wxpay_signature(result.copy(), settings.WECHAT_PAY_API_KEY):
logger.warning("微信退款回调签名验证失败")
return {"return_code": "FAIL", "return_msg": "签名验证失败"}
if result['return_code'] == 'SUCCESS':
out_refund_no = result['out_refund_no'] # 对应freeze_log.id
refund_status = result.get('refund_status_0', 'UNKNOWN') # 假设只有一个退款
async with async_db_session.begin() as db:
# 使用SELECT FOR UPDATE确保并发安全
stmt = select(FreezeLog).where(FreezeLog.id == int(out_refund_no)).with_for_update()
freeze_result = await db.execute(stmt)
freeze_log = freeze_result.scalar_one_or_none()
if not freeze_log:
logger.warning(f"冻结记录不存在: {out_refund_no}")
return {"return_code": "SUCCESS"}
# 幂等性检查
if freeze_log.status in ["confirmed", "cancelled"]:
logger.info(f"冻结记录已处理过: {out_refund_no}")
return {"return_code": "SUCCESS"}
if refund_status == 'SUCCESS':
freeze_log.status = "confirmed"
await freeze_log_dao.update(db, freeze_log.id, freeze_log)
# 从总购买次数中扣除
account_stmt = select(UserAccount).where(
UserAccount.user_id == freeze_log.user_id).with_for_update()
account_result = await db.execute(account_stmt)
user_account = account_result.scalar_one_or_none()
if user_account:
user_account.total_purchased = max(0, user_account.total_purchased - freeze_log.amount)
await user_account_dao.update(db, user_account.id, user_account)
# 记录使用日志
await usage_log_dao.add(db, {
"user_id": freeze_log.user_id,
"action": "refund",
"amount": -freeze_log.amount,
"balance_after": user_account.balance if user_account else 0,
"related_id": freeze_log.id,
"details": {"refund_id": result.get('refund_id_0')}
})
elif refund_status in ['FAIL', 'CHANGE']:
# 退款失败,取消冻结
await RefundService.cancel_freeze(freeze_log.id)
return {"return_code": "SUCCESS"}
else:
logger.error(f"微信退款回调失败: {result}")
return {"return_code": "FAIL", "return_msg": "处理失败"}
except Exception as e:
logger.error(f"微信退款回调处理异常: {str(e)}")
return {"return_code": "FAIL", "return_msg": "服务器异常"}

View File

@@ -0,0 +1,44 @@
import os
import aiofiles
from pathlib import Path
from abc import ABC, abstractmethod
from fastapi import UploadFile
from backend.core.conf import settings
from backend.core.path_conf import UPLOAD_DIR
class StorageProvider(ABC):
@abstractmethod
async def save(self, file_id: int, content: bytes):
pass
@abstractmethod
async def read(self, file_id: int) -> bytes:
pass
class LocalStorage(StorageProvider):
def __init__(self, base_path: str = UPLOAD_DIR):
self.base_path = base_path
def _get_path(self, file_id: int) -> str:
"""使用 os.path.join 构建文件路径"""
return os.path.join(self.base_path, str(file_id))
async def save(self, file_id: int, content: bytes):
path = self._get_path(file_id)
# 确保目录存在
os.makedirs(os.path.dirname(path), exist_ok=True)
async with aiofiles.open(path, 'wb') as f:
await f.write(content)
async def read(self, file_id: int) -> bytes:
path = self._get_path(file_id)
async with aiofiles.open(path, 'rb') as f:
return await f.read()
# 未来可添加S3存储
# class S3Storage(StorageProvider): ...

View File

@@ -0,0 +1,42 @@
from datetime import timedelta, datetime
from backend.app.admin.crud.usage_log_crud import usage_log_dao
from backend.app.admin.crud.user_account_crud import user_account_dao
from backend.app.admin.service.usage_service import UsageService
from backend.common.exception import errors
from backend.database.db import async_db_session
SUBSCRIPTION_PLANS = {
"monthly": {"price": 1290, "times": 300, "duration": timedelta(days=30)},
"quarterly": {"price": 2990, "times": 900, "duration": timedelta(days=90)},
"half_yearly": {"price": 3990, "times": 1800, "duration": timedelta(days=180)},
"yearly": {"price": 6990, "times": 3600, "duration": timedelta(days=365)},
}
class SubscriptionService:
@staticmethod
async def subscribe(user_id: int, plan_key: str):
plan = SUBSCRIPTION_PLANS.get(plan_key)
if not plan:
raise errors.RequestError(msg="无效订阅计划")
async with async_db_session.begin() as db:
account = await UsageService.get_user_account(user_id)
# 处理未用完次数
account.balance += account.carryover_balance
account.carryover_balance = 0
# 更新订阅信息
account.subscription_type = plan_key
account.subscription_expires_at = datetime.now() + plan["duration"]
account.balance += plan["times"]
await user_account_dao.update(db, account.id, account)
await usage_log_dao.add(db, {
"user_id": user_id,
"action": "renewal",
"amount": plan["times"],
"balance_after": account.balance,
"details": {"plan": plan_key}
})

View File

@@ -0,0 +1,214 @@
from decimal import Decimal, ROUND_HALF_UP
from backend.app.admin.crud.user_account_crud import user_account_dao
from backend.app.admin.crud.usage_log_crud import usage_log_dao
from backend.app.admin.crud.order_crud import order_dao
from backend.app.admin.model.order import UserAccount
from backend.app.admin.model.order import Order
from backend.app.admin.schema.usage import PurchaseRequest
from backend.common.exception import errors
from backend.database.db import async_db_session
from datetime import datetime, timedelta
from sqlalchemy import func, select, update
class UsageService:
@staticmethod
async def get_user_account(user_id: int) -> UserAccount:
async with async_db_session() as db:
account = await user_account_dao.get_by_user_id(db, user_id)
if not account:
account = UserAccount(user_id=user_id)
await user_account_dao.add(db, account)
return account
@staticmethod
def calculate_purchase_times_safe(amount_cents: int) -> int:
"""
安全的充值次数计算使用Decimal避免浮点数精度问题
"""
if amount_cents <= 0:
raise ValueError("充值金额必须大于0")
# 限制最大充值金额(防止溢出)
if amount_cents > 10000000: # 10万元
raise ValueError("单次充值金额不能超过10万元")
amount_yuan = Decimal(amount_cents) / Decimal(100)
base_times = amount_yuan * Decimal(10)
# 计算优惠比例每10元增加10%最多100%
tens = (amount_yuan // Decimal(10))
bonus_percent = min(tens * Decimal('0.1'), Decimal('1.0'))
total_times = base_times * (Decimal('1') + bonus_percent)
return int(total_times.quantize(Decimal('1'), rounding=ROUND_HALF_UP))
@staticmethod
async def purchase_times(user_id: int, request: PurchaseRequest):
# 输入验证
if request.amount_cents < 100: # 最少1元
raise errors.RequestError(msg="充值金额不能少于1元")
if request.amount_cents > 10000000: # 最多10万元
raise errors.RequestError(msg="单次充值金额不能超过10万元")
async with async_db_session.begin() as db:
account = await UsageService.get_user_account(user_id)
times = UsageService.calculate_purchase_times_safe(request.amount_cents)
order = Order(
user_id=user_id,
order_type="purchase",
amount_cents=request.amount_cents,
amount_times=times,
status="pending"
)
await order_dao.add(db, order)
await db.flush() # 获取order.id
# 原子性更新账户(防止并发问题)
result = await db.execute(
update(UserAccount)
.where(UserAccount.id == account.id)
.values(
balance=UserAccount.balance + times,
total_purchased=UserAccount.total_purchased + times
)
)
if result.rowcount == 0:
raise errors.ServerError(msg="账户更新失败")
# 更新订单状态
order.status = "completed"
order.processed_at = datetime.now()
await order_dao.update(db, order.id, order)
await usage_log_dao.add(db, {
"user_id": user_id,
"action": "purchase",
"amount": times,
"balance_after": account.balance + times,
"related_id": order.id,
"details": {"amount_cents": request.amount_cents}
})
@staticmethod
async def use_times_atomic(user_id: int, count: int = 1):
"""
原子性扣减次数,支持免费试用优先使用
"""
if count <= 0:
raise ValueError("扣减次数必须大于0")
async with async_db_session.begin() as db:
account = await UsageService.get_user_account(user_id)
# 检查免费试用是否有效
is_free_trial_valid = (
account.free_trial_expires_at and
account.free_trial_expires_at > datetime.now() and
account.free_trial_balance > 0
)
if is_free_trial_valid:
# 优先使用免费试用次数
free_trial_deduct = min(count, account.free_trial_balance)
remaining_count = count - free_trial_deduct
# 更新免费试用余额
if free_trial_deduct > 0:
await db.execute(
update(UserAccount)
.where(UserAccount.id == account.id)
.values(
free_trial_balance=UserAccount.free_trial_balance - free_trial_deduct,
balance=UserAccount.balance - free_trial_deduct
)
)
# 如果还有剩余次数,从普通余额中扣除
if remaining_count > 0:
result = await db.execute(
update(UserAccount)
.where(UserAccount.id == account.id)
.where(UserAccount.balance >= remaining_count)
.values(balance=UserAccount.balance - remaining_count)
)
if result.rowcount == 0:
# 恢复免费试用余额
await db.execute(
update(UserAccount)
.where(UserAccount.id == account.id)
.values(
free_trial_balance=UserAccount.free_trial_balance + free_trial_deduct,
balance=UserAccount.balance + free_trial_deduct
)
)
raise errors.ForbiddenError(msg="余额不足")
else:
# 直接从普通余额中扣除
result = await db.execute(
update(UserAccount)
.where(UserAccount.id == account.id)
.where(UserAccount.balance >= count)
.values(balance=UserAccount.balance - count)
)
if result.rowcount == 0:
raise errors.ForbiddenError(msg="余额不足")
# 记录使用日志
updated_account = await user_account_dao.get_by_id(db, account.id)
await usage_log_dao.add(db, {
"user_id": user_id,
"action": "use",
"amount": -count,
"balance_after": updated_account.balance if updated_account else 0,
"metadata_": {
"used_at": datetime.now().isoformat(),
"is_free_trial": is_free_trial_valid
}
})
@staticmethod
async def get_account_info(user_id: int) -> dict:
"""
获取用户账户详细信息
"""
async with async_db_session() as db:
account = await UsageService.get_user_account(user_id)
if not account:
return {}
frozen_balance = await user_account_dao.get_frozen_balance(db, user_id)
available_balance = max(0, account.balance - frozen_balance)
# 检查免费试用状态
is_free_trial_active = (
account.free_trial_expires_at and
account.free_trial_expires_at > datetime.now()
)
return {
"balance": account.balance,
"available_balance": available_balance,
"frozen_balance": frozen_balance,
"total_purchased": account.total_purchased,
"subscription_type": account.subscription_type,
"subscription_expires_at": account.subscription_expires_at,
"carryover_balance": account.carryover_balance,
# 免费试用信息
"free_trial_balance": account.free_trial_balance,
"free_trial_expires_at": account.free_trial_expires_at,
"free_trial_active": is_free_trial_active,
"free_trial_used": account.free_trial_used,
# 计算剩余天数
"free_trial_days_left": (
max(0, (account.free_trial_expires_at - datetime.now()).days)
if account.free_trial_expires_at else 0
)
}

View File

@@ -0,0 +1,128 @@
# services/wx.py
from sqlalchemy.orm import Session
from typing import Optional
from fastapi import Request, Response
from backend.app.admin.crud.wx_user_crud import wx_user_dao
from backend.app.admin.model.wx_user import WxUser
from backend.app.admin.schema.token import GetWxLoginToken
from backend.app.admin.schema.wx import DictLevel
from backend.common.security.jwt import create_access_token, create_refresh_token
from backend.core.conf import settings
import logging
from backend.app.admin.service.wx_user_service import WxUserService
from backend.core.wx_integration import decrypt_wx_data
from backend.database.db import async_db_session
from backend.utils.timezone import timezone
class WxAuthService:
@staticmethod
async def login(
*,
request: Request, response: Response,
openid: str, session_key: str,
encrypted_data: str = None,
iv: str = None
) -> GetWxLoginToken:
"""
处理用户登录逻辑:
1. 查找或创建用户
2. 更新用户session_key
3. 解密用户信息(如果提供)
4. 生成访问令牌
"""
async with async_db_session.begin() as db:
user = None
try:
# 查找或创建用户
user = await wx_user_dao.get_by_openid(db, openid)
if not user:
user = WxUser(
openid=openid,
session_key=session_key,
profile={'dict_level': DictLevel.LEVEL1.value},
)
await wx_user_dao.add(db, user)
await db.flush()
await db.refresh(user)
else:
await wx_user_dao.update_session_key(db, user.id, session_key)
# 解密用户信息(如果提供)
if encrypted_data and iv:
try:
decrypted_data = decrypt_wx_data(
encrypted_data,
session_key,
iv
)
WxUserService.update_user_profile(db, user.id, decrypted_data)
except Exception as e:
logging.warning(f"用户数据解密失败: {str(e)}")
# 生成访问令牌
access_token = await create_access_token(
user.id,
False,
# extra info
ip=request.client.host,
# os=request.state.os,
# browser=request.state.browser,
# device=request.state.device,
)
refresh_token = await create_refresh_token(access_token.session_uuid, user.id, False)
response.set_cookie(
key=settings.COOKIE_REFRESH_TOKEN_KEY,
value=refresh_token.refresh_token,
max_age=settings.COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS,
expires=timezone.to_utc(refresh_token.refresh_token_expire_time),
httponly=True,
)
except Exception as e:
db.rollback()
logging.error(f"登录处理失败: {str(e)}")
raise
else:
# 从用户资料中获取词典等级设置
dict_level = None
if user and user.profile and isinstance(user.profile, dict):
dict_level = user.profile.get("dict_level")
data = GetWxLoginToken(
access_token=access_token.access_token,
access_token_expire_time=access_token.access_token_expire_time,
session_uuid=access_token.session_uuid,
dict_level=dict_level
)
return data
@staticmethod
async def update_user_settings(
*,
user_id: int,
dict_level: Optional[DictLevel] = None
) -> None:
"""
更新用户设置
"""
async with async_db_session.begin() as db:
user = await wx_user_dao.get(db, user_id)
if not user:
raise ValueError("用户不存在")
# 如果用户没有profile初始化为空字典
if not user.profile:
user.profile = {}
# 更新词典等级设置
if dict_level is not None:
user.profile["dict_level"] = dict_level.value
# 使用新的方法更新用户资料会自动更新updated_time字段
await wx_user_dao.update_user_profile(db, user_id, user.profile)
wx_service: WxAuthService = WxAuthService()

View File

@@ -0,0 +1,143 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from sqlalchemy import Select
from sqlalchemy.orm import Session
from backend.common.exception import errors
from backend.common.security.jwt import superuser_verify, password_verify, get_hash_password
from backend.app.admin.crud.wx_user_crud import wx_user_dao
from backend.database.db import async_db_session
from backend.app.admin.model import WxUser
from backend.app.admin.schema.user import RegisterUserParam, ResetPassword, UpdateUserParam, AvatarParam
class WxUserService:
@staticmethod
def update_user_profile(db: Session, user_id: int, profile_data: dict):
"""
更新用户资料
"""
user = db.query(WxUser).get(user_id)
if not user:
raise ValueError("用户不存在")
# 保存用户资料,保留现有设置
current_profile = user.profile or {}
# 更新微信用户信息
if "nickName" in profile_data:
current_profile["nickname"] = profile_data.get("nickName")
if "avatarUrl" in profile_data:
current_profile["avatar"] = profile_data.get("avatarUrl")
if "gender" in profile_data:
current_profile["gender"] = profile_data.get("gender")
if "city" in profile_data:
current_profile["city"] = profile_data.get("city")
if "province" in profile_data:
current_profile["province"] = profile_data.get("province")
if "country" in profile_data:
current_profile["country"] = profile_data.get("country")
user.profile = current_profile
db.commit()
return user
@staticmethod
def get_user_info(db: Session, user_id: int):
"""
获取用户信息(解密敏感数据)
"""
user = db.query(WxUser).get(user_id)
if not user:
return None
# 解密手机号
# encryptor = DataEncrypt()
# mobile = encryptor.decrypt(user.mobile) if user.mobile else None
return {
"id": user.id,
"openid": user.openid,
"mobile": user.mobile,
"profile": user.profile,
"created_at": user.created_at
}
# @staticmethod
# async def register(*, obj: RegisterUserParam) -> None:
# async with async_db_session.begin() as db:
# if not obj.password:
# raise errors.ForbiddenError(msg='密码为空')
# username = await user_dao.get_by_username(db, obj.username)
# if username:
# raise errors.ForbiddenError(msg='用户已注册')
# email = await user_dao.check_email(db, obj.email)
# if email:
# raise errors.ForbiddenError(msg='邮箱已注册')
# await user_dao.create(db, obj)
#
# @staticmethod
# async def pwd_reset(*, obj: ResetPassword) -> int:
# async with async_db_session.begin() as db:
# user = await user_dao.get_by_username(db, obj.username)
# if not password_verify(obj.old_password, user.password):
# raise errors.ForbiddenError(msg='原密码错误')
# np1 = obj.new_password
# np2 = obj.confirm_password
# if np1 != np2:
# raise errors.ForbiddenError(msg='密码输入不一致')
# new_pwd = get_hash_password(obj.new_password, user.salt)
# count = await user_dao.reset_password(db, user.id, new_pwd)
# return count
#
# @staticmethod
# async def get_userinfo(*, username: str) -> User:
# async with async_db_session() as db:
# user = await user_dao.get_by_username(db, username)
# if not user:
# raise errors.NotFoundError(msg='用户不存在')
# return user
#
# @staticmethod
# async def update(*, username: str, obj: UpdateUserParam) -> int:
# async with async_db_session.begin() as db:
# input_user = await user_dao.get_by_username(db, username=username)
# if not input_user:
# raise errors.NotFoundError(msg='用户不存在')
# superuser_verify(input_user)
# if input_user.username != obj.username:
# _username = await user_dao.get_by_username(db, obj.username)
# if _username:
# raise errors.ForbiddenError(msg='用户名已注册')
# if input_user.email != obj.email:
# email = await user_dao.check_email(db, obj.email)
# if email:
# raise errors.ForbiddenError(msg='邮箱已注册')
# count = await user_dao.update_userinfo(db, input_user.id, obj)
# return count
#
# @staticmethod
# async def update_avatar(*, username: str, avatar: AvatarParam) -> int:
# async with async_db_session.begin() as db:
# input_user = await user_dao.get_by_username(db, username)
# if not input_user:
# raise errors.NotFoundError(msg='用户不存在')
# count = await user_dao.update_avatar(db, input_user.id, avatar)
# return count
#
# @staticmethod
# async def get_select(*, username: str = None, phone: str = None, status: int = None) -> Select:
# return await user_dao.get_list(username=username, phone=phone, status=status)
#
# @staticmethod
# async def delete(*, current_user: User, username: str) -> int:
# async with async_db_session.begin() as db:
# superuser_verify(current_user)
# input_user = await user_dao.get_by_username(db, username)
# if not input_user:
# raise errors.NotFoundError(msg='用户不存在')
# count = await user_dao.delete(db, input_user.id)
# return count

View File

@@ -0,0 +1,37 @@
from wechatpy.pay import WeChatPay
from backend.app.admin.model.order import Order
from backend.app.admin.service.usage_service import UsageService
from backend.core.conf import settings
from backend.app.admin.crud.order_crud import order_dao
from backend.database.db import async_db_session
wxpay = WeChatPay(
appid=settings.WECHAT_APP_ID,
api_key=settings.WECHAT_PAY_API_KEY,
mch_id=settings.WECHAT_MCH_ID,
mch_cert=settings.WECHAT_PAY_CERT_PATH,
mch_key=settings.WECHAT_PAY_KEY_PATH
)
class WxPayService:
@staticmethod
async def create_order(user_id: int, amount_cents: int, description: str):
async with async_db_session.begin() as db:
order = Order(
user_id=user_id,
order_type="purchase",
amount_cents=amount_cents,
amount_times=UsageService.calculate_purchase_times(amount_cents),
status="pending"
)
await order_dao.add(db, order)
result = wxpay.order.create(
trade_type="JSAPI",
body=description,
total_fee=amount_cents,
notify_url=settings.WECHAT_NOTIFY_URL,
out_trade_no=str(order.id)
)
return result

158
backend/app/admin/tasks.py Executable file
View File

@@ -0,0 +1,158 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
定时任务处理模块
"""
from datetime import datetime, timedelta
from sqlalchemy import select, and_, desc
from backend.app.admin.model.audit_log import AuditLog, DailySummary
from backend.app.ai.model.image import Image
from backend.app.admin.model.wx_user import WxUser
from backend.app.admin.crud.daily_summary_crud import daily_summary_dao
from backend.app.admin.schema.audit_log import CreateDailySummaryParam
from backend.database.db import async_db_session
async def wx_user_index_history() -> None:
"""异步实现 wx_user_index_history 任务"""
# 计算前一天的时间范围
today = datetime.now().date()
yesterday = today - timedelta(days=1)
yesterday_start = datetime(yesterday.year, yesterday.month, yesterday.day)
yesterday_end = datetime(today.year, today.month, today.day)
async with async_db_session.begin() as db:
# 优化:通过 audit_log 表查询有相关记录的用户,避免遍历所有用户
# 先获取有前一天 recognition 记录的用户 ID 列表
user_ids_stmt = (
select(AuditLog.user_id)
.where(
and_(
AuditLog.api_type == 'recognition',
AuditLog.called_at >= yesterday_start,
AuditLog.called_at < yesterday_end
)
)
.distinct()
)
user_ids_result = await db.execute(user_ids_stmt)
user_ids = [row[0] for row in user_ids_result.fetchall()]
if not user_ids:
return # 没有用户有相关记录,直接返回
# 分批处理用户,避免一次性加载过多数据
batch_size = 500
for i in range(0, len(user_ids), batch_size):
batch_user_ids = user_ids[i:i + batch_size]
# 为这批用户获取 audit_log 记录
audit_logs_stmt = (
select(AuditLog)
.where(
and_(
AuditLog.user_id.in_(batch_user_ids),
AuditLog.api_type == 'recognition',
AuditLog.called_at >= yesterday_start,
AuditLog.called_at < yesterday_end
)
)
.order_by(AuditLog.user_id, desc(AuditLog.called_at))
)
audit_logs_result = await db.execute(audit_logs_stmt)
all_audit_logs = audit_logs_result.scalars().all()
# 按用户 ID 分组 audit_logs
user_audit_logs = {}
for log in all_audit_logs:
if log.user_id not in user_audit_logs:
user_audit_logs[log.user_id] = []
user_audit_logs[log.user_id].append(log)
# 获取这批用户的信息
users_stmt = select(WxUser).where(WxUser.id.in_(batch_user_ids))
users_result = await db.execute(users_stmt)
users = {user.id: user for user in users_result.scalars().all()}
# 获取所有相关的 image 记录
all_image_ids = [log.image_id for log in all_audit_logs if log.image_id]
image_map = {}
if all_image_ids:
images_stmt = select(Image).where(Image.id.in_(all_image_ids))
images_result = await db.execute(images_stmt)
images = images_result.scalars().all()
image_map = {img.id: img for img in images}
# 处理每个用户
for user_id, audit_logs in user_audit_logs.items():
user = users.get(user_id)
if not user:
continue
# 构建 ref_word 数据
image_ids = []
thumbnail_ids = []
for log in audit_logs:
if log.image_id and log.image_id in image_map:
image = image_map[log.image_id]
# 获取用户词典等级
dict_level = log.dict_level or "default"
# 从 image.details 中提取 ref_word
ref_words = []
try:
if image.details and isinstance(image.details, dict):
recognition_result = image.details.get("recognition_result", {})
if isinstance(recognition_result, dict):
dict_level_data = recognition_result.get(dict_level, {})
if isinstance(dict_level_data, dict):
ref_word = dict_level_data.get("ref_word")
if ref_word:
if isinstance(ref_word, list):
ref_words = ref_word
else:
ref_words = [str(ref_word)]
except Exception:
# 如果解析出错,跳过该记录
pass
# 收集图片ID和缩略图ID用于DailySummary
image_ids.append(str(log.image_id))
if image.thumbnail_id:
thumbnail_ids.append(str(image.thumbnail_id))
else:
thumbnail_ids.append('')
# 创建或更新DailySummary记录
daily_summary_data = {
"user_id": user_id,
"image_ids": image_ids,
"thumbnail_ids": thumbnail_ids,
"summary_time": yesterday_start
}
# 检查是否已存在该用户当天的记录
existing_summary_stmt = (
select(DailySummary)
.where(
and_(
DailySummary.user_id == user_id,
DailySummary.summary_time == yesterday_start
)
)
)
existing_summary_result = await db.execute(existing_summary_stmt)
existing_summary = existing_summary_result.scalar_one_or_none()
if existing_summary:
# 更新现有记录
await daily_summary_dao.update_model(db, existing_summary.id, daily_summary_data)
else:
# 创建新记录
await daily_summary_dao.create_model(db, CreateDailySummaryParam(**daily_summary_data))
# 提交批次更改
await db.commit()

View File

@@ -0,0 +1,4 @@
from backend.app.ai.model import Image, ImageText #, Article, ArticleParagraph, ArticleSentence
from backend.app.ai.schema import *
from backend.app.ai.crud import image_dao, image_text_dao #, article_dao, article_paragraph_dao, article_sentence_dao
from backend.app.ai.service import ImageService, image_service, ImageTextService, image_text_service #, ArticleService, article_service

View File

@@ -0,0 +1 @@
from backend.app.ai.api.image import router as image_router

View File

@@ -0,0 +1,342 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List
from fastapi import APIRouter, Request, Query
from starlette.background import BackgroundTasks
from backend.app.ai.schema.article import ArticleSchema, ArticleWithParagraphsSchema, CreateArticleParam, UpdateArticleParam, ArticleParagraphSchema, CreateArticleParagraphParam, UpdateArticleParagraphParam, ArticleSentenceSchema, CreateArticleSentenceParam, UpdateArticleSentenceParam
from backend.app.ai.service.article_service import article_service
from backend.common.response.response_schema import response_base, ResponseSchemaModel
from backend.common.security.jwt import DependsJwtAuth
router = APIRouter()
@router.post("", summary="创建文章", dependencies=[DependsJwtAuth])
async def create_article(
request: Request,
background_tasks: BackgroundTasks,
params: CreateArticleParam
) -> ResponseSchemaModel[ArticleSchema]:
"""
创建文章记录
请求体参数:
- title: 文章标题
- content: 文章完整内容
- author: 作者(可选)
- category: 分类(可选)
- level: 难度等级(可选)
- info: 附加信息(可选)
返回:
- 创建的文章记录
"""
article_id = await article_service.create_article(obj=params)
article = await article_service.get_article_by_id(article_id)
return response_base.success(data=article)
@router.get("/{article_id}", summary="获取文章详情", dependencies=[DependsJwtAuth])
async def get_article(
article_id: int
) -> ResponseSchemaModel[ArticleWithParagraphsSchema]:
"""
获取文章详情,包括所有段落和句子
参数:
- article_id: 文章ID
返回:
- 文章记录及其所有段落和句子
"""
article = await article_service.get_article_with_content(article_id)
if not article:
return response_base.fail(code=404, msg="文章不存在")
return response_base.success(data=article)
@router.put("/{article_id}", summary="更新文章", dependencies=[DependsJwtAuth])
async def update_article(
article_id: int,
request: Request,
background_tasks: BackgroundTasks,
params: UpdateArticleParam
) -> ResponseSchemaModel[ArticleSchema]:
"""
更新文章记录
参数:
- article_id: 文章ID
请求体参数:
- title: 文章标题
- content: 文章完整内容
- author: 作者(可选)
- category: 分类(可选)
- level: 难度等级(可选)
- info: 附加信息(可选)
返回:
- 更新后的文章记录
"""
success = await article_service.update_article(article_id, params)
if not success:
return response_base.fail(code=404, msg="文章不存在")
article = await article_service.get_article_by_id(article_id)
return response_base.success(data=article)
@router.delete("/{article_id}", summary="删除文章", dependencies=[DependsJwtAuth])
async def delete_article(
article_id: int,
request: Request,
background_tasks: BackgroundTasks
) -> ResponseSchemaModel[None]:
"""
删除文章记录
参数:
- article_id: 文章ID
返回:
- 无
"""
success = await article_service.delete_article(article_id)
if not success:
return response_base.fail(code=404, msg="文章不存在")
return response_base.success()
@router.post("/paragraph", summary="创建文章段落", dependencies=[DependsJwtAuth])
async def create_article_paragraph(
request: Request,
background_tasks: BackgroundTasks,
params: CreateArticleParagraphParam
) -> ResponseSchemaModel[ArticleParagraphSchema]:
"""
创建文章段落记录
请求体参数:
- article_id: 关联的文章ID
- paragraph_index: 段落序号
- content: 段落内容
- standard_audio_id: 标准朗读音频文件ID可选
- info: 附加信息(可选)
返回:
- 创建的段落记录
"""
paragraph_id = await article_service.create_article_paragraph(obj=params)
paragraph = await article_service.get_article_paragraph_by_id(paragraph_id)
return response_base.success(data=paragraph)
@router.put("/paragraph/{paragraph_id}", summary="更新文章段落", dependencies=[DependsJwtAuth])
async def update_article_paragraph(
paragraph_id: int,
request: Request,
background_tasks: BackgroundTasks,
params: UpdateArticleParagraphParam
) -> ResponseSchemaModel[ArticleParagraphSchema]:
"""
更新文章段落记录
参数:
- paragraph_id: 段落ID
请求体参数:
- article_id: 关联的文章ID
- paragraph_index: 段落序号
- content: 段落内容
- standard_audio_id: 标准朗读音频文件ID可选
- info: 附加信息(可选)
返回:
- 更新后的段落记录
"""
success = await article_service.update_article_paragraph(paragraph_id, params)
if not success:
return response_base.fail(code=404, msg="段落不存在")
paragraph = await article_service.get_article_paragraph_by_id(paragraph_id)
return response_base.success(data=paragraph)
@router.delete("/paragraph/{paragraph_id}", summary="删除文章段落", dependencies=[DependsJwtAuth])
async def delete_article_paragraph(
paragraph_id: int,
request: Request,
background_tasks: BackgroundTasks
) -> ResponseSchemaModel[None]:
"""
删除文章段落记录
参数:
- paragraph_id: 段落ID
返回:
- 无
"""
success = await article_service.delete_article_paragraph(paragraph_id)
if not success:
return response_base.fail(code=404, msg="段落不存在")
return response_base.success()
@router.get("/paragraph/{paragraph_id}", summary="获取段落详情", dependencies=[DependsJwtAuth])
async def get_article_paragraph(
paragraph_id: int
) -> ResponseSchemaModel[ArticleParagraphSchema]:
"""
获取文章段落详情
参数:
- paragraph_id: 段落ID
返回:
- 段落记录
"""
paragraph = await article_service.get_article_paragraph_by_id(paragraph_id)
if not paragraph:
return response_base.fail(code=404, msg="段落不存在")
return response_base.success(data=paragraph)
@router.get("/paragraph", summary="获取文章的所有段落", dependencies=[DependsJwtAuth])
async def get_article_paragraphs(
article_id: int = Query(..., description="文章ID")
) -> ResponseSchemaModel[List[ArticleParagraphSchema]]:
"""
获取指定文章的所有段落记录
参数:
- article_id: 文章ID
返回:
- 文章的所有段落记录列表
"""
paragraphs = await article_service.get_article_paragraphs_by_article_id(article_id)
return response_base.success(data=paragraphs)
@router.post("/sentence", summary="创建文章句子", dependencies=[DependsJwtAuth])
async def create_article_sentence(
request: Request,
background_tasks: BackgroundTasks,
params: CreateArticleSentenceParam
) -> ResponseSchemaModel[ArticleSentenceSchema]:
"""
创建文章句子记录
请求体参数:
- paragraph_id: 关联的段落ID
- sentence_index: 句子序号
- content: 句子内容
- standard_audio_id: 标准朗读音频文件ID可选
- info: 附加信息(可选)
返回:
- 创建的句子记录
"""
sentence_id = await article_service.create_article_sentence(obj=params)
sentence = await article_service.get_article_sentence_by_id(sentence_id)
return response_base.success(data=sentence)
@router.put("/sentence/{sentence_id}", summary="更新文章句子", dependencies=[DependsJwtAuth])
async def update_article_sentence(
sentence_id: int,
request: Request,
background_tasks: BackgroundTasks,
params: UpdateArticleSentenceParam
) -> ResponseSchemaModel[ArticleSentenceSchema]:
"""
更新文章句子记录
参数:
- sentence_id: 句子ID
请求体参数:
- paragraph_id: 关联的段落ID
- sentence_index: 句子序号
- content: 句子内容
- standard_audio_id: 标准朗读音频文件ID可选
- info: 附加信息(可选)
返回:
- 更新后的句子记录
"""
success = await article_service.update_article_sentence(sentence_id, params)
if not success:
return response_base.fail(code=404, msg="句子不存在")
sentence = await article_service.get_article_sentence_by_id(sentence_id)
return response_base.success(data=sentence)
@router.delete("/sentence/{sentence_id}", summary="删除文章句子", dependencies=[DependsJwtAuth])
async def delete_article_sentence(
sentence_id: int,
request: Request,
background_tasks: BackgroundTasks
) -> ResponseSchemaModel[None]:
"""
删除文章句子记录
参数:
- sentence_id: 句子ID
返回:
- 无
"""
success = await article_service.delete_article_sentence(sentence_id)
if not success:
return response_base.fail(code=404, msg="句子不存在")
return response_base.success()
@router.get("/sentence/{sentence_id}", summary="获取句子详情", dependencies=[DependsJwtAuth])
async def get_article_sentence(
sentence_id: int
) -> ResponseSchemaModel[ArticleSentenceSchema]:
"""
获取文章句子详情
参数:
- sentence_id: 句子ID
返回:
- 句子记录
"""
sentence = await article_service.get_article_sentence_by_id(sentence_id)
if not sentence:
return response_base.fail(code=404, msg="句子不存在")
return response_base.success(data=sentence)
@router.get("/sentence", summary="获取段落的所有句子", dependencies=[DependsJwtAuth])
async def get_article_sentences(
paragraph_id: int = Query(..., description="段落ID")
) -> ResponseSchemaModel[List[ArticleSentenceSchema]]:
"""
获取指定段落的所有句子记录
参数:
- paragraph_id: 段落ID
返回:
- 段落的所有句子记录列表
"""
sentences = await article_service.get_article_sentences_by_paragraph_id(paragraph_id)
return response_base.success(data=sentences)

113
backend/app/ai/api/image.py Executable file
View File

@@ -0,0 +1,113 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from datetime import datetime
from fastapi import APIRouter, UploadFile, HTTPException, Request, Response, Query
from fastapi.params import File, Depends
from starlette.background import BackgroundTasks
from backend.app.admin.schema.audit_log import CreateAuditLogParam
from backend.app.ai.schema.image import ImageRecognizeRes, ProcessImageRequest, ImageBean, ImageShowRes
from backend.app.admin.service.audit_log_service import audit_log_service
from backend.app.ai.service.image_service import ImageService, image_service
from backend.common.response.response_schema import response_base, ResponseSchemaModel
from backend.common.security.jwt import DependsJwtAuth
from backend.app.admin.schema.wx import DictLevel
router = APIRouter()
# @router.post("/upload", summary="上传图片进行识别", dependencies=[DependsJwtAuth])
# async def upload_image(
# request: Request, background_tasks: BackgroundTasks, file: UploadFile = File(...)
# ) -> ResponseSchemaModel[ImageRecognizeRes]:
# """
# 上传图片并调用通义千问API进行识别
#
# 返回:
# - 识别结果 (英文单词或类别)
# - 是否来自缓存
# - 图片ID
# """
# image_service.file_verify(file)
# result = await image_service.process_image_upload(file=file, request=request, background_tasks=background_tasks)
# return response_base.success(data=result)
@router.post("/recognize", summary="处理已上传的图片文件", dependencies=[DependsJwtAuth])
async def process_image(
request: Request,
background_tasks: BackgroundTasks,
params: ProcessImageRequest
) -> ResponseSchemaModel[ImageRecognizeRes]:
"""
处理已上传的图片文件并调用通义千问API进行识别
请求体参数:
- file_id: 已上传文件的ID
- type: 文件类型(默认为"image"
返回:
- 识别结果 (英文单词或类别)
- 是否来自缓存
- 图片ID
"""
result = await image_service.process_image_from_file(
params=params,
request=request,
background_tasks=background_tasks
)
return response_base.success(data=result)
@router.get("/{id}", dependencies=[DependsJwtAuth])
async def get_image(
request: Request,
id: int,
dict_level: DictLevel = Query(DictLevel.LEVEL1, description="词典等级")
) -> ResponseSchemaModel[ImageShowRes]:
image = await image_service.find_image(id)
# 检查details和recognition_result是否存在
if not image.details or "recognition_result" not in image.details:
raise HTTPException(status_code=404, detail="Recognition result not found")
# 根据dict_level获取对应的识别结果
recognition_result = image.details["recognition_result"]
if dict_level.value not in recognition_result:
raise HTTPException(status_code=404, detail=f"Recognition result for {dict_level.value} not found")
result = ImageShowRes(
id=image.id,
file_id=image.file_id,
res=recognition_result[dict_level.value]
)
return response_base.success(data=result)
@router.get("/log")
def log(request: Request, background_tasks: BackgroundTasks) -> ResponseSchemaModel[dict]:
audit_log = CreateAuditLogParam(
api_type="test_api",
model_name="test_model",
response_data={"test": "test_response"},
called_at=datetime.now(),
request_data={"test": "test_request"},
token_usage={"usage": "test_usage"},
cost=0.0,
duration=0.0,
status_code=200,
error_message="",
image_id=123123,
user_id=321312,
api_version="test_version",
)
background_tasks.add_task(
audit_log_service.create,
obj=audit_log,
)
return response_base.success(data={"succ": True})
# @router.get("/download/{image_id}")
# async def download_image(image_id: int):
# result = await image_service.download_image(id=image_id)
# return response_base.success(result)

View File

@@ -0,0 +1,161 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List
from fastapi import APIRouter, Request, Query
from starlette.background import BackgroundTasks
from backend.app.ai.schema.image_text import ImageTextSchema, ImageTextWithRecordingsSchema, CreateImageTextParam, UpdateImageTextParam, ImageTextInitResponseSchema, ImageTextInitParam
from backend.app.ai.service.image_text_service import image_text_service
from backend.app.ai.service.recording_service import recording_service
from backend.common.exception import errors
from backend.common.response.response_schema import response_base, ResponseSchemaModel
from backend.common.security.jwt import DependsJwtAuth
router = APIRouter()
@router.post("/init", summary="初始化图片文本", dependencies=[DependsJwtAuth])
async def init_image_texts(
request: Request,
background_tasks: BackgroundTasks,
params: ImageTextInitParam
) -> ResponseSchemaModel[ImageTextInitResponseSchema]:
"""
初始化图片文本记录
根据dict_level从image的recognition_result中提取文本如果不存在则创建如果已存在则直接返回
参数:
- image_id: 图片ID
- dict_level: 词典等级
返回:
- 图片文本记录列表
"""
result = await image_text_service.init_image_texts(request.user.id, params.image_id, params.dict_level, background_tasks)
return response_base.success(data=result)
@router.get("/standard/{text_id}", summary="获取标准音频文件ID", dependencies=[DependsJwtAuth])
async def get_standard_audio_file_id(
text_id: int,
) -> ResponseSchemaModel[dict]:
"""
根据文本ID获取标准音频的文件ID等待异步任务创建完成
参数:
- text_id: 图片文本ID
返回:
- 标准音频的文件ID
"""
file_id = await recording_service.get_standard_audio_file_id_by_text_id(text_id)
if not file_id:
raise errors.NotFoundError(msg="标准音频不存在或创建超时")
return response_base.success(data={'audio_id': str(file_id)})
@router.get("/{text_id}", summary="获取图片文本详情", dependencies=[DependsJwtAuth])
async def get_image_text(
text_id: int
) -> ResponseSchemaModel[ImageTextWithRecordingsSchema]:
"""
获取图片文本详情,包括关联的录音记录
参数:
- text_id: 图片文本ID
返回:
- 图片文本记录及其关联的录音记录
"""
text = await image_text_service.get_text_by_id(text_id)
if not text:
raise errors.NotFoundError(msg="图片文本不存在")
# 获取关联的录音记录
recordings = await recording_service.get_recordings_by_text_id(text_id)
# 构造返回数据
result = ImageTextWithRecordingsSchema(
id=text.id,
image_id=text.image_id,
content=text.content,
position=text.position,
dict_level=text.dict_level,
standard_audio_id=text.standard_audio_id,
info=text.info,
created_time=text.created_time,
recordings=recordings
)
return response_base.success(data=result)
@router.put("/{text_id}", summary="更新图片文本", dependencies=[DependsJwtAuth])
async def update_image_text(
text_id: int,
request: Request,
background_tasks: BackgroundTasks,
params: UpdateImageTextParam
) -> ResponseSchemaModel[ImageTextSchema]:
"""
更新图片文本记录
参数:
- text_id: 图片文本ID
请求体参数:
- image_id: 图片ID
- content: 文本内容
- position: 文本在图片中的位置信息(可选)
- dict_level: 词典等级(可选)
- standard_audio_id: 标准朗读音频文件ID可选
- info: 附加信息(可选)
返回:
- 更新后的图片文本记录
"""
success = await image_text_service.update_text(text_id, params)
if not success:
return response_base.fail(code=404, msg="图片文本不存在")
text = await image_text_service.get_text_by_id(text_id)
return response_base.success(data=text)
@router.delete("/{text_id}", summary="删除图片文本", dependencies=[DependsJwtAuth])
async def delete_image_text(
text_id: int,
request: Request,
background_tasks: BackgroundTasks
) -> ResponseSchemaModel[None]:
"""
删除图片文本记录
参数:
- text_id: 图片文本ID
返回:
- 无
"""
success = await image_text_service.delete_text(text_id)
if not success:
return response_base.fail(code=404, msg="图片文本不存在")
return response_base.success()
@router.get("", summary="获取图片的所有文本", dependencies=[DependsJwtAuth])
async def get_image_texts(
image_id: int = Query(..., description="图片ID")
) -> ResponseSchemaModel[List[ImageTextSchema]]:
"""
获取指定图片的所有文本记录
参数:
- image_id: 图片ID
返回:
- 图片的所有文本记录列表
"""
texts = await image_text_service.get_texts_by_image_id(image_id)
return response_base.success(data=texts)

View File

@@ -0,0 +1,105 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List
from fastapi import APIRouter, Request, Query
from starlette.background import BackgroundTasks
from backend.app.ai.schema.recording import RecordingAssessmentRequest, RecordingAssessmentResponse, ReadingProgressResponse
from backend.app.ai.service.recording_service import recording_service
from backend.app.ai.service.image_text_service import image_text_service
from backend.common.exception import errors
from backend.common.response.response_schema import response_base, ResponseSchemaModel
from backend.common.security.jwt import DependsJwtAuth
router = APIRouter()
@router.post("/assessment", summary="录音评估接口", dependencies=[DependsJwtAuth])
async def assess_recording(
request: Request,
background_tasks: BackgroundTasks,
params: RecordingAssessmentRequest
) -> ResponseSchemaModel[RecordingAssessmentResponse]:
"""
录音评估接口
接收文件ID和图片文本ID调用第三方API获取评估结果并存储到recording表的details字段
请求体参数:
- file_id: 录音文件ID
- image_text_id: 关联的图片文本ID
返回:
- file_id: 录音文件ID
- assessment_result: 评估结果
- image_id: 关联的图片ID
- image_text_id: 关联的图片文本ID
"""
# 获取图片文本记录以获取image_id用于响应
image_text = await image_text_service.get_text_by_id(params.image_text_id)
if not image_text:
raise errors.NotFoundError(msg=f"ImageText with id {params.image_text_id} not found")
# 调用录音服务进行评估
assessment_result = await recording_service.assess_recording(
params.file_id,
# 2087227590107594752,
image_text_id=params.image_text_id,
user_id=request.user.id,
background_tasks=background_tasks
)
# 返回结果
response_data = RecordingAssessmentResponse(
file_id=params.file_id,
assessment_result=assessment_result,
image_text_id=params.image_text_id
)
return response_base.success(data=response_data)
@router.get("/progress", summary="获取朗读进步统计", dependencies=[DependsJwtAuth])
async def get_reading_progress(
image_id: int = Query(..., description="图片ID"),
text: str = Query(..., description="朗读文本")
) -> ResponseSchemaModel[ReadingProgressResponse]:
"""
获取特定图片和文本的朗读进步统计
参数:
- image_id: 图片ID
- text: 朗读文本
返回:
- 包含进步统计信息的字典
"""
# 获取所有相关的录音记录
recordings = await recording_service.get_recordings_by_image_and_text(image_id, text)
# 计算进步统计
progress_stats = recording_service.calculate_progress(recordings)
return response_base.success(data=progress_stats)
@router.get("/progress-by-text-id", summary="根据文本ID获取朗读进步统计", dependencies=[DependsJwtAuth])
async def get_reading_progress_by_text_id(
image_text_id: int = Query(..., description="图片文本ID")
) -> ResponseSchemaModel[ReadingProgressResponse]:
"""
根据图片文本ID获取朗读进步统计
参数:
- image_text_id: 图片文本ID
返回:
- 包含进步统计信息的字典
"""
# 获取所有相关的录音记录
recordings = await recording_service.get_recordings_by_text_id(image_text_id)
# 计算进步统计
progress_stats = recording_service.calculate_progress(recordings)
return response_base.success(data=progress_stats)

View File

@@ -0,0 +1,16 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from fastapi import APIRouter
from backend.app.ai.api.image import router as image_router
from backend.app.ai.api.recording import router as recording_router
from backend.app.ai.api.image_text import router as image_text_router
from backend.app.ai.api.article import router as article_router
from backend.core.conf import settings
v1 = APIRouter(prefix=settings.FASTAPI_API_V1_PATH)
v1.include_router(image_router, prefix='/image', tags=['AI图片服务'])
v1.include_router(recording_router, prefix='/recording', tags=['AI录音服务'])
v1.include_router(image_text_router, prefix='/image_text', tags=['AI图片文本服务'])
# v1.include_router(article_router, prefix='/article', tags=['AI文章服务'])

View File

@@ -0,0 +1,3 @@
from backend.app.ai.crud.image_curd import image_dao
from backend.app.ai.crud.image_text_crud import image_text_dao
from backend.app.ai.crud.recording_crud import recording_dao

View File

@@ -0,0 +1,65 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional, List
from sqlalchemy import select, and_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.ai.model.article import Article, ArticleParagraph, ArticleSentence
class CRUDArticle(CRUDPlus[Article]):
async def get_by_title(self, db: AsyncSession, title: str) -> Optional[Article]:
"""根据标题获取文章"""
stmt = select(self.model).where(self.model.title == title)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_articles_by_category(self, db: AsyncSession, category: str) -> List[Article]:
"""根据分类获取文章列表"""
stmt = select(self.model).where(self.model.category == category)
result = await db.execute(stmt)
return list(result.scalars().all())
class CRUDArticleParagraph(CRUDPlus[ArticleParagraph]):
async def get_by_article_id(self, db: AsyncSession, article_id: int) -> List[ArticleParagraph]:
"""根据文章ID获取所有段落按序号排序"""
stmt = select(self.model).where(self.model.article_id == article_id).order_by(self.model.paragraph_index)
result = await db.execute(stmt)
return list(result.scalars().all())
async def get_by_article_id_and_index(self, db: AsyncSession, article_id: int, paragraph_index: int) -> Optional[ArticleParagraph]:
"""根据文章ID和段落序号获取段落"""
stmt = select(self.model).where(
and_(
self.model.article_id == article_id,
self.model.paragraph_index == paragraph_index
)
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
class CRUDArticleSentence(CRUDPlus[ArticleSentence]):
async def get_by_paragraph_id(self, db: AsyncSession, paragraph_id: int) -> List[ArticleSentence]:
"""根据段落ID获取所有句子按序号排序"""
stmt = select(self.model).where(self.model.paragraph_id == paragraph_id).order_by(self.model.sentence_index)
result = await db.execute(stmt)
return list(result.scalars().all())
async def get_by_paragraph_id_and_index(self, db: AsyncSession, paragraph_id: int, sentence_index: int) -> Optional[ArticleSentence]:
"""根据段落ID和句子序号获取句子"""
stmt = select(self.model).where(
and_(
self.model.paragraph_id == paragraph_id,
self.model.sentence_index == sentence_index
)
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
article_dao: CRUDArticle = CRUDArticle(Article)
article_paragraph_dao: CRUDArticleParagraph = CRUDArticleParagraph(ArticleParagraph)
article_sentence_dao: CRUDArticleSentence = CRUDArticleSentence(ArticleSentence)

160
backend/app/ai/crud/image_curd.py Executable file
View File

@@ -0,0 +1,160 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List
import numpy as np
from sqlalchemy import select, update, desc, and_, func, Float
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import Select
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.ai.model import Image
from backend.app.ai.schema.image import UpdateImageParam, AddImageParam
class ImageCRUD(CRUDPlus[Image]):
async def get(self, db: AsyncSession, id: int) -> Image | None:
return await self.select_model(db, id)
async def get_by_file_id(self, db: AsyncSession, file_id: int) -> Image | None:
return await self.select_model_by_column(db, file_id=file_id)
async def get_by_file_id_and_dict_level(self, db: AsyncSession, file_id: int, dict_level: str) -> Image | None:
"""
根据文件ID和词典等级获取图片记录
:param db: 数据库会话
:param file_id: 文件ID
:param dict_level: 词典等级
:return: 图片记录或None
"""
return await self.select_model_by_column(db, file_id=file_id, dict_level=dict_level)
async def get_images_by_file_id(self, db: AsyncSession, file_id: int) -> List[Image]:
"""
根据文件ID获取具有不同词典等级的图片记录
:param db: 数据库会话
:param file_id: 文件ID
:param dict_level: 当前词典等级
:return: 图片记录列表
"""
stmt = select(Image).where(Image.file_id == file_id)
result = await db.execute(stmt)
return result.scalars().all()
async def get_images_by_file_id_and_different_dict_level(self, db: AsyncSession, file_id: int, dict_level: str) -> \
List[Image]:
"""
根据文件ID获取具有不同词典等级的图片记录
:param db: 数据库会话
:param file_id: 文件ID
:param dict_level: 当前词典等级
:return: 图片记录列表
"""
stmt = select(Image).where(
and_(
Image.file_id == file_id,
Image.dict_level != dict_level
)
)
result = await db.execute(stmt)
return result.scalars().all()
async def get_images_by_ids(self, db: AsyncSession, image_ids: List[int]) -> List[Image]:
"""
根据ID列表获取图片记录
:param db: 数据库会话
:param image_ids: 图片ID列表
:return: 图片记录列表
"""
if not image_ids:
return []
stmt = select(Image).where(Image.id.in_(image_ids))
result = await db.execute(stmt)
return result.scalars().all()
async def find_similar_images_by_dict_level(self, db: AsyncSession, embedding: List[float], dict_level: str,
top_k: int = 3, threshold: float = 0.8) -> List[int]:
"""
根据向量和词典等级查找相似图片
参数:
db: 数据库会话
embedding: 1024维向量
dict_level: 词典等级
top_k: 返回最相似的K个结果
threshold: 相似度阈值 (0.0-1.0)
"""
# 确保向量是numpy数组
if not isinstance(embedding, np.ndarray):
embedding = np.array(embedding)
# 转换为 NumPy 数组提高性能
target_embedding = np.array(embedding, dtype=np.float32)
# 构建查询
cosine_distance_expr = Image.embedding.cosine_distance(target_embedding)
similarity_expr = (1 - func.cast(cosine_distance_expr, Float)).label("similarity")
# 构建查询
stmt = select(
Image.id,
# Image.info,
similarity_expr
).where(
and_(
Image.dict_level == dict_level,
similarity_expr >= threshold
)
).order_by(
cosine_distance_expr
).limit(top_k)
results = await db.execute(stmt)
id_list: List[int] = results.scalars().all()
return id_list
async def add(self, db: AsyncSession, new_image: Image) -> None:
db.add(new_image)
async def update(self, db: AsyncSession, id: int, obj: UpdateImageParam) -> int:
return await self.update_model(db, id, obj)
async def find_similar_image_ids(self, db: AsyncSession, embedding: List[float], top_k: int = 3,
threshold: float = 0.8) -> List[int]:
"""
直接通过向量查找相似图片
参数:
embedding: 1024维向量
top_k: 返回最相似的K个结果
threshold: 相似度阈值 (0.0-1.0)
"""
# 确保向量是numpy数组
if not isinstance(embedding, np.ndarray):
embedding = np.array(embedding)
# 转换为 NumPy 数组提高性能
target_embedding = np.array(embedding, dtype=np.float32)
# 构建查询
cosine_distance_expr = Image.embedding.cosine_distance(target_embedding)
similarity_expr = (1 - func.cast(cosine_distance_expr, Float)).label("similarity")
# 构建查询
stmt = select(
Image.id,
# Image.info,
similarity_expr
).where(
similarity_expr >= threshold
).order_by(
cosine_distance_expr
).limit(top_k)
results = await db.execute(stmt)
id_list: List[int] = results.scalars().all()
return id_list
image_dao: ImageCRUD = ImageCRUD(Image)

View File

@@ -0,0 +1,44 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional, List
from sqlalchemy import select, and_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.ai.model.image_text import ImageText
class ImageTextCRUD(CRUDPlus[ImageText]):
async def get(self, db: AsyncSession, id: int) -> Optional[ImageText]:
"""根据ID获取文本记录"""
return await self.select_model(db, id)
async def get_by_image_id(self, db: AsyncSession, image_id: int) -> List[ImageText]:
"""根据图片ID获取所有文本"""
stmt = select(self.model).where(self.model.image_id == image_id)
result = await db.execute(stmt)
return list(result.scalars().all())
async def get_by_image_id_and_content(self, db: AsyncSession, image_id: int, content: str) -> Optional[ImageText]:
"""根据图片ID和文本内容获取文本记录"""
stmt = select(self.model).where(
and_(
self.model.image_id == image_id,
self.model.content == content
)
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_by_standard_audio_id(self, db: AsyncSession, standard_audio_id: int) -> Optional[ImageText]:
"""根据标准音频文件ID获取文本记录"""
stmt = select(self.model).where(self.model.standard_audio_id == standard_audio_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def update(self, db: AsyncSession, id: int, obj_in: dict) -> int:
"""更新文本记录"""
return await self.update_model(db, id, obj_in)
image_text_dao: ImageTextCRUD = ImageTextCRUD(ImageText)

View File

@@ -0,0 +1,57 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional, List
from sqlalchemy import select, and_, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus
from backend.app.ai.model.recording import Recording
class RecordingCRUD(CRUDPlus[Recording]):
async def get(self, db: AsyncSession, id: int) -> Optional[Recording]:
"""根据ID获取录音记录"""
return await self.select_model(db, id)
async def get_by_file_id(self, db: AsyncSession, file_id: int) -> Optional[Recording]:
"""根据文件ID获取录音记录"""
stmt = select(self.model).where(self.model.file_id == file_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_by_text_id(self, db: AsyncSession, text_id: int) -> List[Recording]:
"""根据文本ID获取所有录音记录不包括标准音频"""
stmt = select(self.model).where(
and_(
self.model.image_text_id == text_id,
self.model.is_standard == False
)
).order_by(self.model.created_time.asc())
result = await db.execute(stmt)
return list(result.scalars().all())
async def get_latest_by_text_id(self, db: AsyncSession, text_id: int) -> Optional[Recording]:
"""根据文本ID获取最新的录音记录不包括标准音频"""
stmt = select(self.model).where(
and_(
self.model.image_text_id == text_id,
self.model.is_standard == False
)
).order_by(self.model.created_time.desc()).limit(1)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def get_standard_by_text_id(self, db: AsyncSession, text_id: int) -> Optional[Recording]:
"""根据文本ID获取标准音频记录"""
stmt = select(self.model).where(
and_(
self.model.image_text_id == text_id,
self.model.is_standard == True
)
).limit(1)
result = await db.execute(stmt)
return result.scalar_one_or_none()
recording_dao: RecordingCRUD = RecordingCRUD(Recording)

Some files were not shown because too many files have changed in this diff Show More