|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | """Module to generate OpenELM output given a model and an input prompt.""" | 
					
						
						|  | import os | 
					
						
						|  | import logging | 
					
						
						|  | import time | 
					
						
						|  | import argparse | 
					
						
						|  | from typing import Optional, Union | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  | from transformers import AutoTokenizer, AutoModelForCausalLM | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def generate( | 
					
						
						|  | prompt: str, | 
					
						
						|  | model: Union[str, AutoModelForCausalLM], | 
					
						
						|  | hf_access_token: str = None, | 
					
						
						|  | tokenizer: Union[str, AutoTokenizer] = 'meta-llama/Llama-2-7b-hf', | 
					
						
						|  | device: Optional[str] = None, | 
					
						
						|  | max_length: int = 1024, | 
					
						
						|  | assistant_model: Optional[Union[str, AutoModelForCausalLM]] = None, | 
					
						
						|  | generate_kwargs: Optional[dict] = None, | 
					
						
						|  | ) -> str: | 
					
						
						|  | """ Generates output given a prompt. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | prompt: The string prompt. | 
					
						
						|  | model: The LLM Model. If a string is passed, it should be the path to | 
					
						
						|  | the hf converted checkpoint. | 
					
						
						|  | hf_access_token: Hugging face access token. | 
					
						
						|  | tokenizer: Tokenizer instance. If model is set as a string path, | 
					
						
						|  | the tokenizer will be loaded from the checkpoint. | 
					
						
						|  | device: String representation of device to run the model on. If None | 
					
						
						|  | and cuda available it would be set to cuda:0 else cpu. | 
					
						
						|  | max_length: Maximum length of tokens, input prompt + generated tokens. | 
					
						
						|  | assistant_model: If set, this model will be used for | 
					
						
						|  | speculative generation. If a string is passed, it should be the | 
					
						
						|  | path to the hf converted checkpoint. | 
					
						
						|  | generate_kwargs: Extra kwargs passed to the hf generate function. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | output_text: output generated as a string. | 
					
						
						|  | generation_time: generation time in seconds. | 
					
						
						|  |  | 
					
						
						|  | Raises: | 
					
						
						|  | ValueError: If device is set to CUDA but no CUDA device is detected. | 
					
						
						|  | ValueError: If tokenizer is not set. | 
					
						
						|  | ValueError: If hf_access_token is not specified. | 
					
						
						|  | """ | 
					
						
						|  | if not device: | 
					
						
						|  | if torch.cuda.is_available() and torch.cuda.device_count(): | 
					
						
						|  | device = "cuda:0" | 
					
						
						|  | logging.warning( | 
					
						
						|  | 'inference device is not set, using cuda:0, %s', | 
					
						
						|  | torch.cuda.get_device_name(0) | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | device = 'cpu' | 
					
						
						|  | logging.warning( | 
					
						
						|  | ( | 
					
						
						|  | 'No CUDA device detected, using cpu, ' | 
					
						
						|  | 'expect slower speeds.' | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if 'cuda' in device and not torch.cuda.is_available(): | 
					
						
						|  | raise ValueError('CUDA device requested but no CUDA device detected.') | 
					
						
						|  |  | 
					
						
						|  | if not tokenizer: | 
					
						
						|  | raise ValueError('Tokenizer is not set in the generate function.') | 
					
						
						|  |  | 
					
						
						|  | if not hf_access_token: | 
					
						
						|  | raise ValueError(( | 
					
						
						|  | 'Hugging face access token needs to be specified. ' | 
					
						
						|  | 'Please refer to https://huggingface.co/docs/hub/security-tokens' | 
					
						
						|  | ' to obtain one.' | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if isinstance(model, str): | 
					
						
						|  | checkpoint_path = model | 
					
						
						|  | model = AutoModelForCausalLM.from_pretrained( | 
					
						
						|  | checkpoint_path, | 
					
						
						|  | trust_remote_code=True | 
					
						
						|  | ) | 
					
						
						|  | model.to(device).eval() | 
					
						
						|  | if isinstance(tokenizer, str): | 
					
						
						|  | tokenizer = AutoTokenizer.from_pretrained( | 
					
						
						|  | tokenizer, | 
					
						
						|  | token=hf_access_token, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | draft_model = None | 
					
						
						|  | if assistant_model: | 
					
						
						|  | draft_model = assistant_model | 
					
						
						|  | if isinstance(assistant_model, str): | 
					
						
						|  | draft_model = AutoModelForCausalLM.from_pretrained( | 
					
						
						|  | assistant_model, | 
					
						
						|  | trust_remote_code=True | 
					
						
						|  | ) | 
					
						
						|  | draft_model.to(device).eval() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tokenized_prompt = tokenizer(prompt) | 
					
						
						|  | tokenized_prompt = torch.tensor( | 
					
						
						|  | tokenized_prompt['input_ids'], | 
					
						
						|  | device=device | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | tokenized_prompt = tokenized_prompt.unsqueeze(0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | stime = time.time() | 
					
						
						|  | output_ids = model.generate( | 
					
						
						|  | tokenized_prompt, | 
					
						
						|  | max_length=max_length, | 
					
						
						|  | pad_token_id=0, | 
					
						
						|  | assistant_model=draft_model, | 
					
						
						|  | **(generate_kwargs if generate_kwargs else {}), | 
					
						
						|  | ) | 
					
						
						|  | generation_time = time.time() - stime | 
					
						
						|  |  | 
					
						
						|  | output_text = tokenizer.decode( | 
					
						
						|  | output_ids[0].tolist(), | 
					
						
						|  | skip_special_tokens=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return output_text, generation_time | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def openelm_generate_parser(): | 
					
						
						|  | """Argument Parser""" | 
					
						
						|  |  | 
					
						
						|  | class KwargsParser(argparse.Action): | 
					
						
						|  | """Parser action class to parse kwargs of form key=value""" | 
					
						
						|  | def __call__(self, parser, namespace, values, option_string=None): | 
					
						
						|  | setattr(namespace, self.dest, dict()) | 
					
						
						|  | for val in values: | 
					
						
						|  | if '=' not in val: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | ( | 
					
						
						|  | 'Argument parsing error, kwargs are expected in' | 
					
						
						|  | ' the form of key=value.' | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | kwarg_k, kwarg_v = val.split('=') | 
					
						
						|  | try: | 
					
						
						|  | converted_v = int(kwarg_v) | 
					
						
						|  | except ValueError: | 
					
						
						|  | try: | 
					
						
						|  | converted_v = float(kwarg_v) | 
					
						
						|  | except ValueError: | 
					
						
						|  | converted_v = kwarg_v | 
					
						
						|  | getattr(namespace, self.dest)[kwarg_k] = converted_v | 
					
						
						|  |  | 
					
						
						|  | parser = argparse.ArgumentParser('OpenELM Generate Module') | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | '--model', | 
					
						
						|  | dest='model', | 
					
						
						|  | help='Path to the hf converted model.', | 
					
						
						|  | required=True, | 
					
						
						|  | type=str, | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | '--hf_access_token', | 
					
						
						|  | dest='hf_access_token', | 
					
						
						|  | help='Hugging face access token, starting with "hf_".', | 
					
						
						|  | type=str, | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | '--prompt', | 
					
						
						|  | dest='prompt', | 
					
						
						|  | help='Prompt for LLM call.', | 
					
						
						|  | default='', | 
					
						
						|  | type=str, | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | '--device', | 
					
						
						|  | dest='device', | 
					
						
						|  | help='Device used for inference.', | 
					
						
						|  | type=str, | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | '--max_length', | 
					
						
						|  | dest='max_length', | 
					
						
						|  | help='Maximum length of tokens.', | 
					
						
						|  | default=256, | 
					
						
						|  | type=int, | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | '--assistant_model', | 
					
						
						|  | dest='assistant_model', | 
					
						
						|  | help=( | 
					
						
						|  | ( | 
					
						
						|  | 'If set, this is used as a draft model ' | 
					
						
						|  | 'for assisted speculative generation.' | 
					
						
						|  | ) | 
					
						
						|  | ), | 
					
						
						|  | type=str, | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | '--generate_kwargs', | 
					
						
						|  | dest='generate_kwargs', | 
					
						
						|  | help='Additional kwargs passed to the HF generate function.', | 
					
						
						|  | type=str, | 
					
						
						|  | nargs='*', | 
					
						
						|  | action=KwargsParser, | 
					
						
						|  | ) | 
					
						
						|  | return parser.parse_args() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == '__main__': | 
					
						
						|  | args = openelm_generate_parser() | 
					
						
						|  | prompt = args.prompt | 
					
						
						|  |  | 
					
						
						|  | output_text, genertaion_time = generate( | 
					
						
						|  | prompt=prompt, | 
					
						
						|  | model=args.model, | 
					
						
						|  | device=args.device, | 
					
						
						|  | max_length=args.max_length, | 
					
						
						|  | assistant_model=args.assistant_model, | 
					
						
						|  | generate_kwargs=args.generate_kwargs, | 
					
						
						|  | hf_access_token=args.hf_access_token, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | print_txt = ( | 
					
						
						|  | f'\r\n{"=" * os.get_terminal_size().columns}\r\n' | 
					
						
						|  | '\033[1m Prompt + Generated Output\033[0m\r\n' | 
					
						
						|  | f'{"-" * os.get_terminal_size().columns}\r\n' | 
					
						
						|  | f'{output_text}\r\n' | 
					
						
						|  | f'{"-" * os.get_terminal_size().columns}\r\n' | 
					
						
						|  | '\r\nGeneration took' | 
					
						
						|  | f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m' | 
					
						
						|  | 'seconds.\r\n' | 
					
						
						|  | ) | 
					
						
						|  | print(print_txt) | 
					
						
						|  |  |