Islam YAHIAOUI
Update space
1e4288a
raw
history blame contribute delete
No virus
1.59 kB
from transformers import BertTokenizer, BertModel
import torch
class TextEmbedder:
def __init__(self):
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.model = BertModel.from_pretrained('bert-base-uncased')
def _mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
def embed_text(self, examples):
inputs = self.tokenizer(
examples["content"], padding=True, truncation=True, return_tensors="pt"
)
with torch.no_grad():
model_output = self.model(**inputs)
pooled_embeds = self._mean_pooling(model_output, inputs["attention_mask"])
return {"embedding": pooled_embeds.cpu().numpy()}
def generate_embeddings(self, dataset):
return dataset.map(self.embed_text, batched=True, batch_size=128)
def embed_query(self, query_text):
query_inputs = self.tokenizer(
query_text,
padding=True,
truncation=True,
return_tensors="pt"
)
with torch.no_grad():
query_model_output = self.model(**query_inputs)
query_embedding = self._mean_pooling(query_model_output, query_inputs["attention_mask"])
return query_embedding