from fastapi import FastAPI, File, UploadFile, HTTPException, Request, Depends, APIRouter from fastapi.responses import StreamingResponse import requests import os import json import base64 from io import BytesIO from datetime import datetime import hashlib import time from huggingface_hub import HfApi API_AUTH_KEY = os.environ.get('AUTH_KEY') async def verify_internal_token(request: Request): auth_key = request.query_params.get("auth_key") if not auth_key: raise HTTPException(status_code=403, detail="Forbidden") timestamp = int(time.time()) timestampStr,rand,uid,hashStr = auth_key.split('-', 3) path = request.url.path if timestamp > int(timestampStr): raise HTTPException(status_code=403, detail="Forbidden") md5hash = hashlib.md5(f"{path}-{timestampStr}-{rand}-{uid}-{API_AUTH_KEY}".encode()) md5 = md5hash.hexdigest() if md5 != hashStr: raise HTTPException(status_code=403, detail="Forbidden") router = APIRouter() REPO_ID = "wsj1995/aigc-user-uploaded-models" UPLOAD_DIRECTORY = "/tmp/uploaded_files" HUGGINGFACE_API_TOKEN = os.environ.get("HF_TOKEN") # 如果目录不存在,则创建 if not os.path.exists(UPLOAD_DIRECTORY): os.makedirs(UPLOAD_DIRECTORY) def download_file(url: str): headers = {"Authorization": f"Bearer {HUGGINGFACE_API_TOKEN}"} response = requests.get(f"https://huggingface.co/{REPO_ID}/resolve/main/{url}?download=true", headers=headers, stream=True) # with requests.get(url, stream=True,headers=headers) as response: # for chunk in response.iter_content(chunk_size=1024): # 每次读取 1KB # if chunk: # yield chunkresponse = requests.get(url, headers=headers, stream=True) if response.status_code != 200: raise HTTPException(status_code=response.status_code, detail="Failed to download file") return response @router.get("/") def read_root(): return {"Hello": "World!"} @router.post("/upload/{userId}/{modelId}/{modelVersionId}/{filename}") async def upload_file(userId: str, modelId: str, modelVersionId:str, filename: str, file: UploadFile = File(...)): file_location_folder = os.path.join(UPLOAD_DIRECTORY,userId,modelId) file_location = os.path.join(file_location_folder, filename) os.makedirs(file_location_folder, exist_ok=True) try: with open(file_location, "wb") as buffer: buffer.write(await file.read()) callback(modelVersionId,'UPLOADING') pathInRepo = f"{userId}/{modelId}/{filename}" huggingfaceApi = HfApi(token=HUGGINGFACE_API_TOKEN) huggingfaceApi.upload_file( path_or_fileobj=file_location, path_in_repo=pathInRepo, repo_id=REPO_ID, repo_type="model" ) callback(modelVersionId,'UPLOADED') except Exception as e: print(e) callback(modelVersionId,'FAIL') os.remove(file_location) return {'success': True, 'id': modelVersionId} @router.get("/download/{userId}/{modelId}/{modelVersionId}/{filename}") def download(userId: str,modelId:str, modelVersionId:str, filename: str): pathInRepo = f"{userId}/{modelId}/{filename}" response = download_file(pathInRepo) # 创建流式响应 return StreamingResponse(response.iter_content(chunk_size=1024), media_type="application/octet-stream", headers={"Content-Disposition": f"attachment; filename={filename}"}) # return StreamingResponse(download_file("https://huggingface.co/wsj1995/stable-diffusion-models/resolve/main/3Guofeng3_v34.safetensors?download=true"), media_type="application/octet-stream") def callback(modelVersionId,status): timestamp = int(time.time()) + 60 rand = 0 uid = 0 url = "/api/v1/callback/user/sd/model/upload" md5hash = hashlib.md5(f"{url}-{timestamp}-{rand}-{uid}-{API_AUTH_KEY}".encode()) md5 = md5hash.hexdigest() res = requests.post(f"{os.environ.get('CALLBACK_DOMAIN')}{url}?auth_key={timestamp}-{rand}-{uid}-{md5}",json={ 'status': status, 'model_version_id': modelVersionId }) print(f"回调结果 {res.status_code}")