mahyar-najibi commited on
Commit
255f07f
1 Parent(s): bea1998

Add the generate module.

Browse files
Files changed (1) hide show
  1. generate_openelm.py +240 -0
generate_openelm.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+
6
+ """Module to generate OpenELM output given a model and an input prompt."""
7
+ import os
8
+ import logging
9
+ import time
10
+ import argparse
11
+ from typing import Optional, Union
12
+ import torch
13
+
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM
15
+
16
+
17
+ def generate(
18
+ prompt: str,
19
+ model: Union[str, AutoModelForCausalLM],
20
+ hf_access_token: str = None,
21
+ tokenizer: Union[str, AutoTokenizer] = 'meta-llama/Llama-2-7b-hf',
22
+ device: Optional[str] = None,
23
+ max_length: int = 1024,
24
+ assistant_model: Optional[Union[str, AutoModelForCausalLM]] = None,
25
+ generate_kwargs: Optional[dict] = None,
26
+ ) -> str:
27
+ """ Generates output given a prompt.
28
+
29
+ Args:
30
+ prompt: The string prompt.
31
+ model: The LLM Model. If a string is passed, it should be the path to
32
+ the hf converted checkpoint.
33
+ hf_access_token: Hugging face access token.
34
+ tokenizer: Tokenizer instance. If model is set as a string path,
35
+ the tokenizer will be loaded from the checkpoint.
36
+ device: String representation of device to run the model on. If None
37
+ and cuda available it would be set to cuda:0 else cpu.
38
+ max_length: Maximum length of tokens, input prompt + generated tokens.
39
+ assistant_model: If set, this model will be used for
40
+ speculative generation. If a string is passed, it should be the
41
+ path to the hf converted checkpoint.
42
+ generate_kwargs: Extra kwargs passed to the hf generate function.
43
+
44
+ Returns:
45
+ output_text: output generated as a string.
46
+ generation_time: generation time in seconds.
47
+
48
+ Raises:
49
+ ValueError: If device is set to CUDA but no CUDA device is detected.
50
+ ValueError: If tokenizer is not set.
51
+ ValueError: If hf_access_token is not specified.
52
+ """
53
+ if not device:
54
+ if torch.cuda.is_available() and torch.cuda.device_count():
55
+ device = "cuda:0"
56
+ logging.warning(
57
+ 'inference device is not set, using cuda:0, %s',
58
+ torch.cuda.get_device_name(0)
59
+ )
60
+ else:
61
+ device = 'cpu'
62
+ logging.warning(
63
+ (
64
+ 'No CUDA device detected, using cpu, '
65
+ 'expect slower speeds.'
66
+ )
67
+ )
68
+
69
+ if 'cuda' in device and not torch.cuda.is_available():
70
+ raise ValueError('CUDA device requested but no CUDA device detected.')
71
+
72
+ if not tokenizer:
73
+ raise ValueError('Tokenizer is not set in the generate function.')
74
+
75
+ if not hf_access_token:
76
+ raise ValueError((
77
+ 'Hugging face access token needs to be specified. '
78
+ 'Please refer to https://huggingface.co/docs/hub/security-tokens'
79
+ ' to obtain one.'
80
+ )
81
+ )
82
+
83
+ if isinstance(model, str):
84
+ checkpoint_path = model
85
+ model = AutoModelForCausalLM.from_pretrained(
86
+ checkpoint_path,
87
+ trust_remote_code=True
88
+ )
89
+ model.to(device).eval()
90
+ if isinstance(tokenizer, str):
91
+ tokenizer = AutoTokenizer.from_pretrained(
92
+ tokenizer,
93
+ token=hf_access_token,
94
+ )
95
+
96
+ # Speculative mode
97
+ draft_model = None
98
+ if assistant_model:
99
+ draft_model = assistant_model
100
+ if isinstance(assistant_model, str):
101
+ draft_model = AutoModelForCausalLM.from_pretrained(
102
+ assistant_model,
103
+ trust_remote_code=True
104
+ )
105
+ draft_model.to(device).eval()
106
+
107
+ # Prepare the prompt
108
+ tokenized_prompt = tokenizer(prompt)
109
+ tokenized_prompt = torch.tensor(
110
+ tokenized_prompt['input_ids'],
111
+ device=device
112
+ )
113
+
114
+ tokenized_prompt = tokenized_prompt.unsqueeze(0)
115
+
116
+ # Generate
117
+ stime = time.time()
118
+ output_ids = model.generate(
119
+ tokenized_prompt,
120
+ max_length=max_length,
121
+ pad_token_id=0,
122
+ assistant_model=draft_model,
123
+ **(generate_kwargs if generate_kwargs else {}),
124
+ )
125
+ generation_time = time.time() - stime
126
+
127
+ output_text = tokenizer.decode(
128
+ output_ids[0].tolist(),
129
+ skip_special_tokens=True
130
+ )
131
+
132
+ return output_text, generation_time
133
+
134
+
135
+ def openelm_generate_parser():
136
+ """Argument Parser"""
137
+
138
+ class KwargsParser(argparse.Action):
139
+ """Parser action class to parse kwargs of form key=value"""
140
+ def __call__(self, parser, namespace, values, option_string=None):
141
+ setattr(namespace, self.dest, dict())
142
+ for val in values:
143
+ if '=' not in val:
144
+ raise ValueError(
145
+ (
146
+ 'Argument parsing error, kwargs are expected in'
147
+ ' the form of key=value.'
148
+ )
149
+ )
150
+ kwarg_k, kwarg_v = val.split('=')
151
+ try:
152
+ converted_v = int(kwarg_v)
153
+ except ValueError:
154
+ try:
155
+ converted_v = float(kwarg_v)
156
+ except ValueError:
157
+ converted_v = kwarg_v
158
+ getattr(namespace, self.dest)[kwarg_k] = converted_v
159
+
160
+ parser = argparse.ArgumentParser('OpenELM Generate Module')
161
+ parser.add_argument(
162
+ '--model',
163
+ dest='model',
164
+ help='Path to the hf converted model.',
165
+ required=True,
166
+ type=str,
167
+ )
168
+ parser.add_argument(
169
+ '--hf_access_token',
170
+ dest='hf_access_token',
171
+ help='Hugging face access token, starting with "hf_".',
172
+ type=str,
173
+ )
174
+ parser.add_argument(
175
+ '--prompt',
176
+ dest='prompt',
177
+ help='Prompt for LLM call.',
178
+ default='',
179
+ type=str,
180
+ )
181
+ parser.add_argument(
182
+ '--device',
183
+ dest='device',
184
+ help='Device used for inference.',
185
+ type=str,
186
+ )
187
+ parser.add_argument(
188
+ '--max_length',
189
+ dest='max_length',
190
+ help='Maximum length of tokens.',
191
+ default=256,
192
+ type=int,
193
+ )
194
+ parser.add_argument(
195
+ '--assistant_model',
196
+ dest='assistant_model',
197
+ help=(
198
+ (
199
+ 'If set, this is used as a draft model '
200
+ 'for assisted speculative generation.'
201
+ )
202
+ ),
203
+ type=str,
204
+ )
205
+ parser.add_argument(
206
+ '--generate_kwargs',
207
+ dest='generate_kwargs',
208
+ help='Additional kwargs passed to the HF generate function.',
209
+ type=str,
210
+ nargs='*',
211
+ action=KwargsParser,
212
+ )
213
+ return parser.parse_args()
214
+
215
+
216
+ if __name__ == '__main__':
217
+ args = openelm_generate_parser()
218
+ prompt = args.prompt
219
+
220
+ output_text, genertaion_time = generate(
221
+ prompt=prompt,
222
+ model=args.model,
223
+ device=args.device,
224
+ max_length=args.max_length,
225
+ assistant_model=args.assistant_model,
226
+ generate_kwargs=args.generate_kwargs,
227
+ hf_access_token=args.hf_access_token,
228
+ )
229
+
230
+ print_txt = (
231
+ f'\r\n{"=" * os.get_terminal_size().columns}\r\n'
232
+ '\033[1m Prompt + Generated Output\033[0m\r\n'
233
+ f'{"-" * os.get_terminal_size().columns}\r\n'
234
+ f'{output_text}\r\n'
235
+ f'{"-" * os.get_terminal_size().columns}\r\n'
236
+ '\r\nGeneration took'
237
+ f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m'
238
+ 'seconds.\r\n'
239
+ )
240
+ print(print_txt)