gordonchan's picture
Update app.py
7945b4e verified
import asyncio
import aiohttp
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import tiktoken
from typing import List
import os
import requests
import logging
#环境变量传入
sk_key = os.environ.get('SK')
model=os.environ.get('MODEL')
API_URL = f"https://api-inference.huggingface.co/models/{model}"
headers = {"Authorization": f"Bearer {sk_key}"}
# 创建一个FastAPI实例
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 创建一个HTTPBearer实例
security = HTTPBearer()
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()[0]
async def async_query(payload) -> float:
headers = {"Authorization": f"Bearer {sk_key}"}
payload.setdefault('wait_for_model',True)
async with aiohttp.ClientSession(headers=headers) as session:
print(payload)
async with session.post(API_URL, json=payload) as response:
response_json = await response.json()
#print(response.status)
#print(response_json)
try:
return response_json[0]
except (KeyError, IndexError):
logging.warning(response_json)
await asyncio.sleep(3)
class EmbeddingRequest(BaseModel):
input: List[str]
model: str
class EmbeddingResponse(BaseModel):
data: list
model: str
object: str
usage: dict
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding('cl100k_base')
num_tokens = len(encoding.encode(string))
return num_tokens
@app.post("/api/v1/embeddings", response_model=EmbeddingResponse)
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
if credentials.credentials != sk_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authorization code",
)
print(request.json())
embeddings = query(request.input)
# print(embeddings)
prompt_tokens = sum(len(text.split()) for text in request.input)
total_tokens = sum(num_tokens_from_string(text) for text in request.input)
response = {
"data": [
{
"embedding": embedding,
"index": index,
"object": "embedding"
} for index, embedding in enumerate(embeddings)
],
"model": request.model,
"object": "list",
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": total_tokens,
}
}
return response
if __name__ == "__main__":
uvicorn.run("app:app", host='0.0.0.0', port=7860, workers=2)