Files changed (1) hide show
  1. app.py +53 -29
app.py CHANGED
@@ -1,9 +1,9 @@
1
  # app.py (FastAPI server to host the Jina Embedding model)
2
- # Must be set before importing Hugging Face libs
3
  import os
4
  os.environ["HF_HOME"] = "/tmp/huggingface"
5
  os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
6
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
 
7
  from fastapi import FastAPI
8
  from pydantic import BaseModel
9
  from typing import List, Optional
@@ -30,14 +30,25 @@ model.eval()
30
  # -----------------------------
31
  class EmbedRequest(BaseModel):
32
  text: str
33
- task: str = "retrieval" # "retrieval", "text-matching", "code", etc.
34
  prompt_name: Optional[str] = None
35
- return_token_embeddings: bool = True # False β†’ for queries (pooled embedding)
 
36
 
37
 
38
  class EmbedResponse(BaseModel):
39
- embeddings: List[List[float]] # (num_tokens, hidden_dim) if token-level
40
- # (1, hidden_dim) if pooled query
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  class TokenizeRequest(BaseModel):
@@ -57,34 +68,33 @@ class DecodeResponse(BaseModel):
57
 
58
 
59
  # -----------------------------
60
- # Embedding Endpoint
61
  # -----------------------------
62
  @app.post("/embed", response_model=EmbedResponse)
63
  def embed(req: EmbedRequest):
64
  text = req.text
65
 
66
- # -----------------------------
67
- # Case 1: Query β†’ directly pooled embedding
68
- # -----------------------------
69
  if not req.return_token_embeddings:
70
  with torch.no_grad():
71
- emb = model.encode_text(
72
  texts=[text],
73
  task=req.task,
74
  prompt_name=req.prompt_name or "query",
75
- return_multivector=False
 
76
  )
77
- return {"embeddings": emb} # shape: (1, hidden_dim)
 
 
78
 
79
- # -----------------------------
80
- # Case 2: Long passages β†’ sliding window token embeddings
81
- # -----------------------------
82
  enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
83
- input_ids = enc["input_ids"].squeeze(0).to(device) # (total_tokens,)
84
  total_tokens = input_ids.size(0)
85
 
86
- max_len = model.config.max_position_embeddings # e.g., 32k for v4
87
- stride = 50 # overlap for sliding window
88
  embeddings = []
89
  position = 0
90
 
@@ -94,27 +104,41 @@ def embed(req: EmbedRequest):
94
 
95
  with torch.no_grad():
96
  outputs = model.encode_text(
97
- texts=[tokenizer.decode(window_ids[0])],
98
- task=req.task,
99
- prompt_name=req.prompt_name or "passage",
100
- return_multivector=True,
101
- )
 
102
 
103
- window_embeds = outputs[0].cpu() # (window_len, hidden_dim)
104
 
105
- # Drop overlapping tokens except in first window
106
  if position > 0:
107
  window_embeds = window_embeds[stride:]
108
 
109
  embeddings.append(window_embeds)
110
-
111
- # Advance window
112
  position += max_len - stride
113
 
114
- full_embeddings = torch.cat(embeddings, dim=0) # (total_tokens, hidden_dim)
115
  return {"embeddings": full_embeddings}
116
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  # -----------------------------
119
  # Tokenize Endpoint
120
  # -----------------------------
@@ -130,4 +154,4 @@ def tokenize(req: TokenizeRequest):
130
  @app.post("/decode", response_model=DecodeResponse)
131
  def decode(req: DecodeRequest):
132
  decoded = tokenizer.decode(req.input_ids)
133
- return {"text": decoded}
 
1
  # app.py (FastAPI server to host the Jina Embedding model)
 
2
  import os
3
  os.environ["HF_HOME"] = "/tmp/huggingface"
4
  os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
6
+
7
  from fastapi import FastAPI
8
  from pydantic import BaseModel
9
  from typing import List, Optional
 
30
  # -----------------------------
31
  class EmbedRequest(BaseModel):
32
  text: str
33
+ task: str = "retrieval"
34
  prompt_name: Optional[str] = None
35
+ return_token_embeddings: bool = True
36
+ truncate_dim: Optional[int] = None # for matryoshka embeddings
37
 
38
 
39
  class EmbedResponse(BaseModel):
40
+ embeddings: List[List[float]]
41
+
42
+
43
+ class EmbedImageRequest(BaseModel):
44
+ image: str
45
+ task: str = "retrieval"
46
+ return_multivector: bool = True
47
+ truncate_dim: Optional[int] = None
48
+
49
+
50
+ class EmbedImageResponse(BaseModel):
51
+ embeddings: List[List[float]]
52
 
53
 
54
  class TokenizeRequest(BaseModel):
 
68
 
69
 
70
  # -----------------------------
71
+ # Embedding Endpoint (text)
72
  # -----------------------------
73
  @app.post("/embed", response_model=EmbedResponse)
74
  def embed(req: EmbedRequest):
75
  text = req.text
76
 
77
+ # Case 1: Query β†’ pooled mean of multivectors
 
 
78
  if not req.return_token_embeddings:
79
  with torch.no_grad():
80
+ outputs = model.encode_text(
81
  texts=[text],
82
  task=req.task,
83
  prompt_name=req.prompt_name or "query",
84
+ return_multivector=True,
85
+ truncate_dim=req.truncate_dim,
86
  )
87
+ # outputs[0] = (num_vectors, hidden_dim)
88
+ pooled = outputs[0].mean(dim=0).cpu().tolist()
89
+ return {"embeddings": [pooled]}
90
 
91
+ # Case 2: Passage β†’ sliding window, token-level embeddings
 
 
92
  enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
93
+ input_ids = enc["input_ids"].squeeze(0).to(device)
94
  total_tokens = input_ids.size(0)
95
 
96
+ max_len = model.config.max_position_embeddings # ~32k
97
+ stride = 50
98
  embeddings = []
99
  position = 0
100
 
 
104
 
105
  with torch.no_grad():
106
  outputs = model.encode_text(
107
+ texts=[tokenizer.decode(window_ids[0])],
108
+ task=req.task,
109
+ prompt_name=req.prompt_name or "passage",
110
+ return_multivector=True,
111
+ truncate_dim=req.truncate_dim,
112
+ )
113
 
114
+ window_embeds = outputs[0].cpu()
115
 
 
116
  if position > 0:
117
  window_embeds = window_embeds[stride:]
118
 
119
  embeddings.append(window_embeds)
 
 
120
  position += max_len - stride
121
 
122
+ full_embeddings = torch.cat(embeddings, dim=0)
123
  return {"embeddings": full_embeddings}
124
 
125
 
126
+ # -----------------------------
127
+ # Embedding Endpoint (image)
128
+ # -----------------------------
129
+ @app.post("/embed_image", response_model=EmbedImageResponse)
130
+ def embed_image(req: EmbedImageRequest):
131
+ with torch.no_grad():
132
+ outputs = model.encode_image(
133
+ images=[req.image],
134
+ task=req.task,
135
+ return_multivector=req.return_multivector,
136
+ truncate_dim=req.truncate_dim,
137
+ )
138
+ pooled = outputs[0].mean(dim=0).cpu()
139
+ return {"embeddings": [pooled]}
140
+
141
+
142
  # -----------------------------
143
  # Tokenize Endpoint
144
  # -----------------------------
 
154
  @app.post("/decode", response_model=DecodeResponse)
155
  def decode(req: DecodeRequest):
156
  decoded = tokenizer.decode(req.input_ids)
157
+ return {"text": decoded}