Spaces:
Sleeping
Sleeping
File size: 3,086 Bytes
441b241 69a7826 441b241 69a7826 441b241 69a7826 dcd51f5 69a7826 5654556 441b241 69a7826 6887a00 5654556 69a7826 31ee60c 7945b4e fd937c6 441b241 5e6213c 69a7826 5e6213c 69a7826 74e82c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import asyncio
import aiohttp
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import tiktoken
from typing import List
import os
import requests
import logging
#环境变量传入
sk_key = os.environ.get('SK')
model=os.environ.get('MODEL')
API_URL = f"https://api-inference.huggingface.co/models/{model}"
headers = {"Authorization": f"Bearer {sk_key}"}
# 创建一个FastAPI实例
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 创建一个HTTPBearer实例
security = HTTPBearer()
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()[0]
async def async_query(payload) -> float:
headers = {"Authorization": f"Bearer {sk_key}"}
payload.setdefault('wait_for_model',True)
async with aiohttp.ClientSession(headers=headers) as session:
print(payload)
async with session.post(API_URL, json=payload) as response:
response_json = await response.json()
#print(response.status)
#print(response_json)
try:
return response_json[0]
except (KeyError, IndexError):
logging.warning(response_json)
await asyncio.sleep(3)
class EmbeddingRequest(BaseModel):
input: List[str]
model: str
class EmbeddingResponse(BaseModel):
data: list
model: str
object: str
usage: dict
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding('cl100k_base')
num_tokens = len(encoding.encode(string))
return num_tokens
@app.post("/api/v1/embeddings", response_model=EmbeddingResponse)
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
if credentials.credentials != sk_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authorization code",
)
print(request.json())
embeddings = query(request.input)
# print(embeddings)
prompt_tokens = sum(len(text.split()) for text in request.input)
total_tokens = sum(num_tokens_from_string(text) for text in request.input)
response = {
"data": [
{
"embedding": embedding,
"index": index,
"object": "embedding"
} for index, embedding in enumerate(embeddings)
],
"model": request.model,
"object": "list",
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": total_tokens,
}
}
return response
if __name__ == "__main__":
uvicorn.run("app:app", host='0.0.0.0', port=7860, workers=2) |