fiewolf1000 commited on
Commit
9493f68
·
verified ·
1 Parent(s): c6533bc

Upload 3 files

Browse files
Files changed (3) hide show
  1. Procfile +1 -1
  2. app.py +94 -97
  3. requirements.txt +4 -2
Procfile CHANGED
@@ -1 +1 @@
1
- web: gunicorn app:app
 
1
+ web: gunicorn app:app -w 2 -k uvicorn.workers.UvicornWorker
app.py CHANGED
@@ -1,108 +1,105 @@
1
- from flask import Flask, request, jsonify
2
- from sentence_transformers import SentenceTransformer
3
- import numpy as np
4
  import os
5
- import time
 
 
6
 
7
- app = Flask(__name__)
8
 
9
- # 加载模型
10
- model_name = "BAAI/bge-small-en-v1.5"
11
- model = SentenceTransformer(model_name)
 
 
 
 
 
12
 
13
- # 支持的模型列表
14
- SUPPORTED_MODELS = {
15
- "text-embedding-3-small": model,
16
- "bge-small-en-v1.5": model
17
  }
18
 
19
- # 简单的API密钥验证(可选)
20
- API_KEY = os.getenv("API_KEY", "your-default-api-key")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- def verify_api_key(headers):
23
- """验证API密钥"""
24
- auth_header = headers.get("Authorization")
25
- if not auth_header or not auth_header.startswith("Bearer "):
26
- return False
27
- return auth_header.split("Bearer ")[1] == API_KEY
28
 
29
- @app.route('/v1/embeddings', methods=['POST'])
30
- def create_embedding():
31
- """生成嵌入向量,兼容OpenAI API格式"""
32
- # 验证API密钥
33
- if not verify_api_key(request.headers):
34
- return jsonify({
35
- "error": {
36
- "message": "Invalid API key",
37
- "type": "invalid_request_error",
38
- "param": None,
39
- "code": "invalid_api_key"
40
- }
41
- }), 401
42
-
43
- # 解析请求
44
- data = request.json
45
- if not data or "input" not in data:
46
- return jsonify({
47
- "error": {
48
- "message": "Missing input",
49
- "type": "invalid_request_error",
50
- "param": None,
51
- "code": "missing_input"
52
- }
53
- }), 400
54
-
55
- # 获取模型(默认为text-embedding-3-small)
56
- model_name = data.get("model", "text-embedding-3-small")
57
- if model_name not in SUPPORTED_MODELS:
58
- return jsonify({
59
- "error": {
60
- "message": f"Model {model_name} not found",
61
- "type": "invalid_request_error",
62
- "param": None,
63
- "code": "model_not_found"
64
- }
65
- }), 404
66
-
67
- # 处理输入(支持单文本或文本列表)
68
- inputs = data["input"]
69
- if isinstance(inputs, str):
70
- inputs = [inputs]
71
-
72
- # 计算嵌入向量
73
- start_time = time.time()
74
- embeddings = model.encode(inputs, normalize_embeddings=True)
75
- processing_time = time.time() - start_time
76
-
77
- # 准备响应数据
78
- response_data = {
79
- "object": "list",
80
- "data": [
81
- {
82
- "object": "embedding",
83
- "embedding": embedding.tolist(),
84
- "index": i
85
- } for i, embedding in enumerate(embeddings)
86
- ],
87
- "model": model_name,
88
- "usage": {
89
- "prompt_tokens": sum(len(text.split()) for text in inputs), # 简单估算
90
- "total_tokens": sum(len(text.split()) for text in inputs)
91
- }
92
- }
93
-
94
- return jsonify(response_data)
95
 
96
- @app.route('/health', methods=['GET'])
97
- def health_check():
98
- """健康检查接口"""
99
- return jsonify({
100
- "status": "healthy",
101
- "model": model_name,
102
- "supported_models": list(SUPPORTED_MODELS.keys())
103
- })
104
 
105
- if __name__ == '__main__':
106
- # 生产环境应使用Gunicorn等WSGI服务器
107
- app.run(host='0.0.0.0', port=int(os.getenv('PORT', 7860)))
108
 
 
1
+ from fastapi import FastAPI, HTTPException, Depends
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
  import os
