"""Perform inference of one model on one input prompt and measure time and energy.""" from __future__ import annotations from typing import Literal import tyro import rich import torch from fastchat.serve.inference import generate_stream from fastchat.model.model_adapter import load_model, get_conversation_template from zeus.monitor import ZeusMonitor SYSTEM_PROMPTS = { "chat": ( "A chat between a human user (prompter) and an artificial intelligence (AI) assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions." ), "chat-concise": ( "A chat between a human user (prompter) and an artificial intelligence (AI) assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions. " "The assistnat's answers are concise but high-quality." ), "instruct": ( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request." ), "instruct-concise": ( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request." "The response should be concise but high-quality." ), } def main( model_path: str, input_prompt: str, device_index: int = 0, task: Literal[tuple(SYSTEM_PROMPTS)] = "chat", # type: ignore load_8bit: bool = False, temperature: float = 0.7, repitition_penalty: float = 1.0, max_new_tokens: int = 512, ) -> None: """Run the main routine. Code structure is based on https://github.com/lm-sys/FastChat/blob/57dea54055/fastchat/serve/inference.py#L249 Args: model_path: Path to or Huggingface Hub Id of the model. input_prompt: Input prompt to use for inference. device_index: Index of the GPU to use for inference. task: Type of task to perform inference on. load_8bit: Whether to load the model in 8-bit mode. temperature: Temperature to use for sampling. repitition_penalty: Repitition penalty to use for the model. max_new_tokens: Maximum numbers of tokens to generate, ignoring the prompt. """ # NOTE(JW): ChatGLM is implemented as a special case in FastChat inference. # Also, it's primarily a model that's fine-tuned for Chinese, so it doesn't # make sense to prompt it in English and talk about its verbosity. if "chatglm" in model_path.lower(): raise ValueError("ChatGLM is not supported.") # Set the device. torch.cuda.set_device(f"cuda:{device_index}") # Load the model (Huggingface PyTorch) and tokenizer (Huggingface). model, tokenizer = load_model( model_path=model_path, device="cuda", num_gpus=1, max_gpu_memory=None, load_8bit=load_8bit, cpu_offloading=False, gptq_config=None, debug=False, ) # Chats are accumulated in a conversation helper object. conv = get_conversation_template(model_path) # Standardize the system prompt for every model. conv.system = SYSTEM_PROMPTS[task] conv.messages = [] conv.offset = 0 # Construct the input prompt. conv.append_message(conv.roles[0], input_prompt) conv.append_message(conv.roles[1], "") prompt = conv.get_prompt() # Generate the ouptut from the model. gen_params = { "model": model_path, "prompt": prompt, "temperature": temperature, "repitition_penalty": repitition_penalty, "max_new_tokens": max_new_tokens, "stop": conv.stop_str, "stop_token_ids": conv.stop_token_ids, "echo": False, } output_stream = generate_stream(model, tokenizer, gen_params, device="cuda") output = {} # Inference and measurement! monitor = ZeusMonitor(gpu_indices=[torch.cuda.current_device()]) monitor.begin_window("inference") for output in output_stream: pass measurements = monitor.end_window("inference") # Print the input and output. rich.print(f"\n[u]Prompt[/u]:\n{prompt.strip()}\n") output_text = output["text"] rich.print(f"\n[u]Response[/u]:\n{output_text.strip()}\n") # Print numbers. num_tokens = len(tokenizer.encode(output_text)) rich.print(measurements) rich.print(f"Number of tokens: {num_tokens}") rich.print(f"Tokens per seconds: {num_tokens / measurements.time:.2f}") rich.print(f"Joules per token: {measurements.total_energy / num_tokens:.2f}") rich.print(f"Average power consumption: {measurements.total_energy / measurements.time:.2f}") if __name__ == "__main__": tyro.cli(main)