Files changed (1) hide show
  1. modeling_cased.py +8 -3
modeling_cased.py CHANGED
@@ -212,6 +212,8 @@ class CaSEDModel(PreTrainedModel):
212
 
213
  vocabularies, samples_p = [], []
214
  for image_z in images_z:
 
 
215
  # generate a single text embedding from the unfiltered vocabulary
216
  vocabulary = self.query_index(image_z)
217
  text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
@@ -219,6 +221,9 @@ class CaSEDModel(PreTrainedModel):
219
  text["attention_mask"] = text["attention_mask"][:, :77].to(self.device)
220
  text_z = self.language_encoder(**text)[1]
221
  text_z = self.language_proj(text_z)
 
 
 
222
 
223
  # filter the vocabulary, embed it, and get its mean embedding
224
  vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
@@ -231,8 +236,8 @@ class CaSEDModel(PreTrainedModel):
231
  # get the image and text predictions
232
  image_z = image_z / image_z.norm(dim=-1, keepdim=True)
233
  text_z = text_z / text_z.norm(dim=-1, keepdim=True)
234
- image_p = (torch.matmul(image_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1)
235
- text_p = (torch.matmul(text_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1)
236
 
237
  # average the image and text predictions
238
  alpha = alpha or self.hparams["alpha"]
@@ -244,7 +249,7 @@ class CaSEDModel(PreTrainedModel):
244
 
245
  # get the scores
246
  samples_p = torch.stack(samples_p, dim=0)
247
- scores = sample_p.cpu().tolist()
248
 
249
  # define the results
250
  results = {"vocabularies": vocabularies, "scores": scores}
 
212
 
213
  vocabularies, samples_p = [], []
214
  for image_z in images_z:
215
+ image_z = image_z.unsqueeze(0)
216
+
217
  # generate a single text embedding from the unfiltered vocabulary
218
  vocabulary = self.query_index(image_z)
219
  text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
 
221
  text["attention_mask"] = text["attention_mask"][:, :77].to(self.device)
222
  text_z = self.language_encoder(**text)[1]
223
  text_z = self.language_proj(text_z)
224
+ text_z = text_z / text_z.norm(dim=-1, keepdim=True)
225
+ text_z = text_z.mean(dim=0).unsqueeze(0)
226
+ text_z = text_z / text_z.norm(dim=-1, keepdim=True)
227
 
228
  # filter the vocabulary, embed it, and get its mean embedding
229
  vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
 
236
  # get the image and text predictions
237
  image_z = image_z / image_z.norm(dim=-1, keepdim=True)
238
  text_z = text_z / text_z.norm(dim=-1, keepdim=True)
239
+ image_p = (self.logit_scale * image_z @ vocabulary_z.T).softmax(dim=-1)
240
+ text_p = (self.logit_scale * text_z @ vocabulary_z.T).softmax(dim=-1)
241
 
242
  # average the image and text predictions
243
  alpha = alpha or self.hparams["alpha"]
 
249
 
250
  # get the scores
251
  samples_p = torch.stack(samples_p, dim=0)
252
+ scores = sample_p.cpu()
253
 
254
  # define the results
255
  results = {"vocabularies": vocabularies, "scores": scores}