5
+ import numpy as np
6
+ from sentence_transformers import SentenceTransformer
7
+ from typing import List, Optional
8
 
9
+ app = FastAPI()
10
 
11
+ # 允许跨域请求
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"],
15
+ allow_credentials=True,
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
 
20
+ # 模型映射:OpenAI模型名 → 开源模型名
21
+ MODEL_MAPPING = {
22
+ "text-embedding-3-small": "BAAI/bge-small-en-v1.5",
23
+ "text-embedding-3-large": "BAAI/bge-large-en-v1.5" # 新增大模型映射
24
  }
25
 
26
+ # 加载模型(懒加载,首次请求时加载)
27
+ models = {}
28
+
29
+ def get_model(model_name: str):
30
+ if model_name not in models:
31
+ # 检查是否支持该模型
32
+ if model_name not in MODEL_MAPPING:
33
+ raise HTTPException(status_code=400, detail=f"不支持的模型: {model_name}")
34
+ # 加载模型
35
+ models[model_name] = SentenceTransformer(MODEL_MAPPING[model_name])
36
+ return models[model_name]
37
+
38
+ # 验证API密钥
39
+ def verify_api_key(authorization: Optional[str] = None):
40
+ if not authorization or not authorization.startswith("Bearer "):
41
+ raise HTTPException(status_code=401, detail="未提供有效的API密钥")
42
+ api_key = authorization[len("Bearer "):]
43
+ if api_key != os.getenv("API_KEY"):
44
+ raise HTTPException(status_code=401, detail="无效的API密钥")
45
+ return True
46
+
47
+ # 请求体模型(对齐OpenAI格式)
48
+ class EmbeddingRequest(BaseModel):
49
+ input: str or List[str]
50
+ model: str
51
+ encoding_format: Optional[str] = "float" # 仅支持float,忽略base64
52
+
53
+ # 响应体模型(对齐OpenAI格式)
54
+ class EmbeddingData(BaseModel):
55
+ object: str = "embedding"
56
+ embedding: List[float]
57
+ index: int
58
 
59
+ class EmbeddingResponse(BaseModel):
60
+ object: str = "list"
61
+ data: List[EmbeddingData]
62
+ model: str
63
+ usage: dict = {"prompt_tokens": 0, "total_tokens": 0}
 
64
 
65
+ @app.post("/v1/embeddings", response_model=EmbeddingResponse)
66
+ async def create_embedding(
67
+ request: EmbeddingRequest,
68
+ _: bool = Depends(verify_api_key)
69
+ ):
70
+ try:
71
+ # 获取模型
72
+ model = get_model(request.model)
73
+
74
+ # 处理输入(支持单文本或文本列表)
75
+ inputs = [request.input] if isinstance(request.input, str) else request.input
76
+
77
+ # 计算嵌入
78
+ embeddings = model.encode(inputs, normalize_embeddings=True)
79
+
80
+ # 构建响应
81
+ data = [
82
+ EmbeddingData(embedding=embedding.tolist(), index=i)
83
+ for i, embedding in enumerate(embeddings)
84
+ ]
85
+
86
+ # 估算token数(简单近似:每个单词约1 token)
87
+ prompt_tokens = sum(len(text.split()) for text in inputs)
88
+
89
+ return EmbeddingResponse(
90
+ data=data,
91
+ model=request.model,
92
+ usage={"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens}
93
+ )
94
+ except Exception as e:
95
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ # 健康检查接口
98
+ @app.get("/health")
99
+ async def health_check():
100
+ return {"status": "healthy", "models": list(MODEL_MAPPING.keys())}
 
 
 
 
101
 
102
+ if __name__ == "__main__":
103
+ import uvicorn
104
+ uvicorn.run(app, host="0.0.0.0", port=7860)
105
 
requirements.txt CHANGED
@@ -1,5 +1,7 @@
1
- flask==2.3.3
 
 
2
  sentence-transformers==2.7.0
3
  torch==2.2.2
4
  numpy==1.26.4
5
- gunicorn==21.2.0 # 用于生产环境部署
 
1
+ fastapi==0.110.0
2
+ uvicorn==0.29.0
3
+ gunicorn==21.2.0
4
  sentence-transformers==2.7.0
5
  torch==2.2.2
6
  numpy==1.26.4
7
+ pydantic==2.6.4