File size: 1,588 Bytes
1e4288a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from transformers import BertTokenizer, BertModel
import torch


class TextEmbedder:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.model = BertModel.from_pretrained('bert-base-uncased')

    def _mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def embed_text(self, examples):
        inputs = self.tokenizer(
            examples["content"], padding=True, truncation=True, return_tensors="pt"
        )
        with torch.no_grad():
            model_output = self.model(**inputs)
        pooled_embeds = self._mean_pooling(model_output, inputs["attention_mask"])
        return {"embedding": pooled_embeds.cpu().numpy()}
    
    def generate_embeddings(self, dataset):
        return dataset.map(self.embed_text, batched=True, batch_size=128)
    
    def embed_query(self, query_text):
        query_inputs = self.tokenizer(
            query_text,
            padding=True,
            truncation=True,
            return_tensors="pt"
        )

        with torch.no_grad():
            query_model_output = self.model(**query_inputs)

        query_embedding = self._mean_pooling(query_model_output, query_inputs["attention_mask"])

        return query_embedding