foxxy-hm commited on
Commit
2fe0a0f
1 Parent(s): 8a73eb1

Update models/qa_model.py

Browse files
Files changed (1) hide show
  1. models/qa_model.py +55 -19
models/qa_model.py CHANGED
@@ -1,7 +1,7 @@
1
  import numpy as np
2
  import torch
3
  import torch.nn as nn
4
- from transformers import AutoModelForQuestionAnswering, pipeline
5
  from features.text_utils import post_process_answer
6
  from features.graph_utils import find_best_cluster
7
  from optimum.onnxruntime import ORTModelForQuestionAnswering
@@ -11,13 +11,18 @@ class QAEnsembleModel(nn.Module):
11
  def __init__(self, model_name, model_checkpoints, entity_dict,
12
  thr=0.1, device="cpu"):
13
  super(QAEnsembleModel, self).__init__()
14
- self.nlps = []
 
 
15
  for model_checkpoint in model_checkpoints:
16
  model = ORTModelForQuestionAnswering.from_pretrained(model_name, from_transformers=True)#.half()
17
  model.load_state_dict(torch.load(model_checkpoint, map_location=torch.device('cpu')), strict=False)
18
- nlp = pipeline('question-answering', model=model,
19
- tokenizer=model_name, device=device)
20
- self.nlps.append(nlp)
 
 
 
21
  self.entity_dict = entity_dict
22
  self.thr = thr
23
 
@@ -28,22 +33,53 @@ class QAEnsembleModel(nn.Module):
28
  curr_answers = []
29
  curr_scores = []
30
  best_score = 0
31
- for i, nlp in enumerate(self.nlps):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  for text, score in zip(texts, ranking_scores):
33
- QA_input = {
34
- 'question': question,
35
- 'context': text
36
- }
37
- res = nlp(QA_input)
38
- # print(res)
39
- if res["score"] > self.thr:
40
- curr_answers.append(res["answer"])
41
- curr_scores.append(res["score"])
42
- res["score"] = res["score"] * score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  if i == 0:
44
- if res["score"] > best_score:
45
- answer = res["answer"]
46
- best_score = res["score"]
47
  if len(curr_answers) == 0:
48
  return None
49
  curr_answers = [post_process_answer(x, self.entity_dict) for x in curr_answers]
 
1
  import numpy as np
2
  import torch
3
  import torch.nn as nn
4
+ # from transformers import AutoModelForQuestionAnswering, pipeline
5
  from features.text_utils import post_process_answer
6
  from features.graph_utils import find_best_cluster
7
  from optimum.onnxruntime import ORTModelForQuestionAnswering
 
11
  def __init__(self, model_name, model_checkpoints, entity_dict,
12
  thr=0.1, device="cpu"):
13
  super(QAEnsembleModel, self).__init__()
14
+ # self.nlps = []
15
+ self.models = []
16
+ self.tokenizers = []
17
  for model_checkpoint in model_checkpoints:
18
  model = ORTModelForQuestionAnswering.from_pretrained(model_name, from_transformers=True)#.half()
19
  model.load_state_dict(torch.load(model_checkpoint, map_location=torch.device('cpu')), strict=False)
20
+ # nlp = pipeline('question-answering', model=model,
21
+ # tokenizer=model_name, device=device)
22
+ # self.nlps.append(nlp)
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+ self.models.append(model)
25
+ self.tokenizers.append(tokenizer)
26
  self.entity_dict = entity_dict
27
  self.thr = thr
28
 
 
33
  curr_answers = []
34
  curr_scores = []
35
  best_score = 0
36
+ # for i, nlp in enumerate(self.nlps):
37
+ # for text, score in zip(texts, ranking_scores):
38
+ # QA_input = {
39
+ # 'question': question,
40
+ # 'context': text
41
+ # }
42
+ # res = nlp(QA_input)
43
+ # # print(res)
44
+ # if res["score"] > self.thr:
45
+ # curr_answers.append(res["answer"])
46
+ # curr_scores.append(res["score"])
47
+ # res["score"] = res["score"] * score
48
+ # if i == 0:
49
+ # if res["score"] > best_score:
50
+ # answer = res["answer"]
51
+ # best_score = res["score"]
52
+
53
+ for i, (model, tokenizer) in enumerate(zip(self.models, self.tokenizers)):
54
  for text, score in zip(texts, ranking_scores):
55
+ # Encode the question and context as input ids and attention mask
56
+ inputs = tokenizer(question, text, return_tensors="pt")
57
+ input_ids = inputs["input_ids"]
58
+ attention_mask = inputs["attention_mask"]
59
+ # Get the start and end logits from the model
60
+ outputs = model(input_ids, attention_mask=attention_mask)
61
+ start_logits = outputs.start_logits
62
+ end_logits = outputs.end_logits
63
+ # Get the most likely start and end indices
64
+ start_idx = torch.argmax(start_logits)
65
+ end_idx = torch.argmax(end_logits)
66
+ # Get the answer span from the input ids
67
+ answer_ids = input_ids[0][start_idx:end_idx+1]
68
+ # Decode the answer ids to get the answer text
69
+ answer_text = tokenizer.decode(answer_ids)
70
+ # Get the answer score from the start and end logits
71
+ answer_score = torch.max(start_logits) + torch.max(end_logits)
72
+ # Convert to numpy values
73
+ answer_text = answer_text.numpy()
74
+ answer_score = answer_score.numpy()
75
+ if answer_score > self.thr:
76
+ curr_answers.append(answer_text)
77
+ curr_scores.append(answer_score)
78
+ answer_score = answer_score * score
79
  if i == 0:
80
+ if answer_score > best_score:
81
+ answer = answer_text
82
+ best_score = answer_score
83
  if len(curr_answers) == 0:
84
  return None
85
  curr_answers = [post_process_answer(x, self.entity_dict) for x in curr_answers]