first commit
This commit is contained in:
9
.gitignore
vendored
Executable file
9
.gitignore
vendored
Executable file
@@ -0,0 +1,9 @@
|
|||||||
|
__pycache__/
|
||||||
|
.idea/
|
||||||
|
backend/.env
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
backend/alembic/versions/
|
||||||
|
*.log
|
||||||
|
.ruff_cache/
|
||||||
|
backend/static
|
||||||
29
.pre-commit-config.yaml
Executable file
29
.pre-commit-config.yaml
Executable file
@@ -0,0 +1,29 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
- id: check-added-large-files
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-toml
|
||||||
|
|
||||||
|
- repo: https://github.com/charliermarsh/ruff-pre-commit
|
||||||
|
rev: v0.11.4
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args:
|
||||||
|
- '--config'
|
||||||
|
- '.ruff.toml'
|
||||||
|
- '--fix'
|
||||||
|
- '--unsafe-fixes'
|
||||||
|
- id: ruff-format
|
||||||
|
|
||||||
|
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||||
|
rev: 0.6.14
|
||||||
|
hooks:
|
||||||
|
- id: uv-lock
|
||||||
|
- id: uv-export
|
||||||
|
args:
|
||||||
|
- '-o'
|
||||||
|
- 'requirements.txt'
|
||||||
|
- '--no-hashes'
|
||||||
35
.ruff.toml
Executable file
35
.ruff.toml
Executable file
@@ -0,0 +1,35 @@
|
|||||||
|
line-length = 120
|
||||||
|
unsafe-fixes = true
|
||||||
|
cache-dir = ".ruff_cache"
|
||||||
|
target-version = "py310"
|
||||||
|
|
||||||
|
[lint]
|
||||||
|
select = [
|
||||||
|
"E",
|
||||||
|
"F",
|
||||||
|
"W505",
|
||||||
|
"SIM101",
|
||||||
|
"SIM114",
|
||||||
|
"PGH004",
|
||||||
|
"PLE1142",
|
||||||
|
"RUF100",
|
||||||
|
"I002",
|
||||||
|
"F404",
|
||||||
|
"TC",
|
||||||
|
"UP007"
|
||||||
|
]
|
||||||
|
preview = true
|
||||||
|
|
||||||
|
[lint.isort]
|
||||||
|
lines-between-types = 1
|
||||||
|
order-by-type = true
|
||||||
|
|
||||||
|
[lint.per-file-ignores]
|
||||||
|
"**/api/v1/*.py" = ["TC"]
|
||||||
|
"**/model/*.py" = ["TC003"]
|
||||||
|
"**/model/__init__.py" = ["F401"]
|
||||||
|
|
||||||
|
[format]
|
||||||
|
preview = true
|
||||||
|
quote-style = "single"
|
||||||
|
docstring-code-format = true
|
||||||
30
Dockerfile
Executable file
30
Dockerfile
Executable file
@@ -0,0 +1,30 @@
|
|||||||
|
FROM python:3.10-slim
|
||||||
|
|
||||||
|
WORKDIR /fsm
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
RUN sed -i 's/deb.debian.org/mirrors.ustc.edu.cn/g' /etc/apt/sources.list.d/debian.sources \
|
||||||
|
&& sed -i 's|security.debian.org/debian-security|mirrors.ustc.edu.cn/debian-security|g' /etc/apt/sources.list.d/debian.sources
|
||||||
|
|
||||||
|
RUN apt-get update \
|
||||||
|
&& apt-get install -y --no-install-recommends gcc python3-dev supervisor \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
# 某些包可能存在同步不及时导致安装失败的情况,可更改为官方源:https://pypi.org/simple
|
||||||
|
&& pip install --upgrade pip -i https://mirrors.aliyun.com/pypi/simple \
|
||||||
|
&& pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple \
|
||||||
|
&& pip install gunicorn wait-for-it -i https://mirrors.aliyun.com/pypi/simple
|
||||||
|
|
||||||
|
ENV TZ="Asia/Shanghai"
|
||||||
|
|
||||||
|
RUN mkdir -p /var/log/fastapi_server \
|
||||||
|
&& mkdir -p /var/log/supervisor \
|
||||||
|
&& mkdir -p /etc/supervisor/conf.d
|
||||||
|
|
||||||
|
COPY deploy/supervisor.conf /etc/supervisor/supervisord.conf
|
||||||
|
|
||||||
|
COPY deploy/fastapi_server.conf /etc/supervisor/conf.d/
|
||||||
|
|
||||||
|
EXPOSE 8001
|
||||||
|
|
||||||
|
CMD ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "8001"]
|
||||||
1144
assets/dict/dictionary_parser.py
Executable file
1144
assets/dict/dictionary_parser.py
Executable file
File diff suppressed because it is too large
Load Diff
14
backend/.env.example
Executable file
14
backend/.env.example
Executable file
@@ -0,0 +1,14 @@
|
|||||||
|
# Env: dev、pro
|
||||||
|
ENVIRONMENT='dev'
|
||||||
|
# Database
|
||||||
|
DATABASE_HOST='127.0.0.1'
|
||||||
|
DATABASE_PORT=3306
|
||||||
|
DATABASE_USER='root'
|
||||||
|
DATABASE_PASSWORD='123456'
|
||||||
|
# Redis
|
||||||
|
REDIS_HOST='127.0.0.1'
|
||||||
|
REDIS_PORT=6379
|
||||||
|
REDIS_PASSWORD=''
|
||||||
|
REDIS_DATABASE=0
|
||||||
|
# Token
|
||||||
|
TOKEN_SECRET_KEY='1VkVF75nsNABBjK_7-qz7GtzNy3AMvktc9TCPwKczCk'
|
||||||
2
backend/__init__.py
Executable file
2
backend/__init__.py
Executable file
@@ -0,0 +1,2 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
100
backend/alembic.ini
Executable file
100
backend/alembic.ini
Executable file
@@ -0,0 +1,100 @@
|
|||||||
|
# A generic, single database configuration.
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts
|
||||||
|
script_location = alembic
|
||||||
|
|
||||||
|
# template used to generate migration files
|
||||||
|
file_template = %%(year)d-%%(month).2d-%%(day).2d_%%(hour).2d-%%(minute).2d_%%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
|
# defaults to the current working directory.
|
||||||
|
prepend_sys_path = .
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
# as well as the filename.
|
||||||
|
# If specified, requires the python-dateutil library that can be
|
||||||
|
# installed by adding `alembic[tz]` to the pip requirements
|
||||||
|
# string value is passed to dateutil.tz.gettz()
|
||||||
|
# leave blank for localtime
|
||||||
|
# timezone =
|
||||||
|
|
||||||
|
# max length of characters to apply to the
|
||||||
|
# "slug" field
|
||||||
|
# truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during
|
||||||
|
# the 'revision' command, regardless of autogenerate
|
||||||
|
# revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without
|
||||||
|
# a source .py file to be detected as revisions in the
|
||||||
|
# versions/ directory
|
||||||
|
# sourceless = false
|
||||||
|
|
||||||
|
# version location specification; This defaults
|
||||||
|
# to alembic/versions. When using multiple version
|
||||||
|
# directories, initial revisions must be specified with --version-path.
|
||||||
|
# The path separator used here should be the separator specified by "version_path_separator"
|
||||||
|
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
|
||||||
|
|
||||||
|
# version path separator; As mentioned above, this is the character used to split
|
||||||
|
# version_locations. Valid values are:
|
||||||
|
#
|
||||||
|
# version_path_separator = :
|
||||||
|
# version_path_separator = ;
|
||||||
|
# version_path_separator = space
|
||||||
|
version_path_separator = os # default: use os.pathsep
|
||||||
|
|
||||||
|
# the output encoding used when revision files
|
||||||
|
# are written from script.py.mako
|
||||||
|
# output_encoding = utf-8
|
||||||
|
|
||||||
|
sqlalchemy.url = postgresql+asyncpg://root:root@127.0.0.1:5432/db
|
||||||
|
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
# post_write_hooks defines scripts or Python functions that are run
|
||||||
|
# on newly generated revision scripts. See the documentation for further
|
||||||
|
# detail and examples
|
||||||
|
|
||||||
|
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||||
|
# hooks = black
|
||||||
|
# black.type = console_scripts
|
||||||
|
# black.entrypoint = black
|
||||||
|
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
1
backend/alembic/README
Executable file
1
backend/alembic/README
Executable file
@@ -0,0 +1 @@
|
|||||||
|
Generic single-database configuration with an async dbapi.
|
||||||
99
backend/alembic/env.py
Executable file
99
backend/alembic/env.py
Executable file
@@ -0,0 +1,99 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# ruff: noqa: E402
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from sqlalchemy import engine_from_config, MetaData
|
||||||
|
from sqlalchemy import pool
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
|
||||||
|
sys.path.append('../')
|
||||||
|
|
||||||
|
from backend.core import path_conf
|
||||||
|
|
||||||
|
if not os.path.exists(path_conf.ALEMBIC_VERSION_DIR):
|
||||||
|
os.makedirs(path_conf.ALEMBIC_VERSION_DIR)
|
||||||
|
|
||||||
|
# this is the Alembic Config object, which provides
|
||||||
|
# access to the values within the .ini file in use.
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Interpret the config file for Python logging.
|
||||||
|
# This line sets up loggers basically.
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# add your model's MetaData object here
|
||||||
|
# for 'autogenerate' support
|
||||||
|
# https://alembic.sqlalchemy.org/en/latest/autogenerate.html#autogenerating-multiple-metadata-collections
|
||||||
|
from backend.app.admin.model import MappedBase
|
||||||
|
|
||||||
|
target_metadata = [
|
||||||
|
MappedBase.metadata,
|
||||||
|
]
|
||||||
|
|
||||||
|
# other values from the config, defined by the needs of env.py,
|
||||||
|
from backend.database.db import SQLALCHEMY_DATABASE_URL
|
||||||
|
|
||||||
|
config.set_main_option('sqlalchemy.url', SQLALCHEMY_DATABASE_URL)
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline():
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
|
This configures the context with just a URL
|
||||||
|
and not an Engine, though an Engine is acceptable
|
||||||
|
here as well. By skipping the Engine creation
|
||||||
|
we don't even need a DBAPI to be available.
|
||||||
|
|
||||||
|
Calls to context.execute() here emit the given string to the
|
||||||
|
script output.
|
||||||
|
|
||||||
|
"""
|
||||||
|
url = config.get_main_option('sqlalchemy.url')
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata, # type: ignore
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={'paramstyle': 'named'},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def do_run_migrations(connection):
|
||||||
|
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_migrations_online():
|
||||||
|
"""Run migrations in 'online' mode.
|
||||||
|
|
||||||
|
In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
|
||||||
|
"""
|
||||||
|
connectable = AsyncEngine(
|
||||||
|
engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section),
|
||||||
|
prefix='sqlalchemy.',
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
future=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async with connectable.connect() as connection:
|
||||||
|
await connection.run_sync(do_run_migrations)
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
asyncio.run(run_migrations_online())
|
||||||
0
backend/alembic/hooks.py
Executable file
0
backend/alembic/hooks.py
Executable file
24
backend/alembic/script.py.mako
Executable file
24
backend/alembic/script.py.mako
Executable file
@@ -0,0 +1,24 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = ${repr(up_revision)}
|
||||||
|
down_revision = ${repr(down_revision)}
|
||||||
|
branch_labels = ${repr(branch_labels)}
|
||||||
|
depends_on = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
35
backend/app/__init__.py
Executable file
35
backend/app/__init__.py
Executable file
@@ -0,0 +1,35 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os.path
|
||||||
|
|
||||||
|
from backend.core.path_conf import BASE_PATH
|
||||||
|
from backend.utils.import_parse import get_model_objects
|
||||||
|
|
||||||
|
|
||||||
|
def get_app_models() -> list[type]:
|
||||||
|
"""获取 app 所有模型类"""
|
||||||
|
app_path = os.path.join(BASE_PATH, 'app')
|
||||||
|
list_dirs = os.listdir(app_path)
|
||||||
|
|
||||||
|
apps = []
|
||||||
|
|
||||||
|
for d in list_dirs:
|
||||||
|
if os.path.isdir(os.path.join(app_path, d)) and d != '__pycache__':
|
||||||
|
apps.append(d)
|
||||||
|
|
||||||
|
objs = []
|
||||||
|
|
||||||
|
for app in apps:
|
||||||
|
module_path = f'backend.app.{app}.model'
|
||||||
|
obj = get_model_objects(module_path)
|
||||||
|
if obj:
|
||||||
|
objs.extend(obj)
|
||||||
|
|
||||||
|
return objs
|
||||||
|
|
||||||
|
|
||||||
|
# import all app models for auto create db tables
|
||||||
|
for cls in get_app_models():
|
||||||
|
class_name = cls.__name__
|
||||||
|
if class_name not in globals():
|
||||||
|
globals()[class_name] = cls
|
||||||
2
backend/app/admin/__init__.py
Executable file
2
backend/app/admin/__init__.py
Executable file
@@ -0,0 +1,2 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
2
backend/app/admin/api/__init__.py
Executable file
2
backend/app/admin/api/__init__.py
Executable file
@@ -0,0 +1,2 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
26
backend/app/admin/api/router.py
Executable file
26
backend/app/admin/api/router.py
Executable file
@@ -0,0 +1,26 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from backend.app.admin.api.v1.wx import router as wx_router
|
||||||
|
from backend.app.admin.api.v1.wxpay_callback import router as wx_pay_router
|
||||||
|
from backend.app.admin.api.v1.account import router as account_router
|
||||||
|
from backend.app.admin.api.v1.file import router as file_router
|
||||||
|
from backend.app.admin.api.v1.dict import router as dict_router
|
||||||
|
from backend.app.admin.api.v1.audit_log import router as audit_log_router
|
||||||
|
from backend.app.admin.api.v1.feedback import router as feedback_router
|
||||||
|
from backend.app.admin.api.v1.coupon import router as coupon_router
|
||||||
|
from backend.app.admin.api.v1.notification import router as notification_router
|
||||||
|
from backend.core.conf import settings
|
||||||
|
|
||||||
|
v1 = APIRouter(prefix=settings.FASTAPI_API_V1_PATH)
|
||||||
|
|
||||||
|
v1.include_router(account_router, prefix='/account', tags=['账户服务'])
|
||||||
|
v1.include_router(wx_pay_router, prefix="/wxpay", tags=["WeChat Pay Callback"])
|
||||||
|
v1.include_router(wx_router, prefix='/wx', tags=['微信服务'])
|
||||||
|
v1.include_router(file_router, prefix='/file', tags=['文件服务'])
|
||||||
|
v1.include_router(dict_router, prefix='/dict', tags=['字典服务'])
|
||||||
|
v1.include_router(audit_log_router, prefix='/audit', tags=['审计日志服务'])
|
||||||
|
v1.include_router(feedback_router, prefix='/feedback', tags=['反馈服务'])
|
||||||
|
v1.include_router(coupon_router, prefix='/coupon', tags=['兑换券服务'])
|
||||||
|
v1.include_router(notification_router, prefix='/notification', tags=['消息通知服务'])
|
||||||
2
backend/app/admin/api/v1/__init__.py
Executable file
2
backend/app/admin/api/v1/__init__.py
Executable file
@@ -0,0 +1,2 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
54
backend/app/admin/api/v1/account.py
Executable file
54
backend/app/admin/api/v1/account.py
Executable file
@@ -0,0 +1,54 @@
|
|||||||
|
# routers/account.py
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
|
from backend.app.admin.schema.usage import PurchaseRequest, SubscriptionRequest
|
||||||
|
from backend.app.admin.service.ad_share_service import AdShareService
|
||||||
|
from backend.app.admin.service.refund_service import RefundService
|
||||||
|
from backend.app.admin.service.subscription_service import SubscriptionService
|
||||||
|
from backend.app.admin.service.usage_service import UsageService
|
||||||
|
from backend.common.exception import errors
|
||||||
|
from backend.common.response.response_schema import response_base
|
||||||
|
from backend.common.security.jwt import DependsJwtAuth
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.post("/purchase", dependencies=[DependsJwtAuth])
|
||||||
|
async def purchase_times_api(
|
||||||
|
purchase_request: PurchaseRequest,
|
||||||
|
request: Request
|
||||||
|
):
|
||||||
|
await UsageService.purchase_times(request.user.id, purchase_request)
|
||||||
|
return response_base.success(data={"msg": "充值成功"})
|
||||||
|
|
||||||
|
@router.post("/subscribe", dependencies=[DependsJwtAuth])
|
||||||
|
async def subscribe_api(
|
||||||
|
sub_request: SubscriptionRequest,
|
||||||
|
request: Request
|
||||||
|
):
|
||||||
|
await SubscriptionService.subscribe(request.user.id, sub_request.plan)
|
||||||
|
return response_base.success(data={"msg": "订阅成功"})
|
||||||
|
|
||||||
|
@router.post("/ad", dependencies=[DependsJwtAuth])
|
||||||
|
async def ad_grant_api(request: Request):
|
||||||
|
await AdShareService.grant_times_by_ad(request.user.id)
|
||||||
|
return response_base.success(data={"msg": "已通过广告获得次数"})
|
||||||
|
|
||||||
|
@router.post("/share", dependencies=[DependsJwtAuth])
|
||||||
|
async def share_grant_api(request: Request):
|
||||||
|
await AdShareService.grant_times_by_share(request.user.id)
|
||||||
|
return response_base.success(data={"msg": "已通过分享获得次数"})
|
||||||
|
|
||||||
|
@router.post("/refund/{order_id}", dependencies=[DependsJwtAuth])
|
||||||
|
async def apply_refund(
|
||||||
|
order_id: int,
|
||||||
|
request: Request,
|
||||||
|
reason: str = "用户申请退款"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
申请退款
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await RefundService.process_refund(request.user.id, order_id, reason)
|
||||||
|
return response_base.success(data={"msg": "退款申请已提交", "data": result})
|
||||||
|
except Exception as e:
|
||||||
|
raise errors.RequestError(msg=str(e))
|
||||||
114
backend/app/admin/api/v1/audit_log.py
Executable file
114
backend/app/admin/api/v1/audit_log.py
Executable file
@@ -0,0 +1,114 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
from fastapi_pagination import Page
|
||||||
|
from fastapi_pagination.ext.sqlalchemy import paginate
|
||||||
|
|
||||||
|
from backend.app.admin.service.audit_log_service import audit_log_service
|
||||||
|
from backend.app.admin.schema.audit_log import AuditLogHistorySchema, AuditLogStatisticsSchema, DailySummaryPageSchema
|
||||||
|
from backend.app.admin.tasks import wx_user_index_history
|
||||||
|
from backend.common.response.response_schema import response_base, ResponseSchemaModel
|
||||||
|
from backend.common.security.jwt import DependsJwtAuth
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/history", summary="获取用户识别历史记录", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_user_recognition_history(
|
||||||
|
request:Request,
|
||||||
|
page: int = 1,
|
||||||
|
size: int = 20
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
通过用户ID查询历史记录
|
||||||
|
"""
|
||||||
|
history_items, total = await audit_log_service.get_user_recognition_history(
|
||||||
|
user_id=request.user.id,
|
||||||
|
page=page,
|
||||||
|
size=size
|
||||||
|
)
|
||||||
|
|
||||||
|
# await wx_user_index_history()
|
||||||
|
|
||||||
|
# 创建分页结果
|
||||||
|
result = {
|
||||||
|
"items": [item.model_dump() for item in history_items],
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"size": size
|
||||||
|
}
|
||||||
|
|
||||||
|
return response_base.success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/statistics", summary="获取用户识别统计信息", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_user_recognition_statistics(request:Request):
|
||||||
|
"""
|
||||||
|
统计用户 recognition 类型的使用记录
|
||||||
|
返回历史总量和当天总量
|
||||||
|
"""
|
||||||
|
total_count, today_count, image_count = await audit_log_service.get_user_recognition_statistics(user_id=request.user.id)
|
||||||
|
|
||||||
|
result = AuditLogStatisticsSchema(
|
||||||
|
total_count=total_count,
|
||||||
|
today_count=today_count,
|
||||||
|
image_count=image_count
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_base.success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/summary", summary="获取用户每日识别汇总记录", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_user_daily_summary(
|
||||||
|
request: Request,
|
||||||
|
page: int = 1,
|
||||||
|
size: int = 20
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
通过用户ID查询每日识别汇总记录,按创建时间降序排列
|
||||||
|
"""
|
||||||
|
summary_items, total = await audit_log_service.get_user_daily_summaries(
|
||||||
|
user_id=request.user.id,
|
||||||
|
page=page,
|
||||||
|
size=size
|
||||||
|
)
|
||||||
|
|
||||||
|
# await wx_user_index_history()
|
||||||
|
|
||||||
|
# 创建分页结果
|
||||||
|
result = DailySummaryPageSchema(
|
||||||
|
items=summary_items,
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
size=size
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_base.success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/today_summary", summary="获取用户今日识别汇总记录", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_user_today_summary(
|
||||||
|
request: Request,
|
||||||
|
page: int = 1,
|
||||||
|
size: int = 20
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
通过用户ID查询每日识别汇总记录,按创建时间降序排列
|
||||||
|
"""
|
||||||
|
summary_items, total = await audit_log_service.get_user_today_summaries(
|
||||||
|
user_id=request.user.id,
|
||||||
|
page=page,
|
||||||
|
size=size
|
||||||
|
)
|
||||||
|
|
||||||
|
# await wx_user_index_history()
|
||||||
|
|
||||||
|
# 创建分页结果
|
||||||
|
result = DailySummaryPageSchema(
|
||||||
|
items=summary_items,
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
size=size
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_base.success(data=result)
|
||||||
11
backend/app/admin/api/v1/auth/__init__.py
Executable file
11
backend/app/admin/api/v1/auth/__init__.py
Executable file
@@ -0,0 +1,11 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from backend.app.admin.api.v1.auth.auth import router as auth_router
|
||||||
|
from backend.app.admin.api.v1.auth.captcha import router as captcha_router
|
||||||
|
|
||||||
|
router = APIRouter(prefix='/auth')
|
||||||
|
|
||||||
|
router.include_router(auth_router, tags=['授权'])
|
||||||
|
router.include_router(captcha_router, prefix='/captcha', tags=['验证码'])
|
||||||
30
backend/app/admin/api/v1/auth/auth.py
Executable file
30
backend/app/admin/api/v1/auth/auth.py
Executable file
@@ -0,0 +1,30 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
|
from backend.app.admin.service.auth_service import auth_service
|
||||||
|
from backend.common.security.jwt import DependsJwtAuth
|
||||||
|
from backend.common.response.response_schema import response_base, ResponseModel, ResponseSchemaModel
|
||||||
|
from backend.app.admin.schema.token import GetSwaggerToken, GetLoginToken
|
||||||
|
from backend.app.admin.schema.user import AuthLoginParam
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/login/swagger', summary='swagger 调试专用', description='用于快捷进行 swagger 认证')
|
||||||
|
async def swagger_login(form_data: OAuth2PasswordRequestForm = Depends()) -> GetSwaggerToken:
|
||||||
|
token, user = await auth_service.swagger_login(form_data=form_data)
|
||||||
|
return GetSwaggerToken(access_token=token, user=user) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/login', summary='验证码登录')
|
||||||
|
async def user_login(request: Request, obj: AuthLoginParam) -> ResponseSchemaModel[GetLoginToken]:
|
||||||
|
data = await auth_service.login(request=request, obj=obj)
|
||||||
|
return response_base.success(data=data)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/logout', summary='用户登出', dependencies=[DependsJwtAuth])
|
||||||
|
async def user_logout() -> ResponseModel:
|
||||||
|
return response_base.success()
|
||||||
36
backend/app/admin/api/v1/auth/captcha.py
Executable file
36
backend/app/admin/api/v1/auth/captcha.py
Executable file
@@ -0,0 +1,36 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from fast_captcha import img_captcha
|
||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
from fastapi_limiter.depends import RateLimiter
|
||||||
|
from starlette.concurrency import run_in_threadpool
|
||||||
|
|
||||||
|
from backend.app.admin.schema.captcha import GetCaptchaDetail
|
||||||
|
from backend.common.response.response_schema import ResponseSchemaModel, response_base
|
||||||
|
from backend.core.conf import settings
|
||||||
|
from backend.database.db import uuid4_str
|
||||||
|
from backend.database.redis import redis_client
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
'',
|
||||||
|
summary='获取登录验证码',
|
||||||
|
dependencies=[Depends(RateLimiter(times=5, seconds=10))],
|
||||||
|
)
|
||||||
|
async def get_captcha(request: Request) -> ResponseSchemaModel[GetCaptchaDetail]:
|
||||||
|
"""
|
||||||
|
此接口可能存在性能损耗,尽管是异步接口,但是验证码生成是IO密集型任务,使用线程池尽量减少性能损耗
|
||||||
|
"""
|
||||||
|
img_type: str = 'base64'
|
||||||
|
img, code = await run_in_threadpool(img_captcha, img_byte=img_type)
|
||||||
|
uuid = uuid4_str()
|
||||||
|
request.app.state.captcha_uuid = uuid
|
||||||
|
await redis_client.set(
|
||||||
|
f'{settings.CAPTCHA_LOGIN_REDIS_PREFIX}:{uuid}',
|
||||||
|
code,
|
||||||
|
ex=settings.CAPTCHA_LOGIN_EXPIRE_SECONDS,
|
||||||
|
)
|
||||||
|
data = GetCaptchaDetail(image_type=img_type, image=img)
|
||||||
|
return response_base.success(data=data)
|
||||||
67
backend/app/admin/api/v1/coupon.py
Executable file
67
backend/app/admin/api/v1/coupon.py
Executable file
@@ -0,0 +1,67 @@
|
|||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from backend.app.admin.service.coupon_service import CouponService
|
||||||
|
from backend.common.response.response_schema import response_base
|
||||||
|
from backend.common.security.jwt import DependsJwtAuth
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
class RedeemCouponRequest(BaseModel):
|
||||||
|
code: str = Field(..., min_length=1, max_length=32, description="兑换码")
|
||||||
|
|
||||||
|
class CreateCouponRequest(BaseModel):
|
||||||
|
duration: int = Field(..., gt=0, description="兑换时长(分钟)")
|
||||||
|
count: int = Field(1, ge=1, le=1000, description="生成数量")
|
||||||
|
expires_days: Optional[int] = Field(None, ge=1, description="过期天数")
|
||||||
|
|
||||||
|
class CouponHistoryResponse(BaseModel):
|
||||||
|
code: str
|
||||||
|
duration: int
|
||||||
|
used_at: str
|
||||||
|
|
||||||
|
@router.post("/redeem", dependencies=[DependsJwtAuth])
|
||||||
|
async def redeem_coupon_api(
|
||||||
|
request: Request,
|
||||||
|
redeem_request: RedeemCouponRequest
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
兑换兑换券
|
||||||
|
"""
|
||||||
|
result = await CouponService.redeem_coupon(redeem_request.code, request.user.id)
|
||||||
|
return response_base.success(data=result)
|
||||||
|
|
||||||
|
@router.get("/history", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_coupon_history_api(
|
||||||
|
request: Request,
|
||||||
|
limit: int = 100
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取用户兑换历史
|
||||||
|
"""
|
||||||
|
history = await CouponService.get_user_coupon_history(request.user.id, limit)
|
||||||
|
return response_base.success(data=history)
|
||||||
|
|
||||||
|
# 管理员接口,用于批量生成兑换券
|
||||||
|
@router.post("/generate", dependencies=[DependsJwtAuth])
|
||||||
|
async def generate_coupons_api(
|
||||||
|
request: Request,
|
||||||
|
create_request: CreateCouponRequest
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
批量生成兑换券(管理员接口)
|
||||||
|
"""
|
||||||
|
# 这里应该添加管理员权限验证
|
||||||
|
# 为简化示例,暂时省略权限验证
|
||||||
|
|
||||||
|
coupons = await CouponService.batch_create_coupons(
|
||||||
|
create_request.count,
|
||||||
|
create_request.duration,
|
||||||
|
create_request.expires_days
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_base.success(data={
|
||||||
|
"count": len(coupons),
|
||||||
|
"message": f"成功生成{len(coupons)}个兑换券"
|
||||||
|
})
|
||||||
71
backend/app/admin/api/v1/dict.py
Executable file
71
backend/app/admin/api/v1/dict.py
Executable file
@@ -0,0 +1,71 @@
|
|||||||
|
from fastapi import APIRouter, Query, Path, Response
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import io
|
||||||
|
|
||||||
|
from backend.app.admin.schema.dict import DictWordResponse
|
||||||
|
from backend.app.admin.service.dict_service import dict_service
|
||||||
|
from backend.common.response.response_schema import response_base, ResponseSchemaModel
|
||||||
|
from backend.common.security.jwt import DependsJwtAuth
|
||||||
|
from backend.middleware import tencent_cloud
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/word/{word}", summary="根据单词查询字典信息", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_word_info(
|
||||||
|
word: str = Path(..., description="要查询的单词", min_length=1, max_length=100)
|
||||||
|
) -> ResponseSchemaModel[DictWordResponse]:
|
||||||
|
"""
|
||||||
|
根据单词查询字典信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
word: 要查询的单词
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含单词详细信息的响应,包括:
|
||||||
|
- senses: 义项列表
|
||||||
|
- frequency: 频率信息
|
||||||
|
- pronunciations: 发音信息
|
||||||
|
"""
|
||||||
|
# from backend.middleware.tencent_cloud import TencentCloud
|
||||||
|
# data = await TencentCloud().text_to_speak(image_id=2088374826979950592,content='green vegetable cut into small pieces and added to food',image_text_id=2088375029640331267,user_id=2083326996703739904)
|
||||||
|
result = await dict_service.get_word_info(word)
|
||||||
|
return response_base.success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/check/{word}", summary="检查单词是否存在", dependencies=[DependsJwtAuth])
|
||||||
|
async def check_word_exists(
|
||||||
|
word: str = Path(..., description="要检查的单词", min_length=1, max_length=100)
|
||||||
|
) -> ResponseSchemaModel[bool]:
|
||||||
|
"""
|
||||||
|
检查单词是否在字典中存在
|
||||||
|
|
||||||
|
Args:
|
||||||
|
word: 要检查的单词
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
布尔值,表示单词是否存在
|
||||||
|
"""
|
||||||
|
result = await dict_service.check_word_exists(word)
|
||||||
|
return response_base.success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/audio/{file_name:path}", summary="播放音频文件", dependencies=[DependsJwtAuth])
|
||||||
|
async def play_audio(
|
||||||
|
file_name: str = Path(..., description="音频文件名", min_length=1, max_length=255)
|
||||||
|
) -> StreamingResponse:
|
||||||
|
"""
|
||||||
|
根据文件名播放音频文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name: 音频文件名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
音频文件流
|
||||||
|
"""
|
||||||
|
audio_data = await dict_service.get_audio_data(file_name)
|
||||||
|
if not audio_data:
|
||||||
|
# 返回空的响应或默认音频
|
||||||
|
return StreamingResponse(io.BytesIO(b""), media_type="audio/mpeg")
|
||||||
|
|
||||||
|
return StreamingResponse(io.BytesIO(audio_data), media_type="audio/mpeg")
|
||||||
57
backend/app/admin/api/v1/feedback.py
Executable file
57
backend/app/admin/api/v1/feedback.py
Executable file
@@ -0,0 +1,57 @@
|
|||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
|
||||||
|
from backend.app.admin.schema.feedback import CreateFeedbackParam, UpdateFeedbackParam, FeedbackInfoSchema
|
||||||
|
from backend.app.admin.service.feedback_service import feedback_service
|
||||||
|
from backend.common.response.response_schema import response_base, ResponseModel, ResponseSchemaModel
|
||||||
|
from backend.common.security.jwt import DependsJwtAuth
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('', summary='创建用户反馈', dependencies=[DependsJwtAuth])
|
||||||
|
async def create_feedback(request: Request, obj: CreateFeedbackParam) -> ResponseSchemaModel[FeedbackInfoSchema]:
|
||||||
|
"""创建用户反馈"""
|
||||||
|
# 从JWT中获取用户ID
|
||||||
|
user_id = request.user.id
|
||||||
|
feedback = await feedback_service.create_feedback(user_id, obj)
|
||||||
|
return response_base.success(data=feedback)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/{feedback_id}', summary='获取反馈详情', dependencies=[DependsJwtAuth])
|
||||||
|
async def get_feedback(feedback_id: int) -> ResponseSchemaModel[FeedbackInfoSchema]:
|
||||||
|
"""获取反馈详情"""
|
||||||
|
feedback = await feedback_service.get_feedback(feedback_id)
|
||||||
|
if not feedback:
|
||||||
|
return response_base.fail(msg='反馈不存在')
|
||||||
|
return response_base.success(data=feedback)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('', summary='获取反馈列表', dependencies=[DependsJwtAuth])
|
||||||
|
async def get_feedback_list(
|
||||||
|
user_id: int = None,
|
||||||
|
status: str = None,
|
||||||
|
category: str = None,
|
||||||
|
limit: int = 10,
|
||||||
|
offset: int = 0
|
||||||
|
) -> ResponseSchemaModel[list[FeedbackInfoSchema]]:
|
||||||
|
"""获取反馈列表"""
|
||||||
|
feedbacks = await feedback_service.get_feedback_list(user_id, status, category, limit, offset)
|
||||||
|
return response_base.success(data=feedbacks)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put('/{feedback_id}', summary='更新反馈状态', dependencies=[DependsJwtAuth])
|
||||||
|
async def update_feedback(feedback_id: int, obj: UpdateFeedbackParam) -> ResponseModel:
|
||||||
|
"""更新反馈状态"""
|
||||||
|
success = await feedback_service.update_feedback(feedback_id, obj)
|
||||||
|
if not success:
|
||||||
|
return response_base.fail(msg='反馈不存在或更新失败')
|
||||||
|
return response_base.success()
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete('/{feedback_id}', summary='删除反馈', dependencies=[DependsJwtAuth])
|
||||||
|
async def delete_feedback(feedback_id: int) -> ResponseModel:
|
||||||
|
"""删除反馈"""
|
||||||
|
success = await feedback_service.delete_feedback(feedback_id)
|
||||||
|
if not success:
|
||||||
|
return response_base.fail(msg='反馈不存在或删除失败')
|
||||||
|
return response_base.success()
|
||||||
41
backend/app/admin/api/v1/file.py
Executable file
41
backend/app/admin/api/v1/file.py
Executable file
@@ -0,0 +1,41 @@
|
|||||||
|
from fastapi import APIRouter, UploadFile, File, Query, Depends, Response
|
||||||
|
|
||||||
|
from backend.app.admin.schema.file import FileUploadResponse
|
||||||
|
from backend.app.admin.service.file_service import file_service
|
||||||
|
from backend.common.response.response_schema import response_base, ResponseSchemaModel
|
||||||
|
from backend.common.security.jwt import DependsJwtAuth
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/upload", summary="上传文件", dependencies=[DependsJwtAuth])
|
||||||
|
async def upload_file(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
) -> ResponseSchemaModel[FileUploadResponse]:
|
||||||
|
"""上传文件"""
|
||||||
|
result = await file_service.upload_file(file)
|
||||||
|
return response_base.success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{file_id}", summary="下载文件", dependencies=[DependsJwtAuth])
|
||||||
|
# @router.get("/{file_id}", summary="下载文件")
|
||||||
|
async def download_file(file_id: int) -> Response:
|
||||||
|
"""下载文件"""
|
||||||
|
try:
|
||||||
|
content, filename, content_type = await file_service.download_file(file_id)
|
||||||
|
headers = {
|
||||||
|
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||||
|
"Content-Type": content_type
|
||||||
|
}
|
||||||
|
return Response(content=content, headers=headers)
|
||||||
|
except Exception as e:
|
||||||
|
return Response(content=str(e), status_code=404)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{file_id}", summary="删除文件", dependencies=[DependsJwtAuth])
|
||||||
|
async def delete_file(file_id: int) -> ResponseSchemaModel[bool]:
|
||||||
|
"""删除文件"""
|
||||||
|
result = await file_service.delete_file(file_id)
|
||||||
|
if not result:
|
||||||
|
return await response_base.fail(message="文件不存在或删除失败")
|
||||||
|
return await response_base.success(data=result)
|
||||||
97
backend/app/admin/api/v1/notification.py
Executable file
97
backend/app/admin/api/v1/notification.py
Executable file
@@ -0,0 +1,97 @@
|
|||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from backend.app.admin.service.notification_service import NotificationService
|
||||||
|
from backend.common.response.response_schema import response_base
|
||||||
|
from backend.common.security.jwt import DependsJwtAuth
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
class CreateNotificationRequest(BaseModel):
|
||||||
|
title: str = Field(..., min_length=1, max_length=255, description="通知标题")
|
||||||
|
content: str = Field(..., min_length=1, description="通知内容")
|
||||||
|
image_url: Optional[str] = Field(None, max_length=512, description="图片URL(可选)")
|
||||||
|
|
||||||
|
class MarkAsReadRequest(BaseModel):
|
||||||
|
notification_ids: List[int] = Field(..., description="要标记为已读的通知ID列表")
|
||||||
|
|
||||||
|
class NotificationResponse(BaseModel):
|
||||||
|
id: int
|
||||||
|
notification_id: int
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
image_url: Optional[str]
|
||||||
|
is_read: Optional[bool]
|
||||||
|
received_at: str
|
||||||
|
read_at: Optional[str]
|
||||||
|
|
||||||
|
@router.get("/list", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_notifications_api(
|
||||||
|
request: Request,
|
||||||
|
limit: int = 100
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取用户消息通知列表
|
||||||
|
"""
|
||||||
|
notifications = await NotificationService.get_user_notifications(request.user.id, limit)
|
||||||
|
return response_base.success(data=notifications)
|
||||||
|
|
||||||
|
@router.get("/unread", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_unread_notifications_api(
|
||||||
|
request: Request,
|
||||||
|
limit: int = 100
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取用户未读消息通知列表
|
||||||
|
"""
|
||||||
|
notifications = await NotificationService.get_unread_notifications(request.user.id, limit)
|
||||||
|
return response_base.success(data=notifications)
|
||||||
|
|
||||||
|
@router.get("/unread/count", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_unread_count_api(
|
||||||
|
request: Request
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取用户未读消息通知数量
|
||||||
|
"""
|
||||||
|
count = await NotificationService.get_unread_count(request.user.id)
|
||||||
|
return response_base.success(data={"count": count})
|
||||||
|
|
||||||
|
@router.post("/{notification_id}/read", dependencies=[DependsJwtAuth])
|
||||||
|
async def mark_as_read_api(
|
||||||
|
request: Request,
|
||||||
|
notification_id: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
标记指定通知为已读
|
||||||
|
"""
|
||||||
|
success = await NotificationService.mark_notification_as_read(notification_id, request.user.id)
|
||||||
|
if success:
|
||||||
|
return response_base.success(data={"msg": "标记成功"})
|
||||||
|
else:
|
||||||
|
return response_base.fail(data={"msg": "标记失败"})
|
||||||
|
|
||||||
|
# 管理员接口,用于创建通知
|
||||||
|
@router.post("/create", dependencies=[DependsJwtAuth])
|
||||||
|
async def create_notification_api(
|
||||||
|
request: Request,
|
||||||
|
create_request: CreateNotificationRequest
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
创建消息通知(管理员接口)
|
||||||
|
"""
|
||||||
|
# 这里应该添加管理员权限验证
|
||||||
|
# 为简化示例,暂时省略权限验证
|
||||||
|
|
||||||
|
notification = await NotificationService.create_notification(
|
||||||
|
create_request.title,
|
||||||
|
create_request.content,
|
||||||
|
create_request.image_url,
|
||||||
|
request.user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_base.success(data={
|
||||||
|
"id": notification.id,
|
||||||
|
"msg": "通知创建成功"
|
||||||
|
})
|
||||||
88
backend/app/admin/api/v1/user.py
Executable file
88
backend/app/admin/api/v1/user.py
Executable file
@@ -0,0 +1,88 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Query
|
||||||
|
|
||||||
|
from backend.common.security.jwt import DependsJwtAuth
|
||||||
|
from backend.common.pagination import paging_data, DependsPagination, PageData
|
||||||
|
from backend.common.response.response_schema import response_base, ResponseModel, ResponseSchemaModel
|
||||||
|
from backend.database.db import CurrentSession
|
||||||
|
from backend.app.admin.schema.user import (
|
||||||
|
RegisterUserParam,
|
||||||
|
GetUserInfoDetail,
|
||||||
|
ResetPassword,
|
||||||
|
UpdateUserParam,
|
||||||
|
AvatarParam,
|
||||||
|
)
|
||||||
|
from backend.app.admin.service.wx_user_service import WxUserService
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/register', summary='用户注册')
|
||||||
|
async def user_register(obj: RegisterUserParam) -> ResponseModel:
|
||||||
|
await WxUserService.register(obj=obj)
|
||||||
|
return response_base.success()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/password/reset', summary='密码重置', dependencies=[DependsJwtAuth])
|
||||||
|
async def password_reset(obj: ResetPassword) -> ResponseModel:
|
||||||
|
count = await WxUserService.pwd_reset(obj=obj)
|
||||||
|
if count > 0:
|
||||||
|
return response_base.success()
|
||||||
|
return response_base.fail()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/{username}', summary='查看用户信息', dependencies=[DependsJwtAuth])
|
||||||
|
async def get_user(username: str) -> ResponseSchemaModel[GetUserInfoDetail]:
|
||||||
|
data = await WxUserService.get_userinfo(username=username)
|
||||||
|
return response_base.success(data=data)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put('/{username}', summary='更新用户信息', dependencies=[DependsJwtAuth])
|
||||||
|
async def update_userinfo(username: str, obj: UpdateUserParam) -> ResponseModel:
|
||||||
|
count = await WxUserService.update(username=username, obj=obj)
|
||||||
|
if count > 0:
|
||||||
|
return response_base.success()
|
||||||
|
return response_base.fail()
|
||||||
|
|
||||||
|
|
||||||
|
@router.put('/{username}/avatar', summary='更新头像', dependencies=[DependsJwtAuth])
|
||||||
|
async def update_avatar(username: str, avatar: AvatarParam) -> ResponseModel:
|
||||||
|
count = await WxUserService.update_avatar(username=username, avatar=avatar)
|
||||||
|
if count > 0:
|
||||||
|
return response_base.success()
|
||||||
|
return response_base.fail()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
'',
|
||||||
|
summary='(模糊条件)分页获取所有用户',
|
||||||
|
dependencies=[
|
||||||
|
DependsJwtAuth,
|
||||||
|
DependsPagination,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def get_all_users(
|
||||||
|
db: CurrentSession,
|
||||||
|
username: Annotated[str | None, Query()] = None,
|
||||||
|
phone: Annotated[str | None, Query()] = None,
|
||||||
|
status: Annotated[int | None, Query()] = None,
|
||||||
|
) -> ResponseSchemaModel[PageData[GetUserInfoDetail]]:
|
||||||
|
user_select = await WxUserService.get_select(username=username, phone=phone, status=status)
|
||||||
|
page_data = await paging_data(db, user_select)
|
||||||
|
return response_base.success(data=page_data)
|
||||||
|
|
||||||
|
|
||||||
|
# @router.delete(
|
||||||
|
# path='/{username}',
|
||||||
|
# summary='用户注销',
|
||||||
|
# description='用户注销 != 用户登出,注销之后用户将从数据库删除',
|
||||||
|
# dependencies=[DependsJwtAuth],
|
||||||
|
# )
|
||||||
|
# async def delete_user(current_user: CurrentUser, username: str) -> ResponseModel:
|
||||||
|
# count = await WxUserService.delete(current_user=current_user, username=username)
|
||||||
|
# if count > 0:
|
||||||
|
# return response_base.success()
|
||||||
|
# return response_base.fail()
|
||||||
95
backend/app/admin/api/v1/wx.py
Executable file
95
backend/app/admin/api/v1/wx.py
Executable file
@@ -0,0 +1,95 @@
|
|||||||
|
# routers/wx.py
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
|
from fastapi_limiter.depends import RateLimiter
|
||||||
|
|
||||||
|
from backend.app.admin.schema.token import GetWxLoginToken
|
||||||
|
from backend.app.admin.schema.wx import WxLoginRequest, TokenResponse, UserInfo, UpdateUserSettingsRequest, GetUserSettingsResponse, DictLevel
|
||||||
|
|
||||||
|
from backend.app.admin.service.wx_service import wx_service
|
||||||
|
from backend.common.response.response_schema import response_base, ResponseSchemaModel
|
||||||
|
from backend.core.wx_integration import verify_wx_code
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login",
|
||||||
|
summary="微信登录",
|
||||||
|
dependencies=[Depends(RateLimiter(times=5, minutes=1))])
|
||||||
|
async def wechat_login(
|
||||||
|
request: Request, response: Response,
|
||||||
|
wx_request: WxLoginRequest
|
||||||
|
) -> ResponseSchemaModel[GetWxLoginToken]:
|
||||||
|
"""
|
||||||
|
微信小程序登录接口
|
||||||
|
- **code**: 微信小程序前端获取的临时code
|
||||||
|
- **encrypted_data** (可选): 加密的用户信息
|
||||||
|
- **iv** (可选): 加密算法的初始向量
|
||||||
|
"""
|
||||||
|
# 验证微信code并获取用户信息
|
||||||
|
wx_result = await verify_wx_code(wx_request.code)
|
||||||
|
if not wx_result:
|
||||||
|
raise HTTPException(status_code=401, detail="微信认证失败")
|
||||||
|
|
||||||
|
# 处理用户登录逻辑
|
||||||
|
result = await wx_service.login(
|
||||||
|
request=request, response=response,
|
||||||
|
openid=wx_result.get("openid"),
|
||||||
|
session_key=wx_result.get("session_key"),
|
||||||
|
encrypted_data=wx_request.encrypted_data,
|
||||||
|
iv=wx_request.iv
|
||||||
|
)
|
||||||
|
return response_base.success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
# @router.put("/settings", summary="更新用户设置", dependencies=[DependsJwtAuth])
|
||||||
|
# async def update_user_settings(
|
||||||
|
# request: Request,
|
||||||
|
# settings: UpdateUserSettingsRequest
|
||||||
|
# ) -> ResponseSchemaModel[None]:
|
||||||
|
# """
|
||||||
|
# 更新用户设置
|
||||||
|
# """
|
||||||
|
#
|
||||||
|
# # 更新用户设置
|
||||||
|
# await wx_service.update_user_settings(
|
||||||
|
# user_id=request.user.id,
|
||||||
|
# dict_level=settings.dict_level
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# return response_base.success()
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# @router.get("/settings", summary="获取用户设置", dependencies=[DependsJwtAuth])
|
||||||
|
# async def get_user_settings(
|
||||||
|
# request: Request
|
||||||
|
# ) -> ResponseSchemaModel[GetUserSettingsResponse]:
|
||||||
|
# """
|
||||||
|
# 获取用户设置
|
||||||
|
# """
|
||||||
|
# # 从请求中获取用户ID(实际项目中应该从JWT token中获取)
|
||||||
|
# user_id = getattr(request.state, 'user_id', None)
|
||||||
|
# if not user_id:
|
||||||
|
# raise HTTPException(status_code=401, detail="未授权访问")
|
||||||
|
#
|
||||||
|
# # 获取用户信息
|
||||||
|
# async with async_db_session() as db:
|
||||||
|
# user = await wx_user_dao.get(db, user_id)
|
||||||
|
# if not user:
|
||||||
|
# raise HTTPException(status_code=404, detail="用户不存在")
|
||||||
|
#
|
||||||
|
# # 提取设置信息
|
||||||
|
# dict_level = None
|
||||||
|
# if user.profile and isinstance(user.profile, dict):
|
||||||
|
# dict_level_value = user.profile.get("dict_level")
|
||||||
|
# if dict_level_value:
|
||||||
|
# # 将字符串值转换为枚举值
|
||||||
|
# try:
|
||||||
|
# dict_level = DictLevel(dict_level_value)
|
||||||
|
# except ValueError:
|
||||||
|
# pass # 无效值,保持为None
|
||||||
|
#
|
||||||
|
# response_data = GetUserSettingsResponse(
|
||||||
|
# dict_level=dict_level
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# return response_base.success(data=response_data)
|
||||||
150
backend/app/admin/api/v1/wxpay_callback.py
Executable file
150
backend/app/admin/api/v1/wxpay_callback.py
Executable file
@@ -0,0 +1,150 @@
|
|||||||
|
from fastapi import APIRouter, Request, BackgroundTasks
|
||||||
|
from wechatpy.exceptions import InvalidSignatureException
|
||||||
|
from backend.app.admin.crud.order_crud import order_dao
|
||||||
|
from backend.app.admin.crud.user_account_crud import user_account_dao
|
||||||
|
from backend.app.admin.model import Order, UserAccount
|
||||||
|
from backend.app.admin.service.refund_service import RefundService
|
||||||
|
from backend.app.admin.service.subscription_service import SUBSCRIPTION_PLANS
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
from backend.common.log import log as logger
|
||||||
|
from backend.app.admin.crud.usage_log_crud import usage_log_dao
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
from datetime import datetime
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
from backend.utils.wx_pay import wx_pay_utils
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
def verify_wxpay_signature(data: dict, api_key: str) -> bool:
|
||||||
|
"""
|
||||||
|
验证微信支付回调签名
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取签名
|
||||||
|
sign = data.pop('sign', None)
|
||||||
|
if not sign:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 重新计算签名
|
||||||
|
sorted_keys = sorted(data.keys())
|
||||||
|
stringA = '&'.join(f"{key}={data[key]}" for key in sorted_keys if data[key])
|
||||||
|
stringSignTemp = f"{stringA}&key={api_key}"
|
||||||
|
calculated_sign = hashlib.md5(stringSignTemp.encode('utf-8')).hexdigest().upper()
|
||||||
|
|
||||||
|
return sign == calculated_sign
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"签名验证异常: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/notify")
|
||||||
|
async def wxpay_notify(request: Request):
|
||||||
|
"""
|
||||||
|
微信支付异步通知处理(增强安全性和幂等性)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 读取原始数据
|
||||||
|
body = await request.body()
|
||||||
|
body_str = body.decode('utf-8')
|
||||||
|
|
||||||
|
# 解析微信回调数据
|
||||||
|
result = wx_pay_utils.parse_payment_result(body)
|
||||||
|
|
||||||
|
# 验证签名(增强安全性)
|
||||||
|
from backend.core.conf import settings
|
||||||
|
if not verify_wxpay_signature(result.copy(), settings.WECHAT_PAY_API_KEY):
|
||||||
|
logger.warning("微信支付回调签名验证失败")
|
||||||
|
return {"return_code": "FAIL", "return_msg": "签名验证失败"}
|
||||||
|
|
||||||
|
if result['return_code'] == 'SUCCESS' and result['result_code'] == 'SUCCESS':
|
||||||
|
out_trade_no = result['out_trade_no']
|
||||||
|
transaction_id = result['transaction_id']
|
||||||
|
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
# 使用SELECT FOR UPDATE确保并发安全
|
||||||
|
stmt = select(Order).where(Order.id == int(out_trade_no)).with_for_update()
|
||||||
|
order_result = await db.execute(stmt)
|
||||||
|
order = order_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not order:
|
||||||
|
logger.warning(f"订单不存在: {out_trade_no}")
|
||||||
|
return {"return_code": "SUCCESS"}
|
||||||
|
|
||||||
|
# 幂等性检查
|
||||||
|
if order.processed_at is not None:
|
||||||
|
logger.info(f"订单已处理过: {out_trade_no}")
|
||||||
|
return {"return_code": "SUCCESS"}
|
||||||
|
|
||||||
|
if order.status != 'pending':
|
||||||
|
logger.warning(f"订单状态异常: {out_trade_no}, status: {order.status}")
|
||||||
|
return {"return_code": "SUCCESS"}
|
||||||
|
|
||||||
|
# 更新订单状态
|
||||||
|
order.status = 'completed'
|
||||||
|
order.transaction_id = transaction_id
|
||||||
|
order.processed_at = datetime.now()
|
||||||
|
await order_dao.update(db, order.id, order)
|
||||||
|
|
||||||
|
# 如果是订阅订单,更新用户订阅信息
|
||||||
|
if order.order_type == 'subscription':
|
||||||
|
# 使用SELECT FOR UPDATE锁定用户账户
|
||||||
|
account_stmt = select(UserAccount).where(UserAccount.user_id == order.user_id).with_for_update()
|
||||||
|
account_result = await db.execute(account_stmt)
|
||||||
|
user_account = account_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if user_account:
|
||||||
|
plan = None
|
||||||
|
for key, value in SUBSCRIPTION_PLANS.items():
|
||||||
|
if value['price'] == order.amount_cents:
|
||||||
|
plan = key
|
||||||
|
break
|
||||||
|
|
||||||
|
if plan:
|
||||||
|
# 处理未用完次数(仅累计一个月)
|
||||||
|
new_balance = user_account.balance + user_account.carryover_balance
|
||||||
|
carryover = 0
|
||||||
|
|
||||||
|
# 如果是续费且当期次数未用完,则累计到下一期
|
||||||
|
if (user_account.subscription_type and
|
||||||
|
user_account.subscription_expires_at and
|
||||||
|
user_account.subscription_expires_at > datetime.now()):
|
||||||
|
# 计算当期剩余次数
|
||||||
|
remaining = max(0, user_account.balance)
|
||||||
|
carryover = min(remaining, SUBSCRIPTION_PLANS[plan]["times"])
|
||||||
|
|
||||||
|
# 更新订阅信息
|
||||||
|
user_account.subscription_type = plan
|
||||||
|
user_account.subscription_expires_at = datetime.now() + SUBSCRIPTION_PLANS[plan]["duration"]
|
||||||
|
user_account.balance = new_balance + SUBSCRIPTION_PLANS[plan]["times"]
|
||||||
|
user_account.carryover_balance = carryover
|
||||||
|
|
||||||
|
await user_account_dao.update(db, user_account.id, user_account)
|
||||||
|
|
||||||
|
# 记录使用日志
|
||||||
|
account = await user_account_dao.get_by_user_id(db, order.user_id)
|
||||||
|
await usage_log_dao.add(db, {
|
||||||
|
"user_id": order.user_id,
|
||||||
|
"action": "purchase" if order.order_type == "purchase" else "renewal",
|
||||||
|
"amount": order.amount_times,
|
||||||
|
"balance_after": account.balance if account else 0,
|
||||||
|
"related_id": order.id,
|
||||||
|
"details": {"transaction_id": transaction_id}
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"return_code": "SUCCESS"}
|
||||||
|
else:
|
||||||
|
logger.error(f"微信支付回调失败: {result}")
|
||||||
|
return {"return_code": "FAIL", "return_msg": "处理失败"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"微信支付回调处理异常: {str(e)}")
|
||||||
|
return {"return_code": "FAIL", "return_msg": "服务器异常"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refund/notify")
|
||||||
|
async def wxpay_refund_notify(request: Request):
|
||||||
|
"""
|
||||||
|
微信退款异步通知处理
|
||||||
|
"""
|
||||||
|
return await RefundService.handle_refund_notify(request)
|
||||||
6
backend/app/admin/crud/__init__.py
Executable file
6
backend/app/admin/crud/__init__.py
Executable file
@@ -0,0 +1,6 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from backend.app.admin.crud.file_crud import file_dao
|
||||||
|
from backend.app.admin.crud.daily_summary_crud import daily_summary_dao
|
||||||
|
from backend.app.admin.crud.points_crud import points_dao, points_log_dao
|
||||||
204
backend/app/admin/crud/audit_log_crud.py
Executable file
204
backend/app/admin/crud/audit_log_crud.py
Executable file
@@ -0,0 +1,204 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from datetime import datetime, timedelta, date
|
||||||
|
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy_crud_plus import CRUDPlus
|
||||||
|
|
||||||
|
from backend.app.admin.model.audit_log import AuditLog, DailySummary
|
||||||
|
from backend.app.admin.schema.audit_log import CreateAuditLogParam, AuditLogStatisticsSchema, CreateDailySummaryParam
|
||||||
|
from backend.app.ai import Image
|
||||||
|
|
||||||
|
|
||||||
|
class CRUDAuditLog(CRUDPlus[AuditLog]):
|
||||||
|
async def create(self, db: AsyncSession, obj: CreateAuditLogParam) -> None:
|
||||||
|
"""
|
||||||
|
创建操作日志
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param obj: 创建操作日志参数
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
await self.create_model(db, obj)
|
||||||
|
|
||||||
|
async def get_user_recognition_history(self, db: AsyncSession, user_id: int, page: int = 1, size: int = 20):
|
||||||
|
"""
|
||||||
|
通过用户ID查询历史记录,联合image表,查找api_type='recognition'的记录
|
||||||
|
一个image可能有多个audit_log记录,找出其中called_at最早的
|
||||||
|
支持分页查询
|
||||||
|
返回的数据有image.thumbnail_id, image.created_time, audit_log.dict_level
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param user_id: 用户ID
|
||||||
|
:param page: 页码
|
||||||
|
:param size: 每页数量
|
||||||
|
:return: 查询结果和总数
|
||||||
|
"""
|
||||||
|
# 子查询:找出每个image_id对应的最早called_at记录
|
||||||
|
subquery = (
|
||||||
|
select(
|
||||||
|
AuditLog.image_id,
|
||||||
|
func.min(AuditLog.called_at).label('earliest_called_at')
|
||||||
|
)
|
||||||
|
.where(AuditLog.user_id == user_id, AuditLog.api_type == 'recognition')
|
||||||
|
.group_by(AuditLog.image_id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 主查询:关联image表和audit_log表,获取所需字段
|
||||||
|
stmt = (
|
||||||
|
select(
|
||||||
|
Image.id,
|
||||||
|
Image.thumbnail_id,
|
||||||
|
Image.file_id,
|
||||||
|
Image.created_time,
|
||||||
|
AuditLog.dict_level
|
||||||
|
)
|
||||||
|
.join(Image, AuditLog.image_id == Image.id)
|
||||||
|
.join(
|
||||||
|
subquery,
|
||||||
|
(AuditLog.image_id == subquery.c.image_id) &
|
||||||
|
(AuditLog.called_at == subquery.c.earliest_called_at)
|
||||||
|
)
|
||||||
|
.where(AuditLog.user_id == user_id, AuditLog.api_type == 'recognition')
|
||||||
|
.order_by(AuditLog.called_at.desc(), AuditLog.id.desc())
|
||||||
|
.offset((page - 1) * size)
|
||||||
|
.limit(size)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
items = result.fetchall()
|
||||||
|
|
||||||
|
# 获取总数
|
||||||
|
count_stmt = (
|
||||||
|
select(func.count(func.distinct(AuditLog.image_id)))
|
||||||
|
.where(AuditLog.user_id == user_id, AuditLog.api_type == 'recognition')
|
||||||
|
)
|
||||||
|
total_result = await db.execute(count_stmt)
|
||||||
|
total = total_result.scalar()
|
||||||
|
|
||||||
|
return items, total
|
||||||
|
|
||||||
|
async def get_user_today_recognition_history(self, db: AsyncSession, user_id: int, page: int = 1, size: int = 20):
|
||||||
|
"""
|
||||||
|
通过用户ID查询当天的识别记录,联合image表,查找api_type='recognition'的记录
|
||||||
|
一个image可能有多个audit_log记录,找出其中called_at最早的
|
||||||
|
支持分页查询
|
||||||
|
返回的数据有image.thumbnail_id, image.created_time, audit_log.dict_level
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param user_id: 用户ID
|
||||||
|
:param page: 页码
|
||||||
|
:param size: 每页数量
|
||||||
|
:return: 查询结果和总数
|
||||||
|
"""
|
||||||
|
# 获取当天的开始时间
|
||||||
|
today = date.today()
|
||||||
|
today_start = datetime(today.year, today.month, today.day)
|
||||||
|
|
||||||
|
# 子查询:找出每个image_id对应的最早called_at记录(仅当天)
|
||||||
|
subquery = (
|
||||||
|
select(
|
||||||
|
AuditLog.image_id,
|
||||||
|
func.min(AuditLog.called_at).label('earliest_called_at')
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
AuditLog.user_id == user_id,
|
||||||
|
AuditLog.api_type == 'recognition',
|
||||||
|
AuditLog.called_at >= today_start
|
||||||
|
)
|
||||||
|
.group_by(AuditLog.image_id)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 主查询:关联image表和audit_log表,获取所需字段(仅当天)
|
||||||
|
stmt = (
|
||||||
|
select(
|
||||||
|
Image.id,
|
||||||
|
Image.thumbnail_id,
|
||||||
|
Image.file_id,
|
||||||
|
Image.created_time,
|
||||||
|
AuditLog.dict_level
|
||||||
|
)
|
||||||
|
.join(Image, AuditLog.image_id == Image.id)
|
||||||
|
.join(
|
||||||
|
subquery,
|
||||||
|
(AuditLog.image_id == subquery.c.image_id) &
|
||||||
|
(AuditLog.called_at == subquery.c.earliest_called_at)
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
AuditLog.user_id == user_id,
|
||||||
|
AuditLog.api_type == 'recognition',
|
||||||
|
AuditLog.called_at >= today_start
|
||||||
|
)
|
||||||
|
.order_by(AuditLog.called_at.desc(), AuditLog.id.desc())
|
||||||
|
.offset((page - 1) * size)
|
||||||
|
.limit(size)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
items = result.fetchall()
|
||||||
|
|
||||||
|
# 获取当天总数
|
||||||
|
count_stmt = (
|
||||||
|
select(func.count(func.distinct(AuditLog.image_id)))
|
||||||
|
.where(
|
||||||
|
AuditLog.user_id == user_id,
|
||||||
|
AuditLog.api_type == 'recognition',
|
||||||
|
AuditLog.called_at >= today_start
|
||||||
|
)
|
||||||
|
)
|
||||||
|
total_result = await db.execute(count_stmt)
|
||||||
|
total = total_result.scalar()
|
||||||
|
|
||||||
|
return items, total
|
||||||
|
|
||||||
|
async def get_user_recognition_statistics(self, db: AsyncSession, user_id: int):
|
||||||
|
"""
|
||||||
|
统计用户 recognition 类型的使用记录
|
||||||
|
返回历史总量和当天总量
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param user_id: 用户ID
|
||||||
|
:return: (历史总量, 当天总量)
|
||||||
|
"""
|
||||||
|
# 获取历史总量
|
||||||
|
total_stmt = (
|
||||||
|
select(func.count())
|
||||||
|
.where(AuditLog.user_id == user_id, AuditLog.api_type == 'recognition')
|
||||||
|
)
|
||||||
|
total_result = await db.execute(total_stmt)
|
||||||
|
total_count = total_result.scalar()
|
||||||
|
|
||||||
|
# 获取当天总量
|
||||||
|
today = date.today()
|
||||||
|
today_start = datetime(today.year, today.month, today.day)
|
||||||
|
|
||||||
|
today_stmt = (
|
||||||
|
select(func.count())
|
||||||
|
.where(
|
||||||
|
AuditLog.user_id == user_id,
|
||||||
|
AuditLog.api_type == 'recognition',
|
||||||
|
AuditLog.called_at >= today_start
|
||||||
|
)
|
||||||
|
)
|
||||||
|
today_result = await db.execute(today_stmt)
|
||||||
|
today_count = today_result.scalar()
|
||||||
|
|
||||||
|
# 获取总数
|
||||||
|
count_stmt = (
|
||||||
|
select(func.count(func.distinct(AuditLog.image_id)))
|
||||||
|
.where(AuditLog.user_id == user_id, AuditLog.api_type == 'recognition')
|
||||||
|
)
|
||||||
|
count_stmt = (
|
||||||
|
select(func.count(func.distinct(AuditLog.image_id)))
|
||||||
|
.where(AuditLog.user_id == user_id, AuditLog.api_type == 'recognition')
|
||||||
|
)
|
||||||
|
count_result = await db.execute(count_stmt)
|
||||||
|
image_count = count_result.scalar()
|
||||||
|
|
||||||
|
return total_count, today_count, image_count
|
||||||
|
|
||||||
|
|
||||||
|
audit_log_dao: CRUDAuditLog = CRUDAuditLog(AuditLog)
|
||||||
105
backend/app/admin/crud/coupon_crud.py
Executable file
105
backend/app/admin/crud/coupon_crud.py
Executable file
@@ -0,0 +1,105 @@
|
|||||||
|
from typing import Optional, List
|
||||||
|
from sqlalchemy import select, and_, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy_crud_plus import CRUDPlus
|
||||||
|
from backend.app.admin.model.coupon import Coupon, CouponUsage
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class CouponDao(CRUDPlus[Coupon]):
|
||||||
|
|
||||||
|
async def get_by_code(self, db: AsyncSession, code: str) -> Optional[Coupon]:
|
||||||
|
"""
|
||||||
|
根据兑换码获取兑换券
|
||||||
|
"""
|
||||||
|
stmt = select(Coupon).where(Coupon.code == code)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_unused_coupon_by_code(self, db: AsyncSession, code: str) -> Optional[Coupon]:
|
||||||
|
"""
|
||||||
|
根据兑换码获取未使用的兑换券
|
||||||
|
"""
|
||||||
|
stmt = select(Coupon).where(
|
||||||
|
and_(
|
||||||
|
Coupon.code == code,
|
||||||
|
Coupon.is_used == False,
|
||||||
|
(Coupon.expires_at.is_(None)) | (Coupon.expires_at > datetime.now())
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def create_coupon(self, db: AsyncSession, coupon_data: dict) -> Coupon:
|
||||||
|
"""
|
||||||
|
创建兑换券
|
||||||
|
"""
|
||||||
|
coupon = Coupon(**coupon_data)
|
||||||
|
db.add(coupon)
|
||||||
|
await db.flush()
|
||||||
|
return coupon
|
||||||
|
|
||||||
|
async def create_coupons(self, db: AsyncSession, coupons_data: List[dict]) -> List[Coupon]:
|
||||||
|
"""
|
||||||
|
批量创建兑换券
|
||||||
|
"""
|
||||||
|
coupons = [Coupon(**data) for data in coupons_data]
|
||||||
|
db.add_all(coupons)
|
||||||
|
await db.flush()
|
||||||
|
return coupons
|
||||||
|
|
||||||
|
async def mark_as_used(self, db: AsyncSession, coupon_id: int, user_id: int, duration: int) -> bool:
|
||||||
|
"""
|
||||||
|
标记兑换券为已使用并创建使用记录
|
||||||
|
"""
|
||||||
|
# 更新兑换券状态
|
||||||
|
stmt = update(Coupon).where(Coupon.id == coupon_id).values(is_used=True)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
|
||||||
|
if result.rowcount == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 创建使用记录
|
||||||
|
usage = CouponUsage(
|
||||||
|
coupon_id=coupon_id,
|
||||||
|
user_id=user_id,
|
||||||
|
duration=duration
|
||||||
|
)
|
||||||
|
db.add(usage)
|
||||||
|
await db.flush()
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_user_coupons(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[CouponUsage]:
|
||||||
|
"""
|
||||||
|
获取用户使用的兑换券记录
|
||||||
|
"""
|
||||||
|
stmt = select(CouponUsage).where(
|
||||||
|
CouponUsage.user_id == user_id
|
||||||
|
).order_by(CouponUsage.used_at.desc()).limit(limit)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
class CouponUsageDao(CRUDPlus[CouponUsage]):
|
||||||
|
|
||||||
|
async def get_by_user_id(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[CouponUsage]:
|
||||||
|
"""
|
||||||
|
根据用户ID获取兑换记录
|
||||||
|
"""
|
||||||
|
stmt = select(CouponUsage).where(
|
||||||
|
CouponUsage.user_id == user_id
|
||||||
|
).order_by(CouponUsage.used_at.desc()).limit(limit)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_by_coupon_id(self, db: AsyncSession, coupon_id: int) -> Optional[CouponUsage]:
|
||||||
|
"""
|
||||||
|
根据兑换券ID获取使用记录
|
||||||
|
"""
|
||||||
|
stmt = select(CouponUsage).where(CouponUsage.coupon_id == coupon_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
coupon_dao = CouponDao(Coupon)
|
||||||
|
coupon_usage_dao = CouponUsageDao(CouponUsage)
|
||||||
119
backend/app/admin/crud/crud_data_scope.py
Normal file
119
backend/app/admin/crud/crud_data_scope.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from sqlalchemy import Select, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy_crud_plus import CRUDPlus
|
||||||
|
|
||||||
|
from backend.app.admin.model import DataRule, DataScope
|
||||||
|
from backend.app.admin.schema.data_scope import CreateDataScopeParam, UpdateDataScopeParam, UpdateDataScopeRuleParam
|
||||||
|
|
||||||
|
|
||||||
|
class CRUDDataScope(CRUDPlus[DataScope]):
|
||||||
|
"""数据范围数据库操作类"""
|
||||||
|
|
||||||
|
async def get(self, db: AsyncSession, pk: int) -> DataScope | None:
|
||||||
|
"""
|
||||||
|
获取数据范围详情
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param pk: 范围 ID
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.select_model(db, pk)
|
||||||
|
|
||||||
|
async def get_by_name(self, db: AsyncSession, name: str) -> DataScope | None:
|
||||||
|
"""
|
||||||
|
通过名称获取数据范围
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param name: 范围名称
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.select_model_by_column(db, name=name)
|
||||||
|
|
||||||
|
async def get_with_relation(self, db: AsyncSession, pk: int) -> DataScope:
|
||||||
|
"""
|
||||||
|
获取数据范围关联数据
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param pk: 范围 ID
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.select_model(db, pk, load_strategies=['rules'])
|
||||||
|
|
||||||
|
async def get_all(self, db: AsyncSession) -> Sequence[DataScope]:
|
||||||
|
"""
|
||||||
|
获取所有数据范围
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.select_models(db)
|
||||||
|
|
||||||
|
async def get_list(self, name: str | None, status: int | None) -> Select:
|
||||||
|
"""
|
||||||
|
获取数据范围列表
|
||||||
|
|
||||||
|
:param name: 范围名称
|
||||||
|
:param status: 范围状态
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
filters = {}
|
||||||
|
|
||||||
|
if name is not None:
|
||||||
|
filters['name__like'] = f'%{name}%'
|
||||||
|
if status is not None:
|
||||||
|
filters['status'] = status
|
||||||
|
|
||||||
|
return await self.select_order('id', load_strategies={'rules': 'noload', 'roles': 'noload'}, **filters)
|
||||||
|
|
||||||
|
async def create(self, db: AsyncSession, obj: CreateDataScopeParam) -> None:
|
||||||
|
"""
|
||||||
|
创建数据范围
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param obj: 创建数据范围参数
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
await self.create_model(db, obj)
|
||||||
|
|
||||||
|
async def update(self, db: AsyncSession, pk: int, obj: UpdateDataScopeParam) -> int:
|
||||||
|
"""
|
||||||
|
更新数据范围
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param pk: 范围 ID
|
||||||
|
:param obj: 更新数据范围参数
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.update_model(db, pk, obj)
|
||||||
|
|
||||||
|
async def update_rules(self, db: AsyncSession, pk: int, rule_ids: UpdateDataScopeRuleParam) -> int:
|
||||||
|
"""
|
||||||
|
更新数据范围规则
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param pk: 范围 ID
|
||||||
|
:param rule_ids: 数据规则 ID 列表
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
current_data_scope = await self.get_with_relation(db, pk)
|
||||||
|
stmt = select(DataRule).where(DataRule.id.in_(rule_ids.rules))
|
||||||
|
rules = await db.execute(stmt)
|
||||||
|
current_data_scope.rules = rules.scalars().all()
|
||||||
|
return len(current_data_scope.rules)
|
||||||
|
|
||||||
|
async def delete(self, db: AsyncSession, pks: list[int]) -> int:
|
||||||
|
"""
|
||||||
|
批量删除数据范围
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param pks: 范围 ID 列表
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.delete_model_by_column(db, allow_multiple=True, id__in=pks)
|
||||||
|
|
||||||
|
|
||||||
|
data_scope_dao: CRUDDataScope = CRUDDataScope(DataScope)
|
||||||
46
backend/app/admin/crud/daily_summary_crud.py
Normal file
46
backend/app/admin/crud/daily_summary_crud.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy_crud_plus import CRUDPlus
|
||||||
|
from sqlalchemy import select, func, desc
|
||||||
|
from datetime import datetime, date
|
||||||
|
|
||||||
|
from backend.app.admin.model.audit_log import DailySummary
|
||||||
|
from backend.app.admin.schema.audit_log import DailySummarySchema
|
||||||
|
|
||||||
|
|
||||||
|
class DailySummaryCRUD(CRUDPlus[DailySummary]):
|
||||||
|
""" Daily Summary CRUD """
|
||||||
|
|
||||||
|
async def get_user_daily_summaries(self, db: AsyncSession, user_id: int, page: int = 1, size: int = 20):
|
||||||
|
"""
|
||||||
|
获取用户的每日汇总记录,按创建时间降序排列
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param user_id: 用户ID
|
||||||
|
:param page: 页码
|
||||||
|
:param size: 每页数量
|
||||||
|
:return: 查询结果和总数
|
||||||
|
"""
|
||||||
|
# 主查询:获取用户每日汇总记录
|
||||||
|
stmt = (
|
||||||
|
select(DailySummary)
|
||||||
|
.where(DailySummary.user_id == user_id)
|
||||||
|
.order_by(desc(DailySummary.created_time))
|
||||||
|
.offset((page - 1) * size)
|
||||||
|
.limit(size)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
items = result.scalars().all()
|
||||||
|
|
||||||
|
# 获取总数
|
||||||
|
count_stmt = (
|
||||||
|
select(func.count())
|
||||||
|
.where(DailySummary.user_id == user_id)
|
||||||
|
)
|
||||||
|
total_result = await db.execute(count_stmt)
|
||||||
|
total = total_result.scalar()
|
||||||
|
|
||||||
|
return items, total
|
||||||
|
|
||||||
|
|
||||||
|
daily_summary_dao: DailySummaryCRUD = DailySummaryCRUD(DailySummary)
|
||||||
42
backend/app/admin/crud/dict_crud.py
Executable file
42
backend/app/admin/crud/dict_crud.py
Executable file
@@ -0,0 +1,42 @@
|
|||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from sqlalchemy import select, and_
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from backend.app.admin.model.dict import DictionaryEntry, DictionaryMedia
|
||||||
|
|
||||||
|
|
||||||
|
class DictCRUD:
|
||||||
|
async def get_by_word(self, db: AsyncSession, word: str) -> Optional[DictionaryEntry]:
|
||||||
|
"""根据单词查询字典条目"""
|
||||||
|
query = select(DictionaryEntry).where(DictionaryEntry.word == word)
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
async def get_by_id(self, db: AsyncSession, entry_id: int) -> Optional[DictionaryEntry]:
|
||||||
|
"""根据ID获取字典条目"""
|
||||||
|
return await db.get(DictionaryEntry, entry_id)
|
||||||
|
|
||||||
|
async def get_media_by_dict_id(self, db: AsyncSession, dict_id: int) -> List[DictionaryMedia]:
|
||||||
|
"""根据字典条目ID获取相关媒体文件"""
|
||||||
|
query = select(DictionaryMedia).where(DictionaryMedia.dict_id == dict_id)
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_media_by_filename(self, db: AsyncSession, filename: str) -> Optional[DictionaryMedia]:
|
||||||
|
"""根据文件名获取媒体文件"""
|
||||||
|
query = select(DictionaryMedia).where(DictionaryMedia.file_name == filename)
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
async def search_words(self, db: AsyncSession, word_pattern: str, limit: int = 10) -> List[DictionaryEntry]:
|
||||||
|
"""模糊搜索单词"""
|
||||||
|
query = select(DictionaryEntry).where(
|
||||||
|
DictionaryEntry.word.ilike(f'%{word_pattern}%')
|
||||||
|
).limit(limit)
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
dict_dao = DictCRUD()
|
||||||
119
backend/app/admin/crud/feedback_crud.py
Executable file
119
backend/app/admin/crud/feedback_crud.py
Executable file
@@ -0,0 +1,119 @@
|
|||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from sqlalchemy import select, func, and_
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from backend.app.admin.model.feedback import Feedback
|
||||||
|
from backend.app.admin.schema.feedback import CreateFeedbackParam, UpdateFeedbackParam
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackCRUD:
|
||||||
|
async def get(self, db: AsyncSession, feedback_id: int) -> Optional[Feedback]:
|
||||||
|
"""根据ID获取反馈"""
|
||||||
|
return await db.get(Feedback, feedback_id)
|
||||||
|
|
||||||
|
async def get_list(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: Optional[int] = None,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
category: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: Optional[int] = None
|
||||||
|
) -> List[Feedback]:
|
||||||
|
"""获取反馈列表"""
|
||||||
|
query = select(Feedback).order_by(Feedback.created_at.desc())
|
||||||
|
|
||||||
|
# 添加过滤条件
|
||||||
|
filters = []
|
||||||
|
if user_id is not None:
|
||||||
|
filters.append(Feedback.user_id == user_id)
|
||||||
|
if status is not None:
|
||||||
|
filters.append(Feedback.status == status)
|
||||||
|
if category is not None:
|
||||||
|
filters.append(Feedback.category == category)
|
||||||
|
|
||||||
|
if filters:
|
||||||
|
query = query.where(and_(*filters))
|
||||||
|
|
||||||
|
# 添加分页
|
||||||
|
if limit is not None:
|
||||||
|
query = query.limit(limit)
|
||||||
|
if offset is not None:
|
||||||
|
query = query.offset(offset)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def create(self, db: AsyncSession, user_id: int, obj_in: CreateFeedbackParam) -> Feedback:
|
||||||
|
"""创建反馈"""
|
||||||
|
db_obj = Feedback(
|
||||||
|
user_id=user_id,
|
||||||
|
content=obj_in.content,
|
||||||
|
contact_info=obj_in.contact_info,
|
||||||
|
category=obj_in.category.value if obj_in.category else None,
|
||||||
|
metadata_info=obj_in.metadata_info
|
||||||
|
)
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.flush()
|
||||||
|
return db_obj
|
||||||
|
|
||||||
|
async def update(self, db: AsyncSession, feedback_id: int, obj_in: UpdateFeedbackParam) -> int:
|
||||||
|
"""更新反馈"""
|
||||||
|
query = select(Feedback).where(Feedback.id == feedback_id)
|
||||||
|
result = await db.execute(query)
|
||||||
|
db_obj = result.scalars().first()
|
||||||
|
|
||||||
|
if db_obj:
|
||||||
|
update_data = obj_in.model_dump(exclude_unset=True)
|
||||||
|
# 处理枚举类型
|
||||||
|
if 'category' in update_data and update_data['category']:
|
||||||
|
update_data['category'] = update_data['category'].value
|
||||||
|
if 'status' in update_data and update_data['status']:
|
||||||
|
update_data['status'] = update_data['status'].value
|
||||||
|
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(db_obj, field, value)
|
||||||
|
await db.flush()
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def delete(self, db: AsyncSession, feedback_id: int) -> int:
|
||||||
|
"""删除反馈"""
|
||||||
|
query = select(Feedback).where(Feedback.id == feedback_id)
|
||||||
|
result = await db.execute(query)
|
||||||
|
db_obj = result.scalars().first()
|
||||||
|
|
||||||
|
if db_obj:
|
||||||
|
await db.delete(db_obj)
|
||||||
|
await db.flush()
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def count(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: Optional[int] = None,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
category: Optional[str] = None
|
||||||
|
) -> int:
|
||||||
|
"""统计反馈数量"""
|
||||||
|
query = select(func.count(Feedback.id))
|
||||||
|
|
||||||
|
# 添加过滤条件
|
||||||
|
filters = []
|
||||||
|
if user_id is not None:
|
||||||
|
filters.append(Feedback.user_id == user_id)
|
||||||
|
if status is not None:
|
||||||
|
filters.append(Feedback.status == status)
|
||||||
|
if category is not None:
|
||||||
|
filters.append(Feedback.category == category)
|
||||||
|
|
||||||
|
if filters:
|
||||||
|
query = query.where(and_(*filters))
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalar_one()
|
||||||
|
|
||||||
|
|
||||||
|
feedback_dao = FeedbackCRUD()
|
||||||
60
backend/app/admin/crud/file_crud.py
Executable file
60
backend/app/admin/crud/file_crud.py
Executable file
@@ -0,0 +1,60 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from backend.app.admin.model.file import File
|
||||||
|
from backend.app.admin.schema.file import AddFileParam, UpdateFileParam
|
||||||
|
|
||||||
|
|
||||||
|
class FileCRUD:
|
||||||
|
async def get(self, db: AsyncSession, file_id: int) -> Optional[File]:
|
||||||
|
"""根据ID获取文件"""
|
||||||
|
return await db.get(File, file_id)
|
||||||
|
|
||||||
|
async def get_by_hash(self, db: AsyncSession, file_hash: str) -> Optional[File]:
|
||||||
|
"""根据哈希值获取文件"""
|
||||||
|
query = select(File).where(File.file_hash == file_hash)
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
async def create(self, db: AsyncSession, obj_in: AddFileParam) -> File:
|
||||||
|
"""创建文件记录"""
|
||||||
|
db_obj = File(**obj_in.model_dump())
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.flush()
|
||||||
|
return db_obj
|
||||||
|
|
||||||
|
async def update(self, db: AsyncSession, file_id: int, obj_in: UpdateFileParam) -> int:
|
||||||
|
"""更新文件记录"""
|
||||||
|
query = select(File).where(File.id == file_id)
|
||||||
|
result = await db.execute(query)
|
||||||
|
db_obj = result.scalars().first()
|
||||||
|
|
||||||
|
if db_obj:
|
||||||
|
for field, value in obj_in.model_dump(exclude_unset=True).items():
|
||||||
|
setattr(db_obj, field, value)
|
||||||
|
await db.flush()
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def delete(self, db: AsyncSession, file_id: int) -> int:
|
||||||
|
"""删除文件记录"""
|
||||||
|
query = select(File).where(File.id == file_id)
|
||||||
|
result = await db.execute(query)
|
||||||
|
db_obj = result.scalars().first()
|
||||||
|
|
||||||
|
if db_obj:
|
||||||
|
await db.delete(db_obj)
|
||||||
|
await db.flush()
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def count(self, db: AsyncSession) -> int:
|
||||||
|
"""统计文件数量"""
|
||||||
|
query = select(func.count(File.id))
|
||||||
|
result = await db.execute(query)
|
||||||
|
return result.scalar_one()
|
||||||
|
|
||||||
|
|
||||||
|
file_dao = FileCRUD()
|
||||||
38
backend/app/admin/crud/freeze_log_crud.py
Executable file
38
backend/app/admin/crud/freeze_log_crud.py
Executable file
@@ -0,0 +1,38 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy_crud_plus import CRUDPlus
|
||||||
|
from backend.app.admin.model.order import FreezeLog
|
||||||
|
|
||||||
|
|
||||||
|
class FreezeLogDao(CRUDPlus[FreezeLog]):
|
||||||
|
|
||||||
|
async def get_by_id(self, db: AsyncSession, freeze_id: int) -> Optional[FreezeLog]:
|
||||||
|
"""
|
||||||
|
根据ID获取冻结记录
|
||||||
|
"""
|
||||||
|
stmt = select(FreezeLog).where(FreezeLog.id == freeze_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_by_order_id(self, db: AsyncSession, order_id: int) -> Optional[FreezeLog]:
|
||||||
|
"""
|
||||||
|
根据订单ID获取冻结记录
|
||||||
|
"""
|
||||||
|
stmt = select(FreezeLog).where(FreezeLog.order_id == order_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_pending_by_user(self, db: AsyncSession, user_id: int) -> list[FreezeLog]:
|
||||||
|
"""
|
||||||
|
获取用户所有待处理的冻结记录
|
||||||
|
"""
|
||||||
|
stmt = select(FreezeLog).where(
|
||||||
|
FreezeLog.user_id == user_id,
|
||||||
|
FreezeLog.status == "pending"
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
freeze_log_dao = FreezeLogDao(FreezeLog)
|
||||||
135
backend/app/admin/crud/notification_crud.py
Executable file
135
backend/app/admin/crud/notification_crud.py
Executable file
@@ -0,0 +1,135 @@
|
|||||||
|
from typing import Optional, List
|
||||||
|
from sqlalchemy import select, and_, update, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy_crud_plus import CRUDPlus
|
||||||
|
from backend.app.admin.model.notification import Notification, UserNotification
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationDao(CRUDPlus[Notification]):
|
||||||
|
|
||||||
|
async def create_notification(self, db: AsyncSession, notification_data: dict) -> Notification:
|
||||||
|
"""
|
||||||
|
创建消息通知
|
||||||
|
"""
|
||||||
|
notification = Notification(**notification_data)
|
||||||
|
db.add(notification)
|
||||||
|
await db.flush()
|
||||||
|
return notification
|
||||||
|
|
||||||
|
async def get_active_notifications(self, db: AsyncSession, limit: int = 100) -> List[Notification]:
|
||||||
|
"""
|
||||||
|
获取激活的通知列表
|
||||||
|
"""
|
||||||
|
stmt = select(Notification).where(
|
||||||
|
Notification.is_active == True
|
||||||
|
).order_by(Notification.created_at.desc()).limit(limit)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_notification_by_id(self, db: AsyncSession, notification_id: int) -> Optional[Notification]:
|
||||||
|
"""
|
||||||
|
根据ID获取通知
|
||||||
|
"""
|
||||||
|
stmt = select(Notification).where(Notification.id == notification_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
class UserNotificationDao(CRUDPlus[UserNotification]):
|
||||||
|
|
||||||
|
async def create_user_notification(self, db: AsyncSession, user_notification_data: dict) -> UserNotification:
|
||||||
|
"""
|
||||||
|
创建用户通知关联
|
||||||
|
"""
|
||||||
|
user_notification = UserNotification(**user_notification_data)
|
||||||
|
db.add(user_notification)
|
||||||
|
await db.flush()
|
||||||
|
return user_notification
|
||||||
|
|
||||||
|
async def create_user_notifications(self, db: AsyncSession, user_notifications_data: List[dict]) -> List[UserNotification]:
|
||||||
|
"""
|
||||||
|
批量创建用户通知关联
|
||||||
|
"""
|
||||||
|
user_notifications = [UserNotification(**data) for data in user_notifications_data]
|
||||||
|
db.add_all(user_notifications)
|
||||||
|
await db.flush()
|
||||||
|
return user_notifications
|
||||||
|
|
||||||
|
async def get_user_notifications(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[UserNotification]:
|
||||||
|
"""
|
||||||
|
获取用户的通知列表
|
||||||
|
"""
|
||||||
|
stmt = select(UserNotification).where(
|
||||||
|
UserNotification.user_id == user_id
|
||||||
|
).order_by(UserNotification.received_at.desc()).limit(limit)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_unread_notifications(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[UserNotification]:
|
||||||
|
"""
|
||||||
|
获取用户未读通知列表
|
||||||
|
"""
|
||||||
|
stmt = select(UserNotification).where(
|
||||||
|
and_(
|
||||||
|
UserNotification.user_id == user_id,
|
||||||
|
UserNotification.is_read == False
|
||||||
|
)
|
||||||
|
).order_by(UserNotification.received_at.desc()).limit(limit)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def mark_as_read(self, db: AsyncSession, user_notification_id: int) -> bool:
|
||||||
|
"""
|
||||||
|
标记通知为已读
|
||||||
|
"""
|
||||||
|
stmt = update(UserNotification).where(
|
||||||
|
UserNotification.id == user_notification_id
|
||||||
|
).values(
|
||||||
|
is_read=True,
|
||||||
|
read_at=datetime.now()
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def mark_multiple_as_read(self, db: AsyncSession, user_id: int, notification_ids: List[int]) -> int:
|
||||||
|
"""
|
||||||
|
批量标记通知为已读
|
||||||
|
"""
|
||||||
|
stmt = update(UserNotification).where(
|
||||||
|
and_(
|
||||||
|
UserNotification.user_id == user_id,
|
||||||
|
UserNotification.notification_id.in_(notification_ids),
|
||||||
|
UserNotification.is_read == False
|
||||||
|
)
|
||||||
|
).values(
|
||||||
|
is_read=True,
|
||||||
|
read_at=datetime.now()
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
async def get_unread_count(self, db: AsyncSession, user_id: int) -> int:
|
||||||
|
"""
|
||||||
|
获取用户未读通知数量
|
||||||
|
"""
|
||||||
|
stmt = select(func.count(UserNotification.id)).where(
|
||||||
|
and_(
|
||||||
|
UserNotification.user_id == user_id,
|
||||||
|
UserNotification.is_read == False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar() or 0
|
||||||
|
|
||||||
|
async def get_user_notification_by_id(self, db: AsyncSession, user_notification_id: int) -> Optional[UserNotification]:
|
||||||
|
"""
|
||||||
|
根据ID获取用户通知关联记录
|
||||||
|
"""
|
||||||
|
stmt = select(UserNotification).where(UserNotification.id == user_notification_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
notification_dao = NotificationDao(Notification)
|
||||||
|
user_notification_dao = UserNotificationDao(UserNotification)
|
||||||
94
backend/app/admin/crud/order_crud.py
Executable file
94
backend/app/admin/crud/order_crud.py
Executable file
@@ -0,0 +1,94 @@
|
|||||||
|
from typing import Optional, List
|
||||||
|
from sqlalchemy import select, and_
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy_crud_plus import CRUDPlus
|
||||||
|
from backend.app.admin.model.order import Order
|
||||||
|
|
||||||
|
|
||||||
|
class OrderDao(CRUDPlus[Order]):
|
||||||
|
|
||||||
|
async def get_by_id(self, db: AsyncSession, order_id: int) -> Optional[Order]:
|
||||||
|
"""
|
||||||
|
根据ID获取订单
|
||||||
|
"""
|
||||||
|
stmt = select(Order).where(Order.id == order_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_by_user_id(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[Order]:
|
||||||
|
"""
|
||||||
|
根据用户ID获取订单列表
|
||||||
|
"""
|
||||||
|
stmt = select(Order).where(
|
||||||
|
Order.user_id == user_id
|
||||||
|
).order_by(Order.created_at.desc()).limit(limit)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_pending_orders(self, db: AsyncSession, user_id: int) -> List[Order]:
|
||||||
|
"""
|
||||||
|
获取用户所有待处理订单
|
||||||
|
"""
|
||||||
|
stmt = select(Order).where(
|
||||||
|
and_(
|
||||||
|
Order.user_id == user_id,
|
||||||
|
Order.status == "pending"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_completed_orders(self, db: AsyncSession, user_id: int, limit: int = 50) -> List[Order]:
|
||||||
|
"""
|
||||||
|
获取用户已完成的订单
|
||||||
|
"""
|
||||||
|
stmt = select(Order).where(
|
||||||
|
and_(
|
||||||
|
Order.user_id == user_id,
|
||||||
|
Order.status == "completed"
|
||||||
|
)
|
||||||
|
).order_by(Order.created_at.desc()).limit(limit)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_order_by_payment_id(self, db: AsyncSession, payment_id: str) -> Optional[Order]:
|
||||||
|
"""
|
||||||
|
根据支付ID获取订单
|
||||||
|
"""
|
||||||
|
stmt = select(Order).where(Order.payment_id == payment_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_order_by_transaction_id(self, db: AsyncSession, transaction_id: str) -> Optional[Order]:
|
||||||
|
"""
|
||||||
|
根据交易ID获取订单
|
||||||
|
"""
|
||||||
|
stmt = select(Order).where(Order.transaction_id == transaction_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def update_order_status(self, db: AsyncSession, order_id: int, status: str) -> bool:
|
||||||
|
"""
|
||||||
|
更新订单状态
|
||||||
|
"""
|
||||||
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
stmt = update(Order).where(Order.id == order_id).values(status=status)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def get_subscription_orders(self, db: AsyncSession, user_id: int) -> List[Order]:
|
||||||
|
"""
|
||||||
|
获取用户所有订阅订单
|
||||||
|
"""
|
||||||
|
stmt = select(Order).where(
|
||||||
|
and_(
|
||||||
|
Order.user_id == user_id,
|
||||||
|
Order.order_type == "subscription"
|
||||||
|
)
|
||||||
|
).order_by(Order.created_at.desc())
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
order_dao = OrderDao(Order)
|
||||||
103
backend/app/admin/crud/points_crud.py
Normal file
103
backend/app/admin/crud/points_crud.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from sqlalchemy import select, update, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy_crud_plus import CRUDPlus
|
||||||
|
from backend.app.admin.model.points import Points, PointsLog
|
||||||
|
|
||||||
|
|
||||||
|
class PointsDao(CRUDPlus[Points]):
|
||||||
|
async def get_by_user_id(self, db: AsyncSession, user_id: int) -> Optional[Points]:
|
||||||
|
"""
|
||||||
|
根据用户ID获取积分账户信息
|
||||||
|
"""
|
||||||
|
stmt = select(Points).where(Points.user_id == user_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def create_user_points(self, db: AsyncSession, user_id: int) -> Points:
|
||||||
|
"""
|
||||||
|
为用户创建积分账户
|
||||||
|
"""
|
||||||
|
points = Points(user_id=user_id)
|
||||||
|
db.add(points)
|
||||||
|
await db.flush()
|
||||||
|
return points
|
||||||
|
|
||||||
|
async def add_points_atomic(self, db: AsyncSession, user_id: int, amount: int, extend_expiration: bool = False) -> bool:
|
||||||
|
"""
|
||||||
|
原子性增加用户积分
|
||||||
|
"""
|
||||||
|
# 先确保用户有积分账户
|
||||||
|
points_account = await self.get_by_user_id(db, user_id)
|
||||||
|
if not points_account:
|
||||||
|
points_account = await self.create_user_points(db, user_id)
|
||||||
|
|
||||||
|
# 准备更新值
|
||||||
|
update_values = {
|
||||||
|
"balance": Points.balance + amount,
|
||||||
|
"total_earned": Points.total_earned + amount
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果需要延期,则更新过期时间
|
||||||
|
if extend_expiration:
|
||||||
|
update_values["expired_time"] = datetime.now() + timedelta(days=30)
|
||||||
|
|
||||||
|
stmt = update(Points).where(
|
||||||
|
Points.user_id == user_id
|
||||||
|
).values(**update_values)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def deduct_points_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
|
||||||
|
"""
|
||||||
|
原子性扣减用户积分(确保不超扣)
|
||||||
|
"""
|
||||||
|
stmt = update(Points).where(
|
||||||
|
Points.user_id == user_id,
|
||||||
|
Points.balance >= amount
|
||||||
|
).values(
|
||||||
|
balance=Points.balance - amount,
|
||||||
|
total_spent=Points.total_spent + amount
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def get_balance(self, db: AsyncSession, user_id: int) -> int:
|
||||||
|
"""
|
||||||
|
获取用户积分余额
|
||||||
|
"""
|
||||||
|
points_account = await self.get_by_user_id(db, user_id)
|
||||||
|
if not points_account:
|
||||||
|
return 0
|
||||||
|
return points_account.balance
|
||||||
|
|
||||||
|
async def check_and_clear_expired_points(self, db: AsyncSession, user_id: int) -> bool:
|
||||||
|
"""
|
||||||
|
检查并清空过期积分
|
||||||
|
"""
|
||||||
|
stmt = update(Points).where(
|
||||||
|
Points.user_id == user_id,
|
||||||
|
Points.expired_time < datetime.now(),
|
||||||
|
Points.balance > 0
|
||||||
|
).values(
|
||||||
|
balance=0,
|
||||||
|
total_spent=Points.total_spent + Points.balance
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
|
||||||
|
class PointsLogDao(CRUDPlus[PointsLog]):
|
||||||
|
async def add_log(self, db: AsyncSession, log_data: Dict[str, Any]) -> PointsLog:
|
||||||
|
"""
|
||||||
|
添加积分变动日志
|
||||||
|
"""
|
||||||
|
log = PointsLog(**log_data)
|
||||||
|
db.add(log)
|
||||||
|
await db.flush()
|
||||||
|
return log
|
||||||
|
|
||||||
|
|
||||||
|
points_dao = PointsDao(Points)
|
||||||
|
points_log_dao = PointsLogDao(PointsLog)
|
||||||
61
backend/app/admin/crud/usage_log_crud.py
Executable file
61
backend/app/admin/crud/usage_log_crud.py
Executable file
@@ -0,0 +1,61 @@
|
|||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from sqlalchemy import select, and_
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from backend.app.admin.model.order import UsageLog
|
||||||
|
from sqlalchemy_crud_plus import CRUDPlus
|
||||||
|
|
||||||
|
|
||||||
|
class UsageLogDao(CRUDPlus[UsageLog]):
|
||||||
|
|
||||||
|
async def get_by_id(self, db: AsyncSession, log_id: int) -> Optional[UsageLog]:
|
||||||
|
"""
|
||||||
|
根据ID获取使用日志
|
||||||
|
"""
|
||||||
|
stmt = select(UsageLog).where(UsageLog.id == log_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_by_user_id(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[UsageLog]:
|
||||||
|
"""
|
||||||
|
根据用户ID获取使用日志列表
|
||||||
|
"""
|
||||||
|
stmt = select(UsageLog).where(
|
||||||
|
UsageLog.user_id == user_id
|
||||||
|
).order_by(UsageLog.created_at.desc()).limit(limit)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_by_action(self, db: AsyncSession, user_id: int, action: str, limit: int = 50) -> List[UsageLog]:
|
||||||
|
"""
|
||||||
|
根据动作类型获取使用日志
|
||||||
|
"""
|
||||||
|
stmt = select(UsageLog).where(
|
||||||
|
and_(
|
||||||
|
UsageLog.user_id == user_id,
|
||||||
|
UsageLog.action == action
|
||||||
|
)
|
||||||
|
).order_by(UsageLog.created_at.desc()).limit(limit)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_balance_history(self, db: AsyncSession, user_id: int, limit: int = 100) -> List[UsageLog]:
|
||||||
|
"""
|
||||||
|
获取用户余额变动历史
|
||||||
|
"""
|
||||||
|
stmt = select(UsageLog).where(
|
||||||
|
UsageLog.user_id == user_id
|
||||||
|
).order_by(UsageLog.created_at.desc()).limit(limit)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def add_log(self, db: AsyncSession, log_data: Dict[str, Any]) -> UsageLog:
|
||||||
|
"""
|
||||||
|
添加使用日志
|
||||||
|
"""
|
||||||
|
log = UsageLog(**log_data)
|
||||||
|
db.add(log)
|
||||||
|
await db.flush()
|
||||||
|
return log
|
||||||
|
|
||||||
|
|
||||||
|
usage_log_dao = UsageLogDao(UsageLog)
|
||||||
105
backend/app/admin/crud/user_account_crud.py
Executable file
105
backend/app/admin/crud/user_account_crud.py
Executable file
@@ -0,0 +1,105 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional
|
||||||
|
from sqlalchemy import select, update, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy_crud_plus import CRUDPlus
|
||||||
|
from backend.app.admin.model.order import UserAccount, FreezeLog
|
||||||
|
|
||||||
|
|
||||||
|
class UserAccountDao(CRUDPlus[UserAccount]):
|
||||||
|
|
||||||
|
async def get_by_user_id(self, db: AsyncSession, user_id: int) -> Optional[UserAccount]:
|
||||||
|
"""
|
||||||
|
根据用户ID获取账户信息
|
||||||
|
"""
|
||||||
|
stmt = select(UserAccount).where(UserAccount.user_id == user_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_by_id(self, db: AsyncSession, account_id: int) -> Optional[UserAccount]:
|
||||||
|
"""
|
||||||
|
根据账户ID获取账户信息
|
||||||
|
"""
|
||||||
|
stmt = select(UserAccount).where(UserAccount.id == account_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def create_new_user_account(self, db: AsyncSession, user_id: int) -> UserAccount:
|
||||||
|
"""
|
||||||
|
为新用户创建账户(包含免费试用)
|
||||||
|
"""
|
||||||
|
# 设置免费试用期(3天)
|
||||||
|
trial_expires_at = datetime.now() + timedelta(days=3)
|
||||||
|
|
||||||
|
account = UserAccount(
|
||||||
|
user_id=user_id,
|
||||||
|
balance=30, # 初始30次免费次数
|
||||||
|
free_trial_balance=30,
|
||||||
|
free_trial_expires_at=trial_expires_at,
|
||||||
|
free_trial_used=True # 标记为已使用(因为已经给了30次)
|
||||||
|
)
|
||||||
|
db.add(account)
|
||||||
|
await db.flush()
|
||||||
|
return account
|
||||||
|
|
||||||
|
async def update_balance_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
|
||||||
|
"""
|
||||||
|
原子性更新用户余额
|
||||||
|
"""
|
||||||
|
stmt = update(UserAccount).where(
|
||||||
|
UserAccount.user_id == user_id
|
||||||
|
).values(
|
||||||
|
balance=UserAccount.balance + amount
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def deduct_balance_atomic(self, db: AsyncSession, user_id: int, amount: int) -> bool:
|
||||||
|
"""
|
||||||
|
原子性扣减用户余额(确保不超扣)
|
||||||
|
"""
|
||||||
|
stmt = update(UserAccount).where(
|
||||||
|
UserAccount.user_id == user_id,
|
||||||
|
UserAccount.balance >= amount
|
||||||
|
).values(
|
||||||
|
balance=UserAccount.balance - amount
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def get_frozen_balance(self, db: AsyncSession, user_id: int) -> int:
|
||||||
|
"""
|
||||||
|
获取用户被冻结的次数
|
||||||
|
"""
|
||||||
|
|
||||||
|
stmt = select(func.sum(FreezeLog.amount)).where(
|
||||||
|
FreezeLog.user_id == user_id,
|
||||||
|
FreezeLog.status == "pending"
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar() or 0
|
||||||
|
|
||||||
|
async def get_available_balance(self, db: AsyncSession, user_id: int) -> int:
|
||||||
|
"""
|
||||||
|
获取用户可用余额(总余额减去冻结余额)
|
||||||
|
"""
|
||||||
|
account = await self.get_by_user_id(db, user_id)
|
||||||
|
if not account:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
frozen_balance = await self.get_frozen_balance(db, user_id)
|
||||||
|
return max(0, account.balance - frozen_balance)
|
||||||
|
|
||||||
|
async def check_free_trial_valid(self, db: AsyncSession, user_id: int) -> bool:
|
||||||
|
"""
|
||||||
|
检查用户免费试用是否仍然有效
|
||||||
|
"""
|
||||||
|
account = await self.get_by_user_id(db, user_id)
|
||||||
|
if not account or not account.free_trial_expires_at:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return (account.free_trial_expires_at > datetime.now() and
|
||||||
|
account.free_trial_balance > 0)
|
||||||
|
|
||||||
|
|
||||||
|
user_account_dao = UserAccountDao(UserAccount)
|
||||||
201
backend/app/admin/crud/wx_user_crud.py
Executable file
201
backend/app/admin/crud/wx_user_crud.py
Executable file
@@ -0,0 +1,201 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
from sqlalchemy import select, update, desc, and_
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.sql import Select
|
||||||
|
from sqlalchemy_crud_plus import CRUDPlus
|
||||||
|
|
||||||
|
from backend.app.admin.model import WxUser
|
||||||
|
from backend.app.admin.schema.user import RegisterUserParam, UpdateUserParam, AvatarParam
|
||||||
|
from backend.common.security.jwt import get_hash_password
|
||||||
|
from backend.utils.wx_pay import wx_pay_utils
|
||||||
|
|
||||||
|
|
||||||
|
class WxUserCRUD(CRUDPlus[WxUser]):
|
||||||
|
async def get(self, db: AsyncSession, user_id: int) -> WxUser | None:
|
||||||
|
"""
|
||||||
|
获取用户
|
||||||
|
|
||||||
|
:param db:
|
||||||
|
:param user_id:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.select_model(db, user_id)
|
||||||
|
|
||||||
|
async def get_by_username(self, db: AsyncSession, username: str) -> WxUser | None:
|
||||||
|
"""
|
||||||
|
通过 username 获取用户
|
||||||
|
|
||||||
|
:param db:
|
||||||
|
:param username:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.select_model_by_column(db, username=username)
|
||||||
|
|
||||||
|
async def get_by_openid(self, db: AsyncSession, openid: str) -> WxUser | None:
|
||||||
|
"""
|
||||||
|
通过 username 获取用户
|
||||||
|
|
||||||
|
:param db:
|
||||||
|
:param openid:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.select_model_by_column(db, openid=openid)
|
||||||
|
|
||||||
|
async def add(self, db: AsyncSession, new_user: WxUser) -> None:
|
||||||
|
"""
|
||||||
|
通过 openid 添加用户
|
||||||
|
|
||||||
|
:param db:
|
||||||
|
:param openid:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
db.add(new_user)
|
||||||
|
|
||||||
|
async def update_session_key(self, db: AsyncSession, input_user: int, session_key: str) -> int:
|
||||||
|
"""
|
||||||
|
更新用户头像
|
||||||
|
|
||||||
|
:param db:
|
||||||
|
:param input_user:
|
||||||
|
:param avatar:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.update_model(db, input_user, {'session_key': session_key})
|
||||||
|
|
||||||
|
async def update_user_profile(self, db: AsyncSession, user_id: int, profile: dict) -> int:
|
||||||
|
"""
|
||||||
|
更新用户资料并自动更新updated_time字段
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param user_id: 用户ID
|
||||||
|
:param profile: 用户资料
|
||||||
|
:return: 更新的记录数
|
||||||
|
"""
|
||||||
|
return await self.update_model(db, user_id, {'profile': profile})
|
||||||
|
|
||||||
|
async def update_login_time(self, db: AsyncSession, username: str, login_time: datetime) -> int:
|
||||||
|
user = await db.execute(
|
||||||
|
update(self.model).where(self.model.username == username).values(last_login_time=login_time)
|
||||||
|
)
|
||||||
|
return user.rowcount
|
||||||
|
|
||||||
|
async def create(self, db: AsyncSession, obj: RegisterUserParam) -> None:
|
||||||
|
"""
|
||||||
|
创建用户
|
||||||
|
|
||||||
|
:param db:
|
||||||
|
:param obj:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
salt = bcrypt.gensalt()
|
||||||
|
obj.password = get_hash_password(obj.password, salt)
|
||||||
|
dict_obj = obj.model_dump()
|
||||||
|
dict_obj.update({'salt': salt})
|
||||||
|
new_user = self.model(**dict_obj)
|
||||||
|
db.add(new_user)
|
||||||
|
|
||||||
|
async def update_userinfo(self, db: AsyncSession, input_user: int, obj: UpdateUserParam) -> int:
|
||||||
|
"""
|
||||||
|
更新用户信息
|
||||||
|
|
||||||
|
:param db:
|
||||||
|
:param input_user:
|
||||||
|
:param obj:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.update_model(db, input_user, obj)
|
||||||
|
|
||||||
|
async def update_avatar(self, db: AsyncSession, input_user: int, avatar: AvatarParam) -> int:
|
||||||
|
"""
|
||||||
|
更新用户头像
|
||||||
|
|
||||||
|
:param db:
|
||||||
|
:param input_user:
|
||||||
|
:param avatar:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.update_model(db, input_user, {'avatar': avatar.url})
|
||||||
|
|
||||||
|
async def delete(self, db: AsyncSession, user_id: int) -> int:
|
||||||
|
"""
|
||||||
|
删除用户
|
||||||
|
|
||||||
|
:param db:
|
||||||
|
:param user_id:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.delete_model(db, user_id)
|
||||||
|
|
||||||
|
async def check_email(self, db: AsyncSession, email: str) -> WxUser:
|
||||||
|
"""
|
||||||
|
检查邮箱是否存在
|
||||||
|
|
||||||
|
:param db:
|
||||||
|
:param email:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.select_model_by_column(db, email=email)
|
||||||
|
|
||||||
|
async def reset_password(self, db: AsyncSession, pk: int, new_pwd: str) -> int:
|
||||||
|
"""
|
||||||
|
重置用户密码
|
||||||
|
|
||||||
|
:param db:
|
||||||
|
:param pk:
|
||||||
|
:param new_pwd:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return await self.update_model(db, pk, {'password': new_pwd})
|
||||||
|
|
||||||
|
async def get_list(self, username: str = None, phone: str = None, status: int = None) -> Select:
|
||||||
|
"""
|
||||||
|
获取用户列表
|
||||||
|
|
||||||
|
:param username:
|
||||||
|
:param phone:
|
||||||
|
:param status:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
stmt = select(self.model).order_by(desc(self.model.join_time))
|
||||||
|
|
||||||
|
filters = []
|
||||||
|
if username:
|
||||||
|
filters.append(self.model.username.like(f'%{username}%'))
|
||||||
|
if phone:
|
||||||
|
filters.append(self.model.phone.like(f'%{phone}%'))
|
||||||
|
if status is not None:
|
||||||
|
filters.append(self.model.status == status)
|
||||||
|
|
||||||
|
if filters:
|
||||||
|
stmt = stmt.where(and_(*filters))
|
||||||
|
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
async def get_with_relation(
|
||||||
|
self, db: AsyncSession, *, user_id: int | None = None, openid: str | None = None
|
||||||
|
) -> WxUser | None:
|
||||||
|
"""
|
||||||
|
获取用户关联信息
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param user_id: 用户 ID
|
||||||
|
:param username: 用户名
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
filters = {}
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
filters['id'] = user_id
|
||||||
|
if openid:
|
||||||
|
filters['openid'] = openid
|
||||||
|
|
||||||
|
return await self.select_model_by_column(
|
||||||
|
db,
|
||||||
|
**filters,
|
||||||
|
)
|
||||||
|
|
||||||
|
wx_user_dao: WxUserCRUD = WxUserCRUD(WxUser)
|
||||||
12
backend/app/admin/model/__init__.py
Executable file
12
backend/app/admin/model/__init__.py
Executable file
@@ -0,0 +1,12 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from backend.common.model import MappedBase # noqa: I
|
||||||
|
from backend.app.admin.model.wx_user import WxUser
|
||||||
|
from backend.app.admin.model.audit_log import AuditLog, DailySummary
|
||||||
|
from backend.app.admin.model.file import File
|
||||||
|
from backend.app.admin.model.dict import DictionaryEntry, DictionaryMedia
|
||||||
|
from backend.app.admin.model.order import Order, UserAccount, FreezeLog, UsageLog
|
||||||
|
from backend.app.admin.model.coupon import Coupon, CouponUsage
|
||||||
|
from backend.app.admin.model.notification import Notification, UserNotification
|
||||||
|
from backend.app.admin.model.points import Points, PointsLog
|
||||||
|
from backend.app.ai.model import Image
|
||||||
53
backend/app/admin/model/audit_log.py
Executable file
53
backend/app/admin/model/audit_log.py
Executable file
@@ -0,0 +1,53 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from sqlalchemy import Integer, BigInteger, Text, String, Numeric, Float, DateTime, ForeignKey, Index
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, ARRAY
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from backend.common.model import snowflake_id_key, Base
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLog(Base):
|
||||||
|
__tablename__ = 'audit_log'
|
||||||
|
|
||||||
|
id: Mapped[snowflake_id_key] = mapped_column(init=False, primary_key=True)
|
||||||
|
api_type: Mapped[str] = mapped_column(String(20), nullable=False, comment="API类型: recognition embedding assessment")
|
||||||
|
model_name: Mapped[str] = mapped_column(String(50), nullable=False, comment="模型名称")
|
||||||
|
request_data: Mapped[Optional[dict]] = mapped_column(JSONB, comment="请求数据")
|
||||||
|
response_data: Mapped[Optional[dict]] = mapped_column(JSONB, comment="响应数据")
|
||||||
|
token_usage: Mapped[Optional[dict]] = mapped_column(JSONB, comment="消耗的token数量")
|
||||||
|
cost: Mapped[Optional[float]] = mapped_column(Numeric(10, 5), comment="API调用成本")
|
||||||
|
duration: Mapped[Optional[float]] = mapped_column(Float, comment="调用耗时(秒)")
|
||||||
|
status_code: Mapped[Optional[int]] = mapped_column(Integer, comment="HTTP状态码")
|
||||||
|
image_id: Mapped[Optional[int]] = mapped_column(BigInteger, ForeignKey('image.id'), comment="关联的图片ID")
|
||||||
|
user_id: Mapped[Optional[int]] = mapped_column(BigInteger, ForeignKey('wx_user.id'), comment="调用用户ID")
|
||||||
|
dict_level: Mapped[Optional[str]] = mapped_column(String(20), comment="dict level")
|
||||||
|
api_version: Mapped[Optional[str]] = mapped_column(String(20), comment="API版本")
|
||||||
|
error_message: Mapped[Optional[str]] = mapped_column(Text, default=None, comment="错误信息")
|
||||||
|
called_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment="调用时间")
|
||||||
|
|
||||||
|
# 索引优化
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_audit_logs_image_id', 'image_id'),
|
||||||
|
# 为用户历史记录查询优化的索引
|
||||||
|
Index('idx_audit_log_user_api_called', 'user_id', 'api_type', 'called_at'),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DailySummary(Base):
|
||||||
|
__tablename__ = 'daily_summary'
|
||||||
|
|
||||||
|
id: Mapped[snowflake_id_key] = mapped_column(init=False, primary_key=True)
|
||||||
|
user_id: Mapped[Optional[int]] = mapped_column(BigInteger, ForeignKey('wx_user.id'), comment="调用用户ID")
|
||||||
|
image_ids: Mapped[List[str]] = mapped_column(ARRAY(Text), default=None, comment="图片ID列表")
|
||||||
|
thumbnail_ids: Mapped[List[str]] = mapped_column(ARRAY(Text), default=None, comment="图片缩略图列表")
|
||||||
|
summary_time: Mapped[datetime] = mapped_column(DateTime, default=None, comment="总结的时间")
|
||||||
|
|
||||||
|
# 索引优化
|
||||||
|
__table_args__ = (
|
||||||
|
# 为用户历史记录查询优化的索引
|
||||||
|
Index('idx_daily_summary_api_called', 'user_id', 'summary_time'),
|
||||||
|
)
|
||||||
43
backend/app/admin/model/coupon.py
Executable file
43
backend/app/admin/model/coupon.py
Executable file
@@ -0,0 +1,43 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import String, Column, BigInteger, ForeignKey, Boolean, DateTime, Index, Text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from backend.common.model import snowflake_id_key, Base
|
||||||
|
|
||||||
|
|
||||||
|
class Coupon(Base):
|
||||||
|
__tablename__ = 'coupon'
|
||||||
|
|
||||||
|
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||||
|
code: Mapped[str] = mapped_column(String(32), unique=True, nullable=False, comment='兑换码')
|
||||||
|
duration: Mapped[int] = mapped_column(BigInteger, nullable=False, comment='兑换时长(分钟)')
|
||||||
|
is_used: Mapped[bool] = mapped_column(Boolean, default=False, comment='是否已使用')
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='创建时间')
|
||||||
|
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='过期时间')
|
||||||
|
created_by: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, comment='创建者ID')
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_coupon_code', 'code'),
|
||||||
|
Index('idx_coupon_is_used', 'is_used'),
|
||||||
|
{'comment': '兑换券表'}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CouponUsage(Base):
|
||||||
|
__tablename__ = 'coupon_usage'
|
||||||
|
|
||||||
|
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||||
|
coupon_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('coupon.id'), nullable=False, comment='兑换券ID')
|
||||||
|
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='使用者ID')
|
||||||
|
duration: Mapped[int] = mapped_column(BigInteger, nullable=False, comment='兑换时长(分钟)')
|
||||||
|
used_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='使用时间')
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_coupon_usage_user', 'user_id'),
|
||||||
|
Index('idx_coupon_usage_coupon', 'coupon_id'),
|
||||||
|
{'comment': '兑换券使用记录表'}
|
||||||
|
)
|
||||||
42
backend/app/admin/model/dict.py
Executable file
42
backend/app/admin/model/dict.py
Executable file
@@ -0,0 +1,42 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import String, Column, LargeBinary, ForeignKey, BigInteger, Index, func, JSON, Text, Numeric
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from backend.app.admin.schema.dict import WordMetaData
|
||||||
|
from backend.app.admin.schema.pydantic_type import PydanticType
|
||||||
|
from backend.common.model import snowflake_id_key, DataClassBase
|
||||||
|
|
||||||
|
|
||||||
|
class DictionaryEntry(DataClassBase):
|
||||||
|
__tablename__ = 'dict_entry'
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True, init=True, autoincrement=True)
|
||||||
|
word: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
||||||
|
definition: Mapped[Optional[str]] = mapped_column(Text, default=None)
|
||||||
|
details: Mapped[Optional[WordMetaData]] = mapped_column(PydanticType(pydantic_type=WordMetaData), default=None) # 其他可能的字段(根据实际需求添加)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_dict_word', word),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DictionaryMedia(DataClassBase):
|
||||||
|
__tablename__ = 'dict_media'
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True, init=True, autoincrement=True)
|
||||||
|
file_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
file_type: Mapped[str] = mapped_column(String(50), nullable=False) # 'audio', 'image'
|
||||||
|
dict_id: Mapped[Optional[int]] = mapped_column(BigInteger, ForeignKey("dict_entry.id"), default=None)
|
||||||
|
file_data: Mapped[Optional[bytes]] = mapped_column(LargeBinary, default=None)
|
||||||
|
file_hash: Mapped[Optional[str]] = mapped_column(String(64), default=None)
|
||||||
|
details: Mapped[Optional[dict]] = mapped_column(JSONB(astext_type=Text()), default=None, comment="其他信息") # 其他信息
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_media_filename', file_name),
|
||||||
|
Index('idx_media_dict_id', dict_id),
|
||||||
|
)
|
||||||
29
backend/app/admin/model/feedback.py
Executable file
29
backend/app/admin/model/feedback.py
Executable file
@@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from sqlalchemy import String, Text, DateTime, ForeignKey, Index, BigInteger
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from backend.common.model import snowflake_id_key, Base
|
||||||
|
|
||||||
|
|
||||||
|
class Feedback(Base):
|
||||||
|
__tablename__ = 'feedback'
|
||||||
|
|
||||||
|
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='用户ID')
|
||||||
|
content: Mapped[str] = mapped_column(Text, nullable=False, comment='反馈内容')
|
||||||
|
contact_info: Mapped[Optional[str]] = mapped_column(String(255), nullable=True, comment='联系方式')
|
||||||
|
category: Mapped[Optional[str]] = mapped_column(String(50), nullable=True, comment='反馈分类')
|
||||||
|
status: Mapped[str] = mapped_column(String(20), default='pending', comment='处理状态: pending, processing, resolved')
|
||||||
|
|
||||||
|
# 索引优化
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_feedback_user_id', 'user_id'),
|
||||||
|
Index('idx_feedback_status', 'status'),
|
||||||
|
Index('idx_feedback_category', 'category'),
|
||||||
|
Index('idx_feedback_created_at', 'created_time'),
|
||||||
|
)
|
||||||
28
backend/app/admin/model/file.py
Executable file
28
backend/app/admin/model/file.py
Executable file
@@ -0,0 +1,28 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, Text, String, Index, DateTime, LargeBinary
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.orm import mapped_column, Mapped
|
||||||
|
|
||||||
|
from backend.common.model import snowflake_id_key, Base
|
||||||
|
|
||||||
|
|
||||||
|
class File(Base):
|
||||||
|
__tablename__ = 'file'
|
||||||
|
|
||||||
|
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||||
|
file_hash: Mapped[str] = mapped_column(String(64), index=True, nullable=False) # SHA256哈希
|
||||||
|
file_name: Mapped[str] = mapped_column(String(255), nullable=False) # 原始文件名
|
||||||
|
content_type: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) # MIME类型
|
||||||
|
file_size: Mapped[int] = mapped_column(BigInteger, nullable=False) # 文件大小(字节)
|
||||||
|
storage_path: Mapped[Optional[str]] = mapped_column(Text, nullable=True) # 存储路径(非数据库存储时使用)
|
||||||
|
file_data: Mapped[Optional[bytes]] = mapped_column(LargeBinary, default=None, nullable=True) # 文件二进制数据(数据库存储时使用)
|
||||||
|
storage_type: Mapped[str] = mapped_column(String(20), nullable=False, default='database') # 存储类型: database, local, s3
|
||||||
|
metadata_info: Mapped[Optional[dict]] = mapped_column(JSONB(astext_type=Text()), default=None, comment="元数据信息")
|
||||||
|
|
||||||
|
# 表参数 - 包含所有必要的约束
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_file_hash', file_hash),
|
||||||
|
)
|
||||||
45
backend/app/admin/model/notification.py
Executable file
45
backend/app/admin/model/notification.py
Executable file
@@ -0,0 +1,45 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import String, Column, BigInteger, ForeignKey, Boolean, DateTime, Index, Text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from backend.common.model import snowflake_id_key, Base
|
||||||
|
|
||||||
|
|
||||||
|
class Notification(Base):
|
||||||
|
__tablename__ = 'notification'
|
||||||
|
|
||||||
|
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||||
|
title: Mapped[str] = mapped_column(String(255), nullable=False, comment='通知标题')
|
||||||
|
content: Mapped[str] = mapped_column(Text, nullable=False, comment='通知内容')
|
||||||
|
image_url: Mapped[Optional[str]] = mapped_column(String(512), default=None, comment='图片URL(预留)')
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='创建时间')
|
||||||
|
created_by: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, comment='创建者ID')
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True, comment='是否激活')
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_notification_created', 'created_at'),
|
||||||
|
Index('idx_notification_active', 'is_active'),
|
||||||
|
{'comment': '消息通知表'}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserNotification(Base):
|
||||||
|
__tablename__ = 'user_notification'
|
||||||
|
|
||||||
|
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||||
|
notification_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('notification.id'), nullable=False, comment='通知ID')
|
||||||
|
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='用户ID')
|
||||||
|
is_read: Mapped[bool] = mapped_column(Boolean, default=False, comment='是否已读')
|
||||||
|
read_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='阅读时间')
|
||||||
|
received_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='接收时间')
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_user_notification_user', 'user_id'),
|
||||||
|
Index('idx_user_notification_notification', 'notification_id'),
|
||||||
|
Index('idx_user_notification_read', 'is_read'),
|
||||||
|
{'comment': '用户通知关联表'}
|
||||||
|
)
|
||||||
79
backend/app/admin/model/order.py
Executable file
79
backend/app/admin/model/order.py
Executable file
@@ -0,0 +1,79 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import String, Column, BigInteger, ForeignKey, Boolean, DateTime, Index, func, JSON, Text, Numeric
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from backend.common.model import snowflake_id_key, Base
|
||||||
|
|
||||||
|
|
||||||
|
class Order(Base):
|
||||||
|
__tablename__ = 'order'
|
||||||
|
|
||||||
|
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False)
|
||||||
|
order_type: Mapped[str] = mapped_column(String(32), comment='类型:purchase/subscription/extra')
|
||||||
|
payment_id: Mapped[Optional[str]] = mapped_column(String(64), comment='微信支付ID')
|
||||||
|
transaction_id: Mapped[Optional[str]] = mapped_column(String(64), comment='微信交易号')
|
||||||
|
amount_cents: Mapped[int] = mapped_column(BigInteger, comment='金额(分)')
|
||||||
|
amount_times: Mapped[int] = mapped_column(BigInteger, comment='实际获得次数')
|
||||||
|
status: Mapped[str] = mapped_column(String(16), default='pending', comment='订单状态')
|
||||||
|
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='过期时间(用于订阅)')
|
||||||
|
processed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='处理时间(用于幂等性)')
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_order_user_status', 'user_id', 'status'),
|
||||||
|
Index('idx_order_processed', 'processed_at'),
|
||||||
|
)
|
||||||
|
|
||||||
|
class UserAccount(Base):
|
||||||
|
__tablename__ = 'user_account'
|
||||||
|
|
||||||
|
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), unique=True, nullable=False, comment='关联的用户ID')
|
||||||
|
balance: Mapped[int] = mapped_column(BigInteger, default=0, comment='当前可用次数')
|
||||||
|
total_purchased: Mapped[int] = mapped_column(BigInteger, default=0, comment='累计购买次数')
|
||||||
|
subscription_type: Mapped[Optional[str]] = mapped_column(String(32), default=None, comment='订阅类型:monthly/quarterly/half_yearly/yearly')
|
||||||
|
subscription_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='订阅到期时间')
|
||||||
|
carryover_balance: Mapped[int] = mapped_column(BigInteger, default=0, comment='上期未使用的次数')
|
||||||
|
# 新用户免费次数相关
|
||||||
|
free_trial_balance: Mapped[int] = mapped_column(BigInteger, default=30, comment='新用户免费试用次数')
|
||||||
|
free_trial_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=None, comment='免费试用期结束时间')
|
||||||
|
free_trial_used: Mapped[bool] = mapped_column(Boolean, default=False, comment='是否已使用免费试用')
|
||||||
|
|
||||||
|
class FreezeLog(Base):
|
||||||
|
__tablename__ = 'freeze_log'
|
||||||
|
|
||||||
|
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False)
|
||||||
|
order_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('order.id'), nullable=False)
|
||||||
|
amount: Mapped[int] = mapped_column(BigInteger, comment='冻结次数')
|
||||||
|
reason: Mapped[Optional[str]] = mapped_column(Text, comment='冻结原因')
|
||||||
|
status: Mapped[str] = mapped_column(String(16), default='pending', comment='状态:pending/confirmed/cancelled')
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_freeze_user_status', 'user_id', 'status'),
|
||||||
|
{'comment': '次数冻结记录表'}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UsageLog(Base):
|
||||||
|
__tablename__ = 'usage_log'
|
||||||
|
|
||||||
|
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='用户ID')
|
||||||
|
action: Mapped[str] = mapped_column(String(32), comment='动作:purchase/renewal/use/carryover/share/ad/freeze/unfreeze/refund')
|
||||||
|
amount: Mapped[int] = mapped_column(BigInteger, comment='变动数量')
|
||||||
|
balance_after: Mapped[int] = mapped_column(BigInteger, comment='变动后余额')
|
||||||
|
related_id: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, comment='关联ID,如订单ID、冻结记录ID')
|
||||||
|
details: Mapped[Optional[dict]] = mapped_column(JSONB, default=None, comment='附加信息')
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_usage_user_action', 'user_id', 'action'),
|
||||||
|
Index('idx_usage_user_time', 'user_id', 'created_time'),
|
||||||
|
Index('idx_usage_action_time', 'action', 'created_time'),
|
||||||
|
{'comment': '使用日志表'}
|
||||||
|
)
|
||||||
47
backend/app/admin/model/points.py
Normal file
47
backend/app/admin/model/points.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import String, Column, BigInteger, ForeignKey, DateTime, Index, Text
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from backend.common.model import snowflake_id_key, Base
|
||||||
|
|
||||||
|
|
||||||
|
class Points(Base):
|
||||||
|
__tablename__ = 'points'
|
||||||
|
|
||||||
|
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), unique=True, nullable=False, comment='关联的用户ID')
|
||||||
|
balance: Mapped[int] = mapped_column(BigInteger, default=0, comment='当前积分余额')
|
||||||
|
total_earned: Mapped[int] = mapped_column(BigInteger, default=0, comment='累计获得积分')
|
||||||
|
total_spent: Mapped[int] = mapped_column(BigInteger, default=0, comment='累计消费积分')
|
||||||
|
expired_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.now() + timedelta(days=30), comment="过期时间")
|
||||||
|
|
||||||
|
# 索引优化
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_points_user', 'user_id'),
|
||||||
|
{'comment': '用户积分表'}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PointsLog(Base):
|
||||||
|
__tablename__ = 'points_log'
|
||||||
|
|
||||||
|
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey('wx_user.id'), nullable=False, comment='用户ID')
|
||||||
|
action: Mapped[str] = mapped_column(String(32), comment='动作:earn/spend')
|
||||||
|
amount: Mapped[int] = mapped_column(BigInteger, comment='变动数量')
|
||||||
|
balance_after: Mapped[int] = mapped_column(BigInteger, comment='变动后余额')
|
||||||
|
related_id: Mapped[Optional[int]] = mapped_column(BigInteger, default=None, comment='关联ID')
|
||||||
|
details: Mapped[Optional[dict]] = mapped_column(JSONB, default=None, comment='附加信息')
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now, comment='创建时间')
|
||||||
|
|
||||||
|
# 索引优化
|
||||||
|
__table_args__ = (
|
||||||
|
Index('idx_points_log_user_action', 'user_id', 'action'),
|
||||||
|
Index('idx_points_log_user_time', 'user_id', 'created_at'),
|
||||||
|
{'comment': '积分变动日志表'}
|
||||||
|
)
|
||||||
36
backend/app/admin/model/wx_user.py
Executable file
36
backend/app/admin/model/wx_user.py
Executable file
@@ -0,0 +1,36 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import String, Column, BigInteger, SmallInteger, Boolean, DateTime, Index, func, JSON, Text, Numeric
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from backend.common.model import snowflake_id_key, Base
|
||||||
|
|
||||||
|
|
||||||
|
class WxUser(Base):
|
||||||
|
__tablename__ = 'wx_user'
|
||||||
|
|
||||||
|
id: Mapped[snowflake_id_key] = mapped_column(BigInteger, init=False, primary_key=True)
|
||||||
|
openid: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, comment='微信OpenID')
|
||||||
|
session_key: Mapped[str] = mapped_column(String(128), nullable=False, comment='会话密钥')
|
||||||
|
unionid: Mapped[Optional[str]] = mapped_column(String(64), default=None, index=True, comment='微信UnionID')
|
||||||
|
mobile: Mapped[Optional[str]] = mapped_column(String(15), default=None, index=True, comment='加密手机号')
|
||||||
|
profile: Mapped[Optional[dict]] = mapped_column(JSONB(astext_type=Text()), default=None, comment='用户资料JSON')
|
||||||
|
|
||||||
|
|
||||||
|
# class WxPayment(Base):
|
||||||
|
# __tablename__ = 'wx_payments'
|
||||||
|
#
|
||||||
|
# id: Mapped[snowflake_id_key] = mapped_column(init=False, primary_key=True)
|
||||||
|
# user_id: Mapped[int] = mapped_column(BigInteger, index=True, nullable=False)
|
||||||
|
# prepay_id: Mapped[str] = mapped_column(String(64), nullable=False, comment='预支付ID')
|
||||||
|
# transaction_id: Mapped[Optional[str]] = mapped_column(String(32), comment='微信支付单号')
|
||||||
|
# amount: Mapped[float] = mapped_column(Numeric, nullable=False, comment='分单位金额')
|
||||||
|
# status: Mapped[str] = mapped_column(String(16), default='pending', comment='支付状态')
|
||||||
|
#
|
||||||
|
# __table_args__ = (
|
||||||
|
# Index('idx_payment_user', 'user_id', 'status'),
|
||||||
|
# {'comment': '支付记录表'}
|
||||||
|
# )
|
||||||
6
backend/app/admin/schema/__init__.py
Executable file
6
backend/app/admin/schema/__init__.py
Executable file
@@ -0,0 +1,6 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from backend.app.admin.schema.file import FileSchemaBase, AddFileParam, UpdateFileParam, FileInfoSchema, FileUploadResponse
|
||||||
|
from backend.app.admin.schema.audit_log import DailySummarySchema, DailySummaryPageSchema
|
||||||
|
from backend.app.admin.schema.points import PointsSchema, PointsLogSchema, AddPointsRequest, DeductPointsRequest
|
||||||
69
backend/app/admin/schema/audit_log.py
Executable file
69
backend/app/admin/schema/audit_log.py
Executable file
@@ -0,0 +1,69 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from pydantic import ConfigDict, Field
|
||||||
|
|
||||||
|
from backend.common.schema import SchemaBase
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLogSchemaBase(SchemaBase):
|
||||||
|
""" Audit Log Schema """
|
||||||
|
api_type: str = Field(description="API类型: recognition/embedding/assessment")
|
||||||
|
model_name: str = Field(description="模型名称")
|
||||||
|
request_data: Optional[dict] = Field(None, description="请求数据")
|
||||||
|
response_data: Optional[dict] = Field(None, description="响应数据")
|
||||||
|
token_usage: Optional[dict] = Field(0, description="消耗的token数量")
|
||||||
|
cost: Optional[float] = Field(0.0, description="API调用成本")
|
||||||
|
duration: float = Field(description="调用耗时(秒)")
|
||||||
|
status_code: int = Field(description="HTTP状态码")
|
||||||
|
error_message: Optional[str] = Field("", description="错误信息")
|
||||||
|
called_at: Optional[datetime] = Field(None, description="调用时间")
|
||||||
|
image_id: int = Field(description="关联的图片ID")
|
||||||
|
user_id: int = Field(description="调用用户ID")
|
||||||
|
api_version: str = Field(description="API版本")
|
||||||
|
dict_level: Optional[str] = Field(None, description="词典等级")
|
||||||
|
|
||||||
|
|
||||||
|
class CreateAuditLogParam(AuditLogSchemaBase):
|
||||||
|
"""创建操作日志参数"""
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLogHistorySchema(SchemaBase):
|
||||||
|
""" Audit Log History Schema for user history records """
|
||||||
|
image_id: Optional[str] = Field(None, description="图ID")
|
||||||
|
file_id: Optional[str] = Field(None, description="原图ID")
|
||||||
|
thumbnail_id: Optional[str] = Field(None, description="缩略图ID")
|
||||||
|
created_time: Optional[str] = Field(description="图片创建时间")
|
||||||
|
dict_level: Optional[str] = Field(None, description="词典等级")
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLogStatisticsSchema(SchemaBase):
|
||||||
|
""" Audit Log Statistics Schema """
|
||||||
|
total_count: int = Field(description="历史总量")
|
||||||
|
today_count: int = Field(description="当天总量")
|
||||||
|
image_count: int = Field(description="图片总量")
|
||||||
|
|
||||||
|
|
||||||
|
class CreateDailySummaryParam(SchemaBase):
|
||||||
|
"""创建每日总结参数"""
|
||||||
|
user_id: int = Field(description="调用用户ID")
|
||||||
|
image_ids: Optional[List[str]] = Field(None, description="图ID")
|
||||||
|
thumbnail_ids: Optional[List[str]] = Field(None, description="缩略图ID")
|
||||||
|
summary_time: Optional[datetime] = Field(None, description="调用时间")
|
||||||
|
|
||||||
|
|
||||||
|
class DailySummarySchema(SchemaBase):
|
||||||
|
""" Daily Summary Schema """
|
||||||
|
# id: int = Field(description="记录ID")
|
||||||
|
# user_id: int = Field(description="用户ID")
|
||||||
|
image_ids: List[str] = Field(description="图片ID列表")
|
||||||
|
thumbnail_ids: List[str] = Field(description="图片缩略图列表")
|
||||||
|
summary_time: str = Field(description="创建时间")
|
||||||
|
|
||||||
|
|
||||||
|
class DailySummaryPageSchema(SchemaBase):
|
||||||
|
""" Daily Summary Page Schema """
|
||||||
|
items: List[DailySummarySchema] = Field(description="每日汇总记录列表")
|
||||||
|
total: int = Field(description="总记录数")
|
||||||
|
page: int = Field(description="当前页码")
|
||||||
|
size: int = Field(description="每页记录数")
|
||||||
10
backend/app/admin/schema/captcha.py
Executable file
10
backend/app/admin/schema/captcha.py
Executable file
@@ -0,0 +1,10 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from backend.common.schema import SchemaBase
|
||||||
|
|
||||||
|
|
||||||
|
class GetCaptchaDetail(SchemaBase):
|
||||||
|
image_type: str = Field(description='图片类型')
|
||||||
|
image: str = Field(description='图片内容')
|
||||||
198
backend/app/admin/schema/dict.py
Executable file
198
backend/app/admin/schema/dict.py
Executable file
@@ -0,0 +1,198 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional, Dict, Any, List, Union
|
||||||
|
|
||||||
|
# --- 辅助模型 ---
|
||||||
|
|
||||||
|
class CrossReference(BaseModel):
|
||||||
|
# 交叉引用的文本,通常是另一个词条名
|
||||||
|
text: Optional[str] = None
|
||||||
|
sense_id: Optional[str] = None
|
||||||
|
# 链接到另一个词条的 href
|
||||||
|
entry_href: Optional[str] = None # 例如: "entry://cooking-apple"
|
||||||
|
# 如果是图片相关的交叉引用,可能包含图片信息
|
||||||
|
# 图片的 ID (来自 a 标签的 showid 属性)
|
||||||
|
show_id: Optional[str] = None # 例如: "img3241"
|
||||||
|
# image_filename: 指向 dict_media 表中的 file_name
|
||||||
|
image_filename: Optional[str] = None # 例如: "img3241_ldoce4188jpg" (由 showid 和 base64 属性组合)
|
||||||
|
# 图片的标题/描述 (来自 a 标签的 title 属性)
|
||||||
|
image_title: Optional[str] = None # 例如: "apple from LDOCE 4"
|
||||||
|
# LDOCE 版本信息 (如果适用)
|
||||||
|
ldoce_version: Optional[str] = None # 例如: "LDOCEVERSION_5", "LDOCEVERSION_new"
|
||||||
|
|
||||||
|
class FamilyItem(BaseModel):
|
||||||
|
text: Optional[str] = None
|
||||||
|
# 链接词用 href 存储,非链接词 href 为 None
|
||||||
|
href: Optional[str] = None # 例如: "bword://underworld"
|
||||||
|
|
||||||
|
class WordFamily(BaseModel):
|
||||||
|
pos: Optional[str] = None # 词性,如 "noun", "adjective"
|
||||||
|
items: List[FamilyItem] = [] # 该词性下的相关词项
|
||||||
|
|
||||||
|
class Pronunciation(BaseModel):
|
||||||
|
# IPA 音标
|
||||||
|
uk_ipa: Optional[str] = None # 例如: "ˈæpəl"
|
||||||
|
us_ipa: Optional[str] = None # 例如: "ˈæpəl" 或 "wɜːrld"
|
||||||
|
# 音频文件路径 (相对或绝对)
|
||||||
|
uk_audio: Optional[str] = None # 例如: "/media/english/breProns/brelasdeapple.mp3"
|
||||||
|
us_audio: Optional[str] = None # 例如: "/media/english/ameProns/apple1.mp3"
|
||||||
|
# title
|
||||||
|
uk_audio_title: Optional[str] = None # 例如: "Play British pronunciation of plane"
|
||||||
|
us_audio_title: Optional[str] = None # 例如: "Play American pronunciation of plane"
|
||||||
|
|
||||||
|
class Frequency(BaseModel):
|
||||||
|
level: Optional[str] = None # 例如: "Core vocabulary: High-frequency"
|
||||||
|
spoken: Optional[str] = None # 例如: "Top 2000 spoken words"
|
||||||
|
written: Optional[str] = None # 例如: "Top 3000 written words"
|
||||||
|
level_tag: Optional[str] = None # 例如: "●●●"
|
||||||
|
spoken_tag: Optional[str] = None # 例如: "S2"
|
||||||
|
written_tag: Optional[str] = None # 例如: "W2"
|
||||||
|
|
||||||
|
# --- 核心模型 ---
|
||||||
|
|
||||||
|
class Example(BaseModel):
|
||||||
|
# 英文例句 (包含可能的高亮或链接,但这里简化为纯文本)
|
||||||
|
en: Optional[str] = None
|
||||||
|
# 中文翻译
|
||||||
|
cn: Optional[str] = None
|
||||||
|
# 例句中突出的搭配 (如 "a world of difference")
|
||||||
|
collocation: Optional[str] = None
|
||||||
|
# 例句中链接到的其他词条 (如 "wanting", "loving")
|
||||||
|
related_words_in_example: Optional[List[str]] = Field(default_factory=list)
|
||||||
|
# 例句音频文件路径
|
||||||
|
audio: Optional[str] = None # 例如: "/media/english/exaProns/p008-000499910.mp3"
|
||||||
|
|
||||||
|
class Definition(BaseModel):
|
||||||
|
# 英文定义
|
||||||
|
en: Optional[str] = None
|
||||||
|
# 中文定义
|
||||||
|
cn: Optional[str] = None
|
||||||
|
# 定义中链接到的其他词条 (来自 defRef 标签)
|
||||||
|
related_words: Optional[List[str]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
class Sense(BaseModel):
|
||||||
|
# Sense 的唯一标识符 (来自 HTML id)
|
||||||
|
id: Optional[str] = None # 例如: "apple__1", "world__3"
|
||||||
|
# Sense 编号 (如果存在)
|
||||||
|
number: Optional[str] = None # 例如: "1", "2", "a)"
|
||||||
|
# Signpost (英文标签,如 "OUR PLANET/EVERYONE ON IT")
|
||||||
|
signpost_en: Optional[str] = None
|
||||||
|
# Signpost 中文翻译
|
||||||
|
signpost_cn: Optional[str] = None
|
||||||
|
ref_hwd: Optional[str] = None
|
||||||
|
# 语法信息 (如果该 Sense 特有)
|
||||||
|
grammar: Optional[str] = None # 可以是字符串或更复杂的结构,视情况而定
|
||||||
|
# 定义 (可能有中英对照)
|
||||||
|
definitions: Optional[List[Definition]] = Field(default_factory=list)
|
||||||
|
# 例子 (属于该 Sense)
|
||||||
|
examples: Optional[List[Example]] = Field(default_factory=list)
|
||||||
|
# 图片信息 (如果该 Sense 有图)
|
||||||
|
image: Optional[str] = None # 图片文件名或路径, 例如: "apple.jpg"
|
||||||
|
# Sense 内的交叉引用 (来自 Crossref 标签)
|
||||||
|
cross_references: Optional[List[CrossReference]] = Field(default_factory=list)
|
||||||
|
# 可数性
|
||||||
|
countability: Optional[List[str]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
class Topic(BaseModel):
|
||||||
|
# Topic 名称
|
||||||
|
name: Optional[str] = None # 例如: "Food, dish", "Astronomy"
|
||||||
|
# 链接到 Topic 的 href
|
||||||
|
href: Optional[str] = None # 例如: "entry://Food, dish-topic food"
|
||||||
|
|
||||||
|
class DictEntry(BaseModel):
|
||||||
|
# 词条名 (标准化形式)
|
||||||
|
headword: Optional[str] = None # 例如: "apple", "world"
|
||||||
|
# 同形异义词编号 (如果存在)
|
||||||
|
homograph_number: Optional[int] = None # 例如: 1, 2 (对应 HOMNUM)
|
||||||
|
# 词性 (主要词性,或第一个 Sense 的词性)
|
||||||
|
part_of_speech: Optional[str] = None # 例如: "noun", "verb"
|
||||||
|
transitive: Optional[List[str]] = Field(default_factory=list) # 例如: "transitive"
|
||||||
|
# 发音
|
||||||
|
pronunciations: Optional[Pronunciation] = None
|
||||||
|
# 频率信息
|
||||||
|
frequency: Optional[Frequency] = None
|
||||||
|
# 相关话题 (Topics)
|
||||||
|
topics: Optional[List[Topic]] = Field(default_factory=list)
|
||||||
|
# 词族信息
|
||||||
|
word_family: Optional[List[WordFamily]] = None
|
||||||
|
# 词条级别的语法信息 (如果适用,不常见)
|
||||||
|
entry_grammar: Optional[str] = None
|
||||||
|
# 所有义项 (Sense)
|
||||||
|
senses: Optional[List[Sense]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
class EtymologyItem(BaseModel):
|
||||||
|
language: Optional[str] = None
|
||||||
|
origin: Optional[str] = None
|
||||||
|
|
||||||
|
class Etymology(BaseModel):
|
||||||
|
intro: Optional[str] = None
|
||||||
|
headword: Optional[str] = None
|
||||||
|
hom_num: Optional[str] = None
|
||||||
|
item: Optional[List[EtymologyItem]] = None
|
||||||
|
|
||||||
|
class WordMetaData(BaseModel):
|
||||||
|
ref_link: Optional[List[str]] = None
|
||||||
|
dict_list: List[DictEntry] = Field(default_factory=list)
|
||||||
|
etymology: Optional[Etymology] = None
|
||||||
|
# 来源词典信息 (可选)
|
||||||
|
source_dict: Optional[str] = "Longman Dictionary of Contemporary English 5++" # 可根据需要调整
|
||||||
|
|
||||||
|
# --- 简化的查询响应模型 ---
|
||||||
|
|
||||||
|
class SimpleCrossReference(BaseModel):
|
||||||
|
"""简化的交叉引用模型"""
|
||||||
|
text: Optional[str] = None
|
||||||
|
show_id: Optional[str] = None
|
||||||
|
sense_id: Optional[str] = None
|
||||||
|
image_filename: Optional[str] = None
|
||||||
|
|
||||||
|
class SimpleExample(BaseModel):
|
||||||
|
"""简化的例句模型"""
|
||||||
|
cn: Optional[str] = None
|
||||||
|
en: Optional[str] = None
|
||||||
|
audio: Optional[str] = None
|
||||||
|
|
||||||
|
class SimpleDefinition(BaseModel):
|
||||||
|
"""简化的定义模型"""
|
||||||
|
cn: Optional[str] = None
|
||||||
|
en: Optional[str] = None
|
||||||
|
related_words: Optional[List[str]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
class SimpleSense(BaseModel):
|
||||||
|
"""简化的义项模型"""
|
||||||
|
id: Optional[str] = None
|
||||||
|
image: Optional[str] = None
|
||||||
|
number: Optional[str] = None
|
||||||
|
grammar: Optional[str] = None
|
||||||
|
ref_hwd: Optional[str] = None
|
||||||
|
examples: Optional[List[SimpleExample]] = Field(default_factory=list)
|
||||||
|
definitions: Optional[List[SimpleDefinition]] = Field(default_factory=list)
|
||||||
|
signpost_cn: Optional[str] = None
|
||||||
|
signpost_en: Optional[str] = None
|
||||||
|
cross_references: Optional[List[SimpleCrossReference]] = Field(default_factory=list)
|
||||||
|
countability: Optional[List[str]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
class SimpleFrequency(BaseModel):
|
||||||
|
"""简化的频率信息模型"""
|
||||||
|
level_tag: Optional[str] = None
|
||||||
|
spoken_tag: Optional[str] = None
|
||||||
|
written_tag: Optional[str] = None
|
||||||
|
|
||||||
|
class SimplePronunciation(BaseModel):
|
||||||
|
"""简化的发音模型"""
|
||||||
|
uk_ipa: Optional[str] = None
|
||||||
|
us_ipa: Optional[str] = None
|
||||||
|
uk_audio: Optional[str] = None
|
||||||
|
us_audio: Optional[str] = None
|
||||||
|
|
||||||
|
class SimpleDictEntry(BaseModel):
|
||||||
|
"""字典单词查询响应模型(缩减版 DictEntry)"""
|
||||||
|
part_of_speech: Optional[str] = None # 例如: "noun", "verb"
|
||||||
|
transitive: Optional[List[str]] = Field(default_factory=list)
|
||||||
|
senses: Optional[List[SimpleSense]] = Field(default_factory=list)
|
||||||
|
frequency: Optional[SimpleFrequency] = None
|
||||||
|
pronunciations: Optional[SimplePronunciation] = None
|
||||||
|
|
||||||
|
class DictWordResponse(BaseModel):
|
||||||
|
dict_list: List[SimpleDictEntry] = Field(default_factory=list)
|
||||||
|
etymology: Optional[Etymology] = None
|
||||||
|
|
||||||
57
backend/app/admin/schema/feedback.py
Executable file
57
backend/app/admin/schema/feedback.py
Executable file
@@ -0,0 +1,57 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from backend.common.schema import SchemaBase
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackStatus(str, Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
PROCESSING = "processing"
|
||||||
|
RESOLVED = "resolved"
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackCategory(str, Enum):
|
||||||
|
BUG = "bug"
|
||||||
|
FEATURE = "feature"
|
||||||
|
COMPLIMENT = "compliment"
|
||||||
|
OTHER = "other"
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackSchemaBase(SchemaBase):
|
||||||
|
content: str
|
||||||
|
contact_info: Optional[str] = None
|
||||||
|
category: Optional[FeedbackCategory] = None
|
||||||
|
metadata_info: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CreateFeedbackParam(FeedbackSchemaBase):
|
||||||
|
"""创建反馈参数"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateFeedbackParam(FeedbackSchemaBase):
|
||||||
|
"""更新反馈参数"""
|
||||||
|
status: Optional[FeedbackStatus] = None
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackInfoSchema(FeedbackSchemaBase):
|
||||||
|
"""反馈信息Schema"""
|
||||||
|
id: int
|
||||||
|
user_id: int
|
||||||
|
status: FeedbackStatus
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackUserInfoSchema(FeedbackInfoSchema):
|
||||||
|
"""包含用户信息的反馈Schema"""
|
||||||
|
user: Optional[dict] = None # 简化的用户信息
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
52
backend/app/admin/schema/file.py
Executable file
52
backend/app/admin/schema/file.py
Executable file
@@ -0,0 +1,52 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.sql.sqltypes import BigInteger
|
||||||
|
|
||||||
|
from backend.common.schema import SchemaBase
|
||||||
|
|
||||||
|
|
||||||
|
class FileMetadata(BaseModel):
|
||||||
|
"""文件元数据结构"""
|
||||||
|
file_name: Optional[str] = None
|
||||||
|
content_type: Optional[str] = None
|
||||||
|
file_size: int # 文件大小(字节)
|
||||||
|
user_agent: Optional[str] = None # 上传客户端信息
|
||||||
|
extra: Optional[Dict[str, Any]] = None # 其他自定义元数据
|
||||||
|
|
||||||
|
|
||||||
|
class FileSchemaBase(SchemaBase):
|
||||||
|
file_hash: str
|
||||||
|
file_name: str
|
||||||
|
content_type: Optional[str] = None
|
||||||
|
file_size: int
|
||||||
|
storage_type: str = "database"
|
||||||
|
storage_path: Optional[str] = None
|
||||||
|
metadata_info: Optional[FileMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AddFileParam(FileSchemaBase):
|
||||||
|
"""添加文件参数"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateFileParam(FileSchemaBase):
|
||||||
|
"""更新文件参数"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FileInfoSchema(FileSchemaBase):
|
||||||
|
"""文件信息Schema"""
|
||||||
|
id: int
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class FileUploadResponse(SchemaBase):
|
||||||
|
"""文件上传响应"""
|
||||||
|
id: str
|
||||||
|
file_hash: str
|
||||||
|
file_name: str
|
||||||
|
content_type: Optional[str] = None
|
||||||
|
file_size: int
|
||||||
45
backend/app/admin/schema/points.py
Normal file
45
backend/app/admin/schema/points.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class PointsSchema(BaseModel):
|
||||||
|
"""积分账户信息"""
|
||||||
|
id: int
|
||||||
|
user_id: int
|
||||||
|
balance: int = Field(default=0, description="当前积分余额")
|
||||||
|
total_earned: int = Field(default=0, description="累计获得积分")
|
||||||
|
total_spent: int = Field(default=0, description="累计消费积分")
|
||||||
|
expired_time: datetime = Field(default_factory=datetime.now, description="过期时间")
|
||||||
|
|
||||||
|
|
||||||
|
class PointsLogSchema(BaseModel):
|
||||||
|
"""积分变动日志"""
|
||||||
|
id: int
|
||||||
|
user_id: int
|
||||||
|
action: str = Field(description="动作:earn/spend")
|
||||||
|
amount: int = Field(description="变动数量")
|
||||||
|
balance_after: int = Field(description="变动后余额")
|
||||||
|
related_id: Optional[int] = Field(default=None, description="关联ID")
|
||||||
|
details: Optional[dict] = Field(default=None, description="附加信息")
|
||||||
|
created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
|
||||||
|
|
||||||
|
|
||||||
|
class AddPointsRequest(BaseModel):
|
||||||
|
"""增加积分请求"""
|
||||||
|
user_id: int = Field(description="用户ID")
|
||||||
|
amount: int = Field(gt=0, description="增加的积分数量")
|
||||||
|
extend_expiration: bool = Field(default=False, description="是否自动延期过期时间")
|
||||||
|
related_id: Optional[int] = Field(default=None, description="关联ID")
|
||||||
|
details: Optional[dict] = Field(default=None, description="附加信息")
|
||||||
|
|
||||||
|
|
||||||
|
class DeductPointsRequest(BaseModel):
|
||||||
|
"""扣减积分请求"""
|
||||||
|
user_id: int = Field(description="用户ID")
|
||||||
|
amount: int = Field(gt=0, description="扣减的积分数量")
|
||||||
|
related_id: Optional[int] = Field(default=None, description="关联ID")
|
||||||
|
details: Optional[dict] = Field(default=None, description="附加信息")
|
||||||
25
backend/app/admin/schema/pydantic_type.py
Executable file
25
backend/app/admin/schema/pydantic_type.py
Executable file
@@ -0,0 +1,25 @@
|
|||||||
|
from sqlalchemy import Column, BigInteger, String, Text
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from pgvector.sqlalchemy import Vector
|
||||||
|
from sqlalchemy.types import TypeDecorator
|
||||||
|
|
||||||
|
from backend.utils.json_encoder import jsonable_encoder
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticType(TypeDecorator):
|
||||||
|
"""处理 Pydantic 模型的 SQLAlchemy 自定义类型"""
|
||||||
|
impl = JSONB
|
||||||
|
|
||||||
|
def __init__(self, pydantic_type=None, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.pydantic_type = pydantic_type
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return jsonable_encoder(value)
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
if value is None or self.pydantic_type is None:
|
||||||
|
return value
|
||||||
|
return self.pydantic_type(**value)
|
||||||
25
backend/app/admin/schema/qwen.py
Executable file
25
backend/app/admin/schema/qwen.py
Executable file
@@ -0,0 +1,25 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from backend.common.schema import SchemaBase
|
||||||
|
|
||||||
|
|
||||||
|
class QwenParamBase(SchemaBase):
|
||||||
|
""" Qwen Base Params """
|
||||||
|
user_id: int = Field(description="user id")
|
||||||
|
image_id: int = Field(description="image id")
|
||||||
|
data: str = Field(description='Base64')
|
||||||
|
file_name: str = Field(description='文件名')
|
||||||
|
format: str = Field(description='图片后缀')
|
||||||
|
dict_level: str = Field(description='dict level')
|
||||||
|
|
||||||
|
|
||||||
|
class QwenEmbedImageParams(QwenParamBase):
|
||||||
|
""" Embedding Image Params """
|
||||||
|
|
||||||
|
|
||||||
|
class QwenRecognizeImageParams(QwenParamBase):
|
||||||
|
""" Recognize image Params """
|
||||||
|
type: str = Field(description='识别类型')
|
||||||
|
exclude_words: list[str] = Field(description='exclude words')
|
||||||
57
backend/app/admin/schema/token.py
Executable file
57
backend/app/admin/schema/token.py
Executable file
@@ -0,0 +1,57 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from backend.app.admin.schema.user import GetUserInfoDetail
|
||||||
|
from backend.common.enums import StatusType
|
||||||
|
from backend.common.schema import SchemaBase
|
||||||
|
|
||||||
|
|
||||||
|
class GetSwaggerToken(SchemaBase):
|
||||||
|
"""Swagger 认证令牌"""
|
||||||
|
|
||||||
|
access_token: str = Field(description='访问令牌')
|
||||||
|
token_type: str = Field('Bearer', description='令牌类型')
|
||||||
|
user: GetUserInfoDetail = Field(description='用户信息')
|
||||||
|
|
||||||
|
|
||||||
|
class AccessTokenBase(SchemaBase):
|
||||||
|
"""访问令牌基础模型"""
|
||||||
|
|
||||||
|
access_token: str = Field(description='访问令牌')
|
||||||
|
access_token_expire_time: datetime = Field(description='令牌过期时间')
|
||||||
|
session_uuid: str = Field(description='会话 UUID')
|
||||||
|
|
||||||
|
|
||||||
|
class GetNewToken(AccessTokenBase):
|
||||||
|
"""获取新令牌"""
|
||||||
|
|
||||||
|
|
||||||
|
class GetLoginToken(AccessTokenBase):
|
||||||
|
"""获取登录令牌"""
|
||||||
|
|
||||||
|
user: GetUserInfoDetail = Field(description='用户信息')
|
||||||
|
|
||||||
|
|
||||||
|
class GetWxLoginToken(AccessTokenBase):
|
||||||
|
"""微信登录令牌"""
|
||||||
|
dict_level: Optional[str] = Field(None, description="词典等级")
|
||||||
|
|
||||||
|
|
||||||
|
class GetTokenDetail(SchemaBase):
|
||||||
|
"""令牌详情"""
|
||||||
|
|
||||||
|
id: int = Field(description='用户 ID')
|
||||||
|
session_uuid: str = Field(description='会话 UUID')
|
||||||
|
username: str = Field(description='用户名')
|
||||||
|
nickname: str = Field(description='昵称')
|
||||||
|
ip: str = Field(description='IP 地址')
|
||||||
|
os: str = Field(description='操作系统')
|
||||||
|
browser: str = Field(description='浏览器')
|
||||||
|
device: str = Field(description='设备')
|
||||||
|
status: StatusType = Field(description='状态')
|
||||||
|
last_login_time: str = Field(description='最后登录时间')
|
||||||
|
expire_time: datetime = Field(description='过期时间')
|
||||||
12
backend/app/admin/schema/usage.py
Executable file
12
backend/app/admin/schema/usage.py
Executable file
@@ -0,0 +1,12 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
class PurchaseRequest(BaseModel):
|
||||||
|
amount_cents: int = Field(..., ge=100, le=10000000, description="充值金额(分),1元=100分")
|
||||||
|
|
||||||
|
class SubscriptionRequest(BaseModel):
|
||||||
|
plan: str = Field(..., pattern=r'^(monthly|quarterly|half_yearly|yearly)$', description="订阅计划")
|
||||||
|
|
||||||
|
class RefundRequest(BaseModel):
|
||||||
|
order_id: int = Field(..., gt=0, description="订单ID")
|
||||||
|
reason: Optional[str] = Field(None, max_length=200, description="退款原因")
|
||||||
48
backend/app/admin/schema/user.py
Executable file
48
backend/app/admin/schema/user.py
Executable file
@@ -0,0 +1,48 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import Field, EmailStr, ConfigDict, HttpUrl
|
||||||
|
|
||||||
|
from backend.common.schema import SchemaBase, CustomPhoneNumber
|
||||||
|
|
||||||
|
class AuthSchemaBase(SchemaBase):
|
||||||
|
username: str = Field(description='用户名')
|
||||||
|
password: str = Field(description='密码')
|
||||||
|
|
||||||
|
|
||||||
|
class AuthLoginParam(AuthSchemaBase):
|
||||||
|
captcha: str = Field(description='验证码')
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterUserParam(AuthSchemaBase):
|
||||||
|
email: EmailStr = Field(examples=['user@example.com'], description='邮箱')
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateUserParam(SchemaBase):
|
||||||
|
username: str = Field(description='用户名')
|
||||||
|
email: EmailStr = Field(examples=['user@example.com'], description='邮箱')
|
||||||
|
phone: CustomPhoneNumber | None = Field(None, description='手机号')
|
||||||
|
|
||||||
|
|
||||||
|
class AvatarParam(SchemaBase):
|
||||||
|
url: HttpUrl = Field(..., description='头像 http 地址')
|
||||||
|
|
||||||
|
|
||||||
|
class GetUserInfoDetail(UpdateUserParam):
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: int = Field(description='用户 ID')
|
||||||
|
uuid: str = Field(description='用户 UUID')
|
||||||
|
avatar: str | None = Field(None, description='头像')
|
||||||
|
status: int = Field(description='状态')
|
||||||
|
is_superuser: bool = Field(description='是否超级管理员')
|
||||||
|
join_time: datetime = Field(description='加入时间')
|
||||||
|
last_login_time: datetime | None = Field(None, description='最后登录时间')
|
||||||
|
|
||||||
|
|
||||||
|
class ResetPassword(SchemaBase):
|
||||||
|
username: str = Field(description='用户名')
|
||||||
|
old_password: str = Field(description='旧密码')
|
||||||
|
new_password: str = Field(description='新密码')
|
||||||
|
confirm_password: str = Field(description='确认密码')
|
||||||
70
backend/app/admin/schema/wx.py
Executable file
70
backend/app/admin/schema/wx.py
Executable file
@@ -0,0 +1,70 @@
|
|||||||
|
# schemas.py
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from backend.common.schema import SchemaBase
|
||||||
|
|
||||||
|
|
||||||
|
class WxUserBase(SchemaBase):
|
||||||
|
id: int = Field(description='User ID')
|
||||||
|
openid: str = Field(description="微信 user openid")
|
||||||
|
|
||||||
|
|
||||||
|
class GetWxUserInfoDetail(WxUserBase):
|
||||||
|
"""用户信息详情"""
|
||||||
|
|
||||||
|
|
||||||
|
class GetWxUserInfoWithRelationDetail(GetWxUserInfoDetail):
|
||||||
|
"""用户信息关联详情"""
|
||||||
|
|
||||||
|
|
||||||
|
class WxLoginRequest(BaseModel):
|
||||||
|
code: str = Field(..., description="微信登录code")
|
||||||
|
appid: Optional[str] = Field(None, description="微信 appid")
|
||||||
|
encrypted_data: Optional[str] = Field(None, description="加密的用户数据")
|
||||||
|
iv: Optional[str] = Field(None, description="加密算法的初始向量")
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
access_token: str = Field(..., description="访问令牌")
|
||||||
|
refresh_token: str = Field(..., description="刷新令牌")
|
||||||
|
token_type: str = Field("bearer", description="令牌类型")
|
||||||
|
expires_in: int = Field(..., description="过期时间(秒)")
|
||||||
|
|
||||||
|
|
||||||
|
class UserBase(BaseModel):
|
||||||
|
id: int = Field(..., description="用户ID")
|
||||||
|
openid: str = Field(..., description="微信OpenID")
|
||||||
|
|
||||||
|
|
||||||
|
class UserInfo(UserBase):
|
||||||
|
mobile: Optional[str] = Field(None, description="手机号")
|
||||||
|
profile: Optional[dict] = Field({}, description="用户资料")
|
||||||
|
created_at: datetime = Field(..., description="创建时间")
|
||||||
|
|
||||||
|
|
||||||
|
class UserAuth(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
expires_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class DictLevel(str, Enum):
|
||||||
|
LEVEL1 = "LEVEL1" # "小学"
|
||||||
|
LEVEL2 = "LEVEL2" # "初高中"
|
||||||
|
LEVEL3 = "LEVEL3" # "四六级"
|
||||||
|
|
||||||
|
|
||||||
|
class UserSettings(BaseModel):
|
||||||
|
dict_level: Optional[DictLevel] = Field(None, description="词典等级")
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateUserSettingsRequest(UserSettings):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GetUserSettingsResponse(UserSettings):
|
||||||
|
pass
|
||||||
4
backend/app/admin/service/__init__.py
Executable file
4
backend/app/admin/service/__init__.py
Executable file
@@ -0,0 +1,4 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from backend.app.admin.service.points_service import PointsService
|
||||||
247
backend/app/admin/service/ad_share_service.py
Executable file
247
backend/app/admin/service/ad_share_service.py
Executable file
@@ -0,0 +1,247 @@
|
|||||||
|
from backend.app.admin.crud.user_account_crud import user_account_dao
|
||||||
|
from backend.app.admin.crud.usage_log_crud import usage_log_dao
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
from backend.common.exception import errors
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from backend.common.log import log as logger
|
||||||
|
|
||||||
|
# 导入 Redis 客户端
|
||||||
|
from backend.database.redis import redis_client
|
||||||
|
|
||||||
|
|
||||||
|
class AdShareService:
|
||||||
|
# 每日限制配置
|
||||||
|
DAILY_AD_LIMIT = 5 # 每日广告观看限制
|
||||||
|
DAILY_SHARE_LIMIT = 3 # 每日分享限制
|
||||||
|
AD_REWARD_TIMES = 3 # 每次广告奖励次数
|
||||||
|
SHARE_REWARD_TIMES = 3 # 每次分享奖励次数
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_redis_key(user_id: int, action_type: str, date_str: str = None) -> str:
|
||||||
|
"""
|
||||||
|
生成 Redis key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
action_type: 动作类型 (ad/share)
|
||||||
|
date_str: 日期字符串 (YYYY-MM-DD)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Redis key
|
||||||
|
"""
|
||||||
|
if date_str is None:
|
||||||
|
date_str = datetime.now().strftime('%Y-%m-%d')
|
||||||
|
return f"user:{user_id}:{action_type}:count:{date_str}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _check_daily_limit(user_id: int, action_type: str, limit: int) -> bool:
|
||||||
|
"""
|
||||||
|
检查每日限制
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
action_type: 动作类型
|
||||||
|
limit: 限制次数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否超过限制
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 确保 Redis 连接
|
||||||
|
await redis_client.ping()
|
||||||
|
|
||||||
|
# 获取今天的计数
|
||||||
|
key = AdShareService._get_redis_key(user_id, action_type)
|
||||||
|
current_count = await redis_client.get(key)
|
||||||
|
|
||||||
|
if current_count is None:
|
||||||
|
current_count = 0
|
||||||
|
else:
|
||||||
|
current_count = int(current_count)
|
||||||
|
|
||||||
|
return current_count >= limit
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检查每日限制失败: {str(e)}")
|
||||||
|
# Redis 异常时允许继续操作(避免服务中断)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _increment_daily_count(user_id: int, action_type: str) -> int:
|
||||||
|
"""
|
||||||
|
增加每日计数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
action_type: 动作类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 当前计数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 确保 Redis 连接
|
||||||
|
await redis_client.ping()
|
||||||
|
|
||||||
|
key = AdShareService._get_redis_key(user_id, action_type)
|
||||||
|
|
||||||
|
# 增加计数
|
||||||
|
current_count = await redis_client.incr(key)
|
||||||
|
|
||||||
|
# 设置过期时间(第二天自动清除)
|
||||||
|
tomorrow = datetime.now() + timedelta(days=1)
|
||||||
|
tomorrow_midnight = tomorrow.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
expire_seconds = int((tomorrow_midnight - datetime.now()).total_seconds())
|
||||||
|
|
||||||
|
if expire_seconds > 0:
|
||||||
|
await redis_client.expire(key, expire_seconds)
|
||||||
|
|
||||||
|
return current_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"增加每日计数失败: {str(e)}")
|
||||||
|
raise errors.ServerError(msg="系统繁忙,请稍后重试")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _get_daily_count(user_id: int, action_type: str) -> int:
|
||||||
|
"""
|
||||||
|
获取今日计数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
action_type: 动作类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 当前计数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 确保 Redis 连接
|
||||||
|
await redis_client.ping()
|
||||||
|
|
||||||
|
key = AdShareService._get_redis_key(user_id, action_type)
|
||||||
|
count = await redis_client.get(key)
|
||||||
|
|
||||||
|
return int(count) if count is not None else 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取每日计数失败: {str(e)}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def grant_times_by_ad(user_id: int):
|
||||||
|
"""
|
||||||
|
通过观看广告获得次数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
"""
|
||||||
|
# 检查每日限制
|
||||||
|
if await AdShareService._check_daily_limit(user_id, "ad", AdShareService.DAILY_AD_LIMIT):
|
||||||
|
raise errors.ForbiddenError(msg=f"今日广告观看次数已达上限({AdShareService.DAILY_AD_LIMIT}次)")
|
||||||
|
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
try:
|
||||||
|
# 增加计数
|
||||||
|
current_count = await AdShareService._increment_daily_count(user_id, "ad")
|
||||||
|
|
||||||
|
# 增加用户余额
|
||||||
|
result = await user_account_dao.update_balance_atomic(db, user_id, AdShareService.AD_REWARD_TIMES)
|
||||||
|
if not result:
|
||||||
|
# 回滚计数
|
||||||
|
key = AdShareService._get_redis_key(user_id, "ad")
|
||||||
|
await redis_client.decr(key)
|
||||||
|
raise errors.ServerError(msg="账户更新失败")
|
||||||
|
|
||||||
|
# 记录使用日志
|
||||||
|
account = await user_account_dao.get_by_user_id(db, user_id)
|
||||||
|
await usage_log_dao.add(db, {
|
||||||
|
"user_id": user_id,
|
||||||
|
"action": "ad",
|
||||||
|
"amount": AdShareService.AD_REWARD_TIMES,
|
||||||
|
"balance_after": account.balance if account else AdShareService.AD_REWARD_TIMES,
|
||||||
|
"metadata_": {
|
||||||
|
"daily_count": current_count,
|
||||||
|
"max_limit": AdShareService.DAILY_AD_LIMIT
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if not isinstance(e, errors.ForbiddenError):
|
||||||
|
logger.error(f"广告奖励处理失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def grant_times_by_share(user_id: int):
|
||||||
|
"""
|
||||||
|
通过分享获得次数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
"""
|
||||||
|
# 检查每日限制
|
||||||
|
if await AdShareService._check_daily_limit(user_id, "share", AdShareService.DAILY_SHARE_LIMIT):
|
||||||
|
raise errors.ForbiddenError(msg=f"今日分享次数已达上限({AdShareService.DAILY_SHARE_LIMIT}次)")
|
||||||
|
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
try:
|
||||||
|
# 增加计数
|
||||||
|
current_count = await AdShareService._increment_daily_count(user_id, "share")
|
||||||
|
|
||||||
|
# 增加用户余额
|
||||||
|
result = await user_account_dao.update_balance_atomic(db, user_id, AdShareService.SHARE_REWARD_TIMES)
|
||||||
|
if not result:
|
||||||
|
# 回滚计数
|
||||||
|
key = AdShareService._get_redis_key(user_id, "share")
|
||||||
|
await redis_client.decr(key)
|
||||||
|
raise errors.ServerError(msg="账户更新失败")
|
||||||
|
|
||||||
|
# 记录使用日志
|
||||||
|
account = await user_account_dao.get_by_user_id(db, user_id)
|
||||||
|
await usage_log_dao.add(db, {
|
||||||
|
"user_id": user_id,
|
||||||
|
"action": "share",
|
||||||
|
"amount": AdShareService.SHARE_REWARD_TIMES,
|
||||||
|
"balance_after": account.balance if account else AdShareService.SHARE_REWARD_TIMES,
|
||||||
|
"metadata_": {
|
||||||
|
"daily_count": current_count,
|
||||||
|
"max_limit": AdShareService.DAILY_SHARE_LIMIT
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if not isinstance(e, errors.ForbiddenError):
|
||||||
|
logger.error(f"分享奖励处理失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_daily_stats(user_id: int) -> dict:
|
||||||
|
"""
|
||||||
|
获取用户今日统计信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 统计信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ad_count = await AdShareService._get_daily_count(user_id, "ad")
|
||||||
|
share_count = await AdShareService._get_daily_count(user_id, "share")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ad_count": ad_count,
|
||||||
|
"ad_limit": AdShareService.DAILY_AD_LIMIT,
|
||||||
|
"share_count": share_count,
|
||||||
|
"share_limit": AdShareService.DAILY_SHARE_LIMIT,
|
||||||
|
"can_watch_ad": ad_count < AdShareService.DAILY_AD_LIMIT,
|
||||||
|
"can_share": share_count < AdShareService.DAILY_SHARE_LIMIT
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取每日统计失败: {str(e)}")
|
||||||
|
return {
|
||||||
|
"ad_count": 0,
|
||||||
|
"ad_limit": AdShareService.DAILY_AD_LIMIT,
|
||||||
|
"share_count": 0,
|
||||||
|
"share_limit": AdShareService.DAILY_SHARE_LIMIT,
|
||||||
|
"can_watch_ad": True,
|
||||||
|
"can_share": True
|
||||||
|
}
|
||||||
146
backend/app/admin/service/audit_log_service.py
Executable file
146
backend/app/admin/service/audit_log_service.py
Executable file
@@ -0,0 +1,146 @@
|
|||||||
|
from backend.app.admin.crud.audit_log_crud import audit_log_dao
|
||||||
|
from backend.app.admin.crud.daily_summary_crud import daily_summary_dao
|
||||||
|
from backend.app.admin.schema.audit_log import CreateAuditLogParam, AuditLogHistorySchema, DailySummarySchema
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
from typing import List, Tuple
|
||||||
|
from datetime import datetime, date
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLogService:
|
||||||
|
""" Audit Log Service """
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create(*, obj: CreateAuditLogParam) -> None:
|
||||||
|
"""
|
||||||
|
创建操作日志
|
||||||
|
|
||||||
|
:param obj: 操作日志创建参数
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
await audit_log_dao.create(db, obj)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_recognition_history(
|
||||||
|
*,
|
||||||
|
user_id: int,
|
||||||
|
page: int = 1,
|
||||||
|
size: int = 20
|
||||||
|
) -> Tuple[List[AuditLogHistorySchema], int]:
|
||||||
|
"""
|
||||||
|
通过用户ID查询历史记录
|
||||||
|
|
||||||
|
:param user_id: 用户ID
|
||||||
|
:param page: 页码
|
||||||
|
:param size: 每页数量
|
||||||
|
:return: 历史记录列表和总数
|
||||||
|
"""
|
||||||
|
async with async_db_session() as db:
|
||||||
|
items, total = await audit_log_dao.get_user_recognition_history(db, user_id, page, size)
|
||||||
|
|
||||||
|
# 转换为 schema 对象
|
||||||
|
history_items = [
|
||||||
|
AuditLogHistorySchema(
|
||||||
|
image_id=str(item.id),
|
||||||
|
file_id=str(item.file_id),
|
||||||
|
thumbnail_id=str(item.thumbnail_id),
|
||||||
|
created_time=item.created_time.strftime("%Y-%m-%d %H:%M:%S") if item.created_time else None,
|
||||||
|
dict_level=item.dict_level
|
||||||
|
)
|
||||||
|
for item in items
|
||||||
|
]
|
||||||
|
|
||||||
|
return history_items, total
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_recognition_statistics(*, user_id: int) -> Tuple[int, int, int]:
|
||||||
|
"""
|
||||||
|
统计用户 recognition 类型的使用记录
|
||||||
|
返回历史总量和当天总量
|
||||||
|
|
||||||
|
:param user_id: 用户ID
|
||||||
|
:return: (历史总量, 当天总量)
|
||||||
|
"""
|
||||||
|
async with async_db_session() as db:
|
||||||
|
total_count, today_count, image_count = await audit_log_dao.get_user_recognition_statistics(db, user_id)
|
||||||
|
return total_count, today_count, image_count
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_daily_summaries(
|
||||||
|
*,
|
||||||
|
user_id: int,
|
||||||
|
page: int = 1,
|
||||||
|
size: int = 20
|
||||||
|
) -> Tuple[List[DailySummarySchema], int]:
|
||||||
|
"""
|
||||||
|
通过用户ID查询每日识别汇总记录,按创建时间降序排列
|
||||||
|
|
||||||
|
:param user_id: 用户ID
|
||||||
|
:param page: 页码
|
||||||
|
:param size: 每页数量
|
||||||
|
:return: 每日汇总记录列表和总数
|
||||||
|
"""
|
||||||
|
async with async_db_session() as db:
|
||||||
|
items, total = await daily_summary_dao.get_user_daily_summaries(db, user_id, page, size)
|
||||||
|
|
||||||
|
# 转换为 schema 对象
|
||||||
|
summary_items = [
|
||||||
|
DailySummarySchema(
|
||||||
|
image_ids=item.image_ids,
|
||||||
|
thumbnail_ids=item.thumbnail_ids,
|
||||||
|
summary_time=str(item.summary_time.date())
|
||||||
|
)
|
||||||
|
for item in items
|
||||||
|
]
|
||||||
|
|
||||||
|
return summary_items, total
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_today_summaries(
|
||||||
|
*,
|
||||||
|
user_id: int,
|
||||||
|
page: int = 1,
|
||||||
|
size: int = 20
|
||||||
|
) -> Tuple[List[DailySummarySchema], int]:
|
||||||
|
"""
|
||||||
|
获取用户当天的识别记录摘要
|
||||||
|
查询当天时间内,AuditLog.api_type == 'recognition' 的所有记录,
|
||||||
|
获取相关的图片和缩略图,构成返回结构中的数据
|
||||||
|
|
||||||
|
:param user_id: 用户ID
|
||||||
|
:param page: 页码
|
||||||
|
:param size: 每页数量
|
||||||
|
:return: 当天识别记录摘要列表和总数
|
||||||
|
"""
|
||||||
|
async with async_db_session() as db:
|
||||||
|
# 获取当天的识别记录
|
||||||
|
items, total = await audit_log_dao.get_user_today_recognition_history(db, user_id, page, size)
|
||||||
|
|
||||||
|
# 如果没有记录,返回空列表
|
||||||
|
if not items:
|
||||||
|
return [], total
|
||||||
|
|
||||||
|
# 提取图片ID和缩略图ID
|
||||||
|
image_ids = []
|
||||||
|
thumbnail_ids = []
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
if item.id:
|
||||||
|
image_ids.append(str(item.id))
|
||||||
|
if item.thumbnail_id:
|
||||||
|
thumbnail_ids.append(str(item.thumbnail_id))
|
||||||
|
|
||||||
|
# 构建返回数据
|
||||||
|
# 由于是当天的数据,所有记录都属于同一天
|
||||||
|
today = date.today()
|
||||||
|
summary_item = DailySummarySchema(
|
||||||
|
image_ids=image_ids,
|
||||||
|
thumbnail_ids=thumbnail_ids,
|
||||||
|
summary_time=str(today)
|
||||||
|
)
|
||||||
|
|
||||||
|
return [summary_item], total
|
||||||
|
|
||||||
|
|
||||||
|
audit_log_service: AuditLogService = AuditLogService()
|
||||||
171
backend/app/admin/service/audit_service.py
Executable file
171
backend/app/admin/service/audit_service.py
Executable file
@@ -0,0 +1,171 @@
|
|||||||
|
# audit_service.py
|
||||||
|
from sqlalchemy import func, extract, case, Integer
|
||||||
|
from sqlalchemy.sql import label
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from backend.app.admin.model.audit_log import AuditLog
|
||||||
|
from backend.core.database import SessionLocal
|
||||||
|
from backend.common.log import log as logger
|
||||||
|
|
||||||
|
|
||||||
|
class AuditService:
|
||||||
|
@staticmethod
|
||||||
|
def get_audit_logs(page: int = 1, page_size: int = 20,
|
||||||
|
filters: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||||
|
"""获取审计日志(分页)"""
|
||||||
|
with SessionLocal() as db:
|
||||||
|
query = db.query(AuditLog)
|
||||||
|
|
||||||
|
# 应用过滤条件
|
||||||
|
if filters:
|
||||||
|
if filters.get("api_type"):
|
||||||
|
query = query.filter(AuditLog.api_type == filters["api_type"])
|
||||||
|
if filters.get("model_name"):
|
||||||
|
query = query.filter(AuditLog.model_name == filters["model_name"])
|
||||||
|
if filters.get("status_code"):
|
||||||
|
query = query.filter(AuditLog.status_code == filters["status_code"])
|
||||||
|
if filters.get("start_date"):
|
||||||
|
query = query.filter(AuditLog.called_at >= filters["start_date"])
|
||||||
|
if filters.get("end_date"):
|
||||||
|
query = query.filter(AuditLog.called_at <= filters["end_date"])
|
||||||
|
if filters.get("user_id"):
|
||||||
|
query = query.filter(AuditLog.user_id == filters["user_id"])
|
||||||
|
|
||||||
|
# 分页处理
|
||||||
|
total = query.count()
|
||||||
|
logs = query.order_by(AuditLog.called_at.desc()) \
|
||||||
|
.offset((page - 1) * page_size) \
|
||||||
|
.limit(page_size) \
|
||||||
|
.all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"page_size": page_size,
|
||||||
|
"results": [log.to_dict() for log in logs]
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_usage_statistics(time_range: str = "daily") -> Dict[str, Any]:
|
||||||
|
"""获取API使用统计"""
|
||||||
|
with SessionLocal() as db:
|
||||||
|
# 根据时间范围确定分组方式
|
||||||
|
if time_range == "hourly":
|
||||||
|
time_group = func.date_trunc('hour', AuditLog.called_at)
|
||||||
|
elif time_range == "weekly":
|
||||||
|
time_group = func.date_trunc('week', AuditLog.called_at)
|
||||||
|
elif time_range == "monthly":
|
||||||
|
time_group = func.date_trunc('month', AuditLog.called_at)
|
||||||
|
else: # daily
|
||||||
|
time_group = func.date_trunc('day', AuditLog.called_at)
|
||||||
|
|
||||||
|
# 基本统计
|
||||||
|
stats = db.query(
|
||||||
|
time_group.label("time_period"),
|
||||||
|
AuditLog.api_type,
|
||||||
|
AuditLog.model_name,
|
||||||
|
func.count().label("request_count"),
|
||||||
|
func.sum(AuditLog.cost).label("total_cost"),
|
||||||
|
func.avg(AuditLog.duration).label("avg_duration"),
|
||||||
|
func.sum(
|
||||||
|
case(
|
||||||
|
(AuditLog.status_code >= 200, 1),
|
||||||
|
else_=0
|
||||||
|
)
|
||||||
|
).label("success_count"),
|
||||||
|
func.sum(
|
||||||
|
case(
|
||||||
|
(AuditLog.status_code >= 400, 1),
|
||||||
|
else_=0
|
||||||
|
)
|
||||||
|
).label("error_count")
|
||||||
|
).group_by(
|
||||||
|
"time_period",
|
||||||
|
AuditLog.api_type,
|
||||||
|
AuditLog.model_name
|
||||||
|
).order_by(
|
||||||
|
"time_period"
|
||||||
|
).all()
|
||||||
|
|
||||||
|
# Token使用统计
|
||||||
|
token_stats = db.query(
|
||||||
|
time_group.label("time_period"),
|
||||||
|
AuditLog.api_type,
|
||||||
|
AuditLog.model_name,
|
||||||
|
func.sum(AuditLog.token_usage['input_tokens'].astext.cast(Integer)).label("input_tokens"),
|
||||||
|
func.sum(AuditLog.token_usage['output_tokens'].astext.cast(Integer)).label("output_tokens"),
|
||||||
|
func.sum(AuditLog.token_usage['total_tokens'].astext.cast(Integer)).label("total_tokens")
|
||||||
|
).filter(
|
||||||
|
AuditLog.token_usage != None
|
||||||
|
).group_by(
|
||||||
|
"time_period",
|
||||||
|
AuditLog.api_type,
|
||||||
|
AuditLog.model_name
|
||||||
|
).order_by(
|
||||||
|
"time_period"
|
||||||
|
).all()
|
||||||
|
|
||||||
|
# 转换结果
|
||||||
|
return {
|
||||||
|
"usage_stats": [
|
||||||
|
{
|
||||||
|
"time_period": s.time_period,
|
||||||
|
"api_type": s.api_type,
|
||||||
|
"model_name": s.model_name,
|
||||||
|
"request_count": s.request_count,
|
||||||
|
"total_cost": float(s.total_cost) if s.total_cost else 0.0,
|
||||||
|
"avg_duration": float(s.avg_duration) if s.avg_duration else 0.0,
|
||||||
|
"success_rate": s.success_count / s.request_count if s.request_count else 0.0,
|
||||||
|
"error_rate": s.error_count / s.request_count if s.request_count else 0.0
|
||||||
|
} for s in stats
|
||||||
|
],
|
||||||
|
"token_stats": [
|
||||||
|
{
|
||||||
|
"time_period": t.time_period,
|
||||||
|
"api_type": t.api_type,
|
||||||
|
"model_name": t.model_name,
|
||||||
|
"input_tokens": t.input_tokens,
|
||||||
|
"output_tokens": t.output_tokens,
|
||||||
|
"total_tokens": t.total_tokens
|
||||||
|
} for t in token_stats
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_cost_forecast() -> Dict[str, Any]:
|
||||||
|
"""预测未来成本"""
|
||||||
|
with SessionLocal() as db:
|
||||||
|
# 获取最近30天的成本数据
|
||||||
|
end_date = datetime.utcnow()
|
||||||
|
start_date = end_date - timedelta(days=30)
|
||||||
|
|
||||||
|
daily_costs = db.query(
|
||||||
|
func.date_trunc('day', AuditLog.called_at).label("day"),
|
||||||
|
func.sum(AuditLog.cost).label("daily_cost")
|
||||||
|
).filter(
|
||||||
|
AuditLog.called_at >= start_date,
|
||||||
|
AuditLog.called_at <= end_date
|
||||||
|
).group_by("day").order_by("day").all()
|
||||||
|
|
||||||
|
# 简单预测模型 (移动平均)
|
||||||
|
costs = [float(d.daily_cost) if d.daily_cost else 0.0 for d in daily_costs]
|
||||||
|
avg_cost = sum(costs) / len(costs) if costs else 0.0
|
||||||
|
|
||||||
|
# 生成预测 (未来7天)
|
||||||
|
forecast = []
|
||||||
|
for i in range(1, 8):
|
||||||
|
forecast_date = end_date + timedelta(days=i)
|
||||||
|
forecast.append({
|
||||||
|
"date": forecast_date.date(),
|
||||||
|
"predicted_cost": avg_cost
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"historical": [
|
||||||
|
{"date": d.day.date(), "cost": float(d.daily_cost) if d.daily_cost else 0.0}
|
||||||
|
for d in daily_costs
|
||||||
|
],
|
||||||
|
"forecast": forecast
|
||||||
|
}
|
||||||
57
backend/app/admin/service/auth_service.py
Executable file
57
backend/app/admin/service/auth_service.py
Executable file
@@ -0,0 +1,57 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from backend.app.admin.crud.wx_user_crud import wx_user_dao
|
||||||
|
from backend.app.admin.model import WxUser
|
||||||
|
from backend.app.admin.schema.token import GetLoginToken
|
||||||
|
from backend.app.admin.schema.user import AuthLoginParam
|
||||||
|
from backend.common.exception import errors
|
||||||
|
from backend.common.response.response_code import CustomErrorCode
|
||||||
|
from backend.common.security.jwt import password_verify, create_access_token
|
||||||
|
from backend.core.conf import settings
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
from backend.database.redis import redis_client
|
||||||
|
from backend.utils.timezone import timezone
|
||||||
|
|
||||||
|
|
||||||
|
class AuthService:
|
||||||
|
@staticmethod
|
||||||
|
async def user_verify(db: AsyncSession, username: str, password: str) -> WxUser:
|
||||||
|
user = await wx_user_dao.get_by_username(db, username)
|
||||||
|
if not user:
|
||||||
|
raise errors.NotFoundError(msg='用户名或密码有误')
|
||||||
|
elif not password_verify(password, user.password):
|
||||||
|
raise errors.AuthorizationError(msg='用户名或密码有误')
|
||||||
|
elif not user.status:
|
||||||
|
raise errors.AuthorizationError(msg='用户已被锁定, 请联系统管理员')
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def swagger_login(self, *, form_data: OAuth2PasswordRequestForm) -> tuple[str, WxUser]:
|
||||||
|
async with async_db_session() as db:
|
||||||
|
user = await self.user_verify(db, form_data.username, form_data.password)
|
||||||
|
await wx_user_dao.update_login_time(db, user.username, login_time=timezone.now())
|
||||||
|
token = create_access_token(str(user.id))
|
||||||
|
return token, user
|
||||||
|
|
||||||
|
async def login(self, *, request: Request, obj: AuthLoginParam) -> GetLoginToken:
|
||||||
|
async with async_db_session() as db:
|
||||||
|
user = await self.user_verify(db, obj.username, obj.password)
|
||||||
|
try:
|
||||||
|
captcha_uuid = request.app.state.captcha_uuid
|
||||||
|
redis_code = await redis_client.get(f'{settings.CAPTCHA_LOGIN_REDIS_PREFIX}:{captcha_uuid}')
|
||||||
|
if not redis_code:
|
||||||
|
raise errors.ForbiddenError(msg='验证码失效,请重新获取')
|
||||||
|
except AttributeError:
|
||||||
|
raise errors.ForbiddenError(msg='验证码失效,请重新获取')
|
||||||
|
if redis_code.lower() != obj.captcha.lower():
|
||||||
|
raise errors.CustomError(error=CustomErrorCode.CAPTCHA_ERROR)
|
||||||
|
await wx_user_dao.update_login_time(db, user.username, login_time=timezone.now())
|
||||||
|
token = create_access_token(str(user.id))
|
||||||
|
data = GetLoginToken(access_token=token, user=user)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
auth_service: AuthService = AuthService()
|
||||||
148
backend/app/admin/service/coupon_service.py
Executable file
148
backend/app/admin/service/coupon_service.py
Executable file
@@ -0,0 +1,148 @@
|
|||||||
|
import random
|
||||||
|
import string
|
||||||
|
from typing import List, Optional
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from backend.app.admin.crud.coupon_crud import coupon_dao, coupon_usage_dao
|
||||||
|
from backend.app.admin.model.coupon import Coupon
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
from backend.common.exception import errors
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
|
class CouponService:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_unique_code(length: int = 6) -> str:
|
||||||
|
"""
|
||||||
|
生成唯一的兑换码
|
||||||
|
"""
|
||||||
|
characters = string.ascii_uppercase + string.digits
|
||||||
|
# 移除容易混淆的字符
|
||||||
|
characters = characters.replace('0', '').replace('O', '').replace('I', '').replace('1')
|
||||||
|
|
||||||
|
while True:
|
||||||
|
code = ''.join(random.choice(characters) for _ in range(length))
|
||||||
|
# 确保生成的兑换码不包含敏感词汇或重复模式
|
||||||
|
if not CouponService._has_sensitive_pattern(code):
|
||||||
|
return code
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _has_sensitive_pattern(code: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查兑换码是否包含敏感模式
|
||||||
|
"""
|
||||||
|
# 简单的敏感词检查
|
||||||
|
sensitive_words = ['ADMIN', 'ROOT', 'USER', 'TEST']
|
||||||
|
for word in sensitive_words:
|
||||||
|
if word in code:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 检查是否为重复字符
|
||||||
|
if len(set(code)) == 1:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_coupon(duration: int, expires_days: Optional[int] = None) -> Coupon:
|
||||||
|
"""
|
||||||
|
创建单个兑换券
|
||||||
|
"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
# 生成唯一兑换码
|
||||||
|
code = CouponService.generate_unique_code()
|
||||||
|
|
||||||
|
# 确保兑换码唯一性
|
||||||
|
while await coupon_dao.get_by_code(db, code):
|
||||||
|
code = CouponService.generate_unique_code()
|
||||||
|
|
||||||
|
# 设置过期时间
|
||||||
|
expires_at = None
|
||||||
|
if expires_days:
|
||||||
|
expires_at = datetime.now() + timedelta(days=expires_days)
|
||||||
|
|
||||||
|
coupon_data = {
|
||||||
|
'code': code,
|
||||||
|
'duration': duration,
|
||||||
|
'expires_at': expires_at
|
||||||
|
}
|
||||||
|
|
||||||
|
coupon = await coupon_dao.create_coupon(db, coupon_data)
|
||||||
|
return coupon
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def batch_create_coupons(count: int, duration: int, expires_days: Optional[int] = None) -> List[Coupon]:
|
||||||
|
"""
|
||||||
|
批量创建兑换券
|
||||||
|
"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
coupons_data = []
|
||||||
|
|
||||||
|
# 生成唯一兑换码列表
|
||||||
|
codes = set()
|
||||||
|
while len(codes) < count:
|
||||||
|
code = CouponService.generate_unique_code()
|
||||||
|
if code not in codes:
|
||||||
|
# 检查数据库中是否已存在该兑换码
|
||||||
|
existing_coupon = await coupon_dao.get_by_code(db, code)
|
||||||
|
if not existing_coupon:
|
||||||
|
codes.add(code)
|
||||||
|
|
||||||
|
# 设置过期时间
|
||||||
|
expires_at = None
|
||||||
|
if expires_days:
|
||||||
|
expires_at = datetime.now() + timedelta(days=expires_days)
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
for code in codes:
|
||||||
|
coupons_data.append({
|
||||||
|
'code': code,
|
||||||
|
'duration': duration,
|
||||||
|
'expires_at': expires_at
|
||||||
|
})
|
||||||
|
|
||||||
|
coupons = await coupon_dao.create_coupons(db, coupons_data)
|
||||||
|
return coupons
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def redeem_coupon(code: str, user_id: int) -> dict:
|
||||||
|
"""
|
||||||
|
兑换兑换券
|
||||||
|
"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
# 获取未使用的兑换券
|
||||||
|
coupon = await coupon_dao.get_unused_coupon_by_code(db, code)
|
||||||
|
if not coupon:
|
||||||
|
raise errors.RequestError(msg='兑换码无效或已过期')
|
||||||
|
|
||||||
|
# 标记为已使用并创建使用记录
|
||||||
|
success = await coupon_dao.mark_as_used(db, coupon.id, user_id, coupon.duration)
|
||||||
|
if not success:
|
||||||
|
raise errors.ServerError(msg='兑换失败,请稍后重试')
|
||||||
|
|
||||||
|
return {
|
||||||
|
'code': coupon.code,
|
||||||
|
'duration': coupon.duration,
|
||||||
|
'used_at': datetime.now()
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_coupon_history(user_id: int, limit: int = 100) -> List[dict]:
|
||||||
|
"""
|
||||||
|
获取用户兑换历史
|
||||||
|
"""
|
||||||
|
async with async_db_session() as db:
|
||||||
|
usages = await coupon_usage_dao.get_by_user_id(db, user_id, limit)
|
||||||
|
|
||||||
|
history = []
|
||||||
|
for usage in usages:
|
||||||
|
# 获取兑换券信息
|
||||||
|
coupon = await coupon_dao.get(db, usage.coupon_id)
|
||||||
|
if coupon:
|
||||||
|
history.append({
|
||||||
|
'code': coupon.code,
|
||||||
|
'duration': usage.duration,
|
||||||
|
'used_at': usage.used_at
|
||||||
|
})
|
||||||
|
|
||||||
|
return history
|
||||||
262
backend/app/admin/service/dict_service.py
Executable file
262
backend/app/admin/service/dict_service.py
Executable file
@@ -0,0 +1,262 @@
|
|||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from backend.app.admin.crud.dict_crud import dict_dao
|
||||||
|
from backend.app.admin.model.dict import DictionaryEntry
|
||||||
|
from backend.app.admin.schema.dict import (
|
||||||
|
DictWordResponse, SimpleDictEntry, SimpleSense, SimpleDefinition, SimpleExample,
|
||||||
|
SimpleCrossReference, SimpleFrequency, SimplePronunciation, WordMetaData
|
||||||
|
)
|
||||||
|
from backend.common.exception import errors
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
from backend.database.redis import redis_client
|
||||||
|
|
||||||
|
|
||||||
|
class DictService:
|
||||||
|
# Redis缓存键前缀
|
||||||
|
WORD_LINK_CACHE_PREFIX = "dict:word_link:"
|
||||||
|
# 缓存过期时间(24小时)
|
||||||
|
CACHE_EXPIRE_TIME = 24 * 60 * 60
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_to_simple_response(entry: DictionaryEntry) -> DictWordResponse:
|
||||||
|
"""将完整的字典条目转换为简化的响应格式"""
|
||||||
|
if not entry.details:
|
||||||
|
# 如果没有详细信息,返回空的响应
|
||||||
|
return DictWordResponse()
|
||||||
|
|
||||||
|
details: WordMetaData = entry.details
|
||||||
|
|
||||||
|
# 转换字典条目列表
|
||||||
|
simple_dict_list = []
|
||||||
|
if details.dict_list:
|
||||||
|
for dict_entry in details.dict_list:
|
||||||
|
# 转换义项信息
|
||||||
|
simple_senses = []
|
||||||
|
if dict_entry.senses:
|
||||||
|
for sense in dict_entry.senses:
|
||||||
|
# 转换定义
|
||||||
|
simple_definitions = []
|
||||||
|
if sense.definitions:
|
||||||
|
for definition in sense.definitions:
|
||||||
|
simple_definitions.append(SimpleDefinition(
|
||||||
|
cn=definition.cn,
|
||||||
|
en=definition.en,
|
||||||
|
related_words=definition.related_words or []
|
||||||
|
))
|
||||||
|
|
||||||
|
# 转换例句
|
||||||
|
simple_examples = []
|
||||||
|
if sense.examples:
|
||||||
|
for example in sense.examples:
|
||||||
|
simple_examples.append(SimpleExample(
|
||||||
|
cn=example.cn,
|
||||||
|
en=example.en,
|
||||||
|
audio=example.audio
|
||||||
|
))
|
||||||
|
|
||||||
|
# 转换交叉引用
|
||||||
|
simple_cross_refs = []
|
||||||
|
if sense.cross_references:
|
||||||
|
for cross_ref in sense.cross_references:
|
||||||
|
simple_cross_refs.append(SimpleCrossReference(
|
||||||
|
text=cross_ref.text,
|
||||||
|
show_id=cross_ref.show_id,
|
||||||
|
sense_id=cross_ref.sense_id,
|
||||||
|
image_filename=cross_ref.image_filename
|
||||||
|
))
|
||||||
|
|
||||||
|
# 创建简化的义项
|
||||||
|
simple_sense = SimpleSense(
|
||||||
|
id=sense.id,
|
||||||
|
image=sense.image,
|
||||||
|
number=sense.number,
|
||||||
|
grammar=sense.grammar,
|
||||||
|
ref_hwd=sense.ref_hwd,
|
||||||
|
examples=simple_examples,
|
||||||
|
definitions=simple_definitions,
|
||||||
|
signpost_cn=sense.signpost_cn,
|
||||||
|
signpost_en=sense.signpost_en,
|
||||||
|
cross_references=simple_cross_refs,
|
||||||
|
countability = sense.countability or [],
|
||||||
|
)
|
||||||
|
simple_senses.append(simple_sense)
|
||||||
|
|
||||||
|
# 转换频率信息
|
||||||
|
simple_frequency = None
|
||||||
|
if dict_entry.frequency:
|
||||||
|
simple_frequency = SimpleFrequency(
|
||||||
|
level_tag=dict_entry.frequency.level_tag,
|
||||||
|
spoken_tag=dict_entry.frequency.spoken_tag,
|
||||||
|
written_tag=dict_entry.frequency.written_tag
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换发音信息
|
||||||
|
simple_pronunciations = None
|
||||||
|
if dict_entry.pronunciations:
|
||||||
|
simple_pronunciations = SimplePronunciation(
|
||||||
|
uk_ipa=dict_entry.pronunciations.uk_ipa,
|
||||||
|
us_ipa=dict_entry.pronunciations.us_ipa,
|
||||||
|
uk_audio=dict_entry.pronunciations.uk_audio,
|
||||||
|
us_audio=dict_entry.pronunciations.us_audio
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建简化的字典条目
|
||||||
|
simple_dict_entry = SimpleDictEntry(
|
||||||
|
part_of_speech=dict_entry.part_of_speech,
|
||||||
|
transitive=dict_entry.transitive,
|
||||||
|
senses=simple_senses,
|
||||||
|
frequency=simple_frequency,
|
||||||
|
pronunciations=simple_pronunciations
|
||||||
|
)
|
||||||
|
simple_dict_list.append(simple_dict_entry)
|
||||||
|
|
||||||
|
return DictWordResponse(
|
||||||
|
dict_list=simple_dict_list,
|
||||||
|
etymology=details.etymology
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _get_linked_word_from_cache(word: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
从Redis缓存中获取单词的链接单词
|
||||||
|
返回None表示没有链接单词,返回具体单词表示有链接单词
|
||||||
|
"""
|
||||||
|
cache_key = f"{DictService.WORD_LINK_CACHE_PREFIX}{word.lower()}"
|
||||||
|
try:
|
||||||
|
# 尝试从Redis获取缓存
|
||||||
|
cached_result = await redis_client.get(cache_key)
|
||||||
|
if cached_result is not None:
|
||||||
|
# 如果缓存的是"None"字符串,表示没有链接单词
|
||||||
|
if cached_result == "None":
|
||||||
|
return word
|
||||||
|
# 否则返回缓存的链接单词
|
||||||
|
return cached_result
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
# 如果Redis出错,不影响主流程
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _set_linked_word_to_cache(word: str, linked_word: Optional[str]) -> None:
|
||||||
|
"""
|
||||||
|
将单词的链接单词存入Redis缓存
|
||||||
|
linked_word为None表示没有链接单词
|
||||||
|
"""
|
||||||
|
cache_key = f"{DictService.WORD_LINK_CACHE_PREFIX}{word.lower()}"
|
||||||
|
try:
|
||||||
|
# 如果没有链接单词,存储"None"字符串
|
||||||
|
value = linked_word if linked_word is not None else "None"
|
||||||
|
await redis_client.setex(cache_key, DictService.CACHE_EXPIRE_TIME, value)
|
||||||
|
except Exception as e:
|
||||||
|
# 如果Redis出错,不影响主流程
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _get_linked_word_from_db(word: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
从数据库中获取单词的链接单词
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with async_db_session() as db:
|
||||||
|
entry = await dict_dao.get_by_word(db, word.strip().lower())
|
||||||
|
if not entry:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 检查 details 字段内的 json
|
||||||
|
if (entry.details and
|
||||||
|
not entry.details.dict_list and
|
||||||
|
entry.details.ref_link):
|
||||||
|
# dict_list 为空数组,检查 ref_link 的内容
|
||||||
|
ref_links = entry.details.ref_link
|
||||||
|
if ref_links:
|
||||||
|
# 获取第一个链接
|
||||||
|
first_link = ref_links[0]
|
||||||
|
# 检查是否包含 "LINK=" 前缀
|
||||||
|
if isinstance(first_link, str) and first_link.startswith("LINK="):
|
||||||
|
# 截取 "LINK=" 后的单词
|
||||||
|
referenced_word = first_link[5:] # 去掉 "LINK=" 前缀
|
||||||
|
if referenced_word:
|
||||||
|
return referenced_word
|
||||||
|
return word
|
||||||
|
except Exception as e:
|
||||||
|
# 如果出现任何错误,返回None表示未找到链接单词
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_linked_word(word: str) -> str:
|
||||||
|
"""
|
||||||
|
获取单词的链接单词(如果有)
|
||||||
|
首先检查Redis缓存,如果没有则查询数据库并更新缓存
|
||||||
|
"""
|
||||||
|
if not word or not word.strip():
|
||||||
|
return word
|
||||||
|
|
||||||
|
word = word.strip().lower()
|
||||||
|
|
||||||
|
# 首先尝试从缓存获取
|
||||||
|
linked_word = await DictService._get_linked_word_from_cache(word)
|
||||||
|
if linked_word is not None:
|
||||||
|
return linked_word
|
||||||
|
|
||||||
|
# 缓存未命中,从数据库查询
|
||||||
|
linked_word = await DictService._get_linked_word_from_db(word)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
await DictService._set_linked_word_to_cache(word, linked_word)
|
||||||
|
|
||||||
|
# 返回结果,如果没有链接单词则返回原单词
|
||||||
|
return linked_word if linked_word is not None else word
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_word_info(word: str) -> DictWordResponse:
|
||||||
|
"""根据单词获取字典信息"""
|
||||||
|
if not word or not word.strip():
|
||||||
|
raise errors.ForbiddenError(msg="单词不能为空")
|
||||||
|
|
||||||
|
word = word.strip().lower()
|
||||||
|
|
||||||
|
async with async_db_session() as db:
|
||||||
|
entry = await dict_dao.get_by_word(db, word)
|
||||||
|
if not entry:
|
||||||
|
raise errors.NotFoundError(msg=f"未找到单词 '{word}' 的释义")
|
||||||
|
|
||||||
|
# 使用新的缓存机制获取链接单词
|
||||||
|
linked_word = await DictService.get_linked_word(word)
|
||||||
|
if linked_word != word:
|
||||||
|
# 如果找到了引用的单词,返回其信息
|
||||||
|
referenced_entry = await dict_dao.get_by_word(db, linked_word)
|
||||||
|
if referenced_entry:
|
||||||
|
return DictService._convert_to_simple_response(referenced_entry)
|
||||||
|
|
||||||
|
return DictService._convert_to_simple_response(entry)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def check_word_exists(word: str) -> bool:
|
||||||
|
"""检查单词是否存在"""
|
||||||
|
if not word or not word.strip():
|
||||||
|
return False
|
||||||
|
|
||||||
|
word = word.strip().lower()
|
||||||
|
|
||||||
|
async with async_db_session() as db:
|
||||||
|
entry = await dict_dao.get_by_word(db, word)
|
||||||
|
return entry is not None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_audio_data(file_name: str) -> Optional[bytes]:
|
||||||
|
"""根据文件名获取音频数据"""
|
||||||
|
if not file_name or not file_name.strip():
|
||||||
|
return None
|
||||||
|
|
||||||
|
file_name = file_name.strip()
|
||||||
|
|
||||||
|
async with async_db_session() as db:
|
||||||
|
media = await dict_dao.get_media_by_filename(db, file_name)
|
||||||
|
if not media or media.file_type != 'audio':
|
||||||
|
return None
|
||||||
|
|
||||||
|
return media.file_data
|
||||||
|
|
||||||
|
|
||||||
|
dict_service = DictService()
|
||||||
66
backend/app/admin/service/feedback_service.py
Executable file
66
backend/app/admin/service/feedback_service.py
Executable file
@@ -0,0 +1,66 @@
|
|||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from backend.app.admin.crud.feedback_crud import feedback_dao
|
||||||
|
from backend.app.admin.model.feedback import Feedback
|
||||||
|
from backend.app.admin.schema.feedback import CreateFeedbackParam, UpdateFeedbackParam, FeedbackInfoSchema
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackService:
|
||||||
|
@staticmethod
|
||||||
|
async def create_feedback(user_id: int, obj_in: CreateFeedbackParam) -> FeedbackInfoSchema:
|
||||||
|
"""创建反馈"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
feedback = await feedback_dao.create(db, user_id, obj_in)
|
||||||
|
return FeedbackInfoSchema.model_validate(feedback)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_feedback(feedback_id: int) -> Optional[FeedbackInfoSchema]:
|
||||||
|
"""获取反馈详情"""
|
||||||
|
async with async_db_session() as db:
|
||||||
|
feedback = await feedback_dao.get(db, feedback_id)
|
||||||
|
if feedback:
|
||||||
|
return FeedbackInfoSchema.model_validate(feedback)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_feedback_list(
|
||||||
|
user_id: Optional[int] = None,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
category: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: Optional[int] = None
|
||||||
|
) -> List[FeedbackInfoSchema]:
|
||||||
|
"""获取反馈列表"""
|
||||||
|
async with async_db_session() as db:
|
||||||
|
feedbacks = await feedback_dao.get_list(db, user_id, status, category, limit, offset)
|
||||||
|
return [FeedbackInfoSchema.model_validate(feedback) for feedback in feedbacks]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def update_feedback(feedback_id: int, obj_in: UpdateFeedbackParam) -> bool:
|
||||||
|
"""更新反馈"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
count = await feedback_dao.update(db, feedback_id, obj_in)
|
||||||
|
return count > 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def delete_feedback(feedback_id: int) -> bool:
|
||||||
|
"""删除反馈"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
count = await feedback_dao.delete(db, feedback_id)
|
||||||
|
return count > 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def count_feedbacks(
|
||||||
|
user_id: Optional[int] = None,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
category: Optional[str] = None
|
||||||
|
) -> int:
|
||||||
|
"""统计反馈数量"""
|
||||||
|
async with async_db_session() as db:
|
||||||
|
return await feedback_dao.count(db, user_id, status, category)
|
||||||
|
|
||||||
|
|
||||||
|
feedback_service = FeedbackService()
|
||||||
419
backend/app/admin/service/file_service.py
Executable file
419
backend/app/admin/service/file_service.py
Executable file
@@ -0,0 +1,419 @@
|
|||||||
|
import io
|
||||||
|
import imghdr
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from fastapi import UploadFile
|
||||||
|
from PIL import Image as PILImage, ExifTags
|
||||||
|
|
||||||
|
from backend.app.admin.crud.file_crud import file_dao
|
||||||
|
from backend.app.admin.model.file import File
|
||||||
|
from backend.app.admin.schema.file import AddFileParam, FileUploadResponse, UpdateFileParam, FileMetadata
|
||||||
|
from backend.app.ai.schema.image import ColorMode, ImageMetadata, ImageFormat
|
||||||
|
from backend.app.admin.service.file_storage import get_storage_provider, calculate_file_hash
|
||||||
|
from backend.common.exception import errors
|
||||||
|
from backend.core.conf import settings
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
|
||||||
|
|
||||||
|
class FileService:
|
||||||
|
@staticmethod
|
||||||
|
def is_image_file(content_type: str, file_content: bytes, file_name: str) -> bool:
|
||||||
|
"""判断是否为图片文件"""
|
||||||
|
# 首先检查文件扩展名是否在允许的图片类型列表中
|
||||||
|
if file_name:
|
||||||
|
file_ext = file_name.split('.')[-1].lower()
|
||||||
|
if file_ext in settings.UPLOAD_IMAGE_EXT_INCLUDE:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 然后检查content_type
|
||||||
|
if content_type and content_type.startswith('image/'):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 最后通过文件内容检测
|
||||||
|
image_format = imghdr.what(None, h=file_content)
|
||||||
|
return image_format is not None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_image_file(file_name: str) -> None:
|
||||||
|
"""验证图片文件类型是否被允许"""
|
||||||
|
if not file_name:
|
||||||
|
return
|
||||||
|
|
||||||
|
file_ext = file_name.split('.')[-1].lower()
|
||||||
|
if file_ext and file_ext not in settings.UPLOAD_IMAGE_EXT_INCLUDE:
|
||||||
|
raise errors.ForbiddenError(msg=f'[{file_ext}] 此图片格式暂不支持')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def upload_file(
|
||||||
|
file: UploadFile,
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
) -> FileUploadResponse:
|
||||||
|
"""上传文件"""
|
||||||
|
# 读取文件内容
|
||||||
|
content = await file.read()
|
||||||
|
await file.seek(0) # 重置文件指针
|
||||||
|
|
||||||
|
storage_type = settings.DEFAULT_STORAGE_TYPE
|
||||||
|
# 计算文件哈希
|
||||||
|
file_hash = calculate_file_hash(content)
|
||||||
|
|
||||||
|
# 检查文件是否已存在
|
||||||
|
async with async_db_session() as db:
|
||||||
|
existing_file = await file_dao.get_by_hash(db, file_hash)
|
||||||
|
if existing_file:
|
||||||
|
return FileUploadResponse(
|
||||||
|
id=str(existing_file.id),
|
||||||
|
file_hash=existing_file.file_hash,
|
||||||
|
file_name=existing_file.file_name,
|
||||||
|
content_type=existing_file.content_type,
|
||||||
|
file_size=existing_file.file_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取存储提供者
|
||||||
|
storage_provider = get_storage_provider(storage_type)
|
||||||
|
|
||||||
|
# 保存文件到存储
|
||||||
|
storage_path = None
|
||||||
|
file_data = None
|
||||||
|
|
||||||
|
if storage_type == "database":
|
||||||
|
# 数据库存储,将文件数据保存到数据库
|
||||||
|
file_data = content
|
||||||
|
else:
|
||||||
|
# 其他存储方式,保存文件并记录路径
|
||||||
|
storage_path = await storage_provider.save(
|
||||||
|
file_id=0, # 临时ID,后续会更新
|
||||||
|
content=content,
|
||||||
|
file_name=file.filename or "unnamed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建文件记录
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
# 创建文件元数据
|
||||||
|
file_metadata_dict = {
|
||||||
|
"file_name": file.filename,
|
||||||
|
"content_type": file.content_type,
|
||||||
|
"file_size": len(content),
|
||||||
|
"created_at": datetime.now(),
|
||||||
|
"updated_at": datetime.now(),
|
||||||
|
"extra": metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
# 验证图片文件类型
|
||||||
|
if file.filename:
|
||||||
|
FileService.validate_image_file(file.filename)
|
||||||
|
|
||||||
|
# 如果是图片文件,提取图片元数据
|
||||||
|
if FileService.is_image_file(file.content_type or "", content, file.filename or ""):
|
||||||
|
try:
|
||||||
|
additional_info = {
|
||||||
|
"file_name": file.filename,
|
||||||
|
"content_type": file.content_type,
|
||||||
|
"file_size": len(content),
|
||||||
|
}
|
||||||
|
image_metadata = file_service.extract_image_metadata(content, additional_info)
|
||||||
|
file_metadata_dict["image_info"] = image_metadata.dict()
|
||||||
|
except Exception as e:
|
||||||
|
# 如果提取图片元数据失败,记录错误但不中断上传
|
||||||
|
file_metadata_dict["extra"] = {
|
||||||
|
**(metadata or {}),
|
||||||
|
"image_metadata_error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
file_metadata = FileMetadata(**file_metadata_dict)
|
||||||
|
|
||||||
|
# 创建文件参数
|
||||||
|
file_params = AddFileParam(
|
||||||
|
file_hash=file_hash,
|
||||||
|
file_name=file.filename or "unnamed",
|
||||||
|
content_type=file.content_type,
|
||||||
|
file_size=len(content),
|
||||||
|
storage_type=storage_type,
|
||||||
|
storage_path=storage_path,
|
||||||
|
metadata_info=file_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存到数据库
|
||||||
|
db_file = await file_dao.create(db, file_params)
|
||||||
|
|
||||||
|
# 如果是本地存储或其他存储,需要更新文件ID
|
||||||
|
if storage_type != "database":
|
||||||
|
# 重新保存文件,使用真实的文件ID
|
||||||
|
storage_path = await storage_provider.save(
|
||||||
|
file_id=db_file.id,
|
||||||
|
content=content,
|
||||||
|
file_name=file.filename or "unnamed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新数据库中的存储路径
|
||||||
|
update_params = UpdateFileParam(
|
||||||
|
storage_path=storage_path
|
||||||
|
)
|
||||||
|
await file_dao.update(db, db_file.id, update_params)
|
||||||
|
db_file.storage_path = storage_path
|
||||||
|
|
||||||
|
# 如果是数据库存储,更新文件数据
|
||||||
|
if storage_type == "database":
|
||||||
|
db_file.file_data = content
|
||||||
|
await db.flush() # 确保将file_data保存到数据库
|
||||||
|
|
||||||
|
return FileUploadResponse(
|
||||||
|
id=str(db_file.id),
|
||||||
|
file_hash=db_file.file_hash,
|
||||||
|
file_name=db_file.file_name,
|
||||||
|
content_type=db_file.content_type,
|
||||||
|
file_size=db_file.file_size
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def upload_file_with_content_type(
|
||||||
|
file,
|
||||||
|
content_type: str,
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
) -> FileUploadResponse:
|
||||||
|
"""上传文件并指定content_type"""
|
||||||
|
# 读取文件内容
|
||||||
|
content = await file.read()
|
||||||
|
await file.seek(0) # 重置文件指针
|
||||||
|
|
||||||
|
storage_type = settings.DEFAULT_STORAGE_TYPE
|
||||||
|
# 计算文件哈希
|
||||||
|
file_hash = calculate_file_hash(content)
|
||||||
|
|
||||||
|
# 检查文件是否已存在
|
||||||
|
async with async_db_session() as db:
|
||||||
|
existing_file = await file_dao.get_by_hash(db, file_hash)
|
||||||
|
if existing_file:
|
||||||
|
return FileUploadResponse(
|
||||||
|
id=str(existing_file.id),
|
||||||
|
file_hash=existing_file.file_hash,
|
||||||
|
file_name=existing_file.file_name,
|
||||||
|
content_type=existing_file.content_type,
|
||||||
|
file_size=existing_file.file_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取存储提供者
|
||||||
|
storage_provider = get_storage_provider(storage_type)
|
||||||
|
|
||||||
|
# 保存文件到存储
|
||||||
|
storage_path = None
|
||||||
|
file_data = None
|
||||||
|
|
||||||
|
if storage_type == "database":
|
||||||
|
# 数据库存储,将文件数据保存到数据库
|
||||||
|
file_data = content
|
||||||
|
else:
|
||||||
|
# 其他存储方式,保存文件并记录路径
|
||||||
|
storage_path = await storage_provider.save(
|
||||||
|
file_id=0, # 临时ID,后续会更新
|
||||||
|
content=content,
|
||||||
|
file_name=getattr(file, 'filename', 'unnamed') or "unnamed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建文件记录
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
# 创建文件元数据
|
||||||
|
file_metadata_dict = {
|
||||||
|
"file_name": getattr(file, 'filename', 'unnamed'),
|
||||||
|
"content_type": content_type,
|
||||||
|
"file_size": len(content),
|
||||||
|
"created_at": datetime.now(),
|
||||||
|
"updated_at": datetime.now(),
|
||||||
|
"extra": metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
# 验证图片文件类型
|
||||||
|
file_name = getattr(file, 'filename', 'unnamed')
|
||||||
|
if file_name:
|
||||||
|
FileService.validate_image_file(file_name)
|
||||||
|
|
||||||
|
# 如果是图片文件,提取图片元数据
|
||||||
|
if FileService.is_image_file(content_type or "", content, file_name or ""):
|
||||||
|
try:
|
||||||
|
additional_info = {
|
||||||
|
"file_name": file_name,
|
||||||
|
"content_type": content_type,
|
||||||
|
"file_size": len(content),
|
||||||
|
}
|
||||||
|
image_metadata = file_service.extract_image_metadata(content, additional_info)
|
||||||
|
file_metadata_dict["image_info"] = image_metadata.dict()
|
||||||
|
except Exception as e:
|
||||||
|
# 如果提取图片元数据失败,记录错误但不中断上传
|
||||||
|
file_metadata_dict["extra"] = {
|
||||||
|
**(metadata or {}),
|
||||||
|
"image_metadata_error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
file_metadata = FileMetadata(**file_metadata_dict)
|
||||||
|
|
||||||
|
# 创建文件参数
|
||||||
|
file_params = AddFileParam(
|
||||||
|
file_hash=file_hash,
|
||||||
|
file_name=file_name or "unnamed",
|
||||||
|
content_type=content_type,
|
||||||
|
file_size=len(content),
|
||||||
|
storage_type=storage_type,
|
||||||
|
storage_path=storage_path,
|
||||||
|
metadata_info=file_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存到数据库
|
||||||
|
db_file = await file_dao.create(db, file_params)
|
||||||
|
|
||||||
|
# 如果是本地存储或其他存储,需要更新文件ID
|
||||||
|
if storage_type != "database":
|
||||||
|
# 重新保存文件,使用真实的文件ID
|
||||||
|
storage_path = await storage_provider.save(
|
||||||
|
file_id=db_file.id,
|
||||||
|
content=content,
|
||||||
|
file_name=file_name or "unnamed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新数据库中的存储路径
|
||||||
|
update_params = UpdateFileParam(
|
||||||
|
storage_path=storage_path
|
||||||
|
)
|
||||||
|
await file_dao.update(db, db_file.id, update_params)
|
||||||
|
db_file.storage_path = storage_path
|
||||||
|
|
||||||
|
# 如果是数据库存储,更新文件数据
|
||||||
|
if storage_type == "database":
|
||||||
|
db_file.file_data = content
|
||||||
|
await db.flush() # 确保将file_data保存到数据库
|
||||||
|
|
||||||
|
return FileUploadResponse(
|
||||||
|
id=str(db_file.id),
|
||||||
|
file_hash=db_file.file_hash,
|
||||||
|
file_name=db_file.file_name,
|
||||||
|
content_type=db_file.content_type,
|
||||||
|
file_size=db_file.file_size
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_file(file_id: int) -> Optional[File]:
|
||||||
|
"""获取文件信息"""
|
||||||
|
async with async_db_session() as db:
|
||||||
|
return await file_dao.get(db, file_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def download_file(file_id: int) -> tuple[bytes, str, str]:
|
||||||
|
"""下载文件"""
|
||||||
|
async with async_db_session() as db:
|
||||||
|
db_file = await file_dao.get(db, file_id)
|
||||||
|
if not db_file:
|
||||||
|
raise errors.NotFoundError(msg="文件不存在")
|
||||||
|
|
||||||
|
content = b""
|
||||||
|
storage_provider = get_storage_provider(db_file.storage_type)
|
||||||
|
|
||||||
|
if db_file.storage_type == "database":
|
||||||
|
# 从数据库获取文件数据
|
||||||
|
content = db_file.file_data or b""
|
||||||
|
else:
|
||||||
|
# 从存储中读取文件
|
||||||
|
content = await storage_provider.read(file_id, db_file.storage_path or "")
|
||||||
|
|
||||||
|
return content, db_file.file_name, db_file.content_type or "application/octet-stream"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def delete_file(file_id: int) -> bool:
|
||||||
|
"""删除文件"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
db_file = await file_dao.get(db, file_id)
|
||||||
|
if not db_file:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 删除存储中的文件
|
||||||
|
if db_file.storage_type != "database":
|
||||||
|
storage_provider = get_storage_provider(db_file.storage_type)
|
||||||
|
await storage_provider.delete(file_id, db_file.storage_path or "")
|
||||||
|
|
||||||
|
# 删除数据库记录
|
||||||
|
result = await file_dao.delete(db, file_id)
|
||||||
|
return result > 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_file_by_hash(db, file_hash: str) -> Optional[File]:
|
||||||
|
"""通过哈希值获取文件"""
|
||||||
|
return await file_dao.get_by_hash(db, file_hash)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def detect_image_format(image_bytes: bytes) -> ImageFormat:
|
||||||
|
"""通过二进制数据检测图片格式"""
|
||||||
|
# 使用imghdr识别基础格式
|
||||||
|
format_str = imghdr.what(None, h=image_bytes)
|
||||||
|
|
||||||
|
# 映射到枚举类型
|
||||||
|
format_mapping = {
|
||||||
|
'jpeg': ImageFormat.JPEG,
|
||||||
|
'jpg': ImageFormat.JPEG,
|
||||||
|
'png': ImageFormat.PNG,
|
||||||
|
'gif': ImageFormat.GIF,
|
||||||
|
'bmp': ImageFormat.BMP,
|
||||||
|
'webp': ImageFormat.WEBP,
|
||||||
|
'tiff': ImageFormat.TIFF,
|
||||||
|
'svg': ImageFormat.SVG
|
||||||
|
}
|
||||||
|
|
||||||
|
return format_mapping.get(format_str, ImageFormat.UNKNOWN)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_image_metadata(image_bytes: bytes, additional_info: Dict[str, Any] = None) -> ImageMetadata:
|
||||||
|
"""从图片二进制数据中提取元数据"""
|
||||||
|
try:
|
||||||
|
with PILImage.open(io.BytesIO(image_bytes)) as img:
|
||||||
|
# 获取基础信息
|
||||||
|
width, height = img.size
|
||||||
|
color_mode = ColorMode(img.mode) if img.mode in ColorMode.__members__.values() else ColorMode.UNKNOWN
|
||||||
|
|
||||||
|
# 获取EXIF数据
|
||||||
|
exif_data = {}
|
||||||
|
if hasattr(img, '_getexif') and img._getexif():
|
||||||
|
for tag, value in img._getexif().items():
|
||||||
|
decoded_tag = ExifTags.TAGS.get(tag, tag)
|
||||||
|
# 特殊处理日期时间
|
||||||
|
if decoded_tag in ['DateTime', 'DateTimeOriginal', 'DateTimeDigitized']:
|
||||||
|
try:
|
||||||
|
value = datetime.strptime(value, "%Y:%m:%d %H:%M:%S").isoformat()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
exif_data[decoded_tag] = value
|
||||||
|
|
||||||
|
# 获取颜色通道数
|
||||||
|
channels = len(img.getbands())
|
||||||
|
|
||||||
|
# 尝试获取DPI
|
||||||
|
dpi = img.info.get('dpi')
|
||||||
|
|
||||||
|
# 创建元数据对象
|
||||||
|
metadata = ImageMetadata(
|
||||||
|
format=file_service.detect_image_format(image_bytes),
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
color_mode=color_mode,
|
||||||
|
file_size=len(image_bytes),
|
||||||
|
channels=channels,
|
||||||
|
dpi=dpi,
|
||||||
|
exif=exif_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加额外信息
|
||||||
|
if additional_info:
|
||||||
|
for key, value in additional_info.items():
|
||||||
|
if hasattr(metadata, key):
|
||||||
|
setattr(metadata, key, value)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
except Exception as e:
|
||||||
|
# 无法解析图片时返回基础元数据
|
||||||
|
return ImageMetadata(
|
||||||
|
format=file_service.detect_image_format(image_bytes),
|
||||||
|
width=0,
|
||||||
|
height=0,
|
||||||
|
color_mode=ColorMode.UNKNOWN,
|
||||||
|
file_size=len(image_bytes),
|
||||||
|
error=f"Metadata extraction failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
file_service = FileService()
|
||||||
114
backend/app/admin/service/file_storage.py
Executable file
114
backend/app/admin/service/file_storage.py
Executable file
@@ -0,0 +1,114 @@
|
|||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
import aiofiles
|
||||||
|
from fastapi import UploadFile
|
||||||
|
|
||||||
|
from backend.core.conf import settings
|
||||||
|
from backend.core.path_conf import UPLOAD_DIR
|
||||||
|
|
||||||
|
|
||||||
|
class StorageProvider(ABC):
|
||||||
|
"""存储提供者抽象基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def save(self, file_id: int, content: bytes, file_name: str) -> str:
|
||||||
|
"""保存文件"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def read(self, file_id: int, storage_path: str) -> bytes:
|
||||||
|
"""读取文件"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete(self, file_id: int, storage_path: str) -> bool:
|
||||||
|
"""删除文件"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseStorage(StorageProvider):
|
||||||
|
"""数据库存储提供者"""
|
||||||
|
|
||||||
|
async def save(self, file_id: int, content: bytes, file_name: str) -> str:
|
||||||
|
"""数据库存储不需要实际保存文件,直接返回空字符串"""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def read(self, file_id: int, storage_path: str) -> bytes:
|
||||||
|
"""数据库存储不需要读取文件"""
|
||||||
|
return b""
|
||||||
|
|
||||||
|
async def delete(self, file_id: int, storage_path: str) -> bool:
|
||||||
|
"""数据库存储不需要删除文件"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class LocalStorage(StorageProvider):
|
||||||
|
"""本地文件系统存储提供者"""
|
||||||
|
|
||||||
|
def __init__(self, base_path: str = settings.STORAGE_PATH):
|
||||||
|
self.base_path = base_path
|
||||||
|
|
||||||
|
def _get_path(self, file_id: int, file_name: str) -> str:
|
||||||
|
"""构建文件路径"""
|
||||||
|
# 使用文件ID作为目录名,避免单个目录下文件过多
|
||||||
|
dir_name = str(file_id // 1000)
|
||||||
|
file_dir = os.path.join(self.base_path, dir_name)
|
||||||
|
os.makedirs(file_dir, exist_ok=True)
|
||||||
|
return os.path.join(file_dir, f"{file_id}_{file_name}")
|
||||||
|
|
||||||
|
async def save(self, file_id: int, content: bytes, file_name: str) -> str:
|
||||||
|
"""保存文件到本地"""
|
||||||
|
path = self._get_path(file_id, file_name)
|
||||||
|
async with aiofiles.open(path, 'wb') as f:
|
||||||
|
await f.write(content)
|
||||||
|
return path
|
||||||
|
|
||||||
|
async def read(self, file_id: int, storage_path: str) -> bytes:
|
||||||
|
"""从本地读取文件"""
|
||||||
|
async with aiofiles.open(storage_path, 'rb') as f:
|
||||||
|
return await f.read()
|
||||||
|
|
||||||
|
async def delete(self, file_id: int, storage_path: str) -> bool:
|
||||||
|
"""从本地删除文件"""
|
||||||
|
try:
|
||||||
|
os.remove(storage_path)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class S3Storage(StorageProvider):
|
||||||
|
"""AWS S3存储提供者(占位实现)"""
|
||||||
|
|
||||||
|
async def save(self, file_id: int, content: bytes, file_name: str) -> str:
|
||||||
|
"""保存文件到S3"""
|
||||||
|
# 实际实现需要使用boto3等库连接到S3
|
||||||
|
# 这里仅作为示例展示接口
|
||||||
|
return f"s3://bucket-name/{file_id}/{file_name}"
|
||||||
|
|
||||||
|
async def read(self, file_id: int, storage_path: str) -> bytes:
|
||||||
|
"""从S3读取文件"""
|
||||||
|
# 实际实现需要使用boto3等库连接到S3
|
||||||
|
return b""
|
||||||
|
|
||||||
|
async def delete(self, file_id: int, storage_path: str) -> bool:
|
||||||
|
"""从S3删除文件"""
|
||||||
|
# 实际实现需要使用boto3等库连接到S3
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_storage_provider(provider_type: str) -> StorageProvider:
|
||||||
|
"""根据配置获取存储提供者"""
|
||||||
|
if provider_type == "local":
|
||||||
|
return LocalStorage()
|
||||||
|
elif provider_type == "s3":
|
||||||
|
return S3Storage()
|
||||||
|
else: # 默认使用数据库存储
|
||||||
|
return DatabaseStorage()
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_file_hash(content: bytes) -> str:
|
||||||
|
"""计算文件的SHA256哈希值"""
|
||||||
|
return hashlib.sha256(content).hexdigest()
|
||||||
146
backend/app/admin/service/notification_service.py
Executable file
146
backend/app/admin/service/notification_service.py
Executable file
@@ -0,0 +1,146 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from backend.app.admin.crud.notification_crud import notification_dao, user_notification_dao
|
||||||
|
from backend.app.admin.model.notification import Notification, UserNotification
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
from backend.common.exception import errors
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationService:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_notification(title: str, content: str, image_url: Optional[str] = None,
|
||||||
|
created_by: Optional[int] = None) -> Notification:
|
||||||
|
"""
|
||||||
|
创建消息通知
|
||||||
|
"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
notification_data = {
|
||||||
|
'title': title,
|
||||||
|
'content': content,
|
||||||
|
'image_url': image_url,
|
||||||
|
'created_by': created_by
|
||||||
|
}
|
||||||
|
|
||||||
|
notification = await notification_dao.create_notification(db, notification_data)
|
||||||
|
return notification
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def send_notification_to_user(notification_id: int, user_id: int) -> UserNotification:
|
||||||
|
"""
|
||||||
|
发送通知给指定用户
|
||||||
|
"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
# 检查通知是否存在
|
||||||
|
notification = await notification_dao.get(db, notification_id)
|
||||||
|
if not notification:
|
||||||
|
raise errors.RequestError(msg='通知不存在')
|
||||||
|
|
||||||
|
# 创建用户通知关联
|
||||||
|
user_notification_data = {
|
||||||
|
'notification_id': notification_id,
|
||||||
|
'user_id': user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
user_notification = await user_notification_dao.create_user_notification(db, user_notification_data)
|
||||||
|
return user_notification
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def send_notification_to_users(notification_id: int, user_ids: List[int]) -> List[UserNotification]:
|
||||||
|
"""
|
||||||
|
批量发送通知给多个用户
|
||||||
|
"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
# 检查通知是否存在
|
||||||
|
notification = await notification_dao.get(db, notification_id)
|
||||||
|
if not notification:
|
||||||
|
raise errors.RequestError(msg='通知不存在')
|
||||||
|
|
||||||
|
# 创建用户通知关联列表
|
||||||
|
user_notifications_data = [
|
||||||
|
{
|
||||||
|
'notification_id': notification_id,
|
||||||
|
'user_id': user_id
|
||||||
|
}
|
||||||
|
for user_id in user_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
user_notifications = await user_notification_dao.create_user_notifications(db, user_notifications_data)
|
||||||
|
return user_notifications
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_notifications(user_id: int, limit: int = 100) -> List[dict]:
|
||||||
|
"""
|
||||||
|
获取用户的通知列表(包含通知详情)
|
||||||
|
"""
|
||||||
|
async with async_db_session() as db:
|
||||||
|
user_notifications = await user_notification_dao.get_user_notifications(db, user_id, limit)
|
||||||
|
|
||||||
|
notifications = []
|
||||||
|
for user_notification in user_notifications:
|
||||||
|
# 获取通知详情
|
||||||
|
notification = await notification_dao.get(db, user_notification.notification_id)
|
||||||
|
if notification:
|
||||||
|
notifications.append({
|
||||||
|
'id': user_notification.id,
|
||||||
|
'notification_id': notification.id,
|
||||||
|
'title': notification.title,
|
||||||
|
'content': notification.content,
|
||||||
|
'image_url': notification.image_url,
|
||||||
|
'is_read': user_notification.is_read,
|
||||||
|
'received_at': user_notification.received_at,
|
||||||
|
'read_at': user_notification.read_at
|
||||||
|
})
|
||||||
|
|
||||||
|
return notifications
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_unread_notifications(user_id: int, limit: int = 100) -> List[dict]:
|
||||||
|
"""
|
||||||
|
获取用户未读通知列表(包含通知详情)
|
||||||
|
"""
|
||||||
|
async with async_db_session() as db:
|
||||||
|
user_notifications = await user_notification_dao.get_unread_notifications(db, user_id, limit)
|
||||||
|
|
||||||
|
notifications = []
|
||||||
|
for user_notification in user_notifications:
|
||||||
|
# 获取通知详情
|
||||||
|
notification = await notification_dao.get(db, user_notification.notification_id)
|
||||||
|
if notification:
|
||||||
|
notifications.append({
|
||||||
|
'id': user_notification.id,
|
||||||
|
'notification_id': notification.id,
|
||||||
|
'title': notification.title,
|
||||||
|
'content': notification.content,
|
||||||
|
'image_url': notification.image_url,
|
||||||
|
'received_at': user_notification.received_at
|
||||||
|
})
|
||||||
|
|
||||||
|
return notifications
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def mark_notification_as_read(user_notification_id: int, user_id: int) -> bool:
|
||||||
|
"""
|
||||||
|
标记通知为已读
|
||||||
|
"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
# 验证用户权限
|
||||||
|
user_notification = await user_notification_dao.get_user_notification_by_id(db, user_notification_id)
|
||||||
|
if not user_notification:
|
||||||
|
raise errors.RequestError(msg='通知不存在')
|
||||||
|
|
||||||
|
if user_notification.user_id != user_id:
|
||||||
|
raise errors.ForbiddenError(msg='无权限操作此通知')
|
||||||
|
|
||||||
|
success = await user_notification_dao.mark_as_read(db, user_notification_id)
|
||||||
|
return success
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_unread_count(user_id: int) -> int:
|
||||||
|
"""
|
||||||
|
获取用户未读通知数量
|
||||||
|
"""
|
||||||
|
async with async_db_session() as db:
|
||||||
|
count = await user_notification_dao.get_unread_count(db, user_id)
|
||||||
|
return count
|
||||||
180
backend/app/admin/service/points_service.py
Normal file
180
backend/app/admin/service/points_service.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from backend.app.admin.crud.points_crud import points_dao, points_log_dao
|
||||||
|
from backend.app.admin.model.points import Points
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
|
||||||
|
|
||||||
|
class PointsService:
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_points(user_id: int) -> Optional[Points]:
|
||||||
|
"""
|
||||||
|
获取用户积分账户信息(会检查并清空过期积分)
|
||||||
|
"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
# 获取当前积分余额(清空前)
|
||||||
|
points_account_before = await points_dao.get_by_user_id(db, user_id)
|
||||||
|
balance_before = points_account_before.balance if points_account_before else 0
|
||||||
|
|
||||||
|
# 检查并清空过期积分
|
||||||
|
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
|
||||||
|
|
||||||
|
# 如果清空了过期积分,记录日志
|
||||||
|
if expired_cleared and balance_before > 0:
|
||||||
|
await points_log_dao.add_log(db, {
|
||||||
|
"user_id": user_id,
|
||||||
|
"action": "expire_clear",
|
||||||
|
"amount": balance_before, # 记录清空前的积分数量
|
||||||
|
"balance_after": 0,
|
||||||
|
"details": {"message": "过期积分已清空", "cleared_amount": balance_before}
|
||||||
|
})
|
||||||
|
|
||||||
|
return await points_dao.get_by_user_id(db, user_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_balance(user_id: int) -> int:
|
||||||
|
"""
|
||||||
|
获取用户积分余额(会检查并清空过期积分)
|
||||||
|
"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
# 获取当前积分余额(清空前)
|
||||||
|
points_account_before = await points_dao.get_by_user_id(db, user_id)
|
||||||
|
balance_before = points_account_before.balance if points_account_before else 0
|
||||||
|
|
||||||
|
# 检查并清空过期积分
|
||||||
|
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
|
||||||
|
|
||||||
|
# 如果清空了过期积分,记录日志
|
||||||
|
if expired_cleared and balance_before > 0:
|
||||||
|
await points_log_dao.add_log(db, {
|
||||||
|
"user_id": user_id,
|
||||||
|
"action": "expire_clear",
|
||||||
|
"balance_before": balance_before,
|
||||||
|
"balance_after": 0,
|
||||||
|
"details": {"message": "过期积分已清空", "cleared_amount": balance_before}
|
||||||
|
})
|
||||||
|
|
||||||
|
return await points_dao.get_balance(db, user_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def add_points(user_id: int, amount: int, extend_expiration: bool = False, related_id: Optional[int] = None, details: Optional[dict] = None) -> bool:
|
||||||
|
"""
|
||||||
|
为用户增加积分
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
amount: 增加的积分数量
|
||||||
|
extend_expiration: 是否自动延期过期时间
|
||||||
|
related_id: 关联ID(可选)
|
||||||
|
details: 附加信息(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功
|
||||||
|
"""
|
||||||
|
if amount <= 0:
|
||||||
|
raise ValueError("积分数量必须大于0")
|
||||||
|
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
# 获取当前余额以记录日志
|
||||||
|
points_account = await points_dao.get_by_user_id(db, user_id)
|
||||||
|
if not points_account:
|
||||||
|
points_account = await points_dao.create_user_points(db, user_id)
|
||||||
|
|
||||||
|
current_balance = points_account.balance
|
||||||
|
|
||||||
|
# 原子性增加积分(可能延期过期时间)
|
||||||
|
result = await points_dao.add_points_atomic(db, user_id, amount, extend_expiration)
|
||||||
|
if not result:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 准备日志详情
|
||||||
|
log_details = details or {}
|
||||||
|
if extend_expiration:
|
||||||
|
log_details["expiration_extended"] = True
|
||||||
|
log_details["extension_days"] = 30
|
||||||
|
|
||||||
|
# 记录积分变动日志
|
||||||
|
new_balance = current_balance + amount
|
||||||
|
await points_log_dao.add_log(db, {
|
||||||
|
"user_id": user_id,
|
||||||
|
"action": "earn",
|
||||||
|
"amount": amount,
|
||||||
|
"balance_after": new_balance,
|
||||||
|
"related_id": related_id,
|
||||||
|
"details": log_details
|
||||||
|
})
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def deduct_points(user_id: int, amount: int, related_id: Optional[int] = None, details: Optional[dict] = None) -> bool:
|
||||||
|
"""
|
||||||
|
扣减用户积分(会检查并清空过期积分)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
amount: 扣减的积分数量
|
||||||
|
related_id: 关联ID(可选)
|
||||||
|
details: 附加信息(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功
|
||||||
|
"""
|
||||||
|
if amount <= 0:
|
||||||
|
raise ValueError("积分数量必须大于0")
|
||||||
|
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
# 获取当前积分余额(清空前)
|
||||||
|
points_account_before = await points_dao.get_by_user_id(db, user_id)
|
||||||
|
if not points_account_before:
|
||||||
|
return False
|
||||||
|
|
||||||
|
balance_before = points_account_before.balance
|
||||||
|
|
||||||
|
# 检查并清空过期积分
|
||||||
|
expired_cleared = await points_dao.check_and_clear_expired_points(db, user_id)
|
||||||
|
|
||||||
|
# 如果清空了过期积分,记录日志
|
||||||
|
if expired_cleared and balance_before > 0:
|
||||||
|
await points_log_dao.add_log(db, {
|
||||||
|
"user_id": user_id,
|
||||||
|
"action": "expire_clear",
|
||||||
|
"amount": balance_before, # 记录清空前的积分数量
|
||||||
|
"balance_after": 0,
|
||||||
|
"details": {"message": "过期积分已清空", "cleared_amount": balance_before}
|
||||||
|
})
|
||||||
|
|
||||||
|
# 重新获取账户信息(可能已被清空)
|
||||||
|
points_account = await points_dao.get_by_user_id(db, user_id)
|
||||||
|
if not points_account or points_account.balance < amount:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_balance = points_account.balance
|
||||||
|
|
||||||
|
# 原子性扣减积分
|
||||||
|
result = await points_dao.deduct_points_atomic(db, user_id, amount)
|
||||||
|
if not result:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 记录积分变动日志
|
||||||
|
new_balance = current_balance - amount
|
||||||
|
await points_log_dao.add_log(db, {
|
||||||
|
"user_id": user_id,
|
||||||
|
"action": "spend",
|
||||||
|
"amount": amount,
|
||||||
|
"balance_after": new_balance,
|
||||||
|
"related_id": related_id,
|
||||||
|
"details": details
|
||||||
|
})
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def initialize_user_points(user_id: int) -> Points:
|
||||||
|
"""
|
||||||
|
为新用户初始化积分账户
|
||||||
|
"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
points_account = await points_dao.get_by_user_id(db, user_id)
|
||||||
|
if not points_account:
|
||||||
|
points_account = await points_dao.create_user_points(db, user_id)
|
||||||
|
return points_account
|
||||||
233
backend/app/admin/service/refund_service.py
Executable file
233
backend/app/admin/service/refund_service.py
Executable file
@@ -0,0 +1,233 @@
|
|||||||
|
from backend.app.admin.crud.freeze_log_crud import freeze_log_dao
|
||||||
|
from backend.app.admin.crud.user_account_crud import user_account_dao
|
||||||
|
from backend.app.admin.crud.order_crud import order_dao
|
||||||
|
from backend.app.admin.crud.usage_log_crud import usage_log_dao
|
||||||
|
from backend.app.admin.model.order import FreezeLog
|
||||||
|
|
||||||
|
from backend.app.admin.model import UserAccount
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
from wechatpy.pay import WeChatPay
|
||||||
|
from backend.core.conf import settings
|
||||||
|
from backend.common.log import log as logger
|
||||||
|
from backend.common.exception import errors
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from backend.utils.wx_pay import wx_pay_utils
|
||||||
|
|
||||||
|
wxpay = WeChatPay(
|
||||||
|
appid=settings.WX_APPID,
|
||||||
|
api_key=settings.WX_SECRET,
|
||||||
|
mch_id=settings.WX_MCH_ID,
|
||||||
|
mch_cert=settings.WX_PAY_CERT_PATH,
|
||||||
|
mch_key=settings.WX_PAY_KEY_PATH
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RefundService:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def freeze_times_for_refund(user_id: int, order_id: int, reason: str = None):
|
||||||
|
"""
|
||||||
|
为退款冻结次数(增强幂等性和安全性)
|
||||||
|
"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
# 检查是否已有pending的退款申请(幂等性)
|
||||||
|
existing_stmt = select(FreezeLog).where(
|
||||||
|
FreezeLog.order_id == order_id,
|
||||||
|
FreezeLog.status == "pending"
|
||||||
|
)
|
||||||
|
existing_result = await db.execute(existing_stmt)
|
||||||
|
if existing_result.scalar_one_or_none():
|
||||||
|
raise errors.RequestError(msg="该订单已有待处理的退款申请")
|
||||||
|
|
||||||
|
order = await order_dao.get_by_id(db, order_id)
|
||||||
|
if not order:
|
||||||
|
raise errors.NotFoundError(msg="订单不存在")
|
||||||
|
|
||||||
|
if order.status != 'completed':
|
||||||
|
raise errors.RequestError(msg="订单状态异常,无法退款")
|
||||||
|
|
||||||
|
if order.user_id != user_id:
|
||||||
|
raise errors.ForbiddenError(msg="无权操作该订单")
|
||||||
|
|
||||||
|
# 检查退款时效(7天内)
|
||||||
|
if datetime.now() - order.created_at > timedelta(days=7):
|
||||||
|
raise errors.RequestError(msg="超过退款时效(7天)")
|
||||||
|
|
||||||
|
# 原子性冻结次数
|
||||||
|
result = await db.execute(
|
||||||
|
update(UserAccount)
|
||||||
|
.where(UserAccount.user_id == user_id)
|
||||||
|
.where(UserAccount.balance >= order.amount_times)
|
||||||
|
.values(balance=UserAccount.balance - order.amount_times)
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.rowcount == 0:
|
||||||
|
raise errors.ForbiddenError(msg="余额不足或并发冲突")
|
||||||
|
|
||||||
|
# 创建冻结记录
|
||||||
|
freeze_log = FreezeLog(
|
||||||
|
user_id=user_id,
|
||||||
|
order_id=order_id,
|
||||||
|
amount=order.amount_times,
|
||||||
|
reason=reason or "用户申请退款",
|
||||||
|
status="pending"
|
||||||
|
)
|
||||||
|
await freeze_log_dao.add(db, freeze_log)
|
||||||
|
|
||||||
|
# 记录使用日志
|
||||||
|
account = await user_account_dao.get_by_user_id(db, user_id)
|
||||||
|
await usage_log_dao.add(db, {
|
||||||
|
"user_id": user_id,
|
||||||
|
"action": "freeze",
|
||||||
|
"amount": -order.amount_times,
|
||||||
|
"balance_after": account.balance,
|
||||||
|
"related_id": freeze_log.id,
|
||||||
|
"details": {"order_id": order_id, "reason": reason}
|
||||||
|
})
|
||||||
|
|
||||||
|
return freeze_log
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def process_refund(user_id: int, order_id: int, refund_desc: str = "用户申请退款"):
|
||||||
|
"""
|
||||||
|
处理微信退款(增强安全性和幂等性)
|
||||||
|
"""
|
||||||
|
async with async_db_session() as db:
|
||||||
|
# 先冻结次数
|
||||||
|
freeze_log = await RefundService.freeze_times_for_refund(user_id, order_id, refund_desc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
order = await order_dao.get_by_id(db, order_id)
|
||||||
|
|
||||||
|
# 调用微信退款接口
|
||||||
|
refund_result = wxpay.refund.apply(
|
||||||
|
out_trade_no=str(order_id),
|
||||||
|
out_refund_no=str(freeze_log.id), # 使用冻结记录ID作为退款单号
|
||||||
|
total_fee=order.amount_cents,
|
||||||
|
refund_fee=order.amount_cents,
|
||||||
|
refund_desc=refund_desc
|
||||||
|
)
|
||||||
|
|
||||||
|
if refund_result['return_code'] == 'SUCCESS' and refund_result['result_code'] == 'SUCCESS':
|
||||||
|
# 更新冻结记录状态
|
||||||
|
async with async_db_session.begin() as update_db:
|
||||||
|
freeze_log.status = "confirmed"
|
||||||
|
await freeze_log_dao.update(update_db, freeze_log.id, freeze_log)
|
||||||
|
return {"status": "success", "refund_id": refund_result.get('refund_id')}
|
||||||
|
else:
|
||||||
|
# 退款失败,取消冻结
|
||||||
|
await RefundService.cancel_freeze(freeze_log.id)
|
||||||
|
raise errors.ServerError(
|
||||||
|
msg=f"微信退款失败: {refund_result.get('err_code_des', '未知错误')}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 发生异常时也要取消冻结
|
||||||
|
await RefundService.cancel_freeze(freeze_log.id)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def cancel_freeze(freeze_id: int):
|
||||||
|
"""
|
||||||
|
取消冻结(退款失败时调用)
|
||||||
|
"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
freeze_log = await freeze_log_dao.get_by_id(db, freeze_id)
|
||||||
|
if not freeze_log or freeze_log.status != "pending":
|
||||||
|
return
|
||||||
|
|
||||||
|
freeze_log.status = "cancelled"
|
||||||
|
await freeze_log_dao.update(db, freeze_log.id, freeze_log)
|
||||||
|
|
||||||
|
# 原子性恢复用户余额
|
||||||
|
result = await db.execute(
|
||||||
|
update(UserAccount)
|
||||||
|
.where(UserAccount.user_id == freeze_log.user_id)
|
||||||
|
.values(balance=UserAccount.balance + freeze_log.amount)
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.rowcount == 0:
|
||||||
|
logger.error(f"恢复冻结次数失败: freeze_id={freeze_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 记录使用日志
|
||||||
|
account = await user_account_dao.get_by_user_id(db, freeze_log.user_id)
|
||||||
|
await usage_log_dao.add(db, {
|
||||||
|
"user_id": freeze_log.user_id,
|
||||||
|
"action": "unfreeze",
|
||||||
|
"amount": freeze_log.amount,
|
||||||
|
"balance_after": account.balance,
|
||||||
|
"related_id": freeze_log.id,
|
||||||
|
"details": {"reason": "退款失败,恢复次数"}
|
||||||
|
})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def handle_refund_notify(request: Request):
|
||||||
|
"""
|
||||||
|
处理微信退款回调(增强安全性和幂等性)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
body = await request.body()
|
||||||
|
result = wx_pay_utils.parse_payment_result(body)
|
||||||
|
|
||||||
|
# 验证签名
|
||||||
|
if not wx_pay_utils.verify_wxpay_signature(result.copy(), settings.WECHAT_PAY_API_KEY):
|
||||||
|
logger.warning("微信退款回调签名验证失败")
|
||||||
|
return {"return_code": "FAIL", "return_msg": "签名验证失败"}
|
||||||
|
|
||||||
|
if result['return_code'] == 'SUCCESS':
|
||||||
|
out_refund_no = result['out_refund_no'] # 对应freeze_log.id
|
||||||
|
refund_status = result.get('refund_status_0', 'UNKNOWN') # 假设只有一个退款
|
||||||
|
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
# 使用SELECT FOR UPDATE确保并发安全
|
||||||
|
stmt = select(FreezeLog).where(FreezeLog.id == int(out_refund_no)).with_for_update()
|
||||||
|
freeze_result = await db.execute(stmt)
|
||||||
|
freeze_log = freeze_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not freeze_log:
|
||||||
|
logger.warning(f"冻结记录不存在: {out_refund_no}")
|
||||||
|
return {"return_code": "SUCCESS"}
|
||||||
|
|
||||||
|
# 幂等性检查
|
||||||
|
if freeze_log.status in ["confirmed", "cancelled"]:
|
||||||
|
logger.info(f"冻结记录已处理过: {out_refund_no}")
|
||||||
|
return {"return_code": "SUCCESS"}
|
||||||
|
|
||||||
|
if refund_status == 'SUCCESS':
|
||||||
|
freeze_log.status = "confirmed"
|
||||||
|
await freeze_log_dao.update(db, freeze_log.id, freeze_log)
|
||||||
|
|
||||||
|
# 从总购买次数中扣除
|
||||||
|
account_stmt = select(UserAccount).where(
|
||||||
|
UserAccount.user_id == freeze_log.user_id).with_for_update()
|
||||||
|
account_result = await db.execute(account_stmt)
|
||||||
|
user_account = account_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if user_account:
|
||||||
|
user_account.total_purchased = max(0, user_account.total_purchased - freeze_log.amount)
|
||||||
|
await user_account_dao.update(db, user_account.id, user_account)
|
||||||
|
|
||||||
|
# 记录使用日志
|
||||||
|
await usage_log_dao.add(db, {
|
||||||
|
"user_id": freeze_log.user_id,
|
||||||
|
"action": "refund",
|
||||||
|
"amount": -freeze_log.amount,
|
||||||
|
"balance_after": user_account.balance if user_account else 0,
|
||||||
|
"related_id": freeze_log.id,
|
||||||
|
"details": {"refund_id": result.get('refund_id_0')}
|
||||||
|
})
|
||||||
|
elif refund_status in ['FAIL', 'CHANGE']:
|
||||||
|
# 退款失败,取消冻结
|
||||||
|
await RefundService.cancel_freeze(freeze_log.id)
|
||||||
|
|
||||||
|
return {"return_code": "SUCCESS"}
|
||||||
|
else:
|
||||||
|
logger.error(f"微信退款回调失败: {result}")
|
||||||
|
return {"return_code": "FAIL", "return_msg": "处理失败"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"微信退款回调处理异常: {str(e)}")
|
||||||
|
return {"return_code": "FAIL", "return_msg": "服务器异常"}
|
||||||
44
backend/app/admin/service/storage.py
Executable file
44
backend/app/admin/service/storage.py
Executable file
@@ -0,0 +1,44 @@
|
|||||||
|
import os
|
||||||
|
import aiofiles
|
||||||
|
from pathlib import Path
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from fastapi import UploadFile
|
||||||
|
|
||||||
|
from backend.core.conf import settings
|
||||||
|
from backend.core.path_conf import UPLOAD_DIR
|
||||||
|
|
||||||
|
|
||||||
|
class StorageProvider(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def save(self, file_id: int, content: bytes):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def read(self, file_id: int) -> bytes:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LocalStorage(StorageProvider):
|
||||||
|
def __init__(self, base_path: str = UPLOAD_DIR):
|
||||||
|
self.base_path = base_path
|
||||||
|
|
||||||
|
def _get_path(self, file_id: int) -> str:
|
||||||
|
"""使用 os.path.join 构建文件路径"""
|
||||||
|
return os.path.join(self.base_path, str(file_id))
|
||||||
|
|
||||||
|
async def save(self, file_id: int, content: bytes):
|
||||||
|
path = self._get_path(file_id)
|
||||||
|
# 确保目录存在
|
||||||
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
|
||||||
|
async with aiofiles.open(path, 'wb') as f:
|
||||||
|
await f.write(content)
|
||||||
|
|
||||||
|
async def read(self, file_id: int) -> bytes:
|
||||||
|
path = self._get_path(file_id)
|
||||||
|
async with aiofiles.open(path, 'rb') as f:
|
||||||
|
return await f.read()
|
||||||
|
|
||||||
|
|
||||||
|
# 未来可添加S3存储
|
||||||
|
# class S3Storage(StorageProvider): ...
|
||||||
42
backend/app/admin/service/subscription_service.py
Executable file
42
backend/app/admin/service/subscription_service.py
Executable file
@@ -0,0 +1,42 @@
|
|||||||
|
from datetime import timedelta, datetime
|
||||||
|
|
||||||
|
from backend.app.admin.crud.usage_log_crud import usage_log_dao
|
||||||
|
from backend.app.admin.crud.user_account_crud import user_account_dao
|
||||||
|
from backend.app.admin.service.usage_service import UsageService
|
||||||
|
from backend.common.exception import errors
|
||||||
|
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
|
||||||
|
SUBSCRIPTION_PLANS = {
|
||||||
|
"monthly": {"price": 1290, "times": 300, "duration": timedelta(days=30)},
|
||||||
|
"quarterly": {"price": 2990, "times": 900, "duration": timedelta(days=90)},
|
||||||
|
"half_yearly": {"price": 3990, "times": 1800, "duration": timedelta(days=180)},
|
||||||
|
"yearly": {"price": 6990, "times": 3600, "duration": timedelta(days=365)},
|
||||||
|
}
|
||||||
|
|
||||||
|
class SubscriptionService:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def subscribe(user_id: int, plan_key: str):
|
||||||
|
plan = SUBSCRIPTION_PLANS.get(plan_key)
|
||||||
|
if not plan:
|
||||||
|
raise errors.RequestError(msg="无效订阅计划")
|
||||||
|
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
account = await UsageService.get_user_account(user_id)
|
||||||
|
# 处理未用完次数
|
||||||
|
account.balance += account.carryover_balance
|
||||||
|
account.carryover_balance = 0
|
||||||
|
# 更新订阅信息
|
||||||
|
account.subscription_type = plan_key
|
||||||
|
account.subscription_expires_at = datetime.now() + plan["duration"]
|
||||||
|
account.balance += plan["times"]
|
||||||
|
await user_account_dao.update(db, account.id, account)
|
||||||
|
|
||||||
|
await usage_log_dao.add(db, {
|
||||||
|
"user_id": user_id,
|
||||||
|
"action": "renewal",
|
||||||
|
"amount": plan["times"],
|
||||||
|
"balance_after": account.balance,
|
||||||
|
"details": {"plan": plan_key}
|
||||||
|
})
|
||||||
214
backend/app/admin/service/usage_service.py
Executable file
214
backend/app/admin/service/usage_service.py
Executable file
@@ -0,0 +1,214 @@
|
|||||||
|
from decimal import Decimal, ROUND_HALF_UP
|
||||||
|
from backend.app.admin.crud.user_account_crud import user_account_dao
|
||||||
|
from backend.app.admin.crud.usage_log_crud import usage_log_dao
|
||||||
|
from backend.app.admin.crud.order_crud import order_dao
|
||||||
|
from backend.app.admin.model.order import UserAccount
|
||||||
|
from backend.app.admin.model.order import Order
|
||||||
|
from backend.app.admin.schema.usage import PurchaseRequest
|
||||||
|
from backend.common.exception import errors
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from sqlalchemy import func, select, update
|
||||||
|
|
||||||
|
|
||||||
|
class UsageService:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_account(user_id: int) -> UserAccount:
|
||||||
|
async with async_db_session() as db:
|
||||||
|
account = await user_account_dao.get_by_user_id(db, user_id)
|
||||||
|
if not account:
|
||||||
|
account = UserAccount(user_id=user_id)
|
||||||
|
await user_account_dao.add(db, account)
|
||||||
|
return account
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_purchase_times_safe(amount_cents: int) -> int:
|
||||||
|
"""
|
||||||
|
安全的充值次数计算(使用Decimal避免浮点数精度问题)
|
||||||
|
"""
|
||||||
|
if amount_cents <= 0:
|
||||||
|
raise ValueError("充值金额必须大于0")
|
||||||
|
|
||||||
|
# 限制最大充值金额(防止溢出)
|
||||||
|
if amount_cents > 10000000: # 10万元
|
||||||
|
raise ValueError("单次充值金额不能超过10万元")
|
||||||
|
|
||||||
|
amount_yuan = Decimal(amount_cents) / Decimal(100)
|
||||||
|
base_times = amount_yuan * Decimal(10)
|
||||||
|
|
||||||
|
# 计算优惠比例(每10元增加10%,最多100%)
|
||||||
|
tens = (amount_yuan // Decimal(10))
|
||||||
|
bonus_percent = min(tens * Decimal('0.1'), Decimal('1.0'))
|
||||||
|
|
||||||
|
total_times = base_times * (Decimal('1') + bonus_percent)
|
||||||
|
return int(total_times.quantize(Decimal('1'), rounding=ROUND_HALF_UP))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def purchase_times(user_id: int, request: PurchaseRequest):
|
||||||
|
# 输入验证
|
||||||
|
if request.amount_cents < 100: # 最少1元
|
||||||
|
raise errors.RequestError(msg="充值金额不能少于1元")
|
||||||
|
if request.amount_cents > 10000000: # 最多10万元
|
||||||
|
raise errors.RequestError(msg="单次充值金额不能超过10万元")
|
||||||
|
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
account = await UsageService.get_user_account(user_id)
|
||||||
|
times = UsageService.calculate_purchase_times_safe(request.amount_cents)
|
||||||
|
|
||||||
|
order = Order(
|
||||||
|
user_id=user_id,
|
||||||
|
order_type="purchase",
|
||||||
|
amount_cents=request.amount_cents,
|
||||||
|
amount_times=times,
|
||||||
|
status="pending"
|
||||||
|
)
|
||||||
|
await order_dao.add(db, order)
|
||||||
|
await db.flush() # 获取order.id
|
||||||
|
|
||||||
|
# 原子性更新账户(防止并发问题)
|
||||||
|
result = await db.execute(
|
||||||
|
update(UserAccount)
|
||||||
|
.where(UserAccount.id == account.id)
|
||||||
|
.values(
|
||||||
|
balance=UserAccount.balance + times,
|
||||||
|
total_purchased=UserAccount.total_purchased + times
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.rowcount == 0:
|
||||||
|
raise errors.ServerError(msg="账户更新失败")
|
||||||
|
|
||||||
|
# 更新订单状态
|
||||||
|
order.status = "completed"
|
||||||
|
order.processed_at = datetime.now()
|
||||||
|
await order_dao.update(db, order.id, order)
|
||||||
|
|
||||||
|
await usage_log_dao.add(db, {
|
||||||
|
"user_id": user_id,
|
||||||
|
"action": "purchase",
|
||||||
|
"amount": times,
|
||||||
|
"balance_after": account.balance + times,
|
||||||
|
"related_id": order.id,
|
||||||
|
"details": {"amount_cents": request.amount_cents}
|
||||||
|
})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def use_times_atomic(user_id: int, count: int = 1):
|
||||||
|
"""
|
||||||
|
原子性扣减次数,支持免费试用优先使用
|
||||||
|
"""
|
||||||
|
if count <= 0:
|
||||||
|
raise ValueError("扣减次数必须大于0")
|
||||||
|
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
account = await UsageService.get_user_account(user_id)
|
||||||
|
|
||||||
|
# 检查免费试用是否有效
|
||||||
|
is_free_trial_valid = (
|
||||||
|
account.free_trial_expires_at and
|
||||||
|
account.free_trial_expires_at > datetime.now() and
|
||||||
|
account.free_trial_balance > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_free_trial_valid:
|
||||||
|
# 优先使用免费试用次数
|
||||||
|
free_trial_deduct = min(count, account.free_trial_balance)
|
||||||
|
remaining_count = count - free_trial_deduct
|
||||||
|
|
||||||
|
# 更新免费试用余额
|
||||||
|
if free_trial_deduct > 0:
|
||||||
|
await db.execute(
|
||||||
|
update(UserAccount)
|
||||||
|
.where(UserAccount.id == account.id)
|
||||||
|
.values(
|
||||||
|
free_trial_balance=UserAccount.free_trial_balance - free_trial_deduct,
|
||||||
|
balance=UserAccount.balance - free_trial_deduct
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果还有剩余次数,从普通余额中扣除
|
||||||
|
if remaining_count > 0:
|
||||||
|
result = await db.execute(
|
||||||
|
update(UserAccount)
|
||||||
|
.where(UserAccount.id == account.id)
|
||||||
|
.where(UserAccount.balance >= remaining_count)
|
||||||
|
.values(balance=UserAccount.balance - remaining_count)
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.rowcount == 0:
|
||||||
|
# 恢复免费试用余额
|
||||||
|
await db.execute(
|
||||||
|
update(UserAccount)
|
||||||
|
.where(UserAccount.id == account.id)
|
||||||
|
.values(
|
||||||
|
free_trial_balance=UserAccount.free_trial_balance + free_trial_deduct,
|
||||||
|
balance=UserAccount.balance + free_trial_deduct
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise errors.ForbiddenError(msg="余额不足")
|
||||||
|
else:
|
||||||
|
# 直接从普通余额中扣除
|
||||||
|
result = await db.execute(
|
||||||
|
update(UserAccount)
|
||||||
|
.where(UserAccount.id == account.id)
|
||||||
|
.where(UserAccount.balance >= count)
|
||||||
|
.values(balance=UserAccount.balance - count)
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.rowcount == 0:
|
||||||
|
raise errors.ForbiddenError(msg="余额不足")
|
||||||
|
|
||||||
|
# 记录使用日志
|
||||||
|
updated_account = await user_account_dao.get_by_id(db, account.id)
|
||||||
|
await usage_log_dao.add(db, {
|
||||||
|
"user_id": user_id,
|
||||||
|
"action": "use",
|
||||||
|
"amount": -count,
|
||||||
|
"balance_after": updated_account.balance if updated_account else 0,
|
||||||
|
"metadata_": {
|
||||||
|
"used_at": datetime.now().isoformat(),
|
||||||
|
"is_free_trial": is_free_trial_valid
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_account_info(user_id: int) -> dict:
|
||||||
|
"""
|
||||||
|
获取用户账户详细信息
|
||||||
|
"""
|
||||||
|
async with async_db_session() as db:
|
||||||
|
account = await UsageService.get_user_account(user_id)
|
||||||
|
if not account:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
frozen_balance = await user_account_dao.get_frozen_balance(db, user_id)
|
||||||
|
available_balance = max(0, account.balance - frozen_balance)
|
||||||
|
|
||||||
|
# 检查免费试用状态
|
||||||
|
is_free_trial_active = (
|
||||||
|
account.free_trial_expires_at and
|
||||||
|
account.free_trial_expires_at > datetime.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"balance": account.balance,
|
||||||
|
"available_balance": available_balance,
|
||||||
|
"frozen_balance": frozen_balance,
|
||||||
|
"total_purchased": account.total_purchased,
|
||||||
|
"subscription_type": account.subscription_type,
|
||||||
|
"subscription_expires_at": account.subscription_expires_at,
|
||||||
|
"carryover_balance": account.carryover_balance,
|
||||||
|
|
||||||
|
# 免费试用信息
|
||||||
|
"free_trial_balance": account.free_trial_balance,
|
||||||
|
"free_trial_expires_at": account.free_trial_expires_at,
|
||||||
|
"free_trial_active": is_free_trial_active,
|
||||||
|
"free_trial_used": account.free_trial_used,
|
||||||
|
|
||||||
|
# 计算剩余天数
|
||||||
|
"free_trial_days_left": (
|
||||||
|
max(0, (account.free_trial_expires_at - datetime.now()).days)
|
||||||
|
if account.free_trial_expires_at else 0
|
||||||
|
)
|
||||||
|
}
|
||||||
128
backend/app/admin/service/wx_service.py
Executable file
128
backend/app/admin/service/wx_service.py
Executable file
@@ -0,0 +1,128 @@
|
|||||||
|
# services/wx.py
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from backend.app.admin.crud.wx_user_crud import wx_user_dao
|
||||||
|
from backend.app.admin.model.wx_user import WxUser
|
||||||
|
from backend.app.admin.schema.token import GetWxLoginToken
|
||||||
|
from backend.app.admin.schema.wx import DictLevel
|
||||||
|
from backend.common.security.jwt import create_access_token, create_refresh_token
|
||||||
|
from backend.core.conf import settings
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from backend.app.admin.service.wx_user_service import WxUserService
|
||||||
|
from backend.core.wx_integration import decrypt_wx_data
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
from backend.utils.timezone import timezone
|
||||||
|
|
||||||
|
|
||||||
|
class WxAuthService:
|
||||||
|
@staticmethod
|
||||||
|
async def login(
|
||||||
|
*,
|
||||||
|
request: Request, response: Response,
|
||||||
|
openid: str, session_key: str,
|
||||||
|
encrypted_data: str = None,
|
||||||
|
iv: str = None
|
||||||
|
) -> GetWxLoginToken:
|
||||||
|
"""
|
||||||
|
处理用户登录逻辑:
|
||||||
|
1. 查找或创建用户
|
||||||
|
2. 更新用户session_key
|
||||||
|
3. 解密用户信息(如果提供)
|
||||||
|
4. 生成访问令牌
|
||||||
|
"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
user = None
|
||||||
|
try:
|
||||||
|
# 查找或创建用户
|
||||||
|
user = await wx_user_dao.get_by_openid(db, openid)
|
||||||
|
if not user:
|
||||||
|
user = WxUser(
|
||||||
|
openid=openid,
|
||||||
|
session_key=session_key,
|
||||||
|
profile={'dict_level': DictLevel.LEVEL1.value},
|
||||||
|
)
|
||||||
|
await wx_user_dao.add(db, user)
|
||||||
|
await db.flush()
|
||||||
|
await db.refresh(user)
|
||||||
|
else:
|
||||||
|
await wx_user_dao.update_session_key(db, user.id, session_key)
|
||||||
|
|
||||||
|
|
||||||
|
# 解密用户信息(如果提供)
|
||||||
|
if encrypted_data and iv:
|
||||||
|
try:
|
||||||
|
decrypted_data = decrypt_wx_data(
|
||||||
|
encrypted_data,
|
||||||
|
session_key,
|
||||||
|
iv
|
||||||
|
)
|
||||||
|
WxUserService.update_user_profile(db, user.id, decrypted_data)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"用户数据解密失败: {str(e)}")
|
||||||
|
|
||||||
|
# 生成访问令牌
|
||||||
|
access_token = await create_access_token(
|
||||||
|
user.id,
|
||||||
|
False,
|
||||||
|
# extra info
|
||||||
|
ip=request.client.host,
|
||||||
|
# os=request.state.os,
|
||||||
|
# browser=request.state.browser,
|
||||||
|
# device=request.state.device,
|
||||||
|
)
|
||||||
|
refresh_token = await create_refresh_token(access_token.session_uuid, user.id, False)
|
||||||
|
response.set_cookie(
|
||||||
|
key=settings.COOKIE_REFRESH_TOKEN_KEY,
|
||||||
|
value=refresh_token.refresh_token,
|
||||||
|
max_age=settings.COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS,
|
||||||
|
expires=timezone.to_utc(refresh_token.refresh_token_expire_time),
|
||||||
|
httponly=True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logging.error(f"登录处理失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# 从用户资料中获取词典等级设置
|
||||||
|
dict_level = None
|
||||||
|
if user and user.profile and isinstance(user.profile, dict):
|
||||||
|
dict_level = user.profile.get("dict_level")
|
||||||
|
|
||||||
|
data = GetWxLoginToken(
|
||||||
|
access_token=access_token.access_token,
|
||||||
|
access_token_expire_time=access_token.access_token_expire_time,
|
||||||
|
session_uuid=access_token.session_uuid,
|
||||||
|
dict_level=dict_level
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def update_user_settings(
|
||||||
|
*,
|
||||||
|
user_id: int,
|
||||||
|
dict_level: Optional[DictLevel] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
更新用户设置
|
||||||
|
"""
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
user = await wx_user_dao.get(db, user_id)
|
||||||
|
if not user:
|
||||||
|
raise ValueError("用户不存在")
|
||||||
|
|
||||||
|
# 如果用户没有profile,初始化为空字典
|
||||||
|
if not user.profile:
|
||||||
|
user.profile = {}
|
||||||
|
|
||||||
|
# 更新词典等级设置
|
||||||
|
if dict_level is not None:
|
||||||
|
user.profile["dict_level"] = dict_level.value
|
||||||
|
|
||||||
|
# 使用新的方法更新用户资料,会自动更新updated_time字段
|
||||||
|
await wx_user_dao.update_user_profile(db, user_id, user.profile)
|
||||||
|
|
||||||
|
|
||||||
|
wx_service: WxAuthService = WxAuthService()
|
||||||
143
backend/app/admin/service/wx_user_service.py
Executable file
143
backend/app/admin/service/wx_user_service.py
Executable file
@@ -0,0 +1,143 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from sqlalchemy import Select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from backend.common.exception import errors
|
||||||
|
from backend.common.security.jwt import superuser_verify, password_verify, get_hash_password
|
||||||
|
from backend.app.admin.crud.wx_user_crud import wx_user_dao
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
from backend.app.admin.model import WxUser
|
||||||
|
from backend.app.admin.schema.user import RegisterUserParam, ResetPassword, UpdateUserParam, AvatarParam
|
||||||
|
|
||||||
|
|
||||||
|
class WxUserService:
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_user_profile(db: Session, user_id: int, profile_data: dict):
|
||||||
|
"""
|
||||||
|
更新用户资料
|
||||||
|
"""
|
||||||
|
user = db.query(WxUser).get(user_id)
|
||||||
|
if not user:
|
||||||
|
raise ValueError("用户不存在")
|
||||||
|
|
||||||
|
# 保存用户资料,保留现有设置
|
||||||
|
current_profile = user.profile or {}
|
||||||
|
|
||||||
|
# 更新微信用户信息
|
||||||
|
if "nickName" in profile_data:
|
||||||
|
current_profile["nickname"] = profile_data.get("nickName")
|
||||||
|
if "avatarUrl" in profile_data:
|
||||||
|
current_profile["avatar"] = profile_data.get("avatarUrl")
|
||||||
|
if "gender" in profile_data:
|
||||||
|
current_profile["gender"] = profile_data.get("gender")
|
||||||
|
if "city" in profile_data:
|
||||||
|
current_profile["city"] = profile_data.get("city")
|
||||||
|
if "province" in profile_data:
|
||||||
|
current_profile["province"] = profile_data.get("province")
|
||||||
|
if "country" in profile_data:
|
||||||
|
current_profile["country"] = profile_data.get("country")
|
||||||
|
|
||||||
|
user.profile = current_profile
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_info(db: Session, user_id: int):
|
||||||
|
"""
|
||||||
|
获取用户信息(解密敏感数据)
|
||||||
|
"""
|
||||||
|
user = db.query(WxUser).get(user_id)
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 解密手机号
|
||||||
|
# encryptor = DataEncrypt()
|
||||||
|
# mobile = encryptor.decrypt(user.mobile) if user.mobile else None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": user.id,
|
||||||
|
"openid": user.openid,
|
||||||
|
"mobile": user.mobile,
|
||||||
|
"profile": user.profile,
|
||||||
|
"created_at": user.created_at
|
||||||
|
}
|
||||||
|
|
||||||
|
# @staticmethod
|
||||||
|
# async def register(*, obj: RegisterUserParam) -> None:
|
||||||
|
# async with async_db_session.begin() as db:
|
||||||
|
# if not obj.password:
|
||||||
|
# raise errors.ForbiddenError(msg='密码为空')
|
||||||
|
# username = await user_dao.get_by_username(db, obj.username)
|
||||||
|
# if username:
|
||||||
|
# raise errors.ForbiddenError(msg='用户已注册')
|
||||||
|
# email = await user_dao.check_email(db, obj.email)
|
||||||
|
# if email:
|
||||||
|
# raise errors.ForbiddenError(msg='邮箱已注册')
|
||||||
|
# await user_dao.create(db, obj)
|
||||||
|
#
|
||||||
|
# @staticmethod
|
||||||
|
# async def pwd_reset(*, obj: ResetPassword) -> int:
|
||||||
|
# async with async_db_session.begin() as db:
|
||||||
|
# user = await user_dao.get_by_username(db, obj.username)
|
||||||
|
# if not password_verify(obj.old_password, user.password):
|
||||||
|
# raise errors.ForbiddenError(msg='原密码错误')
|
||||||
|
# np1 = obj.new_password
|
||||||
|
# np2 = obj.confirm_password
|
||||||
|
# if np1 != np2:
|
||||||
|
# raise errors.ForbiddenError(msg='密码输入不一致')
|
||||||
|
# new_pwd = get_hash_password(obj.new_password, user.salt)
|
||||||
|
# count = await user_dao.reset_password(db, user.id, new_pwd)
|
||||||
|
# return count
|
||||||
|
#
|
||||||
|
# @staticmethod
|
||||||
|
# async def get_userinfo(*, username: str) -> User:
|
||||||
|
# async with async_db_session() as db:
|
||||||
|
# user = await user_dao.get_by_username(db, username)
|
||||||
|
# if not user:
|
||||||
|
# raise errors.NotFoundError(msg='用户不存在')
|
||||||
|
# return user
|
||||||
|
#
|
||||||
|
# @staticmethod
|
||||||
|
# async def update(*, username: str, obj: UpdateUserParam) -> int:
|
||||||
|
# async with async_db_session.begin() as db:
|
||||||
|
# input_user = await user_dao.get_by_username(db, username=username)
|
||||||
|
# if not input_user:
|
||||||
|
# raise errors.NotFoundError(msg='用户不存在')
|
||||||
|
# superuser_verify(input_user)
|
||||||
|
# if input_user.username != obj.username:
|
||||||
|
# _username = await user_dao.get_by_username(db, obj.username)
|
||||||
|
# if _username:
|
||||||
|
# raise errors.ForbiddenError(msg='用户名已注册')
|
||||||
|
# if input_user.email != obj.email:
|
||||||
|
# email = await user_dao.check_email(db, obj.email)
|
||||||
|
# if email:
|
||||||
|
# raise errors.ForbiddenError(msg='邮箱已注册')
|
||||||
|
# count = await user_dao.update_userinfo(db, input_user.id, obj)
|
||||||
|
# return count
|
||||||
|
#
|
||||||
|
# @staticmethod
|
||||||
|
# async def update_avatar(*, username: str, avatar: AvatarParam) -> int:
|
||||||
|
# async with async_db_session.begin() as db:
|
||||||
|
# input_user = await user_dao.get_by_username(db, username)
|
||||||
|
# if not input_user:
|
||||||
|
# raise errors.NotFoundError(msg='用户不存在')
|
||||||
|
# count = await user_dao.update_avatar(db, input_user.id, avatar)
|
||||||
|
# return count
|
||||||
|
#
|
||||||
|
# @staticmethod
|
||||||
|
# async def get_select(*, username: str = None, phone: str = None, status: int = None) -> Select:
|
||||||
|
# return await user_dao.get_list(username=username, phone=phone, status=status)
|
||||||
|
#
|
||||||
|
# @staticmethod
|
||||||
|
# async def delete(*, current_user: User, username: str) -> int:
|
||||||
|
# async with async_db_session.begin() as db:
|
||||||
|
# superuser_verify(current_user)
|
||||||
|
# input_user = await user_dao.get_by_username(db, username)
|
||||||
|
# if not input_user:
|
||||||
|
# raise errors.NotFoundError(msg='用户不存在')
|
||||||
|
# count = await user_dao.delete(db, input_user.id)
|
||||||
|
# return count
|
||||||
37
backend/app/admin/service/wxpay_service.py
Executable file
37
backend/app/admin/service/wxpay_service.py
Executable file
@@ -0,0 +1,37 @@
|
|||||||
|
from wechatpy.pay import WeChatPay
|
||||||
|
|
||||||
|
from backend.app.admin.model.order import Order
|
||||||
|
from backend.app.admin.service.usage_service import UsageService
|
||||||
|
from backend.core.conf import settings
|
||||||
|
from backend.app.admin.crud.order_crud import order_dao
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
|
||||||
|
wxpay = WeChatPay(
|
||||||
|
appid=settings.WECHAT_APP_ID,
|
||||||
|
api_key=settings.WECHAT_PAY_API_KEY,
|
||||||
|
mch_id=settings.WECHAT_MCH_ID,
|
||||||
|
mch_cert=settings.WECHAT_PAY_CERT_PATH,
|
||||||
|
mch_key=settings.WECHAT_PAY_KEY_PATH
|
||||||
|
)
|
||||||
|
|
||||||
|
class WxPayService:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_order(user_id: int, amount_cents: int, description: str):
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
order = Order(
|
||||||
|
user_id=user_id,
|
||||||
|
order_type="purchase",
|
||||||
|
amount_cents=amount_cents,
|
||||||
|
amount_times=UsageService.calculate_purchase_times(amount_cents),
|
||||||
|
status="pending"
|
||||||
|
)
|
||||||
|
await order_dao.add(db, order)
|
||||||
|
result = wxpay.order.create(
|
||||||
|
trade_type="JSAPI",
|
||||||
|
body=description,
|
||||||
|
total_fee=amount_cents,
|
||||||
|
notify_url=settings.WECHAT_NOTIFY_URL,
|
||||||
|
out_trade_no=str(order.id)
|
||||||
|
)
|
||||||
|
return result
|
||||||
158
backend/app/admin/tasks.py
Executable file
158
backend/app/admin/tasks.py
Executable file
@@ -0,0 +1,158 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
定时任务处理模块
|
||||||
|
"""
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from sqlalchemy import select, and_, desc
|
||||||
|
from backend.app.admin.model.audit_log import AuditLog, DailySummary
|
||||||
|
from backend.app.ai.model.image import Image
|
||||||
|
from backend.app.admin.model.wx_user import WxUser
|
||||||
|
from backend.app.admin.crud.daily_summary_crud import daily_summary_dao
|
||||||
|
from backend.app.admin.schema.audit_log import CreateDailySummaryParam
|
||||||
|
from backend.database.db import async_db_session
|
||||||
|
|
||||||
|
|
||||||
|
async def wx_user_index_history() -> None:
|
||||||
|
"""异步实现 wx_user_index_history 任务"""
|
||||||
|
# 计算前一天的时间范围
|
||||||
|
today = datetime.now().date()
|
||||||
|
yesterday = today - timedelta(days=1)
|
||||||
|
yesterday_start = datetime(yesterday.year, yesterday.month, yesterday.day)
|
||||||
|
yesterday_end = datetime(today.year, today.month, today.day)
|
||||||
|
|
||||||
|
async with async_db_session.begin() as db:
|
||||||
|
# 优化:通过 audit_log 表查询有相关记录的用户,避免遍历所有用户
|
||||||
|
# 先获取有前一天 recognition 记录的用户 ID 列表
|
||||||
|
user_ids_stmt = (
|
||||||
|
select(AuditLog.user_id)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
AuditLog.api_type == 'recognition',
|
||||||
|
AuditLog.called_at >= yesterday_start,
|
||||||
|
AuditLog.called_at < yesterday_end
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.distinct()
|
||||||
|
)
|
||||||
|
|
||||||
|
user_ids_result = await db.execute(user_ids_stmt)
|
||||||
|
user_ids = [row[0] for row in user_ids_result.fetchall()]
|
||||||
|
|
||||||
|
if not user_ids:
|
||||||
|
return # 没有用户有相关记录,直接返回
|
||||||
|
|
||||||
|
# 分批处理用户,避免一次性加载过多数据
|
||||||
|
batch_size = 500
|
||||||
|
for i in range(0, len(user_ids), batch_size):
|
||||||
|
batch_user_ids = user_ids[i:i + batch_size]
|
||||||
|
|
||||||
|
# 为这批用户获取 audit_log 记录
|
||||||
|
audit_logs_stmt = (
|
||||||
|
select(AuditLog)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
AuditLog.user_id.in_(batch_user_ids),
|
||||||
|
AuditLog.api_type == 'recognition',
|
||||||
|
AuditLog.called_at >= yesterday_start,
|
||||||
|
AuditLog.called_at < yesterday_end
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.order_by(AuditLog.user_id, desc(AuditLog.called_at))
|
||||||
|
)
|
||||||
|
|
||||||
|
audit_logs_result = await db.execute(audit_logs_stmt)
|
||||||
|
all_audit_logs = audit_logs_result.scalars().all()
|
||||||
|
|
||||||
|
# 按用户 ID 分组 audit_logs
|
||||||
|
user_audit_logs = {}
|
||||||
|
for log in all_audit_logs:
|
||||||
|
if log.user_id not in user_audit_logs:
|
||||||
|
user_audit_logs[log.user_id] = []
|
||||||
|
user_audit_logs[log.user_id].append(log)
|
||||||
|
|
||||||
|
# 获取这批用户的信息
|
||||||
|
users_stmt = select(WxUser).where(WxUser.id.in_(batch_user_ids))
|
||||||
|
users_result = await db.execute(users_stmt)
|
||||||
|
users = {user.id: user for user in users_result.scalars().all()}
|
||||||
|
|
||||||
|
# 获取所有相关的 image 记录
|
||||||
|
all_image_ids = [log.image_id for log in all_audit_logs if log.image_id]
|
||||||
|
image_map = {}
|
||||||
|
if all_image_ids:
|
||||||
|
images_stmt = select(Image).where(Image.id.in_(all_image_ids))
|
||||||
|
images_result = await db.execute(images_stmt)
|
||||||
|
images = images_result.scalars().all()
|
||||||
|
image_map = {img.id: img for img in images}
|
||||||
|
|
||||||
|
# 处理每个用户
|
||||||
|
for user_id, audit_logs in user_audit_logs.items():
|
||||||
|
user = users.get(user_id)
|
||||||
|
if not user:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 构建 ref_word 数据
|
||||||
|
image_ids = []
|
||||||
|
thumbnail_ids = []
|
||||||
|
|
||||||
|
for log in audit_logs:
|
||||||
|
if log.image_id and log.image_id in image_map:
|
||||||
|
image = image_map[log.image_id]
|
||||||
|
# 获取用户词典等级
|
||||||
|
dict_level = log.dict_level or "default"
|
||||||
|
|
||||||
|
# 从 image.details 中提取 ref_word
|
||||||
|
ref_words = []
|
||||||
|
try:
|
||||||
|
if image.details and isinstance(image.details, dict):
|
||||||
|
recognition_result = image.details.get("recognition_result", {})
|
||||||
|
if isinstance(recognition_result, dict):
|
||||||
|
dict_level_data = recognition_result.get(dict_level, {})
|
||||||
|
if isinstance(dict_level_data, dict):
|
||||||
|
ref_word = dict_level_data.get("ref_word")
|
||||||
|
if ref_word:
|
||||||
|
if isinstance(ref_word, list):
|
||||||
|
ref_words = ref_word
|
||||||
|
else:
|
||||||
|
ref_words = [str(ref_word)]
|
||||||
|
except Exception:
|
||||||
|
# 如果解析出错,跳过该记录
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 收集图片ID和缩略图ID用于DailySummary
|
||||||
|
image_ids.append(str(log.image_id))
|
||||||
|
if image.thumbnail_id:
|
||||||
|
thumbnail_ids.append(str(image.thumbnail_id))
|
||||||
|
else:
|
||||||
|
thumbnail_ids.append('')
|
||||||
|
|
||||||
|
# 创建或更新DailySummary记录
|
||||||
|
daily_summary_data = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"image_ids": image_ids,
|
||||||
|
"thumbnail_ids": thumbnail_ids,
|
||||||
|
"summary_time": yesterday_start
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查是否已存在该用户当天的记录
|
||||||
|
existing_summary_stmt = (
|
||||||
|
select(DailySummary)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
DailySummary.user_id == user_id,
|
||||||
|
DailySummary.summary_time == yesterday_start
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing_summary_result = await db.execute(existing_summary_stmt)
|
||||||
|
existing_summary = existing_summary_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing_summary:
|
||||||
|
# 更新现有记录
|
||||||
|
await daily_summary_dao.update_model(db, existing_summary.id, daily_summary_data)
|
||||||
|
else:
|
||||||
|
# 创建新记录
|
||||||
|
await daily_summary_dao.create_model(db, CreateDailySummaryParam(**daily_summary_data))
|
||||||
|
|
||||||
|
# 提交批次更改
|
||||||
|
await db.commit()
|
||||||
4
backend/app/ai/__init__.py
Normal file
4
backend/app/ai/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from backend.app.ai.model import Image, ImageText #, Article, ArticleParagraph, ArticleSentence
|
||||||
|
from backend.app.ai.schema import *
|
||||||
|
from backend.app.ai.crud import image_dao, image_text_dao #, article_dao, article_paragraph_dao, article_sentence_dao
|
||||||
|
from backend.app.ai.service import ImageService, image_service, ImageTextService, image_text_service #, ArticleService, article_service
|
||||||
1
backend/app/ai/api/__init__.py
Normal file
1
backend/app/ai/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from backend.app.ai.api.image import router as image_router
|
||||||
342
backend/app/ai/api/article.py
Normal file
342
backend/app/ai/api/article.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import List
|
||||||
|
from fastapi import APIRouter, Request, Query
|
||||||
|
from starlette.background import BackgroundTasks
|
||||||
|
|
||||||
|
from backend.app.ai.schema.article import ArticleSchema, ArticleWithParagraphsSchema, CreateArticleParam, UpdateArticleParam, ArticleParagraphSchema, CreateArticleParagraphParam, UpdateArticleParagraphParam, ArticleSentenceSchema, CreateArticleSentenceParam, UpdateArticleSentenceParam
|
||||||
|
from backend.app.ai.service.article_service import article_service
|
||||||
|
from backend.common.response.response_schema import response_base, ResponseSchemaModel
|
||||||
|
from backend.common.security.jwt import DependsJwtAuth
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", summary="创建文章", dependencies=[DependsJwtAuth])
|
||||||
|
async def create_article(
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
params: CreateArticleParam
|
||||||
|
) -> ResponseSchemaModel[ArticleSchema]:
|
||||||
|
"""
|
||||||
|
创建文章记录
|
||||||
|
|
||||||
|
请求体参数:
|
||||||
|
- title: 文章标题
|
||||||
|
- content: 文章完整内容
|
||||||
|
- author: 作者(可选)
|
||||||
|
- category: 分类(可选)
|
||||||
|
- level: 难度等级(可选)
|
||||||
|
- info: 附加信息(可选)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 创建的文章记录
|
||||||
|
"""
|
||||||
|
article_id = await article_service.create_article(obj=params)
|
||||||
|
article = await article_service.get_article_by_id(article_id)
|
||||||
|
|
||||||
|
return response_base.success(data=article)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{article_id}", summary="获取文章详情", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_article(
|
||||||
|
article_id: int
|
||||||
|
) -> ResponseSchemaModel[ArticleWithParagraphsSchema]:
|
||||||
|
"""
|
||||||
|
获取文章详情,包括所有段落和句子
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- article_id: 文章ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 文章记录及其所有段落和句子
|
||||||
|
"""
|
||||||
|
article = await article_service.get_article_with_content(article_id)
|
||||||
|
if not article:
|
||||||
|
return response_base.fail(code=404, msg="文章不存在")
|
||||||
|
|
||||||
|
return response_base.success(data=article)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{article_id}", summary="更新文章", dependencies=[DependsJwtAuth])
|
||||||
|
async def update_article(
|
||||||
|
article_id: int,
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
params: UpdateArticleParam
|
||||||
|
) -> ResponseSchemaModel[ArticleSchema]:
|
||||||
|
"""
|
||||||
|
更新文章记录
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- article_id: 文章ID
|
||||||
|
|
||||||
|
请求体参数:
|
||||||
|
- title: 文章标题
|
||||||
|
- content: 文章完整内容
|
||||||
|
- author: 作者(可选)
|
||||||
|
- category: 分类(可选)
|
||||||
|
- level: 难度等级(可选)
|
||||||
|
- info: 附加信息(可选)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 更新后的文章记录
|
||||||
|
"""
|
||||||
|
success = await article_service.update_article(article_id, params)
|
||||||
|
if not success:
|
||||||
|
return response_base.fail(code=404, msg="文章不存在")
|
||||||
|
|
||||||
|
article = await article_service.get_article_by_id(article_id)
|
||||||
|
return response_base.success(data=article)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{article_id}", summary="删除文章", dependencies=[DependsJwtAuth])
|
||||||
|
async def delete_article(
|
||||||
|
article_id: int,
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks
|
||||||
|
) -> ResponseSchemaModel[None]:
|
||||||
|
"""
|
||||||
|
删除文章记录
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- article_id: 文章ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 无
|
||||||
|
"""
|
||||||
|
success = await article_service.delete_article(article_id)
|
||||||
|
if not success:
|
||||||
|
return response_base.fail(code=404, msg="文章不存在")
|
||||||
|
|
||||||
|
return response_base.success()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/paragraph", summary="创建文章段落", dependencies=[DependsJwtAuth])
|
||||||
|
async def create_article_paragraph(
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
params: CreateArticleParagraphParam
|
||||||
|
) -> ResponseSchemaModel[ArticleParagraphSchema]:
|
||||||
|
"""
|
||||||
|
创建文章段落记录
|
||||||
|
|
||||||
|
请求体参数:
|
||||||
|
- article_id: 关联的文章ID
|
||||||
|
- paragraph_index: 段落序号
|
||||||
|
- content: 段落内容
|
||||||
|
- standard_audio_id: 标准朗读音频文件ID(可选)
|
||||||
|
- info: 附加信息(可选)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 创建的段落记录
|
||||||
|
"""
|
||||||
|
paragraph_id = await article_service.create_article_paragraph(obj=params)
|
||||||
|
paragraph = await article_service.get_article_paragraph_by_id(paragraph_id)
|
||||||
|
|
||||||
|
return response_base.success(data=paragraph)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/paragraph/{paragraph_id}", summary="更新文章段落", dependencies=[DependsJwtAuth])
|
||||||
|
async def update_article_paragraph(
|
||||||
|
paragraph_id: int,
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
params: UpdateArticleParagraphParam
|
||||||
|
) -> ResponseSchemaModel[ArticleParagraphSchema]:
|
||||||
|
"""
|
||||||
|
更新文章段落记录
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- paragraph_id: 段落ID
|
||||||
|
|
||||||
|
请求体参数:
|
||||||
|
- article_id: 关联的文章ID
|
||||||
|
- paragraph_index: 段落序号
|
||||||
|
- content: 段落内容
|
||||||
|
- standard_audio_id: 标准朗读音频文件ID(可选)
|
||||||
|
- info: 附加信息(可选)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 更新后的段落记录
|
||||||
|
"""
|
||||||
|
success = await article_service.update_article_paragraph(paragraph_id, params)
|
||||||
|
if not success:
|
||||||
|
return response_base.fail(code=404, msg="段落不存在")
|
||||||
|
|
||||||
|
paragraph = await article_service.get_article_paragraph_by_id(paragraph_id)
|
||||||
|
return response_base.success(data=paragraph)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/paragraph/{paragraph_id}", summary="删除文章段落", dependencies=[DependsJwtAuth])
|
||||||
|
async def delete_article_paragraph(
|
||||||
|
paragraph_id: int,
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks
|
||||||
|
) -> ResponseSchemaModel[None]:
|
||||||
|
"""
|
||||||
|
删除文章段落记录
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- paragraph_id: 段落ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 无
|
||||||
|
"""
|
||||||
|
success = await article_service.delete_article_paragraph(paragraph_id)
|
||||||
|
if not success:
|
||||||
|
return response_base.fail(code=404, msg="段落不存在")
|
||||||
|
|
||||||
|
return response_base.success()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/paragraph/{paragraph_id}", summary="获取段落详情", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_article_paragraph(
|
||||||
|
paragraph_id: int
|
||||||
|
) -> ResponseSchemaModel[ArticleParagraphSchema]:
|
||||||
|
"""
|
||||||
|
获取文章段落详情
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- paragraph_id: 段落ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 段落记录
|
||||||
|
"""
|
||||||
|
paragraph = await article_service.get_article_paragraph_by_id(paragraph_id)
|
||||||
|
if not paragraph:
|
||||||
|
return response_base.fail(code=404, msg="段落不存在")
|
||||||
|
|
||||||
|
return response_base.success(data=paragraph)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/paragraph", summary="获取文章的所有段落", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_article_paragraphs(
|
||||||
|
article_id: int = Query(..., description="文章ID")
|
||||||
|
) -> ResponseSchemaModel[List[ArticleParagraphSchema]]:
|
||||||
|
"""
|
||||||
|
获取指定文章的所有段落记录
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- article_id: 文章ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 文章的所有段落记录列表
|
||||||
|
"""
|
||||||
|
paragraphs = await article_service.get_article_paragraphs_by_article_id(article_id)
|
||||||
|
return response_base.success(data=paragraphs)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sentence", summary="创建文章句子", dependencies=[DependsJwtAuth])
|
||||||
|
async def create_article_sentence(
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
params: CreateArticleSentenceParam
|
||||||
|
) -> ResponseSchemaModel[ArticleSentenceSchema]:
|
||||||
|
"""
|
||||||
|
创建文章句子记录
|
||||||
|
|
||||||
|
请求体参数:
|
||||||
|
- paragraph_id: 关联的段落ID
|
||||||
|
- sentence_index: 句子序号
|
||||||
|
- content: 句子内容
|
||||||
|
- standard_audio_id: 标准朗读音频文件ID(可选)
|
||||||
|
- info: 附加信息(可选)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 创建的句子记录
|
||||||
|
"""
|
||||||
|
sentence_id = await article_service.create_article_sentence(obj=params)
|
||||||
|
sentence = await article_service.get_article_sentence_by_id(sentence_id)
|
||||||
|
|
||||||
|
return response_base.success(data=sentence)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/sentence/{sentence_id}", summary="更新文章句子", dependencies=[DependsJwtAuth])
|
||||||
|
async def update_article_sentence(
|
||||||
|
sentence_id: int,
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
params: UpdateArticleSentenceParam
|
||||||
|
) -> ResponseSchemaModel[ArticleSentenceSchema]:
|
||||||
|
"""
|
||||||
|
更新文章句子记录
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- sentence_id: 句子ID
|
||||||
|
|
||||||
|
请求体参数:
|
||||||
|
- paragraph_id: 关联的段落ID
|
||||||
|
- sentence_index: 句子序号
|
||||||
|
- content: 句子内容
|
||||||
|
- standard_audio_id: 标准朗读音频文件ID(可选)
|
||||||
|
- info: 附加信息(可选)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 更新后的句子记录
|
||||||
|
"""
|
||||||
|
success = await article_service.update_article_sentence(sentence_id, params)
|
||||||
|
if not success:
|
||||||
|
return response_base.fail(code=404, msg="句子不存在")
|
||||||
|
|
||||||
|
sentence = await article_service.get_article_sentence_by_id(sentence_id)
|
||||||
|
return response_base.success(data=sentence)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/sentence/{sentence_id}", summary="删除文章句子", dependencies=[DependsJwtAuth])
|
||||||
|
async def delete_article_sentence(
|
||||||
|
sentence_id: int,
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks
|
||||||
|
) -> ResponseSchemaModel[None]:
|
||||||
|
"""
|
||||||
|
删除文章句子记录
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- sentence_id: 句子ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 无
|
||||||
|
"""
|
||||||
|
success = await article_service.delete_article_sentence(sentence_id)
|
||||||
|
if not success:
|
||||||
|
return response_base.fail(code=404, msg="句子不存在")
|
||||||
|
|
||||||
|
return response_base.success()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sentence/{sentence_id}", summary="获取句子详情", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_article_sentence(
|
||||||
|
sentence_id: int
|
||||||
|
) -> ResponseSchemaModel[ArticleSentenceSchema]:
|
||||||
|
"""
|
||||||
|
获取文章句子详情
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- sentence_id: 句子ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 句子记录
|
||||||
|
"""
|
||||||
|
sentence = await article_service.get_article_sentence_by_id(sentence_id)
|
||||||
|
if not sentence:
|
||||||
|
return response_base.fail(code=404, msg="句子不存在")
|
||||||
|
|
||||||
|
return response_base.success(data=sentence)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sentence", summary="获取段落的所有句子", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_article_sentences(
|
||||||
|
paragraph_id: int = Query(..., description="段落ID")
|
||||||
|
) -> ResponseSchemaModel[List[ArticleSentenceSchema]]:
|
||||||
|
"""
|
||||||
|
获取指定段落的所有句子记录
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- paragraph_id: 段落ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 段落的所有句子记录列表
|
||||||
|
"""
|
||||||
|
sentences = await article_service.get_article_sentences_by_paragraph_id(paragraph_id)
|
||||||
|
return response_base.success(data=sentences)
|
||||||
113
backend/app/ai/api/image.py
Executable file
113
backend/app/ai/api/image.py
Executable file
@@ -0,0 +1,113 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from fastapi import APIRouter, UploadFile, HTTPException, Request, Response, Query
|
||||||
|
from fastapi.params import File, Depends
|
||||||
|
from starlette.background import BackgroundTasks
|
||||||
|
|
||||||
|
from backend.app.admin.schema.audit_log import CreateAuditLogParam
|
||||||
|
from backend.app.ai.schema.image import ImageRecognizeRes, ProcessImageRequest, ImageBean, ImageShowRes
|
||||||
|
from backend.app.admin.service.audit_log_service import audit_log_service
|
||||||
|
from backend.app.ai.service.image_service import ImageService, image_service
|
||||||
|
from backend.common.response.response_schema import response_base, ResponseSchemaModel
|
||||||
|
from backend.common.security.jwt import DependsJwtAuth
|
||||||
|
from backend.app.admin.schema.wx import DictLevel
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# @router.post("/upload", summary="上传图片进行识别", dependencies=[DependsJwtAuth])
|
||||||
|
# async def upload_image(
|
||||||
|
# request: Request, background_tasks: BackgroundTasks, file: UploadFile = File(...)
|
||||||
|
# ) -> ResponseSchemaModel[ImageRecognizeRes]:
|
||||||
|
# """
|
||||||
|
# 上传图片并调用通义千问API进行识别
|
||||||
|
#
|
||||||
|
# 返回:
|
||||||
|
# - 识别结果 (英文单词或类别)
|
||||||
|
# - 是否来自缓存
|
||||||
|
# - 图片ID
|
||||||
|
# """
|
||||||
|
# image_service.file_verify(file)
|
||||||
|
# result = await image_service.process_image_upload(file=file, request=request, background_tasks=background_tasks)
|
||||||
|
# return response_base.success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/recognize", summary="处理已上传的图片文件", dependencies=[DependsJwtAuth])
|
||||||
|
async def process_image(
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
params: ProcessImageRequest
|
||||||
|
) -> ResponseSchemaModel[ImageRecognizeRes]:
|
||||||
|
"""
|
||||||
|
处理已上传的图片文件并调用通义千问API进行识别
|
||||||
|
|
||||||
|
请求体参数:
|
||||||
|
- file_id: 已上传文件的ID
|
||||||
|
- type: 文件类型(默认为"image")
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 识别结果 (英文单词或类别)
|
||||||
|
- 是否来自缓存
|
||||||
|
- 图片ID
|
||||||
|
"""
|
||||||
|
result = await image_service.process_image_from_file(
|
||||||
|
params=params,
|
||||||
|
request=request,
|
||||||
|
background_tasks=background_tasks
|
||||||
|
)
|
||||||
|
return response_base.success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{id}", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_image(
|
||||||
|
request: Request,
|
||||||
|
id: int,
|
||||||
|
dict_level: DictLevel = Query(DictLevel.LEVEL1, description="词典等级")
|
||||||
|
) -> ResponseSchemaModel[ImageShowRes]:
|
||||||
|
image = await image_service.find_image(id)
|
||||||
|
|
||||||
|
# 检查details和recognition_result是否存在
|
||||||
|
if not image.details or "recognition_result" not in image.details:
|
||||||
|
raise HTTPException(status_code=404, detail="Recognition result not found")
|
||||||
|
|
||||||
|
# 根据dict_level获取对应的识别结果
|
||||||
|
recognition_result = image.details["recognition_result"]
|
||||||
|
if dict_level.value not in recognition_result:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Recognition result for {dict_level.value} not found")
|
||||||
|
|
||||||
|
result = ImageShowRes(
|
||||||
|
id=image.id,
|
||||||
|
file_id=image.file_id,
|
||||||
|
res=recognition_result[dict_level.value]
|
||||||
|
)
|
||||||
|
return response_base.success(data=result)
|
||||||
|
|
||||||
|
@router.get("/log")
|
||||||
|
def log(request: Request, background_tasks: BackgroundTasks) -> ResponseSchemaModel[dict]:
|
||||||
|
audit_log = CreateAuditLogParam(
|
||||||
|
api_type="test_api",
|
||||||
|
model_name="test_model",
|
||||||
|
response_data={"test": "test_response"},
|
||||||
|
called_at=datetime.now(),
|
||||||
|
request_data={"test": "test_request"},
|
||||||
|
token_usage={"usage": "test_usage"},
|
||||||
|
cost=0.0,
|
||||||
|
duration=0.0,
|
||||||
|
status_code=200,
|
||||||
|
error_message="",
|
||||||
|
image_id=123123,
|
||||||
|
user_id=321312,
|
||||||
|
api_version="test_version",
|
||||||
|
)
|
||||||
|
background_tasks.add_task(
|
||||||
|
audit_log_service.create,
|
||||||
|
obj=audit_log,
|
||||||
|
)
|
||||||
|
return response_base.success(data={"succ": True})
|
||||||
|
|
||||||
|
# @router.get("/download/{image_id}")
|
||||||
|
# async def download_image(image_id: int):
|
||||||
|
# result = await image_service.download_image(id=image_id)
|
||||||
|
# return response_base.success(result)
|
||||||
161
backend/app/ai/api/image_text.py
Normal file
161
backend/app/ai/api/image_text.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import List
|
||||||
|
from fastapi import APIRouter, Request, Query
|
||||||
|
from starlette.background import BackgroundTasks
|
||||||
|
|
||||||
|
from backend.app.ai.schema.image_text import ImageTextSchema, ImageTextWithRecordingsSchema, CreateImageTextParam, UpdateImageTextParam, ImageTextInitResponseSchema, ImageTextInitParam
|
||||||
|
from backend.app.ai.service.image_text_service import image_text_service
|
||||||
|
from backend.app.ai.service.recording_service import recording_service
|
||||||
|
from backend.common.exception import errors
|
||||||
|
from backend.common.response.response_schema import response_base, ResponseSchemaModel
|
||||||
|
from backend.common.security.jwt import DependsJwtAuth
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/init", summary="初始化图片文本", dependencies=[DependsJwtAuth])
|
||||||
|
async def init_image_texts(
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
params: ImageTextInitParam
|
||||||
|
) -> ResponseSchemaModel[ImageTextInitResponseSchema]:
|
||||||
|
"""
|
||||||
|
初始化图片文本记录
|
||||||
|
根据dict_level从image的recognition_result中提取文本,如果不存在则创建,如果已存在则直接返回
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- image_id: 图片ID
|
||||||
|
- dict_level: 词典等级
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 图片文本记录列表
|
||||||
|
"""
|
||||||
|
result = await image_text_service.init_image_texts(request.user.id, params.image_id, params.dict_level, background_tasks)
|
||||||
|
return response_base.success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/standard/{text_id}", summary="获取标准音频文件ID", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_standard_audio_file_id(
|
||||||
|
text_id: int,
|
||||||
|
) -> ResponseSchemaModel[dict]:
|
||||||
|
"""
|
||||||
|
根据文本ID获取标准音频的文件ID,等待异步任务创建完成
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- text_id: 图片文本ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 标准音频的文件ID
|
||||||
|
"""
|
||||||
|
file_id = await recording_service.get_standard_audio_file_id_by_text_id(text_id)
|
||||||
|
if not file_id:
|
||||||
|
raise errors.NotFoundError(msg="标准音频不存在或创建超时")
|
||||||
|
return response_base.success(data={'audio_id': str(file_id)})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{text_id}", summary="获取图片文本详情", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_image_text(
|
||||||
|
text_id: int
|
||||||
|
) -> ResponseSchemaModel[ImageTextWithRecordingsSchema]:
|
||||||
|
"""
|
||||||
|
获取图片文本详情,包括关联的录音记录
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- text_id: 图片文本ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 图片文本记录及其关联的录音记录
|
||||||
|
"""
|
||||||
|
text = await image_text_service.get_text_by_id(text_id)
|
||||||
|
if not text:
|
||||||
|
raise errors.NotFoundError(msg="图片文本不存在")
|
||||||
|
|
||||||
|
# 获取关联的录音记录
|
||||||
|
recordings = await recording_service.get_recordings_by_text_id(text_id)
|
||||||
|
|
||||||
|
# 构造返回数据
|
||||||
|
result = ImageTextWithRecordingsSchema(
|
||||||
|
id=text.id,
|
||||||
|
image_id=text.image_id,
|
||||||
|
content=text.content,
|
||||||
|
position=text.position,
|
||||||
|
dict_level=text.dict_level,
|
||||||
|
standard_audio_id=text.standard_audio_id,
|
||||||
|
info=text.info,
|
||||||
|
created_time=text.created_time,
|
||||||
|
recordings=recordings
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_base.success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{text_id}", summary="更新图片文本", dependencies=[DependsJwtAuth])
|
||||||
|
async def update_image_text(
|
||||||
|
text_id: int,
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
params: UpdateImageTextParam
|
||||||
|
) -> ResponseSchemaModel[ImageTextSchema]:
|
||||||
|
"""
|
||||||
|
更新图片文本记录
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- text_id: 图片文本ID
|
||||||
|
|
||||||
|
请求体参数:
|
||||||
|
- image_id: 图片ID
|
||||||
|
- content: 文本内容
|
||||||
|
- position: 文本在图片中的位置信息(可选)
|
||||||
|
- dict_level: 词典等级(可选)
|
||||||
|
- standard_audio_id: 标准朗读音频文件ID(可选)
|
||||||
|
- info: 附加信息(可选)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 更新后的图片文本记录
|
||||||
|
"""
|
||||||
|
success = await image_text_service.update_text(text_id, params)
|
||||||
|
if not success:
|
||||||
|
return response_base.fail(code=404, msg="图片文本不存在")
|
||||||
|
|
||||||
|
text = await image_text_service.get_text_by_id(text_id)
|
||||||
|
return response_base.success(data=text)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{text_id}", summary="删除图片文本", dependencies=[DependsJwtAuth])
|
||||||
|
async def delete_image_text(
|
||||||
|
text_id: int,
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks
|
||||||
|
) -> ResponseSchemaModel[None]:
|
||||||
|
"""
|
||||||
|
删除图片文本记录
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- text_id: 图片文本ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 无
|
||||||
|
"""
|
||||||
|
success = await image_text_service.delete_text(text_id)
|
||||||
|
if not success:
|
||||||
|
return response_base.fail(code=404, msg="图片文本不存在")
|
||||||
|
|
||||||
|
return response_base.success()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", summary="获取图片的所有文本", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_image_texts(
|
||||||
|
image_id: int = Query(..., description="图片ID")
|
||||||
|
) -> ResponseSchemaModel[List[ImageTextSchema]]:
|
||||||
|
"""
|
||||||
|
获取指定图片的所有文本记录
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- image_id: 图片ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 图片的所有文本记录列表
|
||||||
|
"""
|
||||||
|
texts = await image_text_service.get_texts_by_image_id(image_id)
|
||||||
|
return response_base.success(data=texts)
|
||||||
105
backend/app/ai/api/recording.py
Normal file
105
backend/app/ai/api/recording.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import List
|
||||||
|
from fastapi import APIRouter, Request, Query
|
||||||
|
from starlette.background import BackgroundTasks
|
||||||
|
|
||||||
|
from backend.app.ai.schema.recording import RecordingAssessmentRequest, RecordingAssessmentResponse, ReadingProgressResponse
|
||||||
|
from backend.app.ai.service.recording_service import recording_service
|
||||||
|
from backend.app.ai.service.image_text_service import image_text_service
|
||||||
|
from backend.common.exception import errors
|
||||||
|
from backend.common.response.response_schema import response_base, ResponseSchemaModel
|
||||||
|
from backend.common.security.jwt import DependsJwtAuth
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/assessment", summary="录音评估接口", dependencies=[DependsJwtAuth])
|
||||||
|
async def assess_recording(
|
||||||
|
request: Request,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
params: RecordingAssessmentRequest
|
||||||
|
) -> ResponseSchemaModel[RecordingAssessmentResponse]:
|
||||||
|
"""
|
||||||
|
录音评估接口
|
||||||
|
|
||||||
|
接收文件ID和图片文本ID,调用第三方API获取评估结果并存储到recording表的details字段
|
||||||
|
|
||||||
|
请求体参数:
|
||||||
|
- file_id: 录音文件ID
|
||||||
|
- image_text_id: 关联的图片文本ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- file_id: 录音文件ID
|
||||||
|
- assessment_result: 评估结果
|
||||||
|
- image_id: 关联的图片ID
|
||||||
|
- image_text_id: 关联的图片文本ID
|
||||||
|
"""
|
||||||
|
# 获取图片文本记录以获取image_id用于响应
|
||||||
|
image_text = await image_text_service.get_text_by_id(params.image_text_id)
|
||||||
|
if not image_text:
|
||||||
|
raise errors.NotFoundError(msg=f"ImageText with id {params.image_text_id} not found")
|
||||||
|
|
||||||
|
# 调用录音服务进行评估
|
||||||
|
assessment_result = await recording_service.assess_recording(
|
||||||
|
params.file_id,
|
||||||
|
# 2087227590107594752,
|
||||||
|
image_text_id=params.image_text_id,
|
||||||
|
user_id=request.user.id,
|
||||||
|
background_tasks=background_tasks
|
||||||
|
)
|
||||||
|
|
||||||
|
# 返回结果
|
||||||
|
response_data = RecordingAssessmentResponse(
|
||||||
|
file_id=params.file_id,
|
||||||
|
assessment_result=assessment_result,
|
||||||
|
image_text_id=params.image_text_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return response_base.success(data=response_data)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/progress", summary="获取朗读进步统计", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_reading_progress(
|
||||||
|
image_id: int = Query(..., description="图片ID"),
|
||||||
|
text: str = Query(..., description="朗读文本")
|
||||||
|
) -> ResponseSchemaModel[ReadingProgressResponse]:
|
||||||
|
"""
|
||||||
|
获取特定图片和文本的朗读进步统计
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- image_id: 图片ID
|
||||||
|
- text: 朗读文本
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 包含进步统计信息的字典
|
||||||
|
"""
|
||||||
|
# 获取所有相关的录音记录
|
||||||
|
recordings = await recording_service.get_recordings_by_image_and_text(image_id, text)
|
||||||
|
|
||||||
|
# 计算进步统计
|
||||||
|
progress_stats = recording_service.calculate_progress(recordings)
|
||||||
|
|
||||||
|
return response_base.success(data=progress_stats)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/progress-by-text-id", summary="根据文本ID获取朗读进步统计", dependencies=[DependsJwtAuth])
|
||||||
|
async def get_reading_progress_by_text_id(
|
||||||
|
image_text_id: int = Query(..., description="图片文本ID")
|
||||||
|
) -> ResponseSchemaModel[ReadingProgressResponse]:
|
||||||
|
"""
|
||||||
|
根据图片文本ID获取朗读进步统计
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- image_text_id: 图片文本ID
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 包含进步统计信息的字典
|
||||||
|
"""
|
||||||
|
# 获取所有相关的录音记录
|
||||||
|
recordings = await recording_service.get_recordings_by_text_id(image_text_id)
|
||||||
|
|
||||||
|
# 计算进步统计
|
||||||
|
progress_stats = recording_service.calculate_progress(recordings)
|
||||||
|
|
||||||
|
return response_base.success(data=progress_stats)
|
||||||
16
backend/app/ai/api/router.py
Normal file
16
backend/app/ai/api/router.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from backend.app.ai.api.image import router as image_router
|
||||||
|
from backend.app.ai.api.recording import router as recording_router
|
||||||
|
from backend.app.ai.api.image_text import router as image_text_router
|
||||||
|
from backend.app.ai.api.article import router as article_router
|
||||||
|
from backend.core.conf import settings
|
||||||
|
|
||||||
|
v1 = APIRouter(prefix=settings.FASTAPI_API_V1_PATH)
|
||||||
|
|
||||||
|
v1.include_router(image_router, prefix='/image', tags=['AI图片服务'])
|
||||||
|
v1.include_router(recording_router, prefix='/recording', tags=['AI录音服务'])
|
||||||
|
v1.include_router(image_text_router, prefix='/image_text', tags=['AI图片文本服务'])
|
||||||
|
# v1.include_router(article_router, prefix='/article', tags=['AI文章服务'])
|
||||||
3
backend/app/ai/crud/__init__.py
Normal file
3
backend/app/ai/crud/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from backend.app.ai.crud.image_curd import image_dao
|
||||||
|
from backend.app.ai.crud.image_text_crud import image_text_dao
|
||||||
|
from backend.app.ai.crud.recording_crud import recording_dao
|
||||||
65
backend/app/ai/crud/article_crud.py
Normal file
65
backend/app/ai/crud/article_crud.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import Optional, List
|
||||||
|
from sqlalchemy import select, and_
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy_crud_plus import CRUDPlus
|
||||||
|
|
||||||
|
from backend.app.ai.model.article import Article, ArticleParagraph, ArticleSentence
|
||||||
|
|
||||||
|
|
||||||
|
class CRUDArticle(CRUDPlus[Article]):
|
||||||
|
async def get_by_title(self, db: AsyncSession, title: str) -> Optional[Article]:
|
||||||
|
"""根据标题获取文章"""
|
||||||
|
stmt = select(self.model).where(self.model.title == title)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_articles_by_category(self, db: AsyncSession, category: str) -> List[Article]:
|
||||||
|
"""根据分类获取文章列表"""
|
||||||
|
stmt = select(self.model).where(self.model.category == category)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
|
||||||
|
class CRUDArticleParagraph(CRUDPlus[ArticleParagraph]):
|
||||||
|
async def get_by_article_id(self, db: AsyncSession, article_id: int) -> List[ArticleParagraph]:
|
||||||
|
"""根据文章ID获取所有段落,按序号排序"""
|
||||||
|
stmt = select(self.model).where(self.model.article_id == article_id).order_by(self.model.paragraph_index)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def get_by_article_id_and_index(self, db: AsyncSession, article_id: int, paragraph_index: int) -> Optional[ArticleParagraph]:
|
||||||
|
"""根据文章ID和段落序号获取段落"""
|
||||||
|
stmt = select(self.model).where(
|
||||||
|
and_(
|
||||||
|
self.model.article_id == article_id,
|
||||||
|
self.model.paragraph_index == paragraph_index
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
class CRUDArticleSentence(CRUDPlus[ArticleSentence]):
|
||||||
|
async def get_by_paragraph_id(self, db: AsyncSession, paragraph_id: int) -> List[ArticleSentence]:
|
||||||
|
"""根据段落ID获取所有句子,按序号排序"""
|
||||||
|
stmt = select(self.model).where(self.model.paragraph_id == paragraph_id).order_by(self.model.sentence_index)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def get_by_paragraph_id_and_index(self, db: AsyncSession, paragraph_id: int, sentence_index: int) -> Optional[ArticleSentence]:
|
||||||
|
"""根据段落ID和句子序号获取句子"""
|
||||||
|
stmt = select(self.model).where(
|
||||||
|
and_(
|
||||||
|
self.model.paragraph_id == paragraph_id,
|
||||||
|
self.model.sentence_index == sentence_index
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
article_dao: CRUDArticle = CRUDArticle(Article)
|
||||||
|
article_paragraph_dao: CRUDArticleParagraph = CRUDArticleParagraph(ArticleParagraph)
|
||||||
|
article_sentence_dao: CRUDArticleSentence = CRUDArticleSentence(ArticleSentence)
|
||||||
160
backend/app/ai/crud/image_curd.py
Executable file
160
backend/app/ai/crud/image_curd.py
Executable file
@@ -0,0 +1,160 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from sqlalchemy import select, update, desc, and_, func, Float
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.sql import Select
|
||||||
|
from sqlalchemy_crud_plus import CRUDPlus
|
||||||
|
|
||||||
|
from backend.app.ai.model import Image
|
||||||
|
from backend.app.ai.schema.image import UpdateImageParam, AddImageParam
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCRUD(CRUDPlus[Image]):
|
||||||
|
|
||||||
|
async def get(self, db: AsyncSession, id: int) -> Image | None:
|
||||||
|
return await self.select_model(db, id)
|
||||||
|
|
||||||
|
async def get_by_file_id(self, db: AsyncSession, file_id: int) -> Image | None:
|
||||||
|
return await self.select_model_by_column(db, file_id=file_id)
|
||||||
|
|
||||||
|
async def get_by_file_id_and_dict_level(self, db: AsyncSession, file_id: int, dict_level: str) -> Image | None:
|
||||||
|
"""
|
||||||
|
根据文件ID和词典等级获取图片记录
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param file_id: 文件ID
|
||||||
|
:param dict_level: 词典等级
|
||||||
|
:return: 图片记录或None
|
||||||
|
"""
|
||||||
|
return await self.select_model_by_column(db, file_id=file_id, dict_level=dict_level)
|
||||||
|
|
||||||
|
async def get_images_by_file_id(self, db: AsyncSession, file_id: int) -> List[Image]:
|
||||||
|
"""
|
||||||
|
根据文件ID获取具有不同词典等级的图片记录
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param file_id: 文件ID
|
||||||
|
:param dict_level: 当前词典等级
|
||||||
|
:return: 图片记录列表
|
||||||
|
"""
|
||||||
|
stmt = select(Image).where(Image.file_id == file_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_images_by_file_id_and_different_dict_level(self, db: AsyncSession, file_id: int, dict_level: str) -> \
|
||||||
|
List[Image]:
|
||||||
|
"""
|
||||||
|
根据文件ID获取具有不同词典等级的图片记录
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param file_id: 文件ID
|
||||||
|
:param dict_level: 当前词典等级
|
||||||
|
:return: 图片记录列表
|
||||||
|
"""
|
||||||
|
stmt = select(Image).where(
|
||||||
|
and_(
|
||||||
|
Image.file_id == file_id,
|
||||||
|
Image.dict_level != dict_level
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_images_by_ids(self, db: AsyncSession, image_ids: List[int]) -> List[Image]:
|
||||||
|
"""
|
||||||
|
根据ID列表获取图片记录
|
||||||
|
|
||||||
|
:param db: 数据库会话
|
||||||
|
:param image_ids: 图片ID列表
|
||||||
|
:return: 图片记录列表
|
||||||
|
"""
|
||||||
|
if not image_ids:
|
||||||
|
return []
|
||||||
|
stmt = select(Image).where(Image.id.in_(image_ids))
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def find_similar_images_by_dict_level(self, db: AsyncSession, embedding: List[float], dict_level: str,
|
||||||
|
top_k: int = 3, threshold: float = 0.8) -> List[int]:
|
||||||
|
"""
|
||||||
|
根据向量和词典等级查找相似图片
|
||||||
|
参数:
|
||||||
|
db: 数据库会话
|
||||||
|
embedding: 1024维向量
|
||||||
|
dict_level: 词典等级
|
||||||
|
top_k: 返回最相似的K个结果
|
||||||
|
threshold: 相似度阈值 (0.0-1.0)
|
||||||
|
"""
|
||||||
|
# 确保向量是numpy数组
|
||||||
|
if not isinstance(embedding, np.ndarray):
|
||||||
|
embedding = np.array(embedding)
|
||||||
|
|
||||||
|
# 转换为 NumPy 数组提高性能
|
||||||
|
target_embedding = np.array(embedding, dtype=np.float32)
|
||||||
|
|
||||||
|
# 构建查询
|
||||||
|
cosine_distance_expr = Image.embedding.cosine_distance(target_embedding)
|
||||||
|
similarity_expr = (1 - func.cast(cosine_distance_expr, Float)).label("similarity")
|
||||||
|
# 构建查询
|
||||||
|
stmt = select(
|
||||||
|
Image.id,
|
||||||
|
# Image.info,
|
||||||
|
similarity_expr
|
||||||
|
).where(
|
||||||
|
and_(
|
||||||
|
Image.dict_level == dict_level,
|
||||||
|
similarity_expr >= threshold
|
||||||
|
)
|
||||||
|
).order_by(
|
||||||
|
cosine_distance_expr
|
||||||
|
).limit(top_k)
|
||||||
|
|
||||||
|
results = await db.execute(stmt)
|
||||||
|
id_list: List[int] = results.scalars().all()
|
||||||
|
return id_list
|
||||||
|
|
||||||
|
async def add(self, db: AsyncSession, new_image: Image) -> None:
|
||||||
|
db.add(new_image)
|
||||||
|
|
||||||
|
async def update(self, db: AsyncSession, id: int, obj: UpdateImageParam) -> int:
|
||||||
|
return await self.update_model(db, id, obj)
|
||||||
|
|
||||||
|
async def find_similar_image_ids(self, db: AsyncSession, embedding: List[float], top_k: int = 3,
|
||||||
|
threshold: float = 0.8) -> List[int]:
|
||||||
|
"""
|
||||||
|
直接通过向量查找相似图片
|
||||||
|
参数:
|
||||||
|
embedding: 1024维向量
|
||||||
|
top_k: 返回最相似的K个结果
|
||||||
|
threshold: 相似度阈值 (0.0-1.0)
|
||||||
|
"""
|
||||||
|
# 确保向量是numpy数组
|
||||||
|
if not isinstance(embedding, np.ndarray):
|
||||||
|
embedding = np.array(embedding)
|
||||||
|
|
||||||
|
# 转换为 NumPy 数组提高性能
|
||||||
|
target_embedding = np.array(embedding, dtype=np.float32)
|
||||||
|
|
||||||
|
# 构建查询
|
||||||
|
cosine_distance_expr = Image.embedding.cosine_distance(target_embedding)
|
||||||
|
similarity_expr = (1 - func.cast(cosine_distance_expr, Float)).label("similarity")
|
||||||
|
# 构建查询
|
||||||
|
stmt = select(
|
||||||
|
Image.id,
|
||||||
|
# Image.info,
|
||||||
|
similarity_expr
|
||||||
|
).where(
|
||||||
|
similarity_expr >= threshold
|
||||||
|
).order_by(
|
||||||
|
cosine_distance_expr
|
||||||
|
).limit(top_k)
|
||||||
|
|
||||||
|
results = await db.execute(stmt)
|
||||||
|
id_list: List[int] = results.scalars().all()
|
||||||
|
return id_list
|
||||||
|
|
||||||
|
|
||||||
|
image_dao: ImageCRUD = ImageCRUD(Image)
|
||||||
44
backend/app/ai/crud/image_text_crud.py
Normal file
44
backend/app/ai/crud/image_text_crud.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import Optional, List
|
||||||
|
from sqlalchemy import select, and_
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy_crud_plus import CRUDPlus
|
||||||
|
|
||||||
|
from backend.app.ai.model.image_text import ImageText
|
||||||
|
|
||||||
|
|
||||||
|
class ImageTextCRUD(CRUDPlus[ImageText]):
|
||||||
|
async def get(self, db: AsyncSession, id: int) -> Optional[ImageText]:
|
||||||
|
"""根据ID获取文本记录"""
|
||||||
|
return await self.select_model(db, id)
|
||||||
|
|
||||||
|
async def get_by_image_id(self, db: AsyncSession, image_id: int) -> List[ImageText]:
|
||||||
|
"""根据图片ID获取所有文本"""
|
||||||
|
stmt = select(self.model).where(self.model.image_id == image_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def get_by_image_id_and_content(self, db: AsyncSession, image_id: int, content: str) -> Optional[ImageText]:
|
||||||
|
"""根据图片ID和文本内容获取文本记录"""
|
||||||
|
stmt = select(self.model).where(
|
||||||
|
and_(
|
||||||
|
self.model.image_id == image_id,
|
||||||
|
self.model.content == content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_by_standard_audio_id(self, db: AsyncSession, standard_audio_id: int) -> Optional[ImageText]:
|
||||||
|
"""根据标准音频文件ID获取文本记录"""
|
||||||
|
stmt = select(self.model).where(self.model.standard_audio_id == standard_audio_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def update(self, db: AsyncSession, id: int, obj_in: dict) -> int:
|
||||||
|
"""更新文本记录"""
|
||||||
|
return await self.update_model(db, id, obj_in)
|
||||||
|
|
||||||
|
|
||||||
|
image_text_dao: ImageTextCRUD = ImageTextCRUD(ImageText)
|
||||||
57
backend/app/ai/crud/recording_crud.py
Normal file
57
backend/app/ai/crud/recording_crud.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import Optional, List
|
||||||
|
from sqlalchemy import select, and_, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy_crud_plus import CRUDPlus
|
||||||
|
|
||||||
|
from backend.app.ai.model.recording import Recording
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingCRUD(CRUDPlus[Recording]):
|
||||||
|
async def get(self, db: AsyncSession, id: int) -> Optional[Recording]:
|
||||||
|
"""根据ID获取录音记录"""
|
||||||
|
return await self.select_model(db, id)
|
||||||
|
|
||||||
|
async def get_by_file_id(self, db: AsyncSession, file_id: int) -> Optional[Recording]:
|
||||||
|
"""根据文件ID获取录音记录"""
|
||||||
|
stmt = select(self.model).where(self.model.file_id == file_id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_by_text_id(self, db: AsyncSession, text_id: int) -> List[Recording]:
|
||||||
|
"""根据文本ID获取所有录音记录(不包括标准音频)"""
|
||||||
|
stmt = select(self.model).where(
|
||||||
|
and_(
|
||||||
|
self.model.image_text_id == text_id,
|
||||||
|
self.model.is_standard == False
|
||||||
|
)
|
||||||
|
).order_by(self.model.created_time.asc())
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def get_latest_by_text_id(self, db: AsyncSession, text_id: int) -> Optional[Recording]:
|
||||||
|
"""根据文本ID获取最新的录音记录(不包括标准音频)"""
|
||||||
|
stmt = select(self.model).where(
|
||||||
|
and_(
|
||||||
|
self.model.image_text_id == text_id,
|
||||||
|
self.model.is_standard == False
|
||||||
|
)
|
||||||
|
).order_by(self.model.created_time.desc()).limit(1)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_standard_by_text_id(self, db: AsyncSession, text_id: int) -> Optional[Recording]:
|
||||||
|
"""根据文本ID获取标准音频记录"""
|
||||||
|
stmt = select(self.model).where(
|
||||||
|
and_(
|
||||||
|
self.model.image_text_id == text_id,
|
||||||
|
self.model.is_standard == True
|
||||||
|
)
|
||||||
|
).limit(1)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
recording_dao: RecordingCRUD = RecordingCRUD(Recording)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user