Update README.md
Browse files
README.md
CHANGED
@@ -179,7 +179,7 @@ outputs = model(**batch_dict)
|
|
179 |
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
180 |
|
181 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
182 |
-
similarity_matrix = (embeddings[:len(queries)] @ embeddings[len(queries):].T)
|
183 |
print(similarity_matrix.shape)
|
184 |
# (3, 6)
|
185 |
print(np.round(similarity_matrix, 2))
|
|
|
179 |
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
180 |
|
181 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
182 |
+
similarity_matrix = (embeddings[:len(queries)] @ embeddings[len(queries):].T).detach().numpy()
|
183 |
print(similarity_matrix.shape)
|
184 |
# (3, 6)
|
185 |
print(np.round(similarity_matrix, 2))
|