gordonchan commited on
Commit
5e6213c
1 Parent(s): 7804982

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -10
app.py CHANGED
@@ -94,18 +94,35 @@ async def get_embeddings(request: EmbeddingRequest, credentials: HTTPAuthorizati
94
  # embeddings = [model.encode(text) for text in request.input]
95
  print(request.json())
96
  embeddings = query(request.input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
98
 
99
- # 如果嵌入向量的维度不为1536,则使用插值法扩展至1536维度
100
- # embeddings = [interpolate_vector(embedding, 1536) if len(embedding) < 1536 else embedding for embedding in embeddings]
101
- # 如果嵌入向量的维度不为1536,则使用特征扩展法扩展至1536维度
102
- #embeddings = [expand_features(embedding, 1536) if len(embedding) < 1536 else embedding for embedding in embeddings]
 
 
 
103
 
104
- # Min-Max normalization
105
- # embeddings = [(embedding - np.min(embedding)) / (np.max(embedding) - np.min(embedding)) if np.max(embedding) != np.min(embedding) else embedding for embedding in embeddings]
106
- #embeddings = [embedding / np.linalg.norm(embedding) for embedding in embeddings]
107
- # 将numpy数组转换为列表
108
- #embeddings = [embedding.tolist() for embedding in embeddings]
109
  prompt_tokens = sum(len(text.split()) for text in request.input)
110
  total_tokens = sum(num_tokens_from_string(text) for text in request.input)
111
 
@@ -127,6 +144,5 @@ async def get_embeddings(request: EmbeddingRequest, credentials: HTTPAuthorizati
127
 
128
  return response
129
 
130
-
131
  if __name__ == "__main__":
132
  uvicorn.run("app:app", host='0.0.0.0', port=7860, workers=2)
 
94
  # embeddings = [model.encode(text) for text in request.input]
95
  print(request.json())
96
  embeddings = query(request.input)
97
+ prompt_tokens = sum(len(text.split()) for text in request.input)
98
+ total_tokens = sum(num_tokens_from_string(text) for text in request.input)
99
+
100
+ response = {
101
+ "data": [
102
+ {
103
+ "embedding": embedding,
104
+ "index": index,
105
+ "object": "embedding"
106
+ } for index, embedding in enumerate(embeddings)
107
+ ],
108
+ "model": request.model,
109
+ "object": "list",
110
+ "usage": {
111
+ "prompt_tokens": prompt_tokens,
112
+ "total_tokens": total_tokens,
113
+ }
114
+ }
115
 
116
+ return response
117
 
118
+ @app.post("/v1/embeddings", response_model=EmbeddingResponse)
119
+ async def get_embeddings(request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
120
+ if credentials.credentials != sk_key:
121
+ raise HTTPException(
122
+ status_code=status.HTTP_401_UNAUTHORIZED,
123
+ detail="Invalid authorization code",
124
+ )
125
 
 
 
 
 
 
126
  prompt_tokens = sum(len(text.split()) for text in request.input)
127
  total_tokens = sum(num_tokens_from_string(text) for text in request.input)
128
 
 
144
 
145
  return response
146
 
 
147
  if __name__ == "__main__":
148
  uvicorn.run("app:app", host='0.0.0.0', port=7860, workers=2)