This commit is contained in:
Felix
2025-12-05 19:55:31 +08:00
parent a9df500e67
commit 28ee12ca72
11 changed files with 445 additions and 52 deletions

View File

@@ -1,6 +1,13 @@
from fastapi import APIRouter, UploadFile, File, Query, Depends, Response
from fastapi import APIRouter, UploadFile, File, Query, Depends, Response, Request
from pydantic import BaseModel
from backend.app.admin.schema.file import FileUploadResponse
from backend.app.admin.schema.file import (
FileUploadResponse,
UploadInitRequest,
UploadInitResponse,
UploadCompleteRequest,
UploadCompleteResponse,
)
from backend.app.admin.service.file_service import file_service
from backend.common.response.response_schema import response_base, ResponseSchemaModel
from backend.common.security.jwt import DependsJwtAuth
@@ -49,3 +56,34 @@ async def delete_file(file_id: int) -> ResponseSchemaModel[bool]:
if not result:
return await response_base.fail(message="文件不存在或删除失败")
return await response_base.success(data=result)
@router.post("/upload/init", summary="初始化对象存储上传", dependencies=[DependsJwtAuth])
async def upload_init(body: UploadInitRequest, request: Request) -> ResponseSchemaModel[UploadInitResponse]:
user_id = getattr(request, 'user', None).id if hasattr(request, 'user') and request.user else None
d = await file_service.init_upload(body.filename, body.size, body.mime, body.biz_type, user_id)
data = UploadInitResponse(**d)
return response_base.success(data=data)
@router.post("/upload/complete", summary="完成对象存储上传", dependencies=[DependsJwtAuth])
async def upload_complete(body: UploadCompleteRequest, request: Request) -> ResponseSchemaModel[UploadCompleteResponse]:
user_id = getattr(request, 'user', None).id if hasattr(request, 'user') and request.user else None
d = await file_service.complete_upload(
file_id=body.file_id,
cloud_path=body.cloud_path,
file_id_in_cos=body.fileID,
sha256=body.sha256,
size=body.size,
mime=body.mime,
wx_user_id=user_id,
)
data = UploadCompleteResponse(**d)
return response_base.success(data=data)
@router.get("/temp_url/{file_id}", summary="获取临时下载URL", dependencies=[DependsJwtAuth])
async def get_temp_url(file_id: int, request: Request) -> ResponseSchemaModel[dict]:
user_id = getattr(request, 'user', None).id if hasattr(request, 'user') and request.user else None
url = await file_service.get_presigned_download_url(file_id, user_id)
return response_base.success(data={"url": url})

View File

