Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -80,7 +80,7 @@ def embed(req: EmbedRequest):
80
  # Case 1: Query → mean pool across token embeddings
81
  # -----------------------------
82
  if (req.prompt_name or "").lower() == "query":
83
- with torch.inference_mode():
84
  outputs = model.encode_text(
85
  texts=[text],
86
  task=req.task,
@@ -107,7 +107,7 @@ def embed(req: EmbedRequest):
107
  end = min(position + max_len, total_tokens)
108
  window_ids = input_ids[position:end].unsqueeze(0).to(device)
109
 
110
- with torch.inference_mode():
111
  outputs = model.encode_text(
112
  texts=[tokenizer.decode(window_ids[0])],
113
  task=req.task,
@@ -139,7 +139,7 @@ def embed(req: EmbedRequest):
139
  @app.post("/embed_image", response_model=EmbedImageResponse)
140
  def embed_image(req: EmbedImageRequest):
141
  try:
142
- with torch.inference_mode():
143
  outputs = model.encode_image(
144
  images=[req.image],
145
  task=req.task,
 
80
  # Case 1: Query → mean pool across token embeddings
81
  # -----------------------------
82
  if (req.prompt_name or "").lower() == "query":
83
+ with torch.no_grad():
84
  outputs = model.encode_text(
85
  texts=[text],
86
  task=req.task,
 
107
  end = min(position + max_len, total_tokens)
108
  window_ids = input_ids[position:end].unsqueeze(0).to(device)
109
 
110
+ with torch.no_grad():
111
  outputs = model.encode_text(
112
  texts=[tokenizer.decode(window_ids[0])],
113
  task=req.task,
 
139
  @app.post("/embed_image", response_model=EmbedImageResponse)
140
  def embed_image(req: EmbedImageRequest):
141
  try:
142
+ with torch.no_grad():
143
  outputs = model.encode_image(
144
  images=[req.image],
145
  task=req.task,