fix-tuncate
#47
by
bwang0911
- opened
- modeling_xlm_roberta.py +19 -14
modeling_xlm_roberta.py
CHANGED
@@ -600,7 +600,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
600 |
|
601 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
602 |
if truncate_dim:
|
603 |
-
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
604 |
|
605 |
if convert_to_tensor:
|
606 |
all_embeddings = torch.stack(all_embeddings)
|
@@ -613,19 +613,24 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
613 |
self.train(is_training)
|
614 |
return all_embeddings
|
615 |
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
|
|
|
|
|
|
|
|
|
|
629 |
|
630 |
def mean_pooling(
|
631 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
|
|
600 |
|
601 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
602 |
if truncate_dim:
|
603 |
+
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim, normalize_embeddings)
|
604 |
|
605 |
if convert_to_tensor:
|
606 |
all_embeddings = torch.stack(all_embeddings)
|
|
|
613 |
self.train(is_training)
|
614 |
return all_embeddings
|
615 |
|
616 |
+
def truncate_embeddings(self, embeddings, truncate_dim, normalize_embeddings):
|
617 |
+
if not self.config.matryoshka_dimensions:
|
618 |
+
logger.warning(
|
619 |
+
"Matryoshka embeddings are not supported, so dimension truncation will not be performed."
|
620 |
+
)
|
621 |
+
return embeddings
|
622 |
+
elif truncate_dim in self.config.matryoshka_dimensions:
|
623 |
+
truncated_embeddings = [tensor[:truncate_dim] for tensor in embeddings]
|
624 |
+
if normalize_embeddings:
|
625 |
+
truncated_embeddings = [
|
626 |
+
torch.nn.functional.normalize(tensor, p=2, dim=0) for tensor in truncated_embeddings
|
627 |
+
]
|
628 |
+
return truncated_embeddings
|
629 |
+
else:
|
630 |
+
raise ValueError(
|
631 |
+
f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
|
632 |
+
f"Supported dimensions are {self.config.matryoshka_dimensions}."
|
633 |
+
)
|
634 |
|
635 |
def mean_pooling(
|
636 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|