Abdul-Ib commited on
Commit
4afd3d6
1 Parent(s): 3d42900

Upload sentenceTranformer.py

Browse files
Files changed (1) hide show
  1. sentenceTranformer.py +31 -0
sentenceTranformer.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline
2
+ import torch.nn.functional as F
3
+ import torch
4
+
5
+ # copied from the model card
6
+ def mean_pooling(model_output, attention_mask):
7
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
8
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
9
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
10
+
11
+
12
+ class SentenceEmbeddingPipeline(Pipeline):
13
+ def _sanitize_parameters(self, **kwargs):
14
+ # we don't have any hyperameters to sanitize
15
+ preprocess_kwargs = {}
16
+ return preprocess_kwargs, {}, {}
17
+
18
+ def preprocess(self, inputs):
19
+ encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='pt')
20
+ return encoded_inputs
21
+
22
+ def _forward(self, model_inputs):
23
+ outputs = self.model(**model_inputs)
24
+ return {"outputs": outputs, "attention_mask": model_inputs["attention_mask"]}
25
+
26
+ def postprocess(self, model_outputs):
27
+ # Perform pooling
28
+ sentence_embeddings = mean_pooling(model_outputs["outputs"], model_outputs['attention_mask'])
29
+ # Normalize embeddings
30
+ sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
31
+ return sentence_embeddings