gordonchan commited on
Commit
69a7826
1 Parent(s): 4c66e55

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Depends, HTTPException, status
2
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
3
+ from pydantic import BaseModel
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ import uvicorn
6
+ import tiktoken
7
+ import numpy as np
8
+ from scipy.interpolate import interp1d
9
+ from typing import List
10
+ from sklearn.preprocessing import PolynomialFeatures
11
+ import os
12
+ import requests
13
+
14
+ #环境变量传入
15
+ sk_key = os.environ.get('sk-key', 'sk-aaabbbcccdddeeefffggghhhiiijjjkkk')
16
+ API_URL = "https://api-inference.huggingface.co/models/BAAI/bge-large-zh-v1.5"
17
+ headers = {"Authorization": f"Bearer {sk_key}"}
18
+
19
+ # 创建一个FastAPI实例
20
+ app = FastAPI()
21
+
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["*"],
25
+ allow_credentials=True,
26
+ allow_methods=["*"],
27
+ allow_headers=["*"],
28
+ )
29
+
30
+ # 创建一个HTTPBearer实例
31
+ security = HTTPBearer()
32
+
33
+ # 预加载模型
34
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 检测是否有GPU可用,如果有则使用cuda设备,否则使用cpu设备
35
+ # if torch.cuda.is_available():
36
+ # print('本次加载模型的设备为GPU: ', torch.cuda.get_device_name(0))
37
+ # else:
38
+ # print('本次加载模型的设备为CPU.')
39
+ #model = SentenceTransformer('./moka-ai_m3e-large', device=device)
40
+
41
+ def query(payload):
42
+ response = requests.post(API_URL, headers=headers, json=payload)
43
+ return response.json()[0][0]
44
+
45
+ class EmbeddingRequest(BaseModel):
46
+ input: List[str]
47
+ model: str
48
+
49
+
50
+ class EmbeddingResponse(BaseModel):
51
+ data: list
52
+ model: str
53
+ object: str
54
+ usage: dict
55
+
56
+
57
+ def num_tokens_from_string(string: str) -> int:
58
+ """Returns the number of tokens in a text string."""
59
+ encoding = tiktoken.get_encoding('cl100k_base')
60
+ num_tokens = len(encoding.encode(string))
61
+ return num_tokens
62
+
63
+
64
+ # 插值法
65
+ def interpolate_vector(vector, target_length):
66
+ original_indices = np.arange(len(vector))
67
+ target_indices = np.linspace(0, len(vector) - 1, target_length)
68
+ f = interp1d(original_indices, vector, kind='linear')
69
+ return f(target_indices)
70
+
71
+
72
+ def expand_features(embedding, target_length):
73
+ poly = PolynomialFeatures(degree=2)
74
+ expanded_embedding = poly.fit_transform(embedding.reshape(1, -1))
75
+ expanded_embedding = expanded_embedding.flatten()
76
+ if len(expanded_embedding) > target_length:
77
+ # 如果扩展后的特征超过目标长度,可以通过截断或其他方法来减少维度
78
+ expanded_embedding = expanded_embedding[:target_length]
79
+ elif len(expanded_embedding) < target_length:
80
+ # 如果扩展后的特征少于目标长度,可以通过填充或其他方法来增加维度
81
+ expanded_embedding = np.pad(expanded_embedding, (0, target_length - len(expanded_embedding)))
82
+ return expanded_embedding
83
+
84
+
85
+ @app.post("/v1/embeddings", response_model=EmbeddingResponse)
86
+ async def get_embeddings(request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
87
+ if credentials.credentials != sk_key:
88
+ raise HTTPException(
89
+ status_code=status.HTTP_401_UNAUTHORIZED,
90
+ detail="Invalid authorization code",
91
+ )
92
+
93
+ # 计算嵌入向量和tokens数量
94
+ # embeddings = [model.encode(text) for text in request.input]
95
+ embeddings = query(request.input)
96
+
97
+
98
+ # 如果嵌入向量的维度不为1536,则使用插值法扩展至1536维度
99
+ # embeddings = [interpolate_vector(embedding, 1536) if len(embedding) < 1536 else embedding for embedding in embeddings]
100
+ # 如果嵌入向量的维度不为1536,则使用特征扩展法扩展至1536维度
101
+ #embeddings = [expand_features(embedding, 1536) if len(embedding) < 1536 else embedding for embedding in embeddings]
102
+
103
+ # Min-Max normalization
104
+ # embeddings = [(embedding - np.min(embedding)) / (np.max(embedding) - np.min(embedding)) if np.max(embedding) != np.min(embedding) else embedding for embedding in embeddings]
105
+ #embeddings = [embedding / np.linalg.norm(embedding) for embedding in embeddings]
106
+ # 将numpy数组转换为列表
107
+ #embeddings = [embedding.tolist() for embedding in embeddings]
108
+ prompt_tokens = sum(len(text.split()) for text in request.input)
109
+ total_tokens = sum(num_tokens_from_string(text) for text in request.input)
110
+
111
+ response = {
112
+ "data": [
113
+ {
114
+ "embedding": embedding,
115
+ "index": index,
116
+ "object": "embedding"
117
+ } for index, embedding in enumerate(embeddings)
118
+ ],
119
+ "model": request.model,
120
+ "object": "list",
121
+ "usage": {
122
+ "prompt_tokens": prompt_tokens,
123
+ "total_tokens": total_tokens,
124
+ }
125
+ }
126
+
127
+ return response
128
+
129
+
130
+ if __name__ == "__main__":
131
+ uvicorn.run("app:app", host='0.0.0.0', port=7860, workers=1)