Files changed (1) hide show
  1. app.py +77 -62
app.py CHANGED
@@ -8,6 +8,7 @@ from fastapi import FastAPI
8
  from pydantic import BaseModel
9
  from typing import List, Optional
10
  import torch
 
11
  from transformers import AutoModel, AutoTokenizer
12
 
13
  app = FastAPI()
@@ -72,74 +73,88 @@ class DecodeResponse(BaseModel):
72
  # -----------------------------
73
  @app.post("/embed", response_model=EmbedResponse)
74
  def embed(req: EmbedRequest):
75
- text = req.text
76
-
77
- # -----------------------------
78
- # Case 1: Query β†’ mean pool across token embeddings
79
- # -----------------------------
80
- if (req.prompt_name or "").lower() == "query":
81
- with torch.no_grad():
82
- outputs = model.encode_text(
83
- texts=[text],
84
- task=req.task,
85
- prompt_name="query",
86
- return_multivector=True, # always token-level
87
- truncate_dim=req.truncate_dim,
88
- )
89
- pooled = outputs[0].mean(dim=0).cpu().tolist()
90
- return {"embeddings": [pooled]} # wrap in batch dimension
91
-
92
- # -----------------------------
93
- # Case 2: Passage β†’ sliding window, token-level embeddings
94
- # -----------------------------
95
- enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
96
- input_ids = enc["input_ids"].squeeze(0).to(device)
97
- total_tokens = input_ids.size(0)
98
-
99
- max_len = min(15_000, model.config.max_position_embeddings) # ~32k
100
- stride = 50
101
- embeddings = []
102
- position = 0
103
-
104
- while position < total_tokens:
105
- end = min(position + max_len, total_tokens)
106
- window_ids = input_ids[position:end].unsqueeze(0).to(device)
107
-
108
- with torch.no_grad():
109
- outputs = model.encode_text(
110
- texts=[tokenizer.decode(window_ids[0])],
111
- task=req.task,
112
- prompt_name="passage",
113
- return_multivector=True, # always token-level
114
- truncate_dim=req.truncate_dim,
115
- )
116
-
117
- window_embeds = outputs[0].cpu()
118
-
119
- if position > 0:
120
- window_embeds = window_embeds[stride:]
121
-
122
- embeddings.append(window_embeds)
123
- position += max_len - stride
124
-
125
- full_embeddings = torch.cat(embeddings, dim=0).tolist()
126
- return {"embeddings": full_embeddings}
 
 
 
 
 
 
 
127
 
128
  # -----------------------------
129
  # Embedding Endpoint (image)
130
  # -----------------------------
131
  @app.post("/embed_image", response_model=EmbedImageResponse)
132
  def embed_image(req: EmbedImageRequest):
133
- with torch.no_grad():
134
- outputs = model.encode_image(
135
- images=[req.image],
136
- task=req.task,
137
- return_multivector=req.return_multivector,
138
- truncate_dim=req.truncate_dim,
139
- )
140
- pooled = outputs[0].mean(dim=0).cpu()
141
- return {"embeddings": [pooled]}
142
-
 
 
 
 
 
 
 
143
 
144
  # -----------------------------
145
  # Tokenize Endpoint
 
8
  from pydantic import BaseModel
9
  from typing import List, Optional
10
  import torch
11
+ import gc
12
  from transformers import AutoModel, AutoTokenizer
13
 
14
  app = FastAPI()
 
73
  # -----------------------------
74
  @app.post("/embed", response_model=EmbedResponse)
75
  def embed(req: EmbedRequest):
76
+ try:
77
+ text = req.text
78
+
79
+ # -----------------------------
80
+ # Case 1: Query β†’ mean pool across token embeddings
81
+ # -----------------------------
82
+ if (req.prompt_name or "").lower() == "query":
83
+ with torch.inference_mode():
84
+ outputs = model.encode_text(
85
+ texts=[text],
86
+ task=req.task,
87
+ prompt_name="query",
88
+ return_multivector=True, # always token-level
89
+ truncate_dim=req.truncate_dim,
90
+ )
91
+ pooled = outputs[0].mean(dim=0).cpu().tolist()
92
+ return {"embeddings": [pooled]} # wrap in batch dimension
93
+
94
+ # -----------------------------
95
+ # Case 2: Passage β†’ sliding window, token-level embeddings
96
+ # -----------------------------
97
+ enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
98
+ input_ids = enc["input_ids"].squeeze(0).to(device)
99
+ total_tokens = input_ids.size(0)
100
+
101
+ max_len = min(15_000, model.config.max_position_embeddings) # ~32k
102
+ stride = 50
103
+ embeddings = []
104
+ position = 0
105
+
106
+ while position < total_tokens:
107
+ end = min(position + max_len, total_tokens)
108
+ window_ids = input_ids[position:end].unsqueeze(0).to(device)
109
+
110
+ with torch.inference_mode():
111
+ outputs = model.encode_text(
112
+ texts=[tokenizer.decode(window_ids[0])],
113
+ task=req.task,
114
+ prompt_name="passage",
115
+ return_multivector=True, # always token-level
116
+ truncate_dim=req.truncate_dim,
117
+ )
118
+
119
+ window_embeds = outputs[0].cpu()
120
+
121
+ if position > 0:
122
+ window_embeds = window_embeds[stride:]
123
+
124
+ embeddings.append(window_embeds)
125
+ position += max_len - stride
126
+
127
+ full_embeddings = torch.cat(embeddings, dim=0).tolist()
128
+ return {"embeddings": full_embeddings}
129
+ finally:
130
+ # --- Cleanup CUDA memory ---
131
+ gc.collect()
132
+ if torch.cuda.is_available():
133
+ torch.cuda.empty_cache()
134
+ torch.cuda.ipc_collect()
135
 
136
  # -----------------------------
137
  # Embedding Endpoint (image)
138
  # -----------------------------
139
  @app.post("/embed_image", response_model=EmbedImageResponse)
140
  def embed_image(req: EmbedImageRequest):
141
+ try:
142
+ with torch.inference_mode():
143
+ outputs = model.encode_image(
144
+ images=[req.image],
145
+ task=req.task,
146
+ return_multivector=req.return_multivector,
147
+ truncate_dim=req.truncate_dim,
148
+ )
149
+ pooled = outputs[0].mean(dim=0).cpu()
150
+ return {"embeddings": [pooled]}
151
+
152
+ finally:
153
+ # --- Cleanup CUDA memory ---
154
+ gc.collect()
155
+ if torch.cuda.is_available():
156
+ torch.cuda.empty_cache()
157
+ torch.cuda.ipc_collect()
158
 
159
  # -----------------------------
160
  # Tokenize Endpoint