Files changed (1) hide show
  1. modeling_xlm_roberta.py +19 -14
modeling_xlm_roberta.py CHANGED
@@ -600,7 +600,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
600
 
601
  truncate_dim = truncate_dim or self.config.truncate_dim
602
  if truncate_dim:
603
- all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
604
 
605
  if convert_to_tensor:
606
  all_embeddings = torch.stack(all_embeddings)
@@ -613,19 +613,24 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
613
  self.train(is_training)
614
  return all_embeddings
615
 
616
- def truncate_embeddings(self, embeddings, truncate_dim):
617
- if not self.config.matryoshka_dimensions:
618
- logger.warning(
619
- "Matryoshka embeddings are not supported, so dimension truncation will not be performed."
620
- )
621
- return embeddings
622
- elif truncate_dim in self.config.matryoshka_dimensions:
623
- return [tensor[:truncate_dim] for tensor in embeddings]
624
- else:
625
- raise ValueError(
626
- f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
627
- f"Supported dimensions are {self.config.matryoshka_dimensions}."
628
- )
 
 
 
 
 
629
 
630
  def mean_pooling(
631
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
 
600
 
601
  truncate_dim = truncate_dim or self.config.truncate_dim
602
  if truncate_dim:
603
+ all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim, normalize_embeddings)
604
 
605
  if convert_to_tensor:
606
  all_embeddings = torch.stack(all_embeddings)
 
613
  self.train(is_training)
614
  return all_embeddings
615
 
616
+ def truncate_embeddings(self, embeddings, truncate_dim, normalize_embeddings):
617
+ if not self.config.matryoshka_dimensions:
618
+ logger.warning(
619
+ "Matryoshka embeddings are not supported, so dimension truncation will not be performed."
620
+ )
621
+ return embeddings
622
+ elif truncate_dim in self.config.matryoshka_dimensions:
623
+ truncated_embeddings = [tensor[:truncate_dim] for tensor in embeddings]
624
+ if normalize_embeddings:
625
+ truncated_embeddings = [
626
+ torch.nn.functional.normalize(tensor, p=2, dim=0) for tensor in truncated_embeddings
627
+ ]
628
+ return truncated_embeddings
629
+ else:
630
+ raise ValueError(
631
+ f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
632
+ f"Supported dimensions are {self.config.matryoshka_dimensions}."
633
+ )
634
 
635
  def mean_pooling(
636
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor