import json from argparse import ArgumentParser import datasets import torch import transformers from transformers import AutoModelForCausalLM, BatchEncoding """ Usage examples (with the best batch sizes on A100-80GB-400W) ============================================================ python -m benchmark_hf_model --model_name_or_path="Deci/DeciLM-7B" --batch_size=352 python -m benchmark_hf_model --model_name_or_path="mistralai/Mistral-7B-v0.1" --batch_size=192 --model_kwargs_json='{"use_flash_attention_2": true}' python -m benchmark_hf_model --model_name_or_path="meta-llama/Llama-2-7b-hf" --batch_size=48 --model_kwargs_json='{"use_flash_attention_2": true}' """ def parse_args(): parser = ArgumentParser() parser.add_argument( "--model_name_or_path", type=str, required=True, ) parser.add_argument( "--warmup_iters", type=int, default=10, ) parser.add_argument( "--iterations", type=int, default=5, ) parser.add_argument( "--batch_size", type=int, default=32, ) parser.add_argument( "--prompt_length", type=int, default=512, ) parser.add_argument( "--max_new_tokens", type=int, default=512, ) parser.add_argument( "--precision", type=str, default="bf16", help="Model precision, from: fp32, fp16 or bf16", ) parser.add_argument( "--model_kwargs_json", type=str, default=None, ) return parser.parse_args() def main(): args = parse_args() transformers.logging.set_verbosity_error() datasets.logging.set_verbosity_error() dict_precisions = { "fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16, } if args.precision not in dict_precisions: raise ValueError( f"Non valid precision {args.precision}, choose from: fp16, fp32, bf16" ) dtype = dict_precisions[args.precision] model_kwargs = {} if args.model_kwargs_json is not None: model_kwargs = json.loads(args.model_kwargs_json) print(f"loading model...") model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, torch_dtype=dtype, **model_kwargs) try: print(model.model.layers[0].self_attn) except: print("couldn't print the model's attention module") starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) model.cuda() model.eval() prompt = torch.ones(args.prompt_length, dtype=torch.long) inputs = BatchEncoding({"input_ids": prompt.repeat(args.batch_size, 1)}) inputs = inputs.to(model.device) # warmup print(f"warming up for {args.warmup_iters} iterations...") for _ in range(args.warmup_iters): with torch.no_grad(): _ = model.generate( **inputs, max_new_tokens=1, do_sample=False, eos_token_id=-1234, ) print('finished warmup') torch.cuda.synchronize() print( f"prefill ({args.prompt_length} tokens{f' x {args.batch_size} batch' if args.batch_size > 1 else ''}) + generation ({args.max_new_tokens} tokens{f' x {args.batch_size} batch' if args.batch_size > 1 else ''}):") tokens_generated = args.max_new_tokens * args.batch_size prefill_and_generation = [] for gen_iter in range(args.iterations): starter.record() with torch.no_grad(): _ = model.generate( **inputs, max_new_tokens=args.max_new_tokens, do_sample=False, eos_token_id=-1234, ) ender.record() torch.cuda.synchronize() t = starter.elapsed_time(ender) / 1000 prefill_and_generation.append(t) print(f" iter {gen_iter + 1}: {t:.03f} sec total, {tokens_generated / t:.02f} generated tokens/sec") aver = sum(prefill_and_generation) / len(prefill_and_generation) print(f" average: {aver:.03f} sec total, {tokens_generated / aver:.02f} generated tokens/sec") print(f"These results are obtained for model '{args.model_name_or_path}' with {args.batch_size=}.") if __name__ == "__main__": main()