Spaces:
Sleeping
Sleeping
import asyncio | |
import aiohttp | |
from fastapi import FastAPI, Depends, HTTPException, status | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from pydantic import BaseModel | |
from fastapi.middleware.cors import CORSMiddleware | |
import uvicorn | |
import tiktoken | |
from typing import List | |
import os | |
import requests | |
import logging | |
#环境变量传入 | |
sk_key = os.environ.get('SK') | |
model=os.environ.get('MODEL') | |
API_URL = f"https://api-inference.huggingface.co/models/{model}" | |
headers = {"Authorization": f"Bearer {sk_key}"} | |
# 创建一个FastAPI实例 | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# 创建一个HTTPBearer实例 | |
security = HTTPBearer() | |
def query(payload): | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.json()[0] | |
async def async_query(payload) -> float: | |
headers = {"Authorization": f"Bearer {sk_key}"} | |
payload.setdefault('wait_for_model',True) | |
async with aiohttp.ClientSession(headers=headers) as session: | |
print(payload) | |
async with session.post(API_URL, json=payload) as response: | |
response_json = await response.json() | |
#print(response.status) | |
#print(response_json) | |
try: | |
return response_json[0] | |
except (KeyError, IndexError): | |
logging.warning(response_json) | |
await asyncio.sleep(3) | |
class EmbeddingRequest(BaseModel): | |
input: List[str] | |
model: str | |
class EmbeddingResponse(BaseModel): | |
data: list | |
model: str | |
object: str | |
usage: dict | |
def num_tokens_from_string(string: str) -> int: | |
"""Returns the number of tokens in a text string.""" | |
encoding = tiktoken.get_encoding('cl100k_base') | |
num_tokens = len(encoding.encode(string)) | |
return num_tokens | |
async def get_embeddings(request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)): | |
if credentials.credentials != sk_key: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Invalid authorization code", | |
) | |
print(request.json()) | |
embeddings = query(request.input) | |
# print(embeddings) | |
prompt_tokens = sum(len(text.split()) for text in request.input) | |
total_tokens = sum(num_tokens_from_string(text) for text in request.input) | |
response = { | |
"data": [ | |
{ | |
"embedding": embedding, | |
"index": index, | |
"object": "embedding" | |
} for index, embedding in enumerate(embeddings) | |
], | |
"model": request.model, | |
"object": "list", | |
"usage": { | |
"prompt_tokens": prompt_tokens, | |
"total_tokens": total_tokens, | |
} | |
} | |
return response | |
if __name__ == "__main__": | |
uvicorn.run("app:app", host='0.0.0.0', port=7860, workers=2) |