fix auth
This commit is contained in:
@@ -1,6 +1,13 @@
|
||||
from fastapi import APIRouter, UploadFile, File, Query, Depends, Response
|
||||
from fastapi import APIRouter, UploadFile, File, Query, Depends, Response, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.app.admin.schema.file import FileUploadResponse
|
||||
from backend.app.admin.schema.file import (
|
||||
FileUploadResponse,
|
||||
UploadInitRequest,
|
||||
UploadInitResponse,
|
||||
UploadCompleteRequest,
|
||||
UploadCompleteResponse,
|
||||
)
|
||||
from backend.app.admin.service.file_service import file_service
|
||||
from backend.common.response.response_schema import response_base, ResponseSchemaModel
|
||||
from backend.common.security.jwt import DependsJwtAuth
|
||||
@@ -49,3 +56,34 @@ async def delete_file(file_id: int) -> ResponseSchemaModel[bool]:
|
||||
if not result:
|
||||
return await response_base.fail(message="文件不存在或删除失败")
|
||||
return await response_base.success(data=result)
|
||||
|
||||
|
||||
@router.post("/upload/init", summary="初始化对象存储上传", dependencies=[DependsJwtAuth])
|
||||
async def upload_init(body: UploadInitRequest, request: Request) -> ResponseSchemaModel[UploadInitResponse]:
|
||||
user_id = getattr(request, 'user', None).id if hasattr(request, 'user') and request.user else None
|
||||
d = await file_service.init_upload(body.filename, body.size, body.mime, body.biz_type, user_id)
|
||||
data = UploadInitResponse(**d)
|
||||
return response_base.success(data=data)
|
||||
|
||||
|
||||
@router.post("/upload/complete", summary="完成对象存储上传", dependencies=[DependsJwtAuth])
|
||||
async def upload_complete(body: UploadCompleteRequest, request: Request) -> ResponseSchemaModel[UploadCompleteResponse]:
|
||||
user_id = getattr(request, 'user', None).id if hasattr(request, 'user') and request.user else None
|
||||
d = await file_service.complete_upload(
|
||||
file_id=body.file_id,
|
||||
cloud_path=body.cloud_path,
|
||||
file_id_in_cos=body.fileID,
|
||||
sha256=body.sha256,
|
||||
size=body.size,
|
||||
mime=body.mime,
|
||||
wx_user_id=user_id,
|
||||
)
|
||||
data = UploadCompleteResponse(**d)
|
||||
return response_base.success(data=data)
|
||||
|
||||
|
||||
@router.get("/temp_url/{file_id}", summary="获取临时下载URL", dependencies=[DependsJwtAuth])
|
||||
async def get_temp_url(file_id: int, request: Request) -> ResponseSchemaModel[dict]:
|
||||
user_id = getattr(request, 'user', None).id if hasattr(request, 'user') and request.user else None
|
||||
url = await file_service.get_presigned_download_url(file_id, user_id)
|
||||
return response_base.success(data={"url": url})
|
||||
|
||||
@@ -8,6 +8,11 @@ from backend.app.admin.schema.wx import WxLoginRequest, TokenResponse, UserInfo,
|
||||
|
||||
from backend.app.admin.service.wx_service import wx_service
|
||||
from backend.common.response.response_schema import response_base, ResponseSchemaModel
|
||||
from backend.common.security.jwt import wx_openid_authentication, create_access_token, create_refresh_token
|
||||
from backend.app.admin.schema.token import BasicWxUserInfo, WxSettings, PointsBrief
|
||||
from backend.app.admin.service.points_service import points_service
|
||||
from backend.utils.timezone import timezone
|
||||
from backend.core.conf import settings
|
||||
from backend.core.wx_integration import verify_wx_code
|
||||
|
||||
router = APIRouter()
|
||||
@@ -54,6 +59,60 @@ async def wechat_login(
|
||||
return response_base.success(data=result)
|
||||
|
||||
|
||||
@router.get("/user", summary="获取当前用户(云托管识别)")
|
||||
async def get_current_user(request: Request, response: Response) -> ResponseSchemaModel[GetWxLoginToken]:
|
||||
openid = request.headers.get("x-wx-openid") or request.headers.get("X-WX-OPENID")
|
||||
unionid = request.headers.get("x-wx-unionid") or request.headers.get("X-WX-UNIONID")
|
||||
if not openid:
|
||||
raise HTTPException(status_code=401, detail="未识别到云托管身份")
|
||||
|
||||
user = await wx_openid_authentication(openid, unionid or "")
|
||||
|
||||
access = await create_access_token(user_id=user.id, multi_login=True)
|
||||
refresh = await create_refresh_token(access.session_uuid, user.id, True)
|
||||
response.set_cookie(
|
||||
key=settings.COOKIE_REFRESH_TOKEN_KEY,
|
||||
value=refresh.refresh_token,
|
||||
max_age=settings.COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS,
|
||||
expires=timezone.to_utc(refresh.refresh_token_expire_time),
|
||||
httponly=True,
|
||||
)
|
||||
|
||||
profile = getattr(user, "profile", {}) or {}
|
||||
dict_level = profile.get("dict_level")
|
||||
dict_category = profile.get("dict_category")
|
||||
basic_user = BasicWxUserInfo(
|
||||
id=user.id,
|
||||
nickname=profile.get("nickname"),
|
||||
avatar_url=profile.get("avatar_url"),
|
||||
gender=profile.get("gender"),
|
||||
country=profile.get("country"),
|
||||
province=profile.get("province"),
|
||||
city=profile.get("city"),
|
||||
language=profile.get("language"),
|
||||
)
|
||||
settings_obj = WxSettings(dict_level=dict_level)
|
||||
|
||||
pts = await points_service.get_user_points(user.id)
|
||||
points_brief = PointsBrief(
|
||||
balance=pts.balance if pts else 0,
|
||||
expired_time=pts.expired_time.isoformat() if pts and getattr(pts, "expired_time", None) else None,
|
||||
)
|
||||
|
||||
data = GetWxLoginToken(
|
||||
access_token=access.access_token,
|
||||
access_token_expire_time=access.access_token_expire_time,
|
||||
session_uuid=access.session_uuid,
|
||||
dict_level=dict_level,
|
||||
dict_category=dict_category,
|
||||
user=basic_user,
|
||||
settings=settings_obj,
|
||||
points=points_brief,
|
||||
server_time=int(timezone.to_utc(timezone.now()).timestamp()),
|
||||
)
|
||||
return response_base.success(data=data)
|
||||
|
||||
|
||||
# @router.put("/settings", summary="更新用户设置", dependencies=[DependsJwtAuth])
|
||||
# async def update_user_settings(
|
||||
# request: Request,
|
||||
@@ -105,4 +164,4 @@ async def wechat_login(
|
||||
# dict_level=dict_level
|
||||
# )
|
||||
#
|
||||
# return response_base.success(data=response_data)
|
||||
# return response_base.success(data=response_data)
|
||||
|
||||
@@ -53,3 +53,33 @@ class FileUploadResponse(SchemaBase):
|
||||
file_name: str
|
||||
content_type: Optional[str] = None
|
||||
file_size: int
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class UploadInitRequest(BaseModel):
|
||||
filename: str
|
||||
size: int
|
||||
mime: Optional[str] = None
|
||||
biz_type: Optional[str] = None
|
||||
|
||||
|
||||
class UploadInitResponse(SchemaBase):
|
||||
file_id: str
|
||||
cloud_path: str
|
||||
|
||||
|
||||
class UploadCompleteRequest(BaseModel):
|
||||
file_id: int
|
||||
cloud_path: str
|
||||
fileID: str
|
||||
size: Optional[int] = None
|
||||
sha256: Optional[str] = None
|
||||
mime: Optional[str] = None
|
||||
|
||||
|
||||
class UploadCompleteResponse(SchemaBase):
|
||||
id: str
|
||||
url: Optional[str] = None
|
||||
status: str
|
||||
|
||||
@@ -40,6 +40,30 @@ class GetWxLoginToken(AccessTokenBase):
|
||||
"""微信登录令牌"""
|
||||
dict_level: Optional[str] = Field(None, description="词典等级")
|
||||
dict_category: Optional[str] = Field(None, description="词典类型")
|
||||
user: Optional["BasicWxUserInfo"] = Field(None, description="用户基础信息")
|
||||
settings: Optional["WxSettings"] = Field(None, description="用户设置")
|
||||
points: Optional["PointsBrief"] = Field(None, description="积分简要信息")
|
||||
server_time: Optional[int] = Field(None, description="服务端当前时间戳(秒)")
|
||||
|
||||
|
||||
class BasicWxUserInfo(SchemaBase):
|
||||
id: int = Field(description="用户ID")
|
||||
nickname: Optional[str] = Field(None, description="昵称")
|
||||
avatar_url: Optional[str] = Field(None, description="头像URL")
|
||||
gender: Optional[str] = Field(None, description="性别")
|
||||
country: Optional[str] = Field(None, description="国家")
|
||||
province: Optional[str] = Field(None, description="省份")
|
||||
city: Optional[str] = Field(None, description="城市")
|
||||
language: Optional[str] = Field(None, description="语言")
|
||||
|
||||
|
||||
class WxSettings(SchemaBase):
|
||||
dict_level: Optional[str] = Field(None, description="词典等级")
|
||||
|
||||
|
||||
class PointsBrief(SchemaBase):
|
||||
balance: int = Field(description="积分余额")
|
||||
expired_time: Optional[str] = Field(None, description="积分过期时间(ISO字符串)")
|
||||
|
||||
|
||||
class GetTokenDetail(SchemaBase):
|
||||
@@ -55,4 +79,4 @@ class GetTokenDetail(SchemaBase):
|
||||
device: str = Field(description='设备')
|
||||
status: StatusType = Field(description='状态')
|
||||
last_login_time: str = Field(description='最后登录时间')
|
||||
expire_time: datetime = Field(description='过期时间')
|
||||
expire_time: datetime = Field(description='过期时间')
|
||||
|
||||
@@ -689,4 +689,213 @@ class FileService:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _mime_to_ext(mime: str | None, filename: str | None) -> str:
|
||||
if filename and '.' in filename:
|
||||
return filename.rsplit('.', 1)[-1].lower()
|
||||
mapping = {
|
||||
'image/jpeg': 'jpg',
|
||||
'image/jpg': 'jpg',
|
||||
'image/png': 'png',
|
||||
'image/gif': 'gif',
|
||||
'image/webp': 'webp',
|
||||
'image/bmp': 'bmp',
|
||||
'image/svg+xml': 'svg',
|
||||
'audio/mpeg': 'mp3',
|
||||
'audio/wav': 'wav',
|
||||
'video/mp4': 'mp4',
|
||||
'video/quicktime': 'mov',
|
||||
}
|
||||
return mapping.get((mime or '').lower(), 'dat')
|
||||
|
||||
@staticmethod
|
||||
async def init_upload(filename: str, size: int, mime: str | None, biz_type: str | None, wx_user_id: int):
|
||||
storage_type = 'cos'
|
||||
async with async_db_session.begin() as db:
|
||||
temp_hash = hashlib.sha256(f"pending:{filename}:{size}".encode()).hexdigest()
|
||||
params = AddFileParam(
|
||||
file_hash=temp_hash,
|
||||
file_name=filename or 'unnamed',
|
||||
content_type=mime,
|
||||
file_size=size,
|
||||
storage_type=storage_type,
|
||||
storage_path=None,
|
||||
metadata_info=FileMetadata(
|
||||
file_name=filename or 'unnamed',
|
||||
content_type=mime,
|
||||
file_size=size,
|
||||
extra={"biz_type": biz_type} if biz_type else None,
|
||||
),
|
||||
)
|
||||
db_file = await file_dao.create(db, params)
|
||||
await db.flush()
|
||||
cos_client = CosClient()
|
||||
ext = FileService._mime_to_ext(mime, filename)
|
||||
cloud_path = cos_client.build_key(db_file.id)
|
||||
await file_dao.update(
|
||||
db,
|
||||
db_file.id,
|
||||
UpdateFileParam(
|
||||
storage_path=cloud_path,
|
||||
details={"status": "pending", "cloud_path": cloud_path, "wx_user_id": wx_user_id},
|
||||
),
|
||||
)
|
||||
db_file.storage_path = cloud_path
|
||||
return {
|
||||
"file_id": str(db_file.id),
|
||||
"cloud_path": cloud_path,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def complete_upload(
|
||||
file_id: int,
|
||||
cloud_path: str,
|
||||
file_id_in_cos: str,
|
||||
sha256: str | None = None,
|
||||
size: int | None = None,
|
||||
mime: str | None = None,
|
||||
wx_user_id: int = 0,
|
||||
):
|
||||
cos = CosClient()
|
||||
cos_key = cloud_path
|
||||
if not cos.object_exists(cos_key):
|
||||
raise errors.NotFoundError(msg="对象未找到或尚未上传")
|
||||
head = cos.head_object(cos_key) or {}
|
||||
content_length = head.get('Content-Length') or head.get('ContentLength')
|
||||
try:
|
||||
content_length = int(content_length) if content_length is not None else (size or 0)
|
||||
except:
|
||||
content_length = size or 0
|
||||
content_type = head.get('Content-Type') or head.get('ContentType') or mime
|
||||
etag = (head.get('ETag') or '').strip('"')
|
||||
file_hash_val = sha256 or etag or hashlib.sha256(f"cos:{cloud_path}:{file_id_in_cos}".encode()).hexdigest()
|
||||
is_image = (content_type or '').lower().startswith('image/') or (mime or '').lower().startswith('image/')
|
||||
if is_image:
|
||||
avif_key = f"{file_id}_avif"
|
||||
pic_ops = {
|
||||
"is_pic_info": 1,
|
||||
"rules": [
|
||||
{
|
||||
"fileid": avif_key,
|
||||
"rule": "imageMogr2/format/avif",
|
||||
}
|
||||
],
|
||||
}
|
||||
resp = cos.process_image(cos_key, pic_ops)
|
||||
data = resp[1] if isinstance(resp, (list, tuple)) and len(resp) >= 2 else (resp if isinstance(resp, dict) else {})
|
||||
process_results = data.get("ProcessResults", {}) or {}
|
||||
obj = process_results.get("Object", {}) or {}
|
||||
final_key = obj.get("Key") or avif_key
|
||||
size_str = obj.get("Size")
|
||||
try:
|
||||
final_size = int(size_str) if isinstance(size_str, str) else (size_str or content_length or 0)
|
||||
except:
|
||||
final_size = content_length or 0
|
||||
final_content_type = "image/avif"
|
||||
expired_seconds = 30 * 24 * 60 * 60
|
||||
url = cos.get_presigned_download_url(final_key, expired_seconds)
|
||||
from datetime import datetime, timezone as dt_tz
|
||||
now_ts = int(datetime.now(dt_tz.utc).timestamp())
|
||||
expire_ts = now_ts + expired_seconds - 60
|
||||
async with async_db_session.begin() as db:
|
||||
meta = FileMetadata(
|
||||
file_name=None,
|
||||
content_type=final_content_type,
|
||||
file_size=final_size,
|
||||
extra=None,
|
||||
)
|
||||
await file_dao.update(
|
||||
db,
|
||||
file_id,
|
||||
UpdateFileParam(
|
||||
file_hash=file_hash_val,
|
||||
storage_path=final_key,
|
||||
metadata_info=meta,
|
||||
details={
|
||||
"status": "stored",
|
||||
"cloud_path": cloud_path,
|
||||
"fileID": file_id_in_cos,
|
||||
"download_url": url,
|
||||
"download_url_expire_ts": expire_ts,
|
||||
"wx_user_id": wx_user_id,
|
||||
},
|
||||
),
|
||||
)
|
||||
return {
|
||||
"id": str(file_id),
|
||||
"url": url,
|
||||
"status": "stored",
|
||||
}
|
||||
else:
|
||||
expired_seconds = 30 * 24 * 60 * 60
|
||||
url = cos.get_presigned_download_url(cos_key, expired_seconds)
|
||||
from datetime import datetime, timezone as dt_tz
|
||||
now_ts = int(datetime.now(dt_tz.utc).timestamp())
|
||||
expire_ts = now_ts + expired_seconds - 60
|
||||
async with async_db_session.begin() as db:
|
||||
meta = FileMetadata(
|
||||
file_name=None,
|
||||
content_type=content_type,
|
||||
file_size=content_length,
|
||||
extra=None,
|
||||
)
|
||||
await file_dao.update(
|
||||
db,
|
||||
file_id,
|
||||
UpdateFileParam(
|
||||
file_hash=file_hash_val,
|
||||
storage_path=cloud_path,
|
||||
metadata_info=meta,
|
||||
details={
|
||||
"status": "stored",
|
||||
"cloud_path": cloud_path,
|
||||
"fileID": file_id_in_cos,
|
||||
"download_url": url,
|
||||
"download_url_expire_ts": expire_ts,
|
||||
"wx_user_id": wx_user_id,
|
||||
},
|
||||
),
|
||||
)
|
||||
return {
|
||||
"id": str(file_id),
|
||||
"url": url,
|
||||
"status": "stored",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def get_presigned_download_url(file_id: int, wx_user_id: int) -> str:
|
||||
async with async_db_session() as db:
|
||||
db_file = await file_dao.get(db, file_id)
|
||||
if not db_file:
|
||||
raise errors.NotFoundError(msg="文件不存在")
|
||||
details = db_file.details or {}
|
||||
owner_id = details.get("wx_user_id")
|
||||
if owner_id is not None and str(owner_id) != str(wx_user_id):
|
||||
raise errors.ForbiddenError(msg="无权限访问该文件")
|
||||
cloud_path = db_file.storage_path or details.get("cloud_path")
|
||||
if not cloud_path:
|
||||
raise errors.ServerError(msg="文件路径缺失")
|
||||
cos = CosClient()
|
||||
cos_key = cloud_path
|
||||
url = details.get("download_url")
|
||||
expire_ts = int(details.get("download_url_expire_ts") or 0)
|
||||
from datetime import datetime, timezone as dt_tz
|
||||
now_ts = int(datetime.now(dt_tz.utc).timestamp())
|
||||
if (not url) or (now_ts >= expire_ts):
|
||||
expired_seconds = 30 * 24 * 60 * 60
|
||||
url = cos.get_presigned_download_url(cos_key, expired_seconds)
|
||||
expire_ts = now_ts + expired_seconds - 60
|
||||
async with async_db_session.begin() as wdb:
|
||||
await file_dao.update(
|
||||
wdb,
|
||||
file_id,
|
||||
UpdateFileParam(details={
|
||||
**details,
|
||||
"download_url": url,
|
||||
"download_url_expire_ts": expire_ts,
|
||||
})
|
||||
)
|
||||
return url
|
||||
|
||||
file_service = FileService()
|
||||
|
||||
@@ -17,6 +17,7 @@ from backend.app.admin.service.wx_user_service import WxUserService
|
||||
from backend.core.wx_integration import decrypt_wx_data
|
||||
from backend.database.db import async_db_session
|
||||
from backend.utils.timezone import timezone
|
||||
from backend.app.admin.schema.token import BasicWxUserInfo, WxSettings, PointsBrief
|
||||
|
||||
|
||||
class WxAuthService:
|
||||
@@ -104,7 +105,23 @@ class WxAuthService:
|
||||
access_token_expire_time=access_token.access_token_expire_time,
|
||||
session_uuid=access_token.session_uuid,
|
||||
dict_level=dict_level,
|
||||
dict_category=dict_category
|
||||
dict_category=dict_category,
|
||||
user=BasicWxUserInfo(
|
||||
id=user.id,
|
||||
nickname=(user.profile or {}).get("nickname") if isinstance(user.profile, dict) else None,
|
||||
avatar_url=(user.profile or {}).get("avatar_url") if isinstance(user.profile, dict) else None,
|
||||
gender=(user.profile or {}).get("gender") if isinstance(user.profile, dict) else None,
|
||||
country=(user.profile or {}).get("country") if isinstance(user.profile, dict) else None,
|
||||
province=(user.profile or {}).get("province") if isinstance(user.profile, dict) else None,
|
||||
city=(user.profile or {}).get("city") if isinstance(user.profile, dict) else None,
|
||||
language=(user.profile or {}).get("language") if isinstance(user.profile, dict) else None,
|
||||
),
|
||||
settings=WxSettings(dict_level=dict_level),
|
||||
points=PointsBrief(
|
||||
balance=(await points_service.get_user_balance(user.id)),
|
||||
expired_time=None,
|
||||
),
|
||||
server_time=int(timezone.to_utc(timezone.now()).timestamp()),
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -138,4 +155,4 @@ class WxAuthService:
|
||||
await wx_user_dao.update_user_profile(db, user_id, user.profile)
|
||||
|
||||
|
||||
wx_service: WxAuthService = WxAuthService()
|
||||
wx_service: WxAuthService = WxAuthService()
|
||||
|
||||
@@ -131,28 +131,28 @@ async def get_image(
|
||||
)
|
||||
return response_base.success(data=result)
|
||||
|
||||
@router.get("/log")
|
||||
def log(request: Request, background_tasks: BackgroundTasks) -> ResponseSchemaModel[dict]:
|
||||
audit_log = CreateAuditLogParam(
|
||||
api_type="test_api",
|
||||
model_name="test_model",
|
||||
response_data={"test": "test_response"},
|
||||
called_at=datetime.now(),
|
||||
request_data={"test": "test_request"},
|
||||
token_usage={"usage": "test_usage"},
|
||||
cost=0.0,
|
||||
duration=0.0,
|
||||
status_code=200,
|
||||
error_message="",
|
||||
image_id=123123,
|
||||
user_id=321312,
|
||||
api_version="test_version",
|
||||
)
|
||||
background_tasks.add_task(
|
||||
audit_log_service.create,
|
||||
obj=audit_log,
|
||||
)
|
||||
return response_base.success(data={"succ": True})
|
||||
# @router.get("/log")
|
||||
# def log(request: Request, background_tasks: BackgroundTasks) -> ResponseSchemaModel[dict]:
|
||||
# audit_log = CreateAuditLogParam(
|
||||
# api_type="test_api",
|
||||
# model_name="test_model",
|
||||
# response_data={"test": "test_response"},
|
||||
# called_at=datetime.now(),
|
||||
# request_data={"test": "test_request"},
|
||||
# token_usage={"usage": "test_usage"},
|
||||
# cost=0.0,
|
||||
# duration=0.0,
|
||||
# status_code=200,
|
||||
# error_message="",
|
||||
# image_id=123123,
|
||||
# user_id=321312,
|
||||
# api_version="test_version",
|
||||
# )
|
||||
# background_tasks.add_task(
|
||||
# audit_log_service.create,
|
||||
# obj=audit_log,
|
||||
# )
|
||||
# return response_base.success(data={"succ": True})
|
||||
|
||||
# @router.get("/download/{image_id}")
|
||||
# async def download_image(image_id: int):
|
||||
|
||||
@@ -88,6 +88,6 @@ class UpdateImageParam(ImageInfoSchemaBase):
|
||||
|
||||
|
||||
class ProcessImageRequest(BaseModel):
|
||||
file_id: int
|
||||
file_id: str
|
||||
type: str = "word"
|
||||
dict_level: Optional[DictLevel] = Field(None, description="词典等级")
|
||||
@@ -189,7 +189,7 @@ class ImageService:
|
||||
"""异步处理图片识别 - 立即返回任务ID"""
|
||||
|
||||
current_user = request.user
|
||||
file_id = params.file_id
|
||||
file_id = int(params.file_id)
|
||||
type = params.type
|
||||
dict_level = params.dict_level.name
|
||||
if not dict_level:
|
||||
|
||||
@@ -33,7 +33,7 @@ from backend.utils.timezone import timezone
|
||||
|
||||
async def wx_openid_authentication(openid: str, unionid: str) -> GetWxUserInfoWithRelationDetail:
|
||||
from backend.app.admin.crud.wx_user_crud import wx_user_dao
|
||||
async with async_db_session() as db:
|
||||
async with async_db_session.begin() as db:
|
||||
user = None
|
||||
try:
|
||||
# 查找或创建用户
|
||||
@@ -46,7 +46,7 @@ async def wx_openid_authentication(openid: str, unionid: str) -> GetWxUserInfoWi
|
||||
session_key=session_key,
|
||||
profile={
|
||||
'dict_level': DictLevel.LEVEL1.value,
|
||||
'dict_category': DictCategory.GENERAL.value
|
||||
'dict_category': DictCategory.GENERAL.value
|
||||
},
|
||||
)
|
||||
await wx_user_dao.add(db, user)
|
||||
@@ -68,6 +68,8 @@ class CustomHTTPBearer(HTTPBearer):
|
||||
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials | None:
|
||||
# Check for x-wx-openid header first (WeChat Cloud Hosting authentication)
|
||||
wx_openid = request.headers.get('x-wx-openid')
|
||||
# print(request.headers)
|
||||
# print(wx_openid)
|
||||
if not wx_openid:
|
||||
wx_openid = request.headers.get('X-WX-OPENID')
|
||||
|
||||
@@ -81,38 +83,26 @@ class CustomHTTPBearer(HTTPBearer):
|
||||
# Check if we have a cached token for this openid
|
||||
cached_token = await redis_client.get(f'wx_openid_token:{wx_openid}')
|
||||
if cached_token:
|
||||
# Verify the cached token is still valid
|
||||
# Verify the cached token is still valid via standard flow
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
cached_token,
|
||||
settings.TOKEN_SECRET_KEY,
|
||||
algorithms=[settings.TOKEN_ALGORITHM],
|
||||
options={'verify_exp': True},
|
||||
)
|
||||
# If token is valid, get user and attach to request
|
||||
user = await jwt_authentication(cached_token)
|
||||
await jwt_authentication(cached_token)
|
||||
return HTTPAuthorizationCredentials(scheme="Bearer", credentials=cached_token)
|
||||
except Exception:
|
||||
# If token is invalid, remove it from cache
|
||||
# If token is invalid, remove it from cache and recreate below
|
||||
await redis_client.delete(f'wx_openid_token:{wx_openid}')
|
||||
|
||||
# If no cached token or invalid token, authenticate the user
|
||||
user = await wx_openid_authentication(wx_openid, wx_unionid)
|
||||
if user:
|
||||
# Attach user to request
|
||||
# Create a new token
|
||||
temp_token = jwt_encode({
|
||||
'session_uuid': str(uuid4()),
|
||||
'exp': timezone.to_utc(timezone.now() + timedelta(seconds=settings.TOKEN_EXPIRE_SECONDS)).timestamp(),
|
||||
'sub': str(user.id),
|
||||
})
|
||||
# Cache the token for this openid
|
||||
# Create a new token using unified storage
|
||||
access = await create_access_token(user_id=user.id, multi_login=True)
|
||||
# Cache the token for this openid mapping
|
||||
await redis_client.setex(
|
||||
f'wx_openid_token:{wx_openid}',
|
||||
settings.TOKEN_EXPIRE_SECONDS,
|
||||
temp_token,
|
||||
access.access_token,
|
||||
)
|
||||
return HTTPAuthorizationCredentials(scheme="Bearer", credentials=temp_token)
|
||||
return HTTPAuthorizationCredentials(scheme="Bearer", credentials=access.access_token)
|
||||
except Exception:
|
||||
# If WeChat authentication fails, continue with normal JWT authentication
|
||||
pass
|
||||
|
||||
@@ -95,6 +95,32 @@ class CosClient:
|
||||
logger.error(f"cos object_url failed: {e}")
|
||||
raise errors.ServerError(msg="COS get url failed")
|
||||
|
||||
def head_object(self, key: str) -> dict | None:
|
||||
try:
|
||||
resp = self.client.head_object(Bucket=self.bucket, Key=key)
|
||||
return resp
|
||||
except Exception as e:
|
||||
logger.error(f"cos head_object failed: {e}")
|
||||
return None
|
||||
|
||||
def object_exists(self, key: str) -> bool:
|
||||
try:
|
||||
resp = self.client.head_object(Bucket=self.bucket, Key=key)
|
||||
return resp is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_presigned_download_url(self, key: str, expired_seconds: int) -> str:
|
||||
try:
|
||||
return self.client.get_presigned_download_url(
|
||||
Bucket=self.bucket,
|
||||
Key=key,
|
||||
Expired=expired_seconds,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"cos presigned_download_url failed: {e}")
|
||||
raise errors.ServerError(msg="COS presigned url failed")
|
||||
|
||||
def process_image(self, key: str, pic_operations: dict | str):
|
||||
try:
|
||||
ops = pic_operations if isinstance(pic_operations, str) else json.dumps(pic_operations)
|
||||
|
||||
Reference in New Issue
Block a user