Spaces:
Runtime error
Runtime error
Commit
·
256abb9
1
Parent(s):
c5fa72e
Update modeling.py
Browse files- modeling.py +2 -0
modeling.py
CHANGED
@@ -59,6 +59,8 @@ class EmbeddingMixer(nn.Module):
|
|
59 |
self.image_start_index = len(original_embedding)
|
60 |
|
61 |
def set_image_embeddings(self, image_embeddings):
|
|
|
|
|
62 |
end_index = self.image_start_index + image_embeddings.shape[0]
|
63 |
self.embedding[self.image_start_index:end_index] = image_embeddings
|
64 |
|
|
|
59 |
self.image_start_index = len(original_embedding)
|
60 |
|
61 |
def set_image_embeddings(self, image_embeddings):
|
62 |
+
if len(image_embeddings.shape) == 3:
|
63 |
+
image_embeddings = image_embeddings.squeeze(0) # remove batch dim
|
64 |
end_index = self.image_start_index + image_embeddings.shape[0]
|
65 |
self.embedding[self.image_start_index:end_index] = image_embeddings
|
66 |
|