from transformers import DPRReader, DPRReaderTokenizer from typing import List, Dict, Tuple from dotenv import load_dotenv from src.readers.base_reader import Reader load_dotenv() class DprReader(Reader): def __init__(self) -> None: self._tokenizer = DPRReaderTokenizer.from_pretrained( "facebook/dpr-reader-single-nq-base") self._model = DPRReader.from_pretrained( "facebook/dpr-reader-single-nq-base") def read(self, query: str, context: Dict[str, List[str]], num_answers=5) -> List[Tuple]: encoded_inputs = self._tokenizer( questions=query, titles=context['titles'], texts=context['texts'], return_tensors='pt', truncation=True, padding=True ) outputs = self._model(**encoded_inputs) predicted_spans = self._tokenizer.decode_best_spans( encoded_inputs, outputs, num_spans=num_answers, num_spans_per_passage=2, max_answer_length=512) return predicted_spans