Spaces:
Sleeping
Sleeping
Commit
•
69a7826
1
Parent(s):
4c66e55
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Depends, HTTPException, status
|
2 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
5 |
+
import uvicorn
|
6 |
+
import tiktoken
|
7 |
+
import numpy as np
|
8 |
+
from scipy.interpolate import interp1d
|
9 |
+
from typing import List
|
10 |
+
from sklearn.preprocessing import PolynomialFeatures
|
11 |
+
import os
|
12 |
+
import requests
|
13 |
+
|
14 |
+
#环境变量传入
|
15 |
+
sk_key = os.environ.get('sk-key', 'sk-aaabbbcccdddeeefffggghhhiiijjjkkk')
|
16 |
+
API_URL = "https://api-inference.huggingface.co/models/BAAI/bge-large-zh-v1.5"
|
17 |
+
headers = {"Authorization": f"Bearer {sk_key}"}
|
18 |
+
|
19 |
+
# 创建一个FastAPI实例
|
20 |
+
app = FastAPI()
|
21 |
+
|
22 |
+
app.add_middleware(
|
23 |
+
CORSMiddleware,
|
24 |
+
allow_origins=["*"],
|
25 |
+
allow_credentials=True,
|
26 |
+
allow_methods=["*"],
|
27 |
+
allow_headers=["*"],
|
28 |
+
)
|
29 |
+
|
30 |
+
# 创建一个HTTPBearer实例
|
31 |
+
security = HTTPBearer()
|
32 |
+
|
33 |
+
# 预加载模型
|
34 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 检测是否有GPU可用,如果有则使用cuda设备,否则使用cpu设备
|
35 |
+
# if torch.cuda.is_available():
|
36 |
+
# print('本次加载模型的设备为GPU: ', torch.cuda.get_device_name(0))
|
37 |
+
# else:
|
38 |
+
# print('本次加载模型的设备为CPU.')
|
39 |
+
#model = SentenceTransformer('./moka-ai_m3e-large', device=device)
|
40 |
+
|
41 |
+
def query(payload):
|
42 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
43 |
+
return response.json()[0][0]
|
44 |
+
|
45 |
+
class EmbeddingRequest(BaseModel):
|
46 |
+
input: List[str]
|
47 |
+
model: str
|
48 |
+
|
49 |
+
|
50 |
+
class EmbeddingResponse(BaseModel):
|
51 |
+
data: list
|
52 |
+
model: str
|
53 |
+
object: str
|
54 |
+
usage: dict
|
55 |
+
|
56 |
+
|
57 |
+
def num_tokens_from_string(string: str) -> int:
|
58 |
+
"""Returns the number of tokens in a text string."""
|
59 |
+
encoding = tiktoken.get_encoding('cl100k_base')
|
60 |
+
num_tokens = len(encoding.encode(string))
|
61 |
+
return num_tokens
|
62 |
+
|
63 |
+
|
64 |
+
# 插值法
|
65 |
+
def interpolate_vector(vector, target_length):
|
66 |
+
original_indices = np.arange(len(vector))
|
67 |
+
target_indices = np.linspace(0, len(vector) - 1, target_length)
|
68 |
+
f = interp1d(original_indices, vector, kind='linear')
|
69 |
+
return f(target_indices)
|
70 |
+
|
71 |
+
|
72 |
+
def expand_features(embedding, target_length):
|
73 |
+
poly = PolynomialFeatures(degree=2)
|
74 |
+
expanded_embedding = poly.fit_transform(embedding.reshape(1, -1))
|
75 |
+
expanded_embedding = expanded_embedding.flatten()
|
76 |
+
if len(expanded_embedding) > target_length:
|
77 |
+
# 如果扩展后的特征超过目标长度,可以通过截断或其他方法来减少维度
|
78 |
+
expanded_embedding = expanded_embedding[:target_length]
|
79 |
+
elif len(expanded_embedding) < target_length:
|
80 |
+
# 如果扩展后的特征少于目标长度,可以通过填充或其他方法来增加维度
|
81 |
+
expanded_embedding = np.pad(expanded_embedding, (0, target_length - len(expanded_embedding)))
|
82 |
+
return expanded_embedding
|
83 |
+
|
84 |
+
|
85 |
+
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
|
86 |
+
async def get_embeddings(request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
87 |
+
if credentials.credentials != sk_key:
|
88 |
+
raise HTTPException(
|
89 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
90 |
+
detail="Invalid authorization code",
|
91 |
+
)
|
92 |
+
|
93 |
+
# 计算嵌入向量和tokens数量
|
94 |
+
# embeddings = [model.encode(text) for text in request.input]
|
95 |
+
embeddings = query(request.input)
|
96 |
+
|
97 |
+
|
98 |
+
# 如果嵌入向量的维度不为1536,则使用插值法扩展至1536维度
|
99 |
+
# embeddings = [interpolate_vector(embedding, 1536) if len(embedding) < 1536 else embedding for embedding in embeddings]
|
100 |
+
# 如果嵌入向量的维度不为1536,则使用特征扩展法扩展至1536维度
|
101 |
+
#embeddings = [expand_features(embedding, 1536) if len(embedding) < 1536 else embedding for embedding in embeddings]
|
102 |
+
|
103 |
+
# Min-Max normalization
|
104 |
+
# embeddings = [(embedding - np.min(embedding)) / (np.max(embedding) - np.min(embedding)) if np.max(embedding) != np.min(embedding) else embedding for embedding in embeddings]
|
105 |
+
#embeddings = [embedding / np.linalg.norm(embedding) for embedding in embeddings]
|
106 |
+
# 将numpy数组转换为列表
|
107 |
+
#embeddings = [embedding.tolist() for embedding in embeddings]
|
108 |
+
prompt_tokens = sum(len(text.split()) for text in request.input)
|
109 |
+
total_tokens = sum(num_tokens_from_string(text) for text in request.input)
|
110 |
+
|
111 |
+
response = {
|
112 |
+
"data": [
|
113 |
+
{
|
114 |
+
"embedding": embedding,
|
115 |
+
"index": index,
|
116 |
+
"object": "embedding"
|
117 |
+
} for index, embedding in enumerate(embeddings)
|
118 |
+
],
|
119 |
+
"model": request.model,
|
120 |
+
"object": "list",
|
121 |
+
"usage": {
|
122 |
+
"prompt_tokens": prompt_tokens,
|
123 |
+
"total_tokens": total_tokens,
|
124 |
+
}
|
125 |
+
}
|
126 |
+
|
127 |
+
return response
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
uvicorn.run("app:app", host='0.0.0.0', port=7860, workers=1)
|