Spaces:
Running
Running
sanbo
commited on
Commit
·
bf8b09b
1
Parent(s):
e767741
update sth. at 2025-02-03 21:03:19
Browse files
app.py
CHANGED
@@ -4,16 +4,35 @@ import torch
|
|
4 |
import gradio as gr
|
5 |
from fastapi import FastAPI, HTTPException
|
6 |
from fastapi.middleware.cors import CORSMiddleware
|
7 |
-
from pydantic import BaseModel
|
8 |
-
from typing import List, Dict
|
9 |
from functools import lru_cache
|
10 |
-
import numpy as np
|
11 |
from threading import Lock
|
12 |
import uvicorn
|
13 |
|
14 |
class EmbeddingRequest(BaseModel):
|
15 |
-
|
16 |
-
model: str =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
class EmbeddingResponse(BaseModel):
|
19 |
status: str
|
@@ -21,7 +40,7 @@ class EmbeddingResponse(BaseModel):
|
|
21 |
|
22 |
class EmbeddingService:
|
23 |
def __init__(self):
|
24 |
-
self.
|
25 |
self.max_length = 512
|
26 |
self.device = torch.device("cpu")
|
27 |
self.model = None
|
@@ -41,23 +60,22 @@ class EmbeddingService:
|
|
41 |
try:
|
42 |
from transformers import AutoTokenizer, AutoModel
|
43 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
44 |
-
self.
|
45 |
trust_remote_code=True
|
46 |
)
|
47 |
self.model = AutoModel.from_pretrained(
|
48 |
-
self.
|
49 |
trust_remote_code=True
|
50 |
).to(self.device)
|
51 |
self.model.eval()
|
52 |
torch.set_grad_enabled(False)
|
53 |
-
self.logger.info(f"
|
54 |
except Exception as e:
|
55 |
self.logger.error(f"模型初始化失败: {str(e)}")
|
56 |
raise
|
57 |
|
58 |
@lru_cache(maxsize=1000)
|
59 |
def get_embedding(self, text: str) -> List[float]:
|
60 |
-
"""同步生成嵌入向量,带缓存"""
|
61 |
with self.lock:
|
62 |
try:
|
63 |
inputs = self.tokenizer(
|
@@ -85,7 +103,8 @@ app.add_middleware(
|
|
85 |
allow_methods=["*"],
|
86 |
allow_headers=["*"],
|
87 |
)
|
88 |
-
|
|
|
89 |
@app.post("/api/embed", response_model=EmbeddingResponse)
|
90 |
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
|
91 |
@app.post("/generate_embeddings", response_model=EmbeddingResponse)
|
@@ -95,11 +114,10 @@ app.add_middleware(
|
|
95 |
@app.post("/hf/v1/chat/completions", response_model=EmbeddingResponse)
|
96 |
async def generate_embeddings(request: EmbeddingRequest):
|
97 |
try:
|
98 |
-
# 使用run_in_executor避免事件循环问题
|
99 |
embedding = await asyncio.get_running_loop().run_in_executor(
|
100 |
None,
|
101 |
embedding_service.get_embedding,
|
102 |
-
request.
|
103 |
)
|
104 |
return EmbeddingResponse(
|
105 |
status="success",
|
@@ -112,7 +130,7 @@ async def generate_embeddings(request: EmbeddingRequest):
|
|
112 |
async def root():
|
113 |
return {
|
114 |
"status": "active",
|
115 |
-
"
|
116 |
"device": str(embedding_service.device)
|
117 |
}
|
118 |
|
@@ -134,8 +152,11 @@ iface = gr.Interface(
|
|
134 |
inputs=gr.Textbox(lines=3, label="输入文本"),
|
135 |
outputs=gr.JSON(label="嵌入向量结果"),
|
136 |
title="Jina Embeddings V3",
|
137 |
-
description="
|
138 |
-
examples=[[
|
|
|
|
|
|
|
139 |
)
|
140 |
|
141 |
@app.on_event("startup")
|
@@ -145,4 +166,4 @@ async def startup_event():
|
|
145 |
if __name__ == "__main__":
|
146 |
asyncio.run(embedding_service.initialize())
|
147 |
gr.mount_gradio_app(app, iface, path="/ui")
|
148 |
-
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)
|
|
|
4 |
import gradio as gr
|
5 |
from fastapi import FastAPI, HTTPException
|
6 |
from fastapi.middleware.cors import CORSMiddleware
|
7 |
+
from pydantic import BaseModel, Field, root_validator
|
8 |
+
from typing import List, Dict, Optional
|
9 |
from functools import lru_cache
|
|
|
10 |
from threading import Lock
|
11 |
import uvicorn
|
12 |
|
13 |
class EmbeddingRequest(BaseModel):
|
14 |
+
# 强制锁定模型参数
|
15 |
+
model: str = Field(
|
16 |
+
default="jinaai/jina-embeddings-v3",
|
17 |
+
description="此参数仅用于API兼容,实际模型固定为jinaai/jina-embeddings-v3",
|
18 |
+
frozen=True # 禁止修改
|
19 |
+
)
|
20 |
+
# 支持三种输入字段
|
21 |
+
inputs: Optional[str] = Field(None, description="输入文本(兼容HuggingFace格式)")
|
22 |
+
input: Optional[str] = Field(None, description="输入文本(兼容OpenAI格式)")
|
23 |
+
prompt: Optional[str] = Field(None, description="输入文本(兼容Ollama格式)")
|
24 |
+
|
25 |
+
# 自动合并输入字段
|
26 |
+
@root_validator(pre=True)
|
27 |
+
def merge_input_fields(cls, values):
|
28 |
+
input_fields = ["inputs", "input", "prompt"]
|
29 |
+
for field in input_fields:
|
30 |
+
if values.get(field):
|
31 |
+
values["inputs"] = values[field]
|
32 |
+
break
|
33 |
+
else:
|
34 |
+
raise ValueError("必须提供 inputs/input/prompt 任一字段")
|
35 |
+
return values
|
36 |
|
37 |
class EmbeddingResponse(BaseModel):
|
38 |
status: str
|
|
|
40 |
|
41 |
class EmbeddingService:
|
42 |
def __init__(self):
|
43 |
+
self._true_model_name = "jinaai/jina-embeddings-v3" # 硬编码模型名称
|
44 |
self.max_length = 512
|
45 |
self.device = torch.device("cpu")
|
46 |
self.model = None
|
|
|
60 |
try:
|
61 |
from transformers import AutoTokenizer, AutoModel
|
62 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
63 |
+
self._true_model_name,
|
64 |
trust_remote_code=True
|
65 |
)
|
66 |
self.model = AutoModel.from_pretrained(
|
67 |
+
self._true_model_name,
|
68 |
trust_remote_code=True
|
69 |
).to(self.device)
|
70 |
self.model.eval()
|
71 |
torch.set_grad_enabled(False)
|
72 |
+
self.logger.info(f"强制加载模型: {self._true_model_name}")
|
73 |
except Exception as e:
|
74 |
self.logger.error(f"模型初始化失败: {str(e)}")
|
75 |
raise
|
76 |
|
77 |
@lru_cache(maxsize=1000)
|
78 |
def get_embedding(self, text: str) -> List[float]:
|
|
|
79 |
with self.lock:
|
80 |
try:
|
81 |
inputs = self.tokenizer(
|
|
|
103 |
allow_methods=["*"],
|
104 |
allow_headers=["*"],
|
105 |
)
|
106 |
+
@app.post("/embed", response_model=EmbeddingResponse)
|
107 |
+
@app.post("/api/embeddings", response_model=EmbeddingResponse)
|
108 |
@app.post("/api/embed", response_model=EmbeddingResponse)
|
109 |
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
|
110 |
@app.post("/generate_embeddings", response_model=EmbeddingResponse)
|
|
|
114 |
@app.post("/hf/v1/chat/completions", response_model=EmbeddingResponse)
|
115 |
async def generate_embeddings(request: EmbeddingRequest):
|
116 |
try:
|
|
|
117 |
embedding = await asyncio.get_running_loop().run_in_executor(
|
118 |
None,
|
119 |
embedding_service.get_embedding,
|
120 |
+
request.inputs # 使用合并后的输入字段
|
121 |
)
|
122 |
return EmbeddingResponse(
|
123 |
status="success",
|
|
|
130 |
async def root():
|
131 |
return {
|
132 |
"status": "active",
|
133 |
+
"true_model": embedding_service._true_model_name,
|
134 |
"device": str(embedding_service.device)
|
135 |
}
|
136 |
|
|
|
152 |
inputs=gr.Textbox(lines=3, label="输入文本"),
|
153 |
outputs=gr.JSON(label="嵌入向量结果"),
|
154 |
title="Jina Embeddings V3",
|
155 |
+
description="强制使用jinaai/jina-embeddings-v3模型(无视请求中的model参数)",
|
156 |
+
examples=[[
|
157 |
+
"Represent this sentence for searching relevant passages: "
|
158 |
+
"The sky is blue because of Rayleigh scattering"
|
159 |
+
]]
|
160 |
)
|
161 |
|
162 |
@app.on_event("startup")
|
|
|
166 |
if __name__ == "__main__":
|
167 |
asyncio.run(embedding_service.initialize())
|
168 |
gr.mount_gradio_app(app, iface, path="/ui")
|
169 |
+
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)
|