File size: 3,086 Bytes
441b241
 
69a7826
 
 
 
 
 
 
 
 
441b241
69a7826
 
441b241
 
 
69a7826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcd51f5
69a7826
5654556
441b241
 
 
 
 
 
 
 
 
 
 
 
 
 
69a7826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6887a00
5654556
69a7826
 
 
 
 
 
 
31ee60c
7945b4e
fd937c6
441b241
5e6213c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69a7826
5e6213c
69a7826
 
 
74e82c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import asyncio
import aiohttp
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import tiktoken
from typing import List
import os
import requests
import logging

#环境变量传入
sk_key = os.environ.get('SK')
model=os.environ.get('MODEL')
API_URL = f"https://api-inference.huggingface.co/models/{model}"
headers = {"Authorization": f"Bearer {sk_key}"}

# 创建一个FastAPI实例
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 创建一个HTTPBearer实例
security = HTTPBearer()

def query(payload):
	response = requests.post(API_URL, headers=headers, json=payload)
	return response.json()[0]

async def async_query(payload) -> float:
    headers = {"Authorization": f"Bearer {sk_key}"}
    payload.setdefault('wait_for_model',True)
    async with aiohttp.ClientSession(headers=headers) as session:
        print(payload)
        async with session.post(API_URL, json=payload) as response:
            response_json = await response.json()
            #print(response.status)
            #print(response_json)

            try:
                return response_json[0]
            except (KeyError, IndexError):
                logging.warning(response_json)
                await asyncio.sleep(3)
class EmbeddingRequest(BaseModel):
    input: List[str]
    model: str


class EmbeddingResponse(BaseModel):
    data: list
    model: str
    object: str
    usage: dict


def num_tokens_from_string(string: str) -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.get_encoding('cl100k_base')
    num_tokens = len(encoding.encode(string))
    return num_tokens

@app.post("/api/v1/embeddings", response_model=EmbeddingResponse)
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
    if credentials.credentials != sk_key:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid authorization code",
        )

    print(request.json())
    embeddings = query(request.input)
    # print(embeddings)

    prompt_tokens = sum(len(text.split()) for text in request.input)
    total_tokens = sum(num_tokens_from_string(text) for text in request.input)

    response = {
        "data": [
            {
                "embedding": embedding,
                "index": index,
                "object": "embedding"
            } for index, embedding in enumerate(embeddings)
        ],
        "model": request.model,
        "object": "list",
        "usage": {
            "prompt_tokens": prompt_tokens,
            "total_tokens": total_tokens,
        }
    }

    return response


if __name__ == "__main__":
    uvicorn.run("app:app", host='0.0.0.0', port=7860, workers=2)