justinhl commited on
Commit
46202d7
1 Parent(s): 216500a

Delete hybrid_model.py

Browse files
Files changed (1) hide show
  1. hybrid_model.py +0 -84
hybrid_model.py DELETED
@@ -1,84 +0,0 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering
2
- from transformers import pipeline, PretrainedConfig
3
-
4
- from huggingface_hub import PyTorchModelHubMixin
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- import os
10
- import json
11
-
12
- class HybridQAModel(nn.Module, PyTorchModelHubMixin):
13
- #config_class = HybridQAConfig
14
- def __init__(self, config):
15
- super().__init__()
16
- self.config = config
17
- self.load_models(config.extractive_id, config.generative_id)
18
-
19
- def can_generate(self):
20
- return False
21
-
22
- def load_models(self, extractive_id, generative_id):
23
- self.tokenizer_extractive = AutoTokenizer.from_pretrained(extractive_id)
24
- self.tokenizer_generative = AutoTokenizer.from_pretrained(generative_id)
25
-
26
- self.model_extractive = AutoModelForQuestionAnswering.from_pretrained(extractive_id)
27
- self.model_generative = AutoModelForSeq2SeqLM.from_pretrained(generative_id)
28
-
29
- def predict(self, question, context):
30
- result_gen, conf_gen = self.infer_generative(self.model_generative, self.tokenizer_generative, question)
31
- result_ext, conf_ext = self.infer_extractive(self.model_extractive, self.tokenizer_extractive, question, context)
32
-
33
- # print("Generative result: ",result_gen)
34
- # print("Confidence: ", conf_gen)
35
- # print("Extractive result: ", result_ext)
36
- # print("Confidence: ", conf_ext)
37
-
38
- if conf_gen > conf_ext:
39
- return {'guess':result_gen, 'confidence':conf_gen}
40
- else:
41
- return {'guess':result_ext, 'confidence':conf_ext}
42
-
43
- def infer_generative(self, model, tokenizer, input_text, **generate_kwargs):
44
- max_input_length = min(tokenizer.model_max_length, model.config.max_length)
45
- input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=max_input_length)
46
-
47
- with torch.no_grad():
48
- output_ids = model.generate(input_ids, **generate_kwargs)
49
- decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
50
-
51
- output_probs = F.softmax(output_ids.float(), dim=-1).squeeze(0)
52
- entropy = -(output_probs * torch.log(output_probs)).sum(dim=-1)
53
- confidence_score = 1 - entropy.item()
54
-
55
- model.save_pretrained("./base_models")
56
- return decoded_output, confidence_score
57
-
58
- def infer_extractive(self, model, tokenizer, question, context):
59
- qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
60
- result = qa_pipeline(question=question, context=context)
61
- confidence_score = result['score']
62
-
63
- model.save_pretrained("./base_models")
64
- return result['answer'], confidence_score
65
-
66
- def save_pretrained(self, save_directory, **kwargs):
67
- if not os.path.exists(save_directory):
68
- os.makedirs(save_directory, exist_ok=True)
69
- self.config.save_pretrained(save_directory, **kwargs)
70
- self.model_extractive.save_pretrained(save_directory, **kwargs)
71
- self.tokenizer_extractive.save_pretrained(save_directory, **kwargs)
72
- self.model_generative.save_pretrained(save_directory, **kwargs)
73
- self.tokenizer_generative.save_pretrained(save_directory, **kwargs)
74
-
75
- def from_pretrained(cls, save_directory, *model_args, **model_kwargs):
76
- config = PretrainedConfig.from_pretrained(save_directory)
77
- model = HybridQAModel(config)
78
-
79
- model.model_extractive = AutoModelForQuestionAnswering.from_pretrained(save_directory)
80
- model.tokenizer_extractive = AutoTokenizer.from_pretrained(save_directory)
81
- model.model_generative = AutoModelForSeq2SeqLM.from_pretrained(save_directory)
82
- model.tokenizer_generative = AutoTokenizer.from_pretrained(save_directory)
83
-
84
- return model