Spaces:
Running
Running
File size: 4,667 Bytes
19b22c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
"""Perform inference of one model on one input prompt and measure time and energy."""
from __future__ import annotations
from typing import Literal
import tyro
import rich
import torch
from fastchat.serve.inference import generate_stream
from fastchat.model.model_adapter import load_model, get_conversation_template
from zeus.monitor import ZeusMonitor
SYSTEM_PROMPTS = {
"chat": (
"A chat between a human user (prompter) and an artificial intelligence (AI) assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
"chat-concise": (
"A chat between a human user (prompter) and an artificial intelligence (AI) assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
"The assistnat's answers are concise but high-quality."
),
"instruct": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
),
"instruct-concise": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
"The response should be concise but high-quality."
),
}
def main(
model_path: str,
input_prompt: str,
device_index: int = 0,
task: Literal[tuple(SYSTEM_PROMPTS)] = "chat", # type: ignore
load_8bit: bool = False,
temperature: float = 0.7,
repitition_penalty: float = 1.0,
max_new_tokens: int = 512,
) -> None:
"""Run the main routine.
Code structure is based on
https://github.com/lm-sys/FastChat/blob/57dea54055/fastchat/serve/inference.py#L249
Args:
model_path: Path to or Huggingface Hub Id of the model.
input_prompt: Input prompt to use for inference.
device_index: Index of the GPU to use for inference.
task: Type of task to perform inference on.
load_8bit: Whether to load the model in 8-bit mode.
temperature: Temperature to use for sampling.
repitition_penalty: Repitition penalty to use for the model.
max_new_tokens: Maximum numbers of tokens to generate, ignoring the prompt.
"""
# NOTE(JW): ChatGLM is implemented as a special case in FastChat inference.
# Also, it's primarily a model that's fine-tuned for Chinese, so it doesn't
# make sense to prompt it in English and talk about its verbosity.
if "chatglm" in model_path.lower():
raise ValueError("ChatGLM is not supported.")
# Set the device.
torch.cuda.set_device(f"cuda:{device_index}")
# Load the model (Huggingface PyTorch) and tokenizer (Huggingface).
model, tokenizer = load_model(
model_path=model_path,
device="cuda",
num_gpus=1,
max_gpu_memory=None,
load_8bit=load_8bit,
cpu_offloading=False,
gptq_config=None,
debug=False,
)
# Chats are accumulated in a conversation helper object.
conv = get_conversation_template(model_path)
# Standardize the system prompt for every model.
conv.system = SYSTEM_PROMPTS[task]
conv.messages = []
conv.offset = 0
# Construct the input prompt.
conv.append_message(conv.roles[0], input_prompt)
conv.append_message(conv.roles[1], "")
prompt = conv.get_prompt()
# Generate the ouptut from the model.
gen_params = {
"model": model_path,
"prompt": prompt,
"temperature": temperature,
"repitition_penalty": repitition_penalty,
"max_new_tokens": max_new_tokens,
"stop": conv.stop_str,
"stop_token_ids": conv.stop_token_ids,
"echo": False,
}
output_stream = generate_stream(model, tokenizer, gen_params, device="cuda")
output = {}
# Inference and measurement!
monitor = ZeusMonitor(gpu_indices=[torch.cuda.current_device()])
monitor.begin_window("inference")
for output in output_stream:
pass
measurements = monitor.end_window("inference")
# Print the input and output.
rich.print(f"\n[u]Prompt[/u]:\n{prompt.strip()}\n")
output_text = output["text"]
rich.print(f"\n[u]Response[/u]:\n{output_text.strip()}\n")
# Print numbers.
num_tokens = len(tokenizer.encode(output_text))
rich.print(measurements)
rich.print(f"Number of tokens: {num_tokens}")
rich.print(f"Tokens per seconds: {num_tokens / measurements.time:.2f}")
rich.print(f"Joules per token: {measurements.total_energy / num_tokens:.2f}")
rich.print(f"Average power consumption: {measurements.total_energy / measurements.time:.2f}")
if __name__ == "__main__":
tyro.cli(main)
|