Update app.py
#7
by
Amlan99
- opened
app.py
CHANGED
@@ -8,6 +8,7 @@ from fastapi import FastAPI
|
|
8 |
from pydantic import BaseModel
|
9 |
from typing import List, Optional
|
10 |
import torch
|
|
|
11 |
from transformers import AutoModel, AutoTokenizer
|
12 |
|
13 |
app = FastAPI()
|
@@ -72,74 +73,88 @@ class DecodeResponse(BaseModel):
|
|
72 |
# -----------------------------
|
73 |
@app.post("/embed", response_model=EmbedResponse)
|
74 |
def embed(req: EmbedRequest):
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
# -----------------------------
|
129 |
# Embedding Endpoint (image)
|
130 |
# -----------------------------
|
131 |
@app.post("/embed_image", response_model=EmbedImageResponse)
|
132 |
def embed_image(req: EmbedImageRequest):
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
# -----------------------------
|
145 |
# Tokenize Endpoint
|
|
|
8 |
from pydantic import BaseModel
|
9 |
from typing import List, Optional
|
10 |
import torch
|
11 |
+
import gc
|
12 |
from transformers import AutoModel, AutoTokenizer
|
13 |
|
14 |
app = FastAPI()
|
|
|
73 |
# -----------------------------
|
74 |
@app.post("/embed", response_model=EmbedResponse)
|
75 |
def embed(req: EmbedRequest):
|
76 |
+
try:
|
77 |
+
text = req.text
|
78 |
+
|
79 |
+
# -----------------------------
|
80 |
+
# Case 1: Query β mean pool across token embeddings
|
81 |
+
# -----------------------------
|
82 |
+
if (req.prompt_name or "").lower() == "query":
|
83 |
+
with torch.inference_mode():
|
84 |
+
outputs = model.encode_text(
|
85 |
+
texts=[text],
|
86 |
+
task=req.task,
|
87 |
+
prompt_name="query",
|
88 |
+
return_multivector=True, # always token-level
|
89 |
+
truncate_dim=req.truncate_dim,
|
90 |
+
)
|
91 |
+
pooled = outputs[0].mean(dim=0).cpu().tolist()
|
92 |
+
return {"embeddings": [pooled]} # wrap in batch dimension
|
93 |
+
|
94 |
+
# -----------------------------
|
95 |
+
# Case 2: Passage β sliding window, token-level embeddings
|
96 |
+
# -----------------------------
|
97 |
+
enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
|
98 |
+
input_ids = enc["input_ids"].squeeze(0).to(device)
|
99 |
+
total_tokens = input_ids.size(0)
|
100 |
+
|
101 |
+
max_len = min(15_000, model.config.max_position_embeddings) # ~32k
|
102 |
+
stride = 50
|
103 |
+
embeddings = []
|
104 |
+
position = 0
|
105 |
+
|
106 |
+
while position < total_tokens:
|
107 |
+
end = min(position + max_len, total_tokens)
|
108 |
+
window_ids = input_ids[position:end].unsqueeze(0).to(device)
|
109 |
+
|
110 |
+
with torch.inference_mode():
|
111 |
+
outputs = model.encode_text(
|
112 |
+
texts=[tokenizer.decode(window_ids[0])],
|
113 |
+
task=req.task,
|
114 |
+
prompt_name="passage",
|
115 |
+
return_multivector=True, # always token-level
|
116 |
+
truncate_dim=req.truncate_dim,
|
117 |
+
)
|
118 |
+
|
119 |
+
window_embeds = outputs[0].cpu()
|
120 |
+
|
121 |
+
if position > 0:
|
122 |
+
window_embeds = window_embeds[stride:]
|
123 |
+
|
124 |
+
embeddings.append(window_embeds)
|
125 |
+
position += max_len - stride
|
126 |
+
|
127 |
+
full_embeddings = torch.cat(embeddings, dim=0).tolist()
|
128 |
+
return {"embeddings": full_embeddings}
|
129 |
+
finally:
|
130 |
+
# --- Cleanup CUDA memory ---
|
131 |
+
gc.collect()
|
132 |
+
if torch.cuda.is_available():
|
133 |
+
torch.cuda.empty_cache()
|
134 |
+
torch.cuda.ipc_collect()
|
135 |
|
136 |
# -----------------------------
|
137 |
# Embedding Endpoint (image)
|
138 |
# -----------------------------
|
139 |
@app.post("/embed_image", response_model=EmbedImageResponse)
|
140 |
def embed_image(req: EmbedImageRequest):
|
141 |
+
try:
|
142 |
+
with torch.inference_mode():
|
143 |
+
outputs = model.encode_image(
|
144 |
+
images=[req.image],
|
145 |
+
task=req.task,
|
146 |
+
return_multivector=req.return_multivector,
|
147 |
+
truncate_dim=req.truncate_dim,
|
148 |
+
)
|
149 |
+
pooled = outputs[0].mean(dim=0).cpu()
|
150 |
+
return {"embeddings": [pooled]}
|
151 |
+
|
152 |
+
finally:
|
153 |
+
# --- Cleanup CUDA memory ---
|
154 |
+
gc.collect()
|
155 |
+
if torch.cuda.is_available():
|
156 |
+
torch.cuda.empty_cache()
|
157 |
+
torch.cuda.ipc_collect()
|
158 |
|
159 |
# -----------------------------
|
160 |
# Tokenize Endpoint
|