@@ -8,6 +8,11 @@ from backend.app.admin.schema.wx import WxLoginRequest, TokenResponse, UserInfo,
from backend.app.admin.service.wx_service import wx_service
from backend.common.response.response_schema import response_base, ResponseSchemaModel
from backend.common.security.jwt import wx_openid_authentication, create_access_token, create_refresh_token
from backend.app.admin.schema.token import BasicWxUserInfo, WxSettings, PointsBrief
from backend.app.admin.service.points_service import points_service
from backend.utils.timezone import timezone
from backend.core.conf import settings
from backend.core.wx_integration import verify_wx_code
router = APIRouter()
@@ -54,6 +59,60 @@ async def wechat_login(
return response_base.success(data=result)
@router.get("/user", summary="获取当前用户(云托管识别)")
async def get_current_user(request: Request, response: Response) -> ResponseSchemaModel[GetWxLoginToken]:
openid = request.headers.get("x-wx-openid") or request.headers.get("X-WX-OPENID")
unionid = request.headers.get("x-wx-unionid") or request.headers.get("X-WX-UNIONID")
if not openid:
raise HTTPException(status_code=401, detail="未识别到云托管身份")
user = await wx_openid_authentication(openid, unionid or "")
access = await create_access_token(user_id=user.id, multi_login=True)
refresh = await create_refresh_token(access.session_uuid, user.id, True)
response.set_cookie(
key=settings.COOKIE_REFRESH_TOKEN_KEY,
value=refresh.refresh_token,
max_age=settings.COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS,
expires=timezone.to_utc(refresh.refresh_token_expire_time),
httponly=True,
)
profile = getattr(user, "profile", {}) or {}
dict_level = profile.get("dict_level")
dict_category = profile.get("dict_category")
basic_user = BasicWxUserInfo(
id=user.id,
nickname=profile.get("nickname"),
avatar_url=profile.get("avatar_url"),
gender=profile.get("gender"),
country=profile.get("country"),
province=profile.get("province"),
city=profile.get("city"),
language=profile.get("language"),
)
settings_obj = WxSettings(dict_level=dict_level)
pts = await points_service.get_user_points(user.id)
points_brief = PointsBrief(
balance=pts.balance if pts else 0,
expired_time=pts.expired_time.isoformat() if pts and getattr(pts, "expired_time", None) else None,
)
data = GetWxLoginToken(
access_token=access.access_token,
access_token_expire_time=access.access_token_expire_time,
session_uuid=access.session_uuid,
dict_level=dict_level,
dict_category=dict_category,
user=basic_user,
settings=settings_obj,
points=points_brief,
server_time=int(timezone.to_utc(timezone.now()).timestamp()),
)
return response_base.success(data=data)
# @router.put("/settings", summary="更新用户设置", dependencies=[DependsJwtAuth])
# async def update_user_settings(
# request: Request,
@@ -105,4 +164,4 @@ async def wechat_login(
# dict_level=dict_level
# )
#
# return response_base.success(data=response_data)
# return response_base.success(data=response_data)

View File

@@ -53,3 +53,33 @@ class FileUploadResponse(SchemaBase):
file_name: str
content_type: Optional[str] = None
file_size: int
class UploadInitRequest(BaseModel):
filename: str
size: int
mime: Optional[str] = None
biz_type: Optional[str] = None
class UploadInitResponse(SchemaBase):
file_id: str
cloud_path: str
class UploadCompleteRequest(BaseModel):
file_id: int
cloud_path: str
fileID: str
size: Optional[int] = None
sha256: Optional[str] = None
mime: Optional[str] = None
class UploadCompleteResponse(SchemaBase):
id: str
url: Optional[str] = None
status: str

View File

@@ -40,6 +40,30 @@ class GetWxLoginToken(AccessTokenBase):
"""微信登录令牌"""
dict_level: Optional[str] = Field(None, description="词典等级")
dict_category: Optional[str] = Field(None, description="词典类型")
user: Optional["BasicWxUserInfo"] = Field(None, description="用户基础信息")
settings: Optional["WxSettings"] = Field(None, description="用户设置")
points: Optional["PointsBrief"] = Field(None, description="积分简要信息")
server_time: Optional[int] = Field(None, description="服务端当前时间戳(秒)")
class BasicWxUserInfo(SchemaBase):
id: int = Field(description="用户ID")
nickname: Optional[str] = Field(None, description="昵称")
avatar_url: Optional[str] = Field(None, description="头像URL")
gender: Optional[str] = Field(None, description="性别")
country: Optional[str] = Field(None, description="国家")
province: Optional[str] = Field(None, description="省份")
city: Optional[str] = Field(None, description="城市")
language: Optional[str] = Field(None, description="语言")
class WxSettings(SchemaBase):
dict_level: Optional[str] = Field(None, description="词典等级")
class PointsBrief(SchemaBase):
balance: int = Field(description="积分余额")
expired_time: Optional[str] = Field(None, description="积分过期时间(ISO字符串)")
class GetTokenDetail(SchemaBase):
@@ -55,4 +79,4 @@ class GetTokenDetail(SchemaBase):
device: str = Field(description='设备')
status: StatusType = Field(description='状态')
last_login_time: str = Field(description='最后登录时间')
expire_time: datetime = Field(description='过期时间')
expire_time: datetime = Field(description='过期时间')

View File

@@ -689,4 +689,213 @@ class FileService:
except Exception:
return None
@staticmethod
def _mime_to_ext(mime: str | None, filename: str | None) -> str:
if filename and '.' in filename:
return filename.rsplit('.', 1)[-1].lower()
mapping = {
'image/jpeg': 'jpg',
'image/jpg': 'jpg',
'image/png': 'png',
'image/gif': 'gif',
'image/webp': 'webp',
'image/bmp': 'bmp',
'image/svg+xml': 'svg',
'audio/mpeg': 'mp3',
'audio/wav': 'wav',
'video/mp4': 'mp4',
'video/quicktime': 'mov',
}
return mapping.get((mime or '').lower(), 'dat')
@staticmethod
async def init_upload(filename: str, size: int, mime: str | None, biz_type: str | None, wx_user_id: int):
storage_type = 'cos'
async with async_db_session.begin() as db:
temp_hash = hashlib.sha256(f"pending:{filename}:{size}".encode()).hexdigest()
params = AddFileParam(
file_hash=temp_hash,
file_name=filename or 'unnamed',
content_type=mime,
file_size=size,
storage_type=storage_type,
storage_path=None,
metadata_info=FileMetadata(
file_name=filename or 'unnamed',
content_type=mime,
file_size=size,
extra={"biz_type": biz_type} if biz_type else None,
),
)
db_file = await file_dao.create(db, params)
await db.flush()
cos_client = CosClient()
ext = FileService._mime_to_ext(mime, filename)
cloud_path = cos_client.build_key(db_file.id)
await file_dao.update(
db,
db_file.id,
UpdateFileParam(
storage_path=cloud_path,
details={"status": "pending", "cloud_path": cloud_path, "wx_user_id": wx_user_id},
),
)
db_file.storage_path = cloud_path
return {
"file_id": str(db_file.id),
"cloud_path": cloud_path,
}
@staticmethod
async def complete_upload(
file_id: int,
cloud_path: str,
file_id_in_cos: str,
sha256: str | None = None,
size: int | None = None,
mime: str | None = None,
wx_user_id: int = 0,
):
cos = CosClient()
cos_key = cloud_path
if not cos.object_exists(cos_key):
raise errors.NotFoundError(msg="对象未找到或尚未上传")
head = cos.head_object(cos_key) or {}
content_length = head.get('Content-Length') or head.get('ContentLength')
try:
content_length = int(content_length) if content_length is not None else (size or 0)
except:
content_length = size or 0
content_type = head.get('Content-Type') or head.get('ContentType') or mime
etag = (head.get('ETag') or '').strip('"')
file_hash_val = sha256 or etag or hashlib.sha256(f"cos:{cloud_path}:{file_id_in_cos}".encode()).hexdigest()
is_image = (content_type or '').lower().startswith('image/') or (mime or '').lower().startswith('image/')
if is_image:
avif_key = f"{file_id}_avif"
pic_ops = {
"is_pic_info": 1,
"rules": [
{
"fileid": avif_key,
"rule": "imageMogr2/format/avif",
}
],
}
resp = cos.process_image(cos_key, pic_ops)
data = resp[1] if isinstance(resp, (list, tuple)) and len(resp) >= 2 else (resp if isinstance(resp, dict) else {})
process_results = data.get("ProcessResults", {}) or {}
obj = process_results.get("Object", {}) or {}
final_key = obj.get("Key") or avif_key
size_str = obj.get("Size")
try:
final_size = int(size_str) if isinstance(size_str, str) else (size_str or content_length or 0)
except:
final_size = content_length or 0
final_content_type = "image/avif"
expired_seconds = 30 * 24 * 60 * 60
url = cos.get_presigned_download_url(final_key, expired_seconds)
from datetime import datetime, timezone as dt_tz
now_ts = int(datetime.now(dt_tz.utc).timestamp())
expire_ts = now_ts + expired_seconds - 60
async with async_db_session.begin() as db:
meta = FileMetadata(
file_name=None,
content_type=final_content_type,
file_size=final_size,
extra=None,
)
await file_dao.update(
db,
file_id,
UpdateFileParam(
file_hash=file_hash_val,
storage_path=final_key,
metadata_info=meta,
details={
"status": "stored",
"cloud_path": cloud_path,
"fileID": file_id_in_cos,
"download_url": url,
"download_url_expire_ts": expire_ts,
"wx_user_id": wx_user_id,
},
),
)
return {
"id": str(file_id),
"url": url,
"status": "stored",
}
else:
expired_seconds = 30 * 24 * 60 * 60
url = cos.get_presigned_download_url(cos_key, expired_seconds)
from datetime import datetime, timezone as dt_tz
now_ts = int(datetime.now(dt_tz.utc).timestamp())
expire_ts = now_ts + expired_seconds - 60
async with async_db_session.begin() as db:
meta = FileMetadata(
file_name=None,
content_type=content_type,
file_size=content_length,
extra=None,
)
await file_dao.update(
db,
file_id,
UpdateFileParam(
file_hash=file_hash_val,
storage_path=cloud_path,
metadata_info=meta,
details={
"status": "stored",
"cloud_path": cloud_path,
"fileID": file_id_in_cos,
"download_url": url,
"download_url_expire_ts": expire_ts,
"wx_user_id": wx_user_id,
},
),
)
return {
"id": str(file_id),
"url": url,
"status": "stored",
}
@staticmethod
async def get_presigned_download_url(file_id: int, wx_user_id: int) -> str:
async with async_db_session() as db:
db_file = await file_dao.get(db, file_id)
if not db_file:
raise errors.NotFoundError(msg="文件不存在")
details = db_file.details or {}
owner_id = details.get("wx_user_id")
if owner_id is not None and str(owner_id) != str(wx_user_id):
raise errors.ForbiddenError(msg="无权限访问该文件")
cloud_path = db_file.storage_path or details.get("cloud_path")
if not cloud_path:
raise errors.ServerError(msg="文件路径缺失")
cos = CosClient()
cos_key = cloud_path
url = details.get("download_url")
expire_ts = int(details.get("download_url_expire_ts") or 0)
from datetime import datetime, timezone as dt_tz
now_ts = int(datetime.now(dt_tz.utc).timestamp())
if (not url) or (now_ts >= expire_ts):
expired_seconds = 30 * 24 * 60 * 60
url = cos.get_presigned_download_url(cos_key, expired_seconds)
expire_ts = now_ts + expired_seconds - 60
async with async_db_session.begin() as wdb:
await file_dao.update(
wdb,
file_id,
UpdateFileParam(details={
**details,
"download_url": url,
"download_url_expire_ts": expire_ts,
})
)
return url
file_service = FileService()

View File

@@ -17,6 +17,7 @@ from backend.app.admin.service.wx_user_service import WxUserService
from backend.core.wx_integration import decrypt_wx_data
from backend.database.db import async_db_session
from backend.utils.timezone import timezone
from backend.app.admin.schema.token import BasicWxUserInfo, WxSettings, PointsBrief
class WxAuthService:
@@ -104,7 +105,23 @@ class WxAuthService:
access_token_expire_time=access_token.access_token_expire_time,
session_uuid=access_token.session_uuid,
dict_level=dict_level,
dict_category=dict_category
dict_category=dict_category,
user=BasicWxUserInfo(
id=user.id,
nickname=(user.profile or {}).get("nickname") if isinstance(user.profile, dict) else None,
avatar_url=(user.profile or {}).get("avatar_url") if isinstance(user.profile, dict) else None,
gender=(user.profile or {}).get("gender") if isinstance(user.profile, dict) else None,
country=(user.profile or {}).get("country") if isinstance(user.profile, dict) else None,
province=(user.profile or {}).get("province") if isinstance(user.profile, dict) else None,
city=(user.profile or {}).get("city") if isinstance(user.profile, dict) else None,
language=(user.profile or {}).get("language") if isinstance(user.profile, dict) else None,
),
settings=WxSettings(dict_level=dict_level),
points=PointsBrief(
balance=(await points_service.get_user_balance(user.id)),
expired_time=None,
),
server_time=int(timezone.to_utc(timezone.now()).timestamp()),
)
return data
@@ -138,4 +155,4 @@ class WxAuthService:
await wx_user_dao.update_user_profile(db, user_id, user.profile)
wx_service: WxAuthService = WxAuthService()
wx_service: WxAuthService = WxAuthService()

View File

@@ -131,28 +131,28 @@ async def get_image(
)
return response_base.success(data=result)
@router.get("/log")
def log(request: Request, background_tasks: BackgroundTasks) -> ResponseSchemaModel[dict]:
audit_log = CreateAuditLogParam(
api_type="test_api",
model_name="test_model",
response_data={"test": "test_response"},
called_at=datetime.now(),
request_data={"test": "test_request"},
token_usage={"usage": "test_usage"},
cost=0.0,
duration=0.0,
status_code=200,
error_message="",
image_id=123123,
user_id=321312,
api_version="test_version",
)
background_tasks.add_task(
audit_log_service.create,
obj=audit_log,
)
return response_base.success(data={"succ": True})
# @router.get("/log")
# def log(request: Request, background_tasks: BackgroundTasks) -> ResponseSchemaModel[dict]:
# audit_log = CreateAuditLogParam(
# api_type="test_api",
# model_name="test_model",
# response_data={"test": "test_response"},
# called_at=datetime.now(),
# request_data={"test": "test_request"},
# token_usage={"usage": "test_usage"},
# cost=0.0,
# duration=0.0,
# status_code=200,
# error_message="",
# image_id=123123,
# user_id=321312,
# api_version="test_version",
# )
# background_tasks.add_task(
# audit_log_service.create,
# obj=audit_log,
# )
# return response_base.success(data={"succ": True})
# @router.get("/download/{image_id}")
# async def download_image(image_id: int):

View File

@@ -88,6 +88,6 @@ class UpdateImageParam(ImageInfoSchemaBase):
class ProcessImageRequest(BaseModel):
file_id: int
file_id: str
type: str = "word"
dict_level: Optional[DictLevel] = Field(None, description="词典等级")

View File

@@ -189,7 +189,7 @@ class ImageService:
"""异步处理图片识别 - 立即返回任务ID"""
current_user = request.user
file_id = params.file_id
file_id = int(params.file_id)
type = params.type
dict_level = params.dict_level.name
if not dict_level:

View File

@@ -33,7 +33,7 @@ from backend.utils.timezone import timezone
async def wx_openid_authentication(openid: str, unionid: str) -> GetWxUserInfoWithRelationDetail:
from backend.app.admin.crud.wx_user_crud import wx_user_dao
async with async_db_session() as db:
async with async_db_session.begin() as db:
user = None
try:
# 查找或创建用户
@@ -46,7 +46,7 @@ async def wx_openid_authentication(openid: str, unionid: str) -> GetWxUserInfoWi
session_key=session_key,
profile={
'dict_level': DictLevel.LEVEL1.value,
'dict_category': DictCategory.GENERAL.value
'dict_category': DictCategory.GENERAL.value
},
)
await wx_user_dao.add(db, user)
@@ -68,6 +68,8 @@ class CustomHTTPBearer(HTTPBearer):
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None:
# Check for x-wx-openid header first (WeChat Cloud Hosting authentication)
wx_openid = request.headers.get('x-wx-openid')
# print(request.headers)
# print(wx_openid)
if not wx_openid:
wx_openid = request.headers.get('X-WX-OPENID')
@@ -81,38 +83,26 @@ class CustomHTTPBearer(HTTPBearer):
# Check if we have a cached token for this openid
cached_token = await redis_client.get(f'wx_openid_token:{wx_openid}')
if cached_token:
# Verify the cached token is still valid
# Verify the cached token is still valid via standard flow
try:
payload = jwt.decode(
cached_token,
settings.TOKEN_SECRET_KEY,
algorithms=[settings.TOKEN_ALGORITHM],
options={'verify_exp': True},
)
# If token is valid, get user and attach to request
user = await jwt_authentication(cached_token)
await jwt_authentication(cached_token)
return HTTPAuthorizationCredentials(scheme="Bearer", credentials=cached_token)
except Exception:
# If token is invalid, remove it from cache
# If token is invalid, remove it from cache and recreate below
await redis_client.delete(f'wx_openid_token:{wx_openid}')
# If no cached token or invalid token, authenticate the user
user = await wx_openid_authentication(wx_openid, wx_unionid)
if user:
# Attach user to request
# Create a new token
temp_token = jwt_encode({
'session_uuid': str(uuid4()),
'exp': timezone.to_utc(timezone.now() + timedelta(seconds=settings.TOKEN_EXPIRE_SECONDS)).timestamp(),
'sub': str(user.id),
})
# Cache the token for this openid
# Create a new token using unified storage
access = await create_access_token(user_id=user.id, multi_login=True)
# Cache the token for this openid mapping
await redis_client.setex(
f'wx_openid_token:{wx_openid}',
settings.TOKEN_EXPIRE_SECONDS,
temp_token,
access.access_token,
)
return HTTPAuthorizationCredentials(scheme="Bearer", credentials=temp_token)
return HTTPAuthorizationCredentials(scheme="Bearer", credentials=access.access_token)
except Exception:
# If WeChat authentication fails, continue with normal JWT authentication
pass

View File

@@ -95,6 +95,32 @@ class CosClient:
logger.error(f"cos object_url failed: {e}")
raise errors.ServerError(msg="COS get url failed")
def head_object(self, key: str) -> dict | None:
try:
resp = self.client.head_object(Bucket=self.bucket, Key=key)
return resp
except Exception as e:
logger.error(f"cos head_object failed: {e}")
return None
def object_exists(self, key: str) -> bool:
try:
resp = self.client.head_object(Bucket=self.bucket, Key=key)
return resp is not None
except Exception:
return False
def get_presigned_download_url(self, key: str, expired_seconds: int) -> str:
try:
return self.client.get_presigned_download_url(
Bucket=self.bucket,
Key=key,
Expired=expired_seconds,
)
except Exception as e:
logger.error(f"cos presigned_download_url failed: {e}")
raise errors.ServerError(msg="COS presigned url failed")
def process_image(self, key: str, pic_operations: dict | str):
try:
ops = pic_operations if isinstance(pic_operations, str) else json.dumps(pic_operations)