remzicam commited on
Commit
cfa232e
1 Parent(s): bc2ec1a

Update blender_model.py

Browse files

Codes are updated. It will work with new versions of transformers.

Files changed (1) hide show
  1. blender_model.py +101 -53
blender_model.py CHANGED
@@ -1,35 +1,30 @@
1
- from transformers import (
2
- AutoConfig,
3
- BlenderbotSmallForConditionalGeneration,
4
- logging
5
- )
6
- from transformers.modeling_outputs import (
7
- Seq2SeqLMOutput,
8
- BaseModelOutput,
9
- )
10
- from huggingface_hub import hf_hub_url, cached_download
11
- from onnxruntime import (GraphOptimizationLevel,
12
- InferenceSession,
13
- SessionOptions)
14
 
 
 
15
  from torch import from_numpy
16
  from torch.nn import Module
17
- from functools import reduce
18
- from operator import iconcat
 
19
 
20
- #supress huggingface warnings
21
- logging.set_verbosity_error()
 
 
 
 
 
 
22
 
23
- model_vocab_size=30000
24
- model_card="remzicam/xs_blenderbot_onnx"
25
- model_file_names=["blenderbot_small-90M-encoder-quantized.onnx",
26
- "blenderbot_small-90M-decoder-quantized.onnx",
27
- "blenderbot_small-90M-init-decoder-quantized.onnx"]
28
 
29
  class BlenderEncoder(Module):
30
  def __init__(self, encoder_sess):
31
  super().__init__()
32
  self.encoder = encoder_sess
 
33
 
34
  def forward(
35
  self,
@@ -113,25 +108,53 @@ class BlenderDecoder(Module):
113
  class OnnxBlender(BlenderbotSmallForConditionalGeneration):
114
  """creates a Blender model using onnx sessions (encode, decoder & init_decoder)"""
115
 
116
- def __init__(self, onnx_model_sessions):
117
- config = AutoConfig.from_pretrained("facebook/blenderbot_small-90M")
118
- config.vocab_size=model_vocab_size
119
  super().__init__(config)
120
 
121
- assert len(onnx_model_sessions) == 3, "all three models should be given"
 
 
122
 
123
- encoder_sess, decoder_sess, decoder_sess_init = onnx_model_sessions
124
 
125
  self.encoder = BlenderEncoder(encoder_sess)
126
  self.decoder = BlenderDecoder(decoder_sess)
127
  self.decoder_init = BlenderDecoderInit(decoder_sess_init)
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  def get_encoder(self):
130
  return self.encoder
131
 
132
  def get_decoder(self):
133
  return self.decoder
134
-
135
  def forward(
136
  self,
137
  input_ids=None,
@@ -151,9 +174,9 @@ class OnnxBlender(BlenderbotSmallForConditionalGeneration):
151
  output_hidden_states=None,
152
  return_dict=None,
153
  ):
154
-
155
  encoder_hidden_states = encoder_outputs[0]
156
-
157
  if past_key_values is not None:
158
  if decoder_input_ids is not None:
159
  decoder_input_ids = decoder_input_ids[:, -1:]
@@ -182,26 +205,51 @@ class OnnxBlender(BlenderbotSmallForConditionalGeneration):
182
 
183
  return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)
184
 
185
- class ModelLoad:
186
- def __init__(self, model_card,file_names):
187
- self.model_card=model_card
188
- self.file_names=file_names
189
-
190
- def model_file_downloader(self,model_card,filename):
191
- config_file_url = hf_hub_url(model_card, filename)
192
- model_file = cached_download(config_file_url)
193
- return model_file
194
-
195
- def inference_session(self,file_name):
196
- model_file=self.model_file_downloader(self.model_card,file_name)
197
- options = SessionOptions()
198
- options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
199
- return InferenceSession(model_file,options=options)
200
-
201
- def __call__(self,model_config):
202
- model=model_config([*map(self.inference_session,
203
- self.file_names)])
204
- return model
205
-
206
- model_loader=ModelLoad(model_card,model_file_names)
207
- blender_onnx_model=model_loader(OnnxBlender)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from operator import iconcat
3
+ from typing import List
 
 
 
 
 
 
 
 
 
 
4
 
5
+ from huggingface_hub import hf_hub_download
6
+ from onnxruntime import InferenceSession
7
  from torch import from_numpy
8
  from torch.nn import Module
9
+ from transformers import (AutoConfig, BlenderbotSmallForConditionalGeneration,
10
+ BlenderbotSmallTokenizer)
11
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
12
 
