Update app.py
#2
by
Amlan99
- opened
app.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
# app.py (FastAPI server to host the Jina Embedding model)
|
2 |
-
# Must be set before importing Hugging Face libs
|
3 |
import os
|
4 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
5 |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
|
6 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
|
|
|
7 |
from fastapi import FastAPI
|
8 |
from pydantic import BaseModel
|
9 |
from typing import List, Optional
|
@@ -30,14 +30,25 @@ model.eval()
|
|
30 |
# -----------------------------
|
31 |
class EmbedRequest(BaseModel):
|
32 |
text: str
|
33 |
-
task: str = "retrieval"
|
34 |
prompt_name: Optional[str] = None
|
35 |
-
return_token_embeddings: bool = True
|
|
|
36 |
|
37 |
|
38 |
class EmbedResponse(BaseModel):
|
39 |
-
embeddings: List[List[float]]
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
class TokenizeRequest(BaseModel):
|
@@ -57,34 +68,33 @@ class DecodeResponse(BaseModel):
|
|
57 |
|
58 |
|
59 |
# -----------------------------
|
60 |
-
# Embedding Endpoint
|
61 |
# -----------------------------
|
62 |
@app.post("/embed", response_model=EmbedResponse)
|
63 |
def embed(req: EmbedRequest):
|
64 |
text = req.text
|
65 |
|
66 |
-
#
|
67 |
-
# Case 1: Query β directly pooled embedding
|
68 |
-
# -----------------------------
|
69 |
if not req.return_token_embeddings:
|
70 |
with torch.no_grad():
|
71 |
-
|
72 |
texts=[text],
|
73 |
task=req.task,
|
74 |
prompt_name=req.prompt_name or "query",
|
75 |
-
return_multivector=
|
|
|
76 |
)
|
77 |
-
|
|
|
|
|
78 |
|
79 |
-
#
|
80 |
-
# Case 2: Long passages β sliding window token embeddings
|
81 |
-
# -----------------------------
|
82 |
enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
|
83 |
-
input_ids = enc["input_ids"].squeeze(0).to(device)
|
84 |
total_tokens = input_ids.size(0)
|
85 |
|
86 |
-
max_len = model.config.max_position_embeddings #
|
87 |
-
stride = 50
|
88 |
embeddings = []
|
89 |
position = 0
|
90 |
|
@@ -94,27 +104,41 @@ def embed(req: EmbedRequest):
|
|
94 |
|
95 |
with torch.no_grad():
|
96 |
outputs = model.encode_text(
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
102 |
|
103 |
-
window_embeds = outputs[0].cpu()
|
104 |
|
105 |
-
# Drop overlapping tokens except in first window
|
106 |
if position > 0:
|
107 |
window_embeds = window_embeds[stride:]
|
108 |
|
109 |
embeddings.append(window_embeds)
|
110 |
-
|
111 |
-
# Advance window
|
112 |
position += max_len - stride
|
113 |
|
114 |
-
full_embeddings = torch.cat(embeddings, dim=0)
|
115 |
return {"embeddings": full_embeddings}
|
116 |
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
# -----------------------------
|
119 |
# Tokenize Endpoint
|
120 |
# -----------------------------
|
@@ -130,4 +154,4 @@ def tokenize(req: TokenizeRequest):
|
|
130 |
@app.post("/decode", response_model=DecodeResponse)
|
131 |
def decode(req: DecodeRequest):
|
132 |
decoded = tokenizer.decode(req.input_ids)
|
133 |
-
return {"text": decoded}
|
|
|
1 |
# app.py (FastAPI server to host the Jina Embedding model)
|
|
|
2 |
import os
|
3 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
4 |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
|
5 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
|
6 |
+
|
7 |
from fastapi import FastAPI
|
8 |
from pydantic import BaseModel
|
9 |
from typing import List, Optional
|
|
|
30 |
# -----------------------------
|
31 |
class EmbedRequest(BaseModel):
|
32 |
text: str
|
33 |
+
task: str = "retrieval"
|
34 |
prompt_name: Optional[str] = None
|
35 |
+
return_token_embeddings: bool = True
|
36 |
+
truncate_dim: Optional[int] = None # for matryoshka embeddings
|
37 |
|
38 |
|
39 |
class EmbedResponse(BaseModel):
|
40 |
+
embeddings: List[List[float]]
|
41 |
+
|
42 |
+
|
43 |
+
class EmbedImageRequest(BaseModel):
|
44 |
+
image: str
|
45 |
+
task: str = "retrieval"
|
46 |
+
return_multivector: bool = True
|
47 |
+
truncate_dim: Optional[int] = None
|
48 |
+
|
49 |
+
|
50 |
+
class EmbedImageResponse(BaseModel):
|
51 |
+
embeddings: List[List[float]]
|
52 |
|
53 |
|
54 |
class TokenizeRequest(BaseModel):
|
|
|
68 |
|
69 |
|
70 |
# -----------------------------
|
71 |
+
# Embedding Endpoint (text)
|
72 |
# -----------------------------
|
73 |
@app.post("/embed", response_model=EmbedResponse)
|
74 |
def embed(req: EmbedRequest):
|
75 |
text = req.text
|
76 |
|
77 |
+
# Case 1: Query β pooled mean of multivectors
|
|
|
|
|
78 |
if not req.return_token_embeddings:
|
79 |
with torch.no_grad():
|
80 |
+
outputs = model.encode_text(
|
81 |
texts=[text],
|
82 |
task=req.task,
|
83 |
prompt_name=req.prompt_name or "query",
|
84 |
+
return_multivector=True,
|
85 |
+
truncate_dim=req.truncate_dim,
|
86 |
)
|
87 |
+
# outputs[0] = (num_vectors, hidden_dim)
|
88 |
+
pooled = outputs[0].mean(dim=0).cpu().tolist()
|
89 |
+
return {"embeddings": [pooled]}
|
90 |
|
91 |
+
# Case 2: Passage β sliding window, token-level embeddings
|
|
|
|
|
92 |
enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
|
93 |
+
input_ids = enc["input_ids"].squeeze(0).to(device)
|
94 |
total_tokens = input_ids.size(0)
|
95 |
|
96 |
+
max_len = model.config.max_position_embeddings # ~32k
|
97 |
+
stride = 50
|
98 |
embeddings = []
|
99 |
position = 0
|
100 |
|
|
|
104 |
|
105 |
with torch.no_grad():
|
106 |
outputs = model.encode_text(
|
107 |
+
texts=[tokenizer.decode(window_ids[0])],
|
108 |
+
task=req.task,
|
109 |
+
prompt_name=req.prompt_name or "passage",
|
110 |
+
return_multivector=True,
|
111 |
+
truncate_dim=req.truncate_dim,
|
112 |
+
)
|
113 |
|
114 |
+
window_embeds = outputs[0].cpu()
|
115 |
|
|
|
116 |
if position > 0:
|
117 |
window_embeds = window_embeds[stride:]
|
118 |
|
119 |
embeddings.append(window_embeds)
|
|
|
|
|
120 |
position += max_len - stride
|
121 |
|
122 |
+
full_embeddings = torch.cat(embeddings, dim=0)
|
123 |
return {"embeddings": full_embeddings}
|
124 |
|
125 |
|
126 |
+
# -----------------------------
|
127 |
+
# Embedding Endpoint (image)
|
128 |
+
# -----------------------------
|
129 |
+
@app.post("/embed_image", response_model=EmbedImageResponse)
|
130 |
+
def embed_image(req: EmbedImageRequest):
|
131 |
+
with torch.no_grad():
|
132 |
+
outputs = model.encode_image(
|
133 |
+
images=[req.image],
|
134 |
+
task=req.task,
|
135 |
+
return_multivector=req.return_multivector,
|
136 |
+
truncate_dim=req.truncate_dim,
|
137 |
+
)
|
138 |
+
pooled = outputs[0].mean(dim=0).cpu()
|
139 |
+
return {"embeddings": [pooled]}
|
140 |
+
|
141 |
+
|
142 |
# -----------------------------
|
143 |
# Tokenize Endpoint
|
144 |
# -----------------------------
|
|
|
154 |
@app.post("/decode", response_model=DecodeResponse)
|
155 |
def decode(req: DecodeRequest):
|
156 |
decoded = tokenizer.decode(req.input_ids)
|
157 |
+
return {"text": decoded}
|