remzicam commited on
Commit
de11a1f
1 Parent(s): e6fd326

Upload blender_model.py

Browse files
Files changed (1) hide show
  1. blender_model.py +207 -0
blender_model.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ AutoConfig,
3
+ BlenderbotSmallForConditionalGeneration,
4
+ logging
5
+ )
6
+ from transformers.modeling_outputs import (
7
+ Seq2SeqLMOutput,
8
+ BaseModelOutput,
9
+ )
10
+ from huggingface_hub import hf_hub_url, cached_download
11
+ from onnxruntime import (GraphOptimizationLevel,
12
+ InferenceSession,
13
+ SessionOptions)
14
+
15
+ from torch import from_numpy
16
+ from torch.nn import Module
17
+ from functools import reduce
18
+ from operator import iconcat
19
+
20
+ #supress huggingface warnings
21
+ logging.set_verbosity_error()
22
+
23
+ model_vocab_size=30000
24
+ model_card="remzicam/xs_blenderbot_onnx"
25
+ model_file_names=["blenderbot_small-90M-encoder-quantized.onnx",
26
+ "blenderbot_small-90M-decoder-quantized.onnx",
27
+ "blenderbot_small-90M-init-decoder-quantized.onnx"]
28
+
29
+ class BlenderEncoder(Module):
30
+ def __init__(self, encoder_sess):
31
+ super().__init__()
32
+ self.encoder = encoder_sess
33
+
34
+ def forward(
35
+ self,
36
+ input_ids,
37
+ attention_mask,
38
+ inputs_embeds=None,
39
+ head_mask=None,
40
+ output_attentions=None,
41
+ output_hidden_states=None,
42
+ return_dict=None,
43
+ ):
44
+
45
+ encoder_hidden_state = from_numpy(
46
+ self.encoder.run(
47
+ None,
48
+ {
49
+ "input_ids": input_ids.cpu().numpy(),
50
+ "attention_mask": attention_mask.cpu().numpy(),
51
+ },
52
+ )[0]
53
+ )
54
+
55
+ return BaseModelOutput(encoder_hidden_state)
56
+
57
+
58
+ class BlenderDecoderInit(Module):
59
+ def __init__(self, decoder_sess):
60
+ super().__init__()
61
+ self.decoder = decoder_sess
62
+
63
+ def forward(self, input_ids, encoder_attention_mask, encoder_hidden_states):
64
+
65
+ decoder_outputs = self.decoder.run(
66
+ None,
67
+ {
68
+ "input_ids": input_ids.cpu().numpy(),
69
+ "encoder_attention_mask": encoder_attention_mask.cpu().numpy(),
70
+ "encoder_hidden_states": encoder_hidden_states.cpu().numpy(),
71
+ },
72
+ )
73
+
74
+ list_pkv = tuple(from_numpy(x) for x in decoder_outputs[1:])
75
+
76
+ out_past_key_values = tuple(
77
+ list_pkv[i : i + 4] for i in range(0, len(list_pkv), 4)
78
+ )
79
+
80
+ return from_numpy(decoder_outputs[0]), out_past_key_values
81
+
82
+
83
+ class BlenderDecoder(Module):
84
+ def __init__(self, decoder_sess):
85
+ super().__init__()
86
+ self.decoder = decoder_sess
87
+
88
+ def forward(self, input_ids, attention_mask, encoder_output, past_key_values):
89
+
90
+ decoder_inputs = {
91
+ "input_ids": input_ids.cpu().numpy(),
92
+ "encoder_attention_mask": attention_mask.cpu().numpy(),
93
+ }
94
+
95
+ flat_past_key_values = reduce(iconcat, past_key_values, [])
96
+
97
+ past_key_values = {
98
+ f"pkv_{i}": pkv.cpu().numpy() for i, pkv in enumerate(flat_past_key_values)
99
+ }
100
+
101
+ decoder_outputs = self.decoder.run(None, {**decoder_inputs, **past_key_values})
102
+ # converts each value of the list to tensor from numpy
103
+ list_pkv = tuple(from_numpy(x) for x in decoder_outputs[1:])
104
+
105
+ # creates a tuple of tuples of shape 6x4 from the above tuple
106
+ out_past_key_values = tuple(
107
+ list_pkv[i : i + 4] for i in range(0, len(list_pkv), 4)
108
+ )
109
+
110
+ return from_numpy(decoder_outputs[0]), out_past_key_values
111
+
112
+
113
+ class OnnxBlender(BlenderbotSmallForConditionalGeneration):
114
+ """creates a Blender model using onnx sessions (encode, decoder & init_decoder)"""
115
+
116
+ def __init__(self, onnx_model_sessions):
117
+ config = AutoConfig.from_pretrained("facebook/blenderbot_small-90M")
118
+ config.vocab_size=model_vocab_size
119
+ super().__init__(config)
120
+
121
+ assert len(onnx_model_sessions) == 3, "all three models should be given"
122
+
123
+ encoder_sess, decoder_sess, decoder_sess_init = onnx_model_sessions
124
+
125
+ self.encoder = BlenderEncoder(encoder_sess)
126
+ self.decoder = BlenderDecoder(decoder_sess)
127
+ self.decoder_init = BlenderDecoderInit(decoder_sess_init)
128
+
129
+ def get_encoder(self):
130
+ return self.encoder
131
+
132
+ def get_decoder(self):
133
+ return self.decoder
134
+
135
+ def forward(
136
+ self,
137
+ input_ids=None,
138
+ attention_mask=None,
139
+ decoder_input_ids=None,
140
+ decoder_attention_mask=None,
141
+ head_mask=None,
142
+ decoder_head_mask=None,
143
+ cross_attn_head_mask=None,
144
+ encoder_outputs=None,
145
+ past_key_values=None,
146
+ inputs_embeds=None,
147
+ decoder_inputs_embeds=None,
148
+ labels=None,
149
+ use_cache=None,
150
+ output_attentions=None,
151
+ output_hidden_states=None,
152
+ return_dict=None,
153
+ ):
154
+
155
+ encoder_hidden_states = encoder_outputs[0]
156
+
157
+ if past_key_values is not None:
158
+ if decoder_input_ids is not None:
159
+ decoder_input_ids = decoder_input_ids[:, -1:]
160
+ if decoder_inputs_embeds is not None:
161
+ decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
162
+
163
+ if past_key_values is None:
164
+
165
+ # runs only for the first time:
166
+ init_onnx_outputs = self.decoder_init(
167
+ decoder_input_ids, attention_mask, encoder_hidden_states
168
+ )
169
+
170
+ logits, past_key_values = init_onnx_outputs
171
+
172
+ else:
173
+
174
+ onnx_outputs = self.decoder(
175
+ decoder_input_ids,
176
+ attention_mask,
177
+ encoder_hidden_states,
178
+ past_key_values,
179
+ )
180
+
181
+ logits, past_key_values = onnx_outputs
182
+
183
+ return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)
184
+
185
+ class ModelLoad:
186
+ def __init__(self, model_card,file_names):
187
+ self.model_card=model_card
188
+ self.file_names=file_names
189
+
190
+ def model_file_downloader(self,model_card,filename):
191
+ config_file_url = hf_hub_url(model_card, filename)
192
+ model_file = cached_download(config_file_url)
193
+ return model_file
194
+
195
+ def inference_session(self,file_name):
196
+ model_file=self.model_file_downloader(self.model_card,file_name)
197
+ options = SessionOptions()
198
+ options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
199
+ return InferenceSession(model_file,options=options)
200
+
201
+ def __call__(self,model_config):
202
+ model=model_config([*map(self.inference_session,
203
+ self.file_names)])
204
+ return model
205
+
206
+ model_loader=ModelLoad(model_card,model_file_names)
207
+ blender_onnx_model=model_loader(OnnxBlender)