howard-hou commited on
Commit
256abb9
·
1 Parent(s): c5fa72e

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +2 -0
modeling.py CHANGED
@@ -59,6 +59,8 @@ class EmbeddingMixer(nn.Module):
59
  self.image_start_index = len(original_embedding)
60
 
61
  def set_image_embeddings(self, image_embeddings):
 
 
62
  end_index = self.image_start_index + image_embeddings.shape[0]
63
  self.embedding[self.image_start_index:end_index] = image_embeddings
64
 
 
59
  self.image_start_index = len(original_embedding)
60
 
61
  def set_image_embeddings(self, image_embeddings):
62
+ if len(image_embeddings.shape) == 3:
63
+ image_embeddings = image_embeddings.squeeze(0) # remove batch dim
64
  end_index = self.image_start_index + image_embeddings.shape[0]
65
  self.embedding[self.image_start_index:end_index] = image_embeddings
66