File size: 4,667 Bytes
19b22c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""Perform inference of one model on one input prompt and measure time and energy."""

from __future__ import annotations

from typing import Literal

import tyro
import rich
import torch
from fastchat.serve.inference import generate_stream
from fastchat.model.model_adapter import load_model, get_conversation_template
from zeus.monitor import ZeusMonitor

SYSTEM_PROMPTS = {
    "chat": (
        "A chat between a human user (prompter) and an artificial intelligence (AI) assistant. "
        "The assistant gives helpful, detailed, and polite answers to the user's questions."
    ),
    "chat-concise": (
        "A chat between a human user (prompter) and an artificial intelligence (AI) assistant. "
        "The assistant gives helpful, detailed, and polite answers to the user's questions. "
        "The assistnat's answers are concise but high-quality."
    ),
    "instruct": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request."
    ),
    "instruct-concise": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request."
        "The response should be concise but high-quality."
    ),
}


def main(
    model_path: str,
    input_prompt: str,
    device_index: int = 0,
    task: Literal[tuple(SYSTEM_PROMPTS)] = "chat",  # type: ignore
    load_8bit: bool = False,
    temperature: float = 0.7,
    repitition_penalty: float = 1.0,
    max_new_tokens: int = 512,
) -> None:
    """Run the main routine.

    Code structure is based on
    https://github.com/lm-sys/FastChat/blob/57dea54055/fastchat/serve/inference.py#L249

    Args:
        model_path: Path to or Huggingface Hub Id of the model.
        input_prompt: Input prompt to use for inference.
        device_index: Index of the GPU to use for inference.
        task: Type of task to perform inference on.
        load_8bit: Whether to load the model in 8-bit mode.
        temperature: Temperature to use for sampling.
        repitition_penalty: Repitition penalty to use for the model.
        max_new_tokens: Maximum numbers of tokens to generate, ignoring the prompt.
    """
    # NOTE(JW): ChatGLM is implemented as a special case in FastChat inference.
    # Also, it's primarily a model that's fine-tuned for Chinese, so it doesn't
    # make sense to prompt it in English and talk about its verbosity.
    if "chatglm" in model_path.lower():
        raise ValueError("ChatGLM is not supported.")

    # Set the device.
    torch.cuda.set_device(f"cuda:{device_index}")

    # Load the model (Huggingface PyTorch) and tokenizer (Huggingface).
    model, tokenizer = load_model(
        model_path=model_path,
        device="cuda",
        num_gpus=1,
        max_gpu_memory=None,
        load_8bit=load_8bit,
        cpu_offloading=False,
        gptq_config=None,
        debug=False,
    )

    # Chats are accumulated in a conversation helper object.
    conv = get_conversation_template(model_path)

    # Standardize the system prompt for every model.
    conv.system = SYSTEM_PROMPTS[task]
    conv.messages = []
    conv.offset = 0

    # Construct the input prompt.
    conv.append_message(conv.roles[0], input_prompt)
    conv.append_message(conv.roles[1], "")
    prompt = conv.get_prompt()

    # Generate the ouptut from the model.
    gen_params = {
        "model": model_path,
        "prompt": prompt,
        "temperature": temperature,
        "repitition_penalty": repitition_penalty,
        "max_new_tokens": max_new_tokens,
        "stop": conv.stop_str,
        "stop_token_ids": conv.stop_token_ids,
        "echo": False,
    }
    output_stream = generate_stream(model, tokenizer, gen_params, device="cuda")
    output = {}

    # Inference and measurement!
    monitor = ZeusMonitor(gpu_indices=[torch.cuda.current_device()])
    monitor.begin_window("inference")
    for output in output_stream:
        pass
    measurements = monitor.end_window("inference")
    
    # Print the input and output.
    rich.print(f"\n[u]Prompt[/u]:\n{prompt.strip()}\n")
    output_text = output["text"]
    rich.print(f"\n[u]Response[/u]:\n{output_text.strip()}\n")

    # Print numbers.
    num_tokens = len(tokenizer.encode(output_text))
    rich.print(measurements)
    rich.print(f"Number of tokens: {num_tokens}")
    rich.print(f"Tokens per seconds: {num_tokens / measurements.time:.2f}")
    rich.print(f"Joules per token: {measurements.total_energy / num_tokens:.2f}")
    rich.print(f"Average power consumption: {measurements.total_energy / measurements.time:.2f}")


if __name__ == "__main__":
    tyro.cli(main)