justinhl commited on
Commit
8afa796
1 Parent(s): f254d2a

Upload 2 files

Browse files
Files changed (1) hide show
  1. hybrid_model.py +84 -0
hybrid_model.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering
2
+ from transformers import pipeline, PretrainedConfig
3
+
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import os
10
+ import json
11
+
12
+ class HybridQAModel(nn.Module, PyTorchModelHubMixin):
13
+ #config_class = HybridQAConfig
14
+ def __init__(self, config):
15
+ super().__init__()
16
+ self.config = config
17
+ self.load_models(config.extractive_id, config.generative_id)
18
+
19
+ def can_generate(self):
20
+ return False
21
+
22
+ def load_models(self, extractive_id, generative_id):
23
+ self.tokenizer_extractive = AutoTokenizer.from_pretrained(extractive_id)
24
+ self.tokenizer_generative = AutoTokenizer.from_pretrained(generative_id)
25
+
26
+ self.model_extractive = AutoModelForQuestionAnswering.from_pretrained(extractive_id)
27
+ self.model_generative = AutoModelForSeq2SeqLM.from_pretrained(generative_id)
28
+
29
+ def predict(self, question, context):
30
+ result_gen, conf_gen = self.infer_generative(self.model_generative, self.tokenizer_generative, question)
31
+ result_ext, conf_ext = self.infer_extractive(self.model_extractive, self.tokenizer_extractive, question, context)
32
+
33
+ # print("Generative result: ",result_gen)
34
+ # print("Confidence: ", conf_gen)
35
+ # print("Extractive result: ", result_ext)
36
+ # print("Confidence: ", conf_ext)
37
+
38
+ if conf_gen > conf_ext:
39
+ return {'guess':result_gen, 'confidence':conf_gen}
40
+ else:
41
+ return {'guess':result_ext, 'confidence':conf_ext}
42
+
43
+ def infer_generative(self, model, tokenizer, input_text, **generate_kwargs):
44
+ max_input_length = min(tokenizer.model_max_length, model.config.max_length)
45
+ input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=max_input_length)
46
+
47
+ with torch.no_grad():
48
+ output_ids = model.generate(input_ids, **generate_kwargs)
49
+ decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
50
+
51
+ output_probs = F.softmax(output_ids.float(), dim=-1).squeeze(0)
52
+ entropy = -(output_probs * torch.log(output_probs)).sum(dim=-1)
53
+ confidence_score = 1 - entropy.item()
54
+
55
+ model.save_pretrained("./base_models")
56
+ return decoded_output, confidence_score
57
+
58
+ def infer_extractive(self, model, tokenizer, question, context):
59
+ qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
60
+ result = qa_pipeline(question=question, context=context)
61
+ confidence_score = result['score']
62
+
63
+ model.save_pretrained("./base_models")
64
+ return result['answer'], confidence_score
65
+
66
+ def save_pretrained(self, save_directory, **kwargs):
67
+ if not os.path.exists(save_directory):
68
+ os.makedirs(save_directory, exist_ok=True)
69
+ self.config.save_pretrained(save_directory, **kwargs)
70
+ self.model_extractive.save_pretrained(save_directory, **kwargs)
71
+ self.tokenizer_extractive.save_pretrained(save_directory, **kwargs)
72
+ self.model_generative.save_pretrained(save_directory, **kwargs)
73
+ self.tokenizer_generative.save_pretrained(save_directory, **kwargs)
74
+
75
+ def from_pretrained(cls, save_directory, *model_args, **model_kwargs):
76
+ config = PretrainedConfig.from_pretrained(save_directory)
77
+ model = HybridQAModel(config)
78
+
79
+ model.model_extractive = AutoModelForQuestionAnswering.from_pretrained(save_directory)
80
+ model.tokenizer_extractive = AutoTokenizer.from_pretrained(save_directory)
81
+ model.model_generative = AutoModelForSeq2SeqLM.from_pretrained(save_directory)
82
+ model.tokenizer_generative = AutoTokenizer.from_pretrained(save_directory)
83
+
84
+ return model