import os import time import argparse from dotenv import load_dotenv from distutils.util import strtobool from memory_profiler import memory_usage from tqdm import tqdm from llama2_wrapper import LLAMA2_WRAPPER def run_iteration( llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS ): def generation(): generator = llama2_wrapper.run( prompt_example, [], DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS, 1, 0.95, 50, ) model_response = None try: first_model_response = next(generator) except StopIteration: pass for model_response in generator: pass return llama2_wrapper.get_token_length(model_response), model_response tic = time.perf_counter() mem_usage, (output_token_length, model_response) = memory_usage( (generation,), max_usage=True, retval=True ) toc = time.perf_counter() generation_time = toc - tic tokens_per_second = output_token_length / generation_time return generation_time, tokens_per_second, mem_usage, model_response def main(): parser = argparse.ArgumentParser() parser.add_argument("--iter", type=int, default=5, help="Number of iterations") parser.add_argument("--model_path", type=str, default="", help="model path") parser.add_argument( "--backend_type", type=str, default="", help="Backend options: llama.cpp, gptq, transformers", ) parser.add_argument( "--load_in_8bit", type=bool, default=False, help="Whether to use bitsandbytes 8 bit.", ) args = parser.parse_args() load_dotenv() DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "") MAX_MAX_NEW_TOKENS = int(os.getenv("MAX_MAX_NEW_TOKENS", 2048)) DEFAULT_MAX_NEW_TOKENS = int(os.getenv("DEFAULT_MAX_NEW_TOKENS", 1024)) MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", 4000)) MODEL_PATH = os.getenv("MODEL_PATH") assert MODEL_PATH is not None, f"MODEL_PATH is required, got: {MODEL_PATH}" BACKEND_TYPE = os.getenv("BACKEND_TYPE") assert BACKEND_TYPE is not None, f"BACKEND_TYPE is required, got: {BACKEND_TYPE}" LOAD_IN_8BIT = bool(strtobool(os.getenv("LOAD_IN_8BIT", "True"))) if args.model_path != "": MODEL_PATH = args.model_path if args.backend_type != "": BACKEND_TYPE = args.backend_type if args.load_in_8bit: LOAD_IN_8BIT = True # Initialization init_tic = time.perf_counter() llama2_wrapper = LLAMA2_WRAPPER( model_path=MODEL_PATH, backend_type=BACKEND_TYPE, max_tokens=MAX_INPUT_TOKEN_LENGTH, load_in_8bit=LOAD_IN_8BIT, # verbose=True, ) init_toc = time.perf_counter() initialization_time = init_toc - init_tic total_time = 0 total_tokens_per_second = 0 total_memory_gen = 0 prompt_example = ( "Can you explain briefly to me what is the Python programming language?" ) # Cold run print("Performing cold run...") run_iteration( llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS ) # Timed runs print(f"Performing {args.iter} timed runs...") for i in tqdm(range(args.iter)): try: gen_time, tokens_per_sec, mem_gen, model_response = run_iteration( llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS, ) total_time += gen_time total_tokens_per_second += tokens_per_sec total_memory_gen += mem_gen except: break avg_time = total_time / (i + 1) avg_tokens_per_second = total_tokens_per_second / (i + 1) avg_memory_gen = total_memory_gen / (i + 1) print(f"Last model response: {model_response}") print(f"Initialization time: {initialization_time:0.4f} seconds.") print( f"Average generation time over {(i + 1)} iterations: {avg_time:0.4f} seconds." ) print( f"Average speed over {(i + 1)} iterations: {avg_tokens_per_second:0.4f} tokens/sec." ) print(f"Average memory usage during generation: {avg_memory_gen:.2f} MiB") if __name__ == "__main__": main()