| | import torch |
| |
|
| | from tests.test_select_block import create_block, Config, SparseConfig |
| | import csv |
| | import time |
| | import torch |
| | import torch.nn as nn |
| | from flash_attn.utils.generation import InferenceParams |
| | from HybridTensor.utils.utils import arg_parser, _get_device, sparse_index, generate_random_BH_index, get_gpu_name |
| | from HybridTensor.utils.profiling import cuda_profiler |
| | import math |
| | from tqdm import tqdm |
| |
|
| | def run_simulation(args, batch_size, seq_len, index_size, attn_topk, device, dtype): |
| | config = Config() |
| | sp_config = SparseConfig() |
| | sp_config.attn_topk = attn_topk |
| | |
| | config.hidden_size = args.in_features |
| | config.num_attention_heads = args.in_features // 128 |
| | config.use_heuristic = False |
| |
|
| | |
| | sparse_block = create_block(config, sp_config, layer_idx=0, process_group=None, device=device, dtype=dtype) |
| | sparse_block.eval() |
| | sparse_block.mlp_topk = index_size |
| | |
| | regular_config = config |
| | regular_config.att_sparse = False |
| | regular_config.mlp_sparse = False |
| | regular_block = create_block(regular_config, None, layer_idx=0, process_group=None, device=device, dtype=dtype) |
| | regular_block.eval() |
| | |
| | |
| | max_seqlen = seq_len + 16 |
| | max_batch_size = batch_size |
| | in_features = args.in_features |
| | head_dim = 128 |
| | |
| | inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=max_batch_size) |
| | process_group = None |
| | sequence_parallel = False |
| | |
| | |
| | heads = config.num_attention_heads |
| | selected_heads = heads // 2 |
| | |
| | |
| | total_neurons = args.in_features * 4 |
| | test_index_vec = torch.empty((total_neurons,), device='cuda', dtype=torch.int32) |
| | active_indices = sparse_index(args.index_size, total_neurons)[0] |
| | test_index_vec[:args.index_size] = active_indices |
| | if args.index_size < total_neurons: |
| | test_index_vec[args.index_size:] = 0 |
| | |
| | |
| | test_bh_idx = generate_random_BH_index(args.batch_size, heads, selected_heads) |
| | test_index_size = args.index_size |
| | |
| | mixer_kwargs = ( |
| | {"seqlen": seq_len} |
| | if process_group is not None and sequence_parallel |
| | else {} |
| | ) |
| | if inference_params is not None: |
| | mixer_kwargs["inference_params"] = inference_params |
| | |
| | with torch.no_grad(): |
| | |
| | original_seq = torch.randn(batch_size, seq_len, in_features, device='cuda', dtype=torch.float16) |
| | |
| | |
| | output_sparse = sparse_block(original_seq, mixer_kwargs=mixer_kwargs) |
| | output_regular = regular_block(original_seq, mixer_kwargs=mixer_kwargs) |
| | |
| | |
| | mixer_kwargs["inference_params"].seqlen_offset = seq_len |
| | |
| | |
| | input_x = torch.randn(batch_size, 1, in_features, device='cuda', dtype=torch.float16) |
| | |
| | out_decode_sparse = sparse_block(input_x, mixer_kwargs=mixer_kwargs) |
| | |
| | mixer_kwargs["inference_params"].seqlen_offset = seq_len |
| | |
| | out_decode_regular = regular_block(input_x, mixer_kwargs=mixer_kwargs) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | input_x_static = input_x.clone() |
| | output_regular_static = torch.empty((batch_size, 1, in_features), device=device, dtype=dtype) |
| |
|
| | |
| | _ = regular_block(input_x_static, mixer_kwargs=mixer_kwargs) |
| | torch.cuda.synchronize() |
| | graph_regular = torch.cuda.CUDAGraph() |
| | with torch.cuda.graph(graph_regular): |
| | res = regular_block(input_x_static, mixer_kwargs=mixer_kwargs) |
| | if isinstance(res, tuple): |
| | res = res[0] |
| | output_regular_static.copy_(res) |
| |
|
| | |
| | |
| | mixer_kwargs["inference_params"].seqlen_offset = seq_len |
| | temp = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs) |
| | if isinstance(temp, tuple): |
| | temp = temp[0] |
| | |
| | |
| | output_sparse_static = torch.empty_like(temp) |
| | |
| | torch.cuda.synchronize() |
| | |
| | mixer_kwargs["inference_params"].seqlen_offset = seq_len |
| | graph_sparse = torch.cuda.CUDAGraph() |
| | with torch.cuda.graph(graph_sparse): |
| | res = sparse_block(input_x_static, mixer_kwargs=mixer_kwargs) |
| | if isinstance(res, tuple): |
| | res = res[0] |
| | output_sparse_static.copy_(res) |
| |
|
| | |
| | for _ in range(5): |
| | graph_regular.replay() |
| | graph_sparse.replay() |
| | torch.cuda.synchronize() |
| |
|
| | |
| | num_replays = 10 |
| |
|
| | start = time.time() |
| | for _ in range(num_replays): |
| | graph_regular.replay() |
| | torch.cuda.synchronize() |
| | regular_graph_time = (time.time() - start) * 1000 / num_replays |
| |
|
| | start = time.time() |
| | for _ in range(num_replays): |
| | graph_sparse.replay() |
| | torch.cuda.synchronize() |
| | sparse_graph_time = (time.time() - start) * 1000 / num_replays |
| | speedup = regular_graph_time / sparse_graph_time |
| | |
| | |
| | |
| | |
| | |
| | |
| | return regular_graph_time, sparse_graph_time, speedup |
| |
|
| | if __name__ == "__main__": |
| | |
| | args = arg_parser() |
| | device = _get_device(0) |
| | dtype = torch.float16 |
| | gpu_name = get_gpu_name() |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | batch_sizes = [1, 8, 16, 32] |
| | seq_lengths = [1024, 2048] |
| | |
| | index_size_p = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5] |
| | total_neurons = args.in_features * 4 |
| | |
| | |
| | index_sizes = [int(total_neurons * i) for i in index_size_p] |
| |
|
| | |
| | index_sizes = [math.ceil(size / 128) * 128 if size % 128 != 0 else size for size in index_sizes] |
| | |
| | attn_topks = [0.3, 0.4, 0.5] |
| |
|
| | |
| | total_runs = len(batch_sizes) * len(seq_lengths) * len(index_sizes) * len(attn_topks) |
| | output_file = f"results/simulations/{gpu_name}_select_block_{args.in_features}_inference_sim.csv" |
| |
|
| | with open(output_file, mode='w', newline='') as csv_file: |
| | fieldnames = ["in_features", "batch_size", "seq_len", "index_size", "neuron_activation", "attn_topk", |
| | "regular_graph_time_ms", "sparse_graph_time_ms", "speedup"] |
| | writer = csv.DictWriter(csv_file, fieldnames=fieldnames) |
| | writer.writeheader() |
| |
|
| | |
| | for batch_size in tqdm(batch_sizes, desc="Batch Sizes"): |
| | for seq_len in seq_lengths: |
| | for index_size in index_sizes: |
| | for attn_topk in attn_topks: |
| | reg_time, spa_time, speedup = run_simulation(args, batch_size, seq_len, index_size, attn_topk, device, dtype) |
| | writer.writerow({ |
| | "in_features": args.in_features, |
| | "batch_size": batch_size, |
| | "seq_len": seq_len, |
| | "index_size": index_size, |
| | "neuron_activation": index_size / total_neurons, |
| | "attn_topk": attn_topk, |
| | "regular_graph_time_ms": reg_time, |
| | "sparse_graph_time_ms": spa_time, |
| | "speedup": speedup |
| | }) |
| | csv_file.flush() |
| | print(f"Simulation complete. Results saved to {output_file}") |