File size: 4,395 Bytes
e97665c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import time
import argparse

from dotenv import load_dotenv
from distutils.util import strtobool
from memory_profiler import memory_usage
from tqdm import tqdm

from llama2_wrapper import LLAMA2_WRAPPER


def run_iteration(
    llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS
):
    def generation():
        generator = llama2_wrapper.run(
            prompt_example,
            [],
            DEFAULT_SYSTEM_PROMPT,
            DEFAULT_MAX_NEW_TOKENS,
            1,
            0.95,
            50,
        )
        model_response = None
        try:
            first_model_response = next(generator)
        except StopIteration:
            pass
        for model_response in generator:
            pass
        return llama2_wrapper.get_token_length(model_response), model_response

    tic = time.perf_counter()
    mem_usage, (output_token_length, model_response) = memory_usage(
        (generation,), max_usage=True, retval=True
    )
    toc = time.perf_counter()

    generation_time = toc - tic
    tokens_per_second = output_token_length / generation_time

    return generation_time, tokens_per_second, mem_usage, model_response


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--iter", type=int, default=5, help="Number of iterations")
    parser.add_argument("--model_path", type=str, default="", help="model path")
    parser.add_argument(
        "--backend_type",
        type=str,
        default="",
        help="Backend options: llama.cpp, gptq, transformers",
    )
    parser.add_argument(
        "--load_in_8bit",
        type=bool,
        default=False,
        help="Whether to use bitsandbytes 8 bit.",
    )

    args = parser.parse_args()

    load_dotenv()

    DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "")
    MAX_MAX_NEW_TOKENS = int(os.getenv("MAX_MAX_NEW_TOKENS", 2048))
    DEFAULT_MAX_NEW_TOKENS = int(os.getenv("DEFAULT_MAX_NEW_TOKENS", 1024))
    MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", 4000))

    MODEL_PATH = os.getenv("MODEL_PATH")
    assert MODEL_PATH is not None, f"MODEL_PATH is required, got: {MODEL_PATH}"
    BACKEND_TYPE = os.getenv("BACKEND_TYPE")
    assert BACKEND_TYPE is not None, f"BACKEND_TYPE is required, got: {BACKEND_TYPE}"

    LOAD_IN_8BIT = bool(strtobool(os.getenv("LOAD_IN_8BIT", "True")))

    if args.model_path != "":
        MODEL_PATH = args.model_path
    if args.backend_type != "":
        BACKEND_TYPE = args.backend_type
    if args.load_in_8bit:
        LOAD_IN_8BIT = True

    # Initialization
    init_tic = time.perf_counter()
    llama2_wrapper = LLAMA2_WRAPPER(
        model_path=MODEL_PATH,
        backend_type=BACKEND_TYPE,
        max_tokens=MAX_INPUT_TOKEN_LENGTH,
        load_in_8bit=LOAD_IN_8BIT,
        # verbose=True,
    )

    init_toc = time.perf_counter()
    initialization_time = init_toc - init_tic

    total_time = 0
    total_tokens_per_second = 0
    total_memory_gen = 0

    prompt_example = (
        "Can you explain briefly to me what is the Python programming language?"
    )

    # Cold run
    print("Performing cold run...")
    run_iteration(
        llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS
    )

    # Timed runs
    print(f"Performing {args.iter} timed runs...")
    for i in tqdm(range(args.iter)):
        try:
            gen_time, tokens_per_sec, mem_gen, model_response = run_iteration(
                llama2_wrapper,
                prompt_example,
                DEFAULT_SYSTEM_PROMPT,
                DEFAULT_MAX_NEW_TOKENS,
            )
            total_time += gen_time
            total_tokens_per_second += tokens_per_sec
            total_memory_gen += mem_gen
        except:
            break
    avg_time = total_time / (i + 1)
    avg_tokens_per_second = total_tokens_per_second / (i + 1)
    avg_memory_gen = total_memory_gen / (i + 1)

    print(f"Last model response: {model_response}")
    print(f"Initialization time: {initialization_time:0.4f} seconds.")
    print(
        f"Average generation time over {(i + 1)} iterations: {avg_time:0.4f} seconds."
    )
    print(
        f"Average speed over {(i + 1)} iterations: {avg_tokens_per_second:0.4f} tokens/sec."
    )
    print(f"Average memory usage during generation: {avg_memory_gen:.2f} MiB")


if __name__ == "__main__":
    main()