radames HF staff commited on
Commit
021ceab
1 Parent(s): 9663a4b

update comment, disable warning msg

Browse files
Files changed (1) hide show
  1. embeddings_encoder.py +5 -3
embeddings_encoder.py CHANGED
@@ -1,8 +1,9 @@
1
- # from https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1
2
  from transformers import AutoTokenizer, AutoModel
3
  import torch
4
  import torch.nn.functional as F
5
-
 
6
 
7
  class EmbeddingsEncoder:
8
  def __init__(self):
@@ -17,7 +18,8 @@ class EmbeddingsEncoder:
17
  def mean_pooling(self, model_output, attention_mask):
18
  # First element of model_output contains all token embeddings
19
  token_embeddings = model_output.last_hidden_state
20
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
 
21
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
22
 
23
  # Encode text
1
+ # from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
2
  from transformers import AutoTokenizer, AutoModel
3
  import torch
4
  import torch.nn.functional as F
5
+ import os
6
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
7
 
8
  class EmbeddingsEncoder:
9
  def __init__(self):
18
  def mean_pooling(self, model_output, attention_mask):
19
  # First element of model_output contains all token embeddings
20
  token_embeddings = model_output.last_hidden_state
21
+ input_mask_expanded = attention_mask.unsqueeze(
22
+ -1).expand(token_embeddings.size()).float()
23
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
24
 
25
  # Encode text