jgrosjean commited on
Commit
1e41992
1 Parent(s): 786feed

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -3
README.md CHANGED
@@ -50,7 +50,7 @@ def generate_sentence_embedding(sentence, language):
50
  model.set_default_language("it_CH")
51
  if "rm" in language:
52
  model.set_default_language("rm_CH")
53
-
54
  # Tokenize input sentence
55
  inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors="pt", max_length=512)
56
 
@@ -58,8 +58,12 @@ def generate_sentence_embedding(sentence, language):
58
  with torch.no_grad():
59
  outputs = model(**inputs)
60
 
61
- # Extract average sentence embeddings from the last hidden layer
62
- embedding = outputs.last_hidden_state.mean(dim=1)
 
 
 
 
63
 
64
  return embedding
65
 
 
50
  model.set_default_language("it_CH")
51
  if "rm" in language:
52
  model.set_default_language("rm_CH")
53
+
54
  # Tokenize input sentence
55
  inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors="pt", max_length=512)
56
 
 
58
  with torch.no_grad():
59
  outputs = model(**inputs)
60
 
61
+ # Extract sentence embeddings via mean pooling
62
+ token_embeddings = outputs.last_hidden_state
63
+ attention_mask = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
64
+ sum_embeddings = torch.sum(token_embeddings * attention_mask, 1)
65
+ sum_mask = torch.clamp(attention_mask.sum(1), min=1e-9)
66
+ embedding = sum_embeddings / sum_mask
67
 
68
  return embedding
69