File size: 4,105 Bytes
69a7826
 
 
 
 
 
 
3909335
69a7826
3909335
69a7826
 
 
 
94d1ea3
69a7826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcd51f5
69a7826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d1c302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69a7826
 
6887a00
265959d
69a7826
 
 
 
 
 
 
 
 
31ee60c
69a7826
871ea87
5e6213c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69a7826
5e6213c
69a7826
 
 
 
7804982
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import tiktoken
import numpy as np
# from scipy.interpolate import interp1d
from typing import List
# from sklearn.preprocessing import PolynomialFeatures
import os
import requests

#环境变量传入
sk_key = os.environ.get('SK', 'sk-aaabbbcccdddeeefffggghhhiiijjjkkk')
API_URL = "https://api-inference.huggingface.co/models/BAAI/bge-large-zh-v1.5"
headers = {"Authorization": f"Bearer {sk_key}"}

# 创建一个FastAPI实例
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 创建一个HTTPBearer实例
security = HTTPBearer()

# 预加载模型
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 检测是否有GPU可用,如果有则使用cuda设备,否则使用cpu设备
# if torch.cuda.is_available():
#     print('本次加载模型的设备为GPU: ', torch.cuda.get_device_name(0))
# else:
#     print('本次加载模型的设备为CPU.')
#model = SentenceTransformer('./moka-ai_m3e-large', device=device)

def query(payload):
	response = requests.post(API_URL, headers=headers, json=payload)
	return response.json()[0]

class EmbeddingRequest(BaseModel):
    input: List[str]
    model: str


class EmbeddingResponse(BaseModel):
    data: list
    model: str
    object: str
    usage: dict


def num_tokens_from_string(string: str) -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.get_encoding('cl100k_base')
    num_tokens = len(encoding.encode(string))
    return num_tokens


# 插值法
# def interpolate_vector(vector, target_length):
#     original_indices = np.arange(len(vector))
#     target_indices = np.linspace(0, len(vector) - 1, target_length)
#     f = interp1d(original_indices, vector, kind='linear')
#     return f(target_indices)


# def expand_features(embedding, target_length):
#     poly = PolynomialFeatures(degree=2)
#     expanded_embedding = poly.fit_transform(embedding.reshape(1, -1))
#     expanded_embedding = expanded_embedding.flatten()
#     if len(expanded_embedding) > target_length:
#         # 如果扩展后的特征超过目标长度,可以通过截断或其他方法来减少维度
#         expanded_embedding = expanded_embedding[:target_length]
#     elif len(expanded_embedding) < target_length:
#         # 如果扩展后的特征少于目标长度,可以通过填充或其他方法来增加维度
#         expanded_embedding = np.pad(expanded_embedding, (0, target_length - len(expanded_embedding)))
#     return expanded_embedding


@app.post("/api/v1/embeddings", response_model=EmbeddingResponse)
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
    if credentials.credentials != sk_key:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid authorization code",
        )

    # 计算嵌入向量和tokens数量
    # embeddings = [model.encode(text) for text in request.input]
    print(request.json())
    embeddings = query(request.input)
    #print(embeddings)
    prompt_tokens = sum(len(text.split()) for text in request.input)
    total_tokens = sum(num_tokens_from_string(text) for text in request.input)

    response = {
        "data": [
            {
                "embedding": embedding,
                "index": index,
                "object": "embedding"
            } for index, embedding in enumerate(embeddings)
        ],
        "model": request.model,
        "object": "list",
        "usage": {
            "prompt_tokens": prompt_tokens,
            "total_tokens": total_tokens,
        }
    }

    return response



if __name__ == "__main__":
    uvicorn.run("app:app", host='0.0.0.0', port=7860, workers=2)