Text Generation
Transformers
Safetensors
openelm
custom_code
mahyar-najibi commited on
Commit
200856b
1 Parent(s): 139d5ba

Add a generate module for OpenELM models.

Browse files
Files changed (1) hide show
  1. generate_openelm.py +239 -0
generate_openelm.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module to generate OpenELM output given a model and an input prompt."""
2
+ import os
3
+ import logging
4
+ import time
5
+ import argparse
6
+ from typing import Optional, Union
7
+ import torch
8
+
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+
11
+
12
+ def generate(
13
+ prompt: str,
14
+ model: Union[str, AutoModelForCausalLM],
15
+ hf_security_token: str = None,
16
+ tokenizer: Union[str, AutoTokenizer] = 'meta-llama/Llama-2-7b-hf',
17
+ device: Optional[str] = None,
18
+ max_length: int = 1024,
19
+ speculative_model: Optional[Union[str, AutoModelForCausalLM]] = None,
20
+ generate_kwargs: Optional[dict] = None,
21
+ ) -> str:
22
+ """ Generates output given a prompt.
23
+
24
+ Args:
25
+ prompt: The string prompt.
26
+ model: The LLM Model. If a string is passed, it should be the path to
27
+ the hf converted checkpoint.
28
+ hf_security_token: Hugging face security token.
29
+ tokenizer: Tokenizer instance. If model is set as a string path,
30
+ the tokenizer will be loaded from the checkpoint.
31
+ device: String representation of device to run the model on. If None
32
+ and cuda available it would be set to cuda:0 else cpu.
33
+ max_length: Maximum length of tokens, input prompt + generated tokens.
34
+ speculative_model: If set, this model will be used for
35
+ speculative generation. If a string is passed, it should be the
36
+ path to the hf converted checkpoint.
37
+ generate_kwargs: Extra kwargs passed to the generate function.
38
+
39
+ Returns:
40
+ output_text: output generated as a string.
41
+ generation_time: generation time in seconds.
42
+
43
+ Raises:
44
+ ValueError: If device is set to CUDA but no CUDA device is detected.
45
+ FileNotFoundError: If model or speculative_model are strings but
46
+ the model paths do not exist.
47
+ ValueError: If hf_security_token is not specified.
48
+ """
49
+ if not device:
50
+ if torch.cuda.is_available() and torch.cuda.device_count():
51
+ device = "cuda:0"
52
+ logging.warning(
53
+ 'inference device is not set, using cuda:0, %s',
54
+ torch.cuda.get_device_name(0)
55
+ )
56
+ else:
57
+ device = 'cpu'
58
+ logging.warning('no CUDA device detected, using cpu, expect slower speeds.')
59
+
60
+ if 'cuda' in device and not torch.cuda.is_available():
61
+ raise ValueError('CUDA device requested but no CUDA device detected.')
62
+
63
+ if isinstance(model, str) and (not model or not os.path.exists(model)):
64
+ raise FileNotFoundError(f'Model checkpoint does not exist at {model}.')
65
+
66
+ if (isinstance(speculative_model, str) and (
67
+ not speculative_model and not os.path.exists(speculative_model))):
68
+ raise FileNotFoundError(
69
+ (
70
+ 'Speculative checkpoint path does not exist at '
71
+ f'{speculative_model}.'
72
+ )
73
+ )
74
+ if not tokenizer and not isinstance(model, str):
75
+ raise ValueError('Tokenizer is not set in the generate function.')
76
+
77
+ if not hf_security_token:
78
+ raise ValueError((
79
+ 'Hugging face security key needs to be specified. '
80
+ 'Please refer to https://huggingface.co/docs/hub/security-tokens'
81
+ ' to obtain one.'
82
+ )
83
+ )
84
+
85
+ if isinstance(model, str):
86
+ checkpoint_path = model
87
+ model = AutoModelForCausalLM.from_pretrained(
88
+ checkpoint_path,
89
+ trust_remote_code=True
90
+ )
91
+ model.to(device).eval()
92
+ if isinstance(tokenizer, str):
93
+ tokenizer = AutoTokenizer.from_pretrained(
94
+ tokenizer,
95
+ token=hf_security_token,
96
+ )
97
+
98
+ # Speculative mode
99
+ draft_model = None
100
+ if speculative_model:
101
+ draft_model = speculative_model
102
+ if isinstance(speculative_model, str):
103
+ draft_model = AutoModelForCausalLM.from_pretrained(
104
+ speculative_model,
105
+ trust_remote_code=True
106
+ )
107
+ draft_model.to(device).eval()
108
+
109
+ # Prepare the prompt
110
+ tokenized_prompt = tokenizer(prompt)
111
+ tokenized_prompt = torch.tensor(
112
+ tokenized_prompt['input_ids'],
113
+ device=device
114
+ )
115
+
116
+ tokenized_prompt = tokenized_prompt.unsqueeze(0)
117
+
118
+ # Generate
119
+ stime = time.time()
120
+ output_ids = model.generate(
121
+ tokenized_prompt,
122
+ max_length=max_length,
123
+ pad_token_id=0,
124
+ assistant_model=draft_model,
125
+ **(generate_kwargs if generate_kwargs else {}),
126
+ )
127
+ generation_time = time.time() - stime
128
+
129
+ output_text = tokenizer.decode(
130
+ output_ids[0].tolist(),
131
+ skip_special_tokens=True
132
+ )
133
+
134
+ return output_text, generation_time
135
+
136
+
137
+ def openelm_generate_parser():
138
+ """Argument Parser"""
139
+
140
+ class KwargsParser(argparse.Action):
141
+ """Parser action class to parse kwargs of form key=value"""
142
+ def __call__(self, parser, namespace, values, option_string=None):
143
+ setattr(namespace, self.dest, dict())
144
+ for val in values:
145
+ if '=' not in val:
146
+ raise ValueError(
147
+ (
148
+ 'Argument parsing error, kwargs are expected in'
149
+ ' the form of key=value.'
150
+ )
151
+ )
152
+ kwarg_k, kwarg_v = val.split('=')
153
+ try:
154
+ converted_v = int(kwarg_v)
155
+ except ValueError:
156
+ try:
157
+ converted_v = float(kwarg_v)
158
+ except ValueError:
159
+ converted_v = kwarg_v
160
+ getattr(namespace, self.dest)[kwarg_k] = converted_v
161
+
162
+ parser = argparse.ArgumentParser('OpenELM Generate Module')
163
+ parser.add_argument(
164
+ '--checkpoint',
165
+ dest='checkpoint_path',
166
+ help='Path to the model hf converted checkpoint.',
167
+ required=True,
168
+ type=str,
169
+ )
170
+ parser.add_argument(
171
+ '--hf_security_token',
172
+ dest='hf_security_token',
173
+ help='HF security token, starting with "hf_".',
174
+ type=str,
175
+ )
176
+ parser.add_argument(
177
+ '--prompt',
178
+ dest='prompt',
179
+ help='Prompt for LLM call. Ignored if demo is True.',
180
+ default='',
181
+ type=str,
182
+ )
183
+ parser.add_argument(
184
+ '--device',
185
+ dest='device',
186
+ help='Device used for inference.',
187
+ type=str,
188
+ )
189
+ parser.add_argument(
190
+ '--max_length',
191
+ dest='max_length',
192
+ help='Maximum length of tokens.',
193
+ default=256,
194
+ type=int,
195
+ )
196
+ parser.add_argument(
197
+ '--speculative_model_ckpt',
198
+ dest='speculative_model_ckpt',
199
+ help=(
200
+ 'If set, this is used as a draft model for speculative generation.'
201
+ ),
202
+ type=str,
203
+ )
204
+ parser.add_argument(
205
+ '--generate_kwargs',
206
+ dest='generate_kwargs',
207
+ help='additional kwargs passed to the HF generate function.',
208
+ type=str,
209
+ nargs='*',
210
+ action=KwargsParser,
211
+ )
212
+ return parser.parse_args()
213
+
214
+
215
+ if __name__ == '__main__':
216
+ args = openelm_generate_parser()
217
+ prompt = args.prompt
218
+
219
+ output_text, genertaion_time = generate(
220
+ prompt=prompt,
221
+ model=args.checkpoint_path,
222
+ device=args.device,
223
+ max_length=args.max_length,
224
+ speculative_model=args.speculative_model_ckpt,
225
+ generate_kwargs=args.generate_kwargs,
226
+ hf_security_token=args.hf_security_token,
227
+ )
228
+
229
+ print_txt = (
230
+ f'\r\n{"=" * os.get_terminal_size().columns}\r\n'
231
+ '\033[1m Prompt + Generated Output\033[0m\r\n'
232
+ f'{"-" * os.get_terminal_size().columns}\r\n'
233
+ f'{output_text}\r\n'
234
+ f'{"-" * os.get_terminal_size().columns}\r\n'
235
+ '\r\nGeneration took'
236
+ f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m'
237
+ 'seconds.\r\n'
238
+ )
239
+ print(print_txt)