Ramon Meffert
Add evaluation
492106d
raw
history blame
1.48 kB
import torch
from transformers import (
LongformerTokenizer,
LongformerForQuestionAnswering
)
from typing import List, Dict, Tuple
from dotenv import load_dotenv
from src.readers.base_reader import Reader
load_dotenv()
class LongformerReader(Reader):
def __init__(self) -> None:
checkpoint = "valhalla/longformer-base-4096-finetuned-squadv1"
self.tokenizer = LongformerTokenizer.from_pretrained(checkpoint)
self.model = LongformerForQuestionAnswering.from_pretrained(checkpoint)
def read(self,
query: str,
context: Dict[str, List[str]],
num_answers=5) -> List[Tuple]:
answers = []
for text in context['texts'][:num_answers]:
encoding = self.tokenizer(query, text, return_tensors="pt")
input_ids = encoding["input_ids"]
attention_mask = encoding["attention_mask"]
outputs = self.model(input_ids, attention_mask=attention_mask)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
all_tokens = self.tokenizer.convert_ids_to_tokens(
input_ids[0].tolist())
answer_tokens = all_tokens[
torch.argmax(start_logits):torch.argmax(end_logits) + 1]
answer = self.tokenizer.decode(
self.tokenizer.convert_tokens_to_ids(answer_tokens)
)
answers.append([answer, [], []])
return answers