import asyncio import aiohttp from fastapi import FastAPI, Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware import uvicorn import tiktoken from typing import List import os import requests import logging #环境变量传入 sk_key = os.environ.get('SK') model=os.environ.get('MODEL') API_URL = f"https://api-inference.huggingface.co/models/{model}" headers = {"Authorization": f"Bearer {sk_key}"} # 创建一个FastAPI实例 app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 创建一个HTTPBearer实例 security = HTTPBearer() def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json()[0] async def async_query(payload) -> float: headers = {"Authorization": f"Bearer {sk_key}"} payload.setdefault('wait_for_model',True) async with aiohttp.ClientSession(headers=headers) as session: print(payload) async with session.post(API_URL, json=payload) as response: response_json = await response.json() #print(response.status) #print(response_json) try: return response_json[0] except (KeyError, IndexError): logging.warning(response_json) await asyncio.sleep(3) class EmbeddingRequest(BaseModel): input: List[str] model: str class EmbeddingResponse(BaseModel): data: list model: str object: str usage: dict def num_tokens_from_string(string: str) -> int: """Returns the number of tokens in a text string.""" encoding = tiktoken.get_encoding('cl100k_base') num_tokens = len(encoding.encode(string)) return num_tokens @app.post("/api/v1/embeddings", response_model=EmbeddingResponse) @app.post("/v1/embeddings", response_model=EmbeddingResponse) async def get_embeddings(request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)): if credentials.credentials != sk_key: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authorization code", ) print(request.json()) embeddings = query(request.input) # print(embeddings) prompt_tokens = sum(len(text.split()) for text in request.input) total_tokens = sum(num_tokens_from_string(text) for text in request.input) response = { "data": [ { "embedding": embedding, "index": index, "object": "embedding" } for index, embedding in enumerate(embeddings) ], "model": request.model, "object": "list", "usage": { "prompt_tokens": prompt_tokens, "total_tokens": total_tokens, } } return response if __name__ == "__main__": uvicorn.run("app:app", host='0.0.0.0', port=7860, workers=2)