File size: 4,298 Bytes
75148a1
 
 
 
 
 
 
 
 
76ce883
75148a1
 
 
 
76ce883
75148a1
 
 
76ce883
 
 
 
 
 
75148a1
 
 
 
 
 
1a11e20
75148a1
76ce883
 
 
75148a1
 
 
 
 
 
 
 
 
 
 
 
 
 
76ce883
 
 
 
 
 
 
75148a1
76ce883
 
75148a1
1a11e20
 
75148a1
 
 
 
 
b750381
75148a1
76ce883
 
 
75148a1
 
 
76ce883
75148a1
 
 
 
 
 
 
 
 
76ce883
 
 
 
 
 
 
75148a1
 
 
 
 
 
 
 
 
 
 
 
76ce883
 
 
75148a1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from typing import List, Tuple, TypedDict
from re import sub

from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer, logging
from transformers import AutoModelForQuestionAnswering, DPRReaderTokenizer, DPRReader
from transformers import QuestionAnsweringPipeline
from transformers import AutoTokenizer, PegasusXForConditionalGeneration, PegasusTokenizerFast
import torch

cuda = torch.cuda.is_available()
max_answer_len = 8
logging.set_verbosity_error()


@torch.inference_mode()
def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditionalGeneration,
                   input_texts: List[str]):
  inputs = tokenizer(input_texts, padding=True,
                     return_tensors='pt', truncation=True)
  if cuda:
    inputs = inputs.to(0)
    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
      summary_ids = model.generate(inputs["input_ids"])
  else:
    summary_ids = model.generate(inputs["input_ids"])
  summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True,
                                     clean_up_tokenization_spaces=False, batch_size=len(input_texts))
  return summaries


def get_summarizer(model_id="seonglae/resrer-pegasus-x") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]:
  tokenizer = PegasusTokenizerFast.from_pretrained(model_id)
  model = PegasusXForConditionalGeneration.from_pretrained(model_id)
  if cuda:
    model = model.to(0)
  model = torch.compile(model)
  return tokenizer, model


class AnswerInfo(TypedDict):
  score: float
  start: int
  end: int
  answer: str


@torch.inference_mode()
def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering,
               questions: List[str], ctxs: List[str]) -> List[AnswerInfo]:
  if cuda:
    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
      pipeline = QuestionAnsweringPipeline(
          model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len)
      answer_infos: List[AnswerInfo] = pipeline(
          question=questions, context=ctxs)
  else:
    pipeline = QuestionAnsweringPipeline(
        model=model, tokenizer=tokenizer, device='cpu', max_answer_len=max_answer_len)
    answer_infos = pipeline(
        question=questions, context=ctxs)
  if not isinstance(answer_infos, list):
    answer_infos = [answer_infos]
  for answer_info in answer_infos:
    answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer'])
  return answer_infos


def get_reader(model_id="facebook/dpr-reader-single-nq-base"):
  tokenizer = DPRReaderTokenizer.from_pretrained(model_id)
  model = DPRReader.from_pretrained(model_id)
  if cuda:
    model = model.to(0)
  return tokenizer, model


@torch.inference_mode()
def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuestionEncoder, questions: List[str]) -> torch.FloatTensor:
  """Encode a question using DPR question encoder.
  https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder

  Args:
      question (str): question string to encode
      model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
  """
  batch_dict = tokenizer(questions, return_tensors="pt",
                         padding=True, truncation=True)
  if cuda:
    batch_dict = batch_dict.to(0)
    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
      embeddings: torch.FloatTensor = model(**batch_dict).pooler_output
  else:
    embeddings = model(**batch_dict).pooler_output
  return embeddings


def get_dpr_encoder(model_id="facebook/dpr-question_encoder-single-nq-base") -> Tuple[DPRQuestionEncoder, DPRQuestionEncoderTokenizer]:
  """Encode a question using DPR question encoder.
  https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder

  Args:
      question (str): question string to encode
      model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
  """
  tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_id)
  model = DPRQuestionEncoder.from_pretrained(model_id)
  if cuda:
    model = model.to(0)
  return tokenizer, model