import torch from transformers import ( LongformerTokenizer, LongformerForQuestionAnswering ) from typing import List, Dict, Tuple from dotenv import load_dotenv from src.readers.base_reader import Reader load_dotenv() class LongformerReader(Reader): def __init__(self) -> None: checkpoint = "valhalla/longformer-base-4096-finetuned-squadv1" self.tokenizer = LongformerTokenizer.from_pretrained(checkpoint) self.model = LongformerForQuestionAnswering.from_pretrained(checkpoint) def read(self, query: str, context: Dict[str, List[str]], num_answers=5) -> List[Tuple]: answers = [] for text in context['texts'][:num_answers]: encoding = self.tokenizer(query, text, return_tensors="pt") input_ids = encoding["input_ids"] attention_mask = encoding["attention_mask"] outputs = self.model(input_ids, attention_mask=attention_mask) start_logits = outputs.start_logits end_logits = outputs.end_logits all_tokens = self.tokenizer.convert_ids_to_tokens( input_ids[0].tolist()) answer_tokens = all_tokens[ torch.argmax(start_logits):torch.argmax(end_logits) + 1] answer = self.tokenizer.decode( self.tokenizer.convert_tokens_to_ids(answer_tokens) ) answers.append([answer, [], []]) return answers