13
+ model_vocab_size = 30000
14
+ original_repo_id = "facebook/blenderbot_small-90M"
15
+ repo_id = "remzicam/xs_blenderbot_onnx"
16
+ model_file_names = [
17
+ "blenderbot_small-90M-encoder-quantized.onnx",
18
+ "blenderbot_small-90M-decoder-quantized.onnx",
19
+ "blenderbot_small-90M-init-decoder-quantized.onnx",
20
+ ]
21
 
 
 
 
 
 
22
 
23
  class BlenderEncoder(Module):
24
  def __init__(self, encoder_sess):
25
  super().__init__()
26
  self.encoder = encoder_sess
27
+ self.main_input_name = "input_ids"
28
 
29
  def forward(
30
  self,
 
108
  class OnnxBlender(BlenderbotSmallForConditionalGeneration):
109
  """creates a Blender model using onnx sessions (encode, decoder & init_decoder)"""
110
 
111
+ def __init__(self, original_repo_id, repo_id, file_names):
112
+ config = AutoConfig.from_pretrained(original_repo_id)
113
+ config.vocab_size = model_vocab_size
114
  super().__init__(config)
115
 
116
+ self.files = self.files_downloader(repo_id, file_names)
117
+ self.onnx_model_sessions = self.onnx_sessions_starter(self.files)
118
+ assert len(self.onnx_model_sessions) == 3, "all three models should be given"
119
 
120
+ encoder_sess, decoder_sess, decoder_sess_init = self.onnx_model_sessions
121
 
122
  self.encoder = BlenderEncoder(encoder_sess)
123
  self.decoder = BlenderDecoder(decoder_sess)
124
  self.decoder_init = BlenderDecoderInit(decoder_sess_init)
125
 
126
+ @staticmethod
127
+ def files_downloader(repo_id: str, file_names: List[str]) -> List[str]:
128
+ """Downloads files from huggingface given file names
129
+
130
+ Args:
131
+
132
+ repo_id (str): repo name at huggingface.
133
+ file_names (List[str]): The names of the files in the repo.
134
+
135
+ Returns:
136
+ List[str]: Local paths of files
137
+ """
138
+ return [hf_hub_download(repo_id, file) for file in file_names]
139
+
140
+ @staticmethod
141
+ def onnx_sessions_starter(files: List[str]) -> List[object]:
142
+ """initiates onnx inference sessions
143
+
144
+ Args:
145
+ files (List[str]): Local paths of files
146
+
147
+ Returns:
148
+ List[object]: onnx sessions for each file
149
+ """
150
+ return [*map(InferenceSession, files)]
151
+
152
  def get_encoder(self):
153
  return self.encoder
154
 
155
  def get_decoder(self):
156
  return self.decoder
157
+
158
  def forward(
159
  self,
160
  input_ids=None,
 
174
  output_hidden_states=None,
175
  return_dict=None,
176
  ):
177
+
178
  encoder_hidden_states = encoder_outputs[0]
179
+
180
  if past_key_values is not None:
181
  if decoder_input_ids is not None:
182
  decoder_input_ids = decoder_input_ids[:, -1:]
 
205
 
206
  return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)
207
 
208
+
209
+ class TextGenerationPipeline:
210
+ """Pipeline for text generation of blenderbot model.
211
+ Returns:
212
+ str: generated text
213
+ """
214
+
215
+ # load tokenizer and the model
216
+ tokenizer = BlenderbotSmallTokenizer.from_pretrained(original_repo_id)
217
+ model = OnnxBlender(original_repo_id, repo_id, model_file_names)
218
+
219
+ def __init__(self, **kwargs):
220
+ """Specififying text generation parameters.
221
+ For example: max_length=100 which generates text shorter than
222
+ 100 tokens. Visit:
223
+ https://huggingface.co/docs/transformers/main_classes/text_generation
224
+ for more parameters
225
+ """
226
+ self.__dict__.update(kwargs)
227
+
228
+ def preprocess(self, text) -> str:
229
+ """Tokenizes input text.
230
+ Args:
231
+ text (str): user specified text
232
+ Returns:
233
+ torch.Tensor (obj): text representation as tensors
234
+ """
235
+ return self.tokenizer(text, return_tensors="pt")
236
+
237
+ def postprocess(self, outputs) -> str:
238
+ """Converts tensors into text.
239
+ Args:
240
+ outputs (torch.Tensor obj): model text generation output
241
+ Returns:
242
+ str: generated text
243
+ """
244
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
245
+
246
+ def __call__(self, text: str) -> str:
247
+ """Generates text from input text.
248
+ Args:
249
+ text (str): user specified text
250
+ Returns:
251
+ str: generated text
252
+ """
253
+ tokenized_text = self.preprocess(text)
254
+ output = self.model.generate(**tokenized_text, **self.__dict__)
255
+ return self.postprocess(output)