Upload 2 files
Browse files- hybrid_model.py +84 -0
hybrid_model.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering
|
2 |
+
from transformers import pipeline, PretrainedConfig
|
3 |
+
|
4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import os
|
10 |
+
import json
|
11 |
+
|
12 |
+
class HybridQAModel(nn.Module, PyTorchModelHubMixin):
|
13 |
+
#config_class = HybridQAConfig
|
14 |
+
def __init__(self, config):
|
15 |
+
super().__init__()
|
16 |
+
self.config = config
|
17 |
+
self.load_models(config.extractive_id, config.generative_id)
|
18 |
+
|
19 |
+
def can_generate(self):
|
20 |
+
return False
|
21 |
+
|
22 |
+
def load_models(self, extractive_id, generative_id):
|
23 |
+
self.tokenizer_extractive = AutoTokenizer.from_pretrained(extractive_id)
|
24 |
+
self.tokenizer_generative = AutoTokenizer.from_pretrained(generative_id)
|
25 |
+
|
26 |
+
self.model_extractive = AutoModelForQuestionAnswering.from_pretrained(extractive_id)
|
27 |
+
self.model_generative = AutoModelForSeq2SeqLM.from_pretrained(generative_id)
|
28 |
+
|
29 |
+
def predict(self, question, context):
|
30 |
+
result_gen, conf_gen = self.infer_generative(self.model_generative, self.tokenizer_generative, question)
|
31 |
+
result_ext, conf_ext = self.infer_extractive(self.model_extractive, self.tokenizer_extractive, question, context)
|
32 |
+
|
33 |
+
# print("Generative result: ",result_gen)
|
34 |
+
# print("Confidence: ", conf_gen)
|
35 |
+
# print("Extractive result: ", result_ext)
|
36 |
+
# print("Confidence: ", conf_ext)
|
37 |
+
|
38 |
+
if conf_gen > conf_ext:
|
39 |
+
return {'guess':result_gen, 'confidence':conf_gen}
|
40 |
+
else:
|
41 |
+
return {'guess':result_ext, 'confidence':conf_ext}
|
42 |
+
|
43 |
+
def infer_generative(self, model, tokenizer, input_text, **generate_kwargs):
|
44 |
+
max_input_length = min(tokenizer.model_max_length, model.config.max_length)
|
45 |
+
input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=max_input_length)
|
46 |
+
|
47 |
+
with torch.no_grad():
|
48 |
+
output_ids = model.generate(input_ids, **generate_kwargs)
|
49 |
+
decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
50 |
+
|
51 |
+
output_probs = F.softmax(output_ids.float(), dim=-1).squeeze(0)
|
52 |
+
entropy = -(output_probs * torch.log(output_probs)).sum(dim=-1)
|
53 |
+
confidence_score = 1 - entropy.item()
|
54 |
+
|
55 |
+
model.save_pretrained("./base_models")
|
56 |
+
return decoded_output, confidence_score
|
57 |
+
|
58 |
+
def infer_extractive(self, model, tokenizer, question, context):
|
59 |
+
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
|
60 |
+
result = qa_pipeline(question=question, context=context)
|
61 |
+
confidence_score = result['score']
|
62 |
+
|
63 |
+
model.save_pretrained("./base_models")
|
64 |
+
return result['answer'], confidence_score
|
65 |
+
|
66 |
+
def save_pretrained(self, save_directory, **kwargs):
|
67 |
+
if not os.path.exists(save_directory):
|
68 |
+
os.makedirs(save_directory, exist_ok=True)
|
69 |
+
self.config.save_pretrained(save_directory, **kwargs)
|
70 |
+
self.model_extractive.save_pretrained(save_directory, **kwargs)
|
71 |
+
self.tokenizer_extractive.save_pretrained(save_directory, **kwargs)
|
72 |
+
self.model_generative.save_pretrained(save_directory, **kwargs)
|
73 |
+
self.tokenizer_generative.save_pretrained(save_directory, **kwargs)
|
74 |
+
|
75 |
+
def from_pretrained(cls, save_directory, *model_args, **model_kwargs):
|
76 |
+
config = PretrainedConfig.from_pretrained(save_directory)
|
77 |
+
model = HybridQAModel(config)
|
78 |
+
|
79 |
+
model.model_extractive = AutoModelForQuestionAnswering.from_pretrained(save_directory)
|
80 |
+
model.tokenizer_extractive = AutoTokenizer.from_pretrained(save_directory)
|
81 |
+
model.model_generative = AutoModelForSeq2SeqLM.from_pretrained(save_directory)
|
82 |
+
model.tokenizer_generative = AutoTokenizer.from_pretrained(save_directory)
|
83 |
+
|
84 |
+
return model
|