resrer-pegasus-x / reader.py
seonglae's picture
Training in progress, step 500
a9082f6
raw
history blame contribute delete
No virus
1.24 kB
from typing import TypedDict, List, Dict
from re import sub
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, DPRReaderTokenizer, DPRReader, logging
from transformers import QuestionAnsweringPipeline
max_answer_len = 8
logging.set_verbosity_error()
class AnswerInfo(TypedDict):
score: float
start: int
end: int
answer: str
@torch.inference_mode()
def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering,
questions: List[str], ctxs: List[str]) -> List[AnswerInfo]:
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
pipeline = QuestionAnsweringPipeline(
model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len)
answer_infos: List[AnswerInfo] = pipeline(
question=questions, context=ctxs)
for answer_info in answer_infos:
answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer'])
return answer_infos
def get_reader(model_id="mrm8488/longformer-base-4096-finetuned-squadv2"):
tokenizer = DPRReaderTokenizer.from_pretrained(model_id)
model = DPRReader.from_pretrained(model_id).to(0)
return tokenizer, model