import einops import torch import torch.nn.functional as F import torch.utils.benchmark as benchmark from torch.backends.cuda import SDPBackend from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer def benchmark_attn(): # Lets define a helpful benchmarking function: # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html device = "cuda" if torch.cuda.is_available() else "cpu" def benchmark_torch_function_in_microseconds(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} ) return t0.blocked_autorange().mean * 1e6 # Lets define the hyper-parameters of our input batch_size = 32 max_sequence_len = 1024 num_heads = 32 embed_dimension = 32 dtype = torch.float16 query = torch.rand( batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype, ) key = torch.rand( batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype, ) value = torch.rand( batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype, ) print(f"q/k/v shape:", query.shape, key.shape, value.shape) # Lets explore the speed of each of the 3 implementations from torch.backends.cuda import SDPBackend, sdp_kernel # Helpful arguments mapper backend_map = { SDPBackend.MATH: { "enable_math": True, "enable_flash": False, "enable_mem_efficient": False, }, SDPBackend.FLASH_ATTENTION: { "enable_math": False, "enable_flash": True, "enable_mem_efficient": False, }, SDPBackend.EFFICIENT_ATTENTION: { "enable_math": False, "enable_flash": False, "enable_mem_efficient": True, }, } from torch.profiler import ProfilerActivity, profile, record_function activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] print( f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" ) with profile( activities=activities, record_shapes=False, profile_memory=True ) as prof: with record_function("Default detailed stats"): for _ in range(25): o = F.scaled_dot_product_attention(query, key, value) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) print( f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" ) with sdp_kernel(**backend_map[SDPBackend.MATH]): with profile( activities=activities, record_shapes=False, profile_memory=True ) as prof: with record_function("Math implmentation stats"): for _ in range(25): o = F.scaled_dot_product_attention(query, key, value) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): try: print( f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" ) except RuntimeError: print("FlashAttention is not supported. See warnings for reasons.") with profile( activities=activities, record_shapes=False, profile_memory=True ) as prof: with record_function("FlashAttention stats"): for _ in range(25): o = F.scaled_dot_product_attention(query, key, value) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): try: print( f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" ) except RuntimeError: print("EfficientAttention is not supported. See warnings for reasons.") with profile( activities=activities, record_shapes=False, profile_memory=True ) as prof: with record_function("EfficientAttention stats"): for _ in range(25): o = F.scaled_dot_product_attention(query, key, value) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) def run_model(model, x, context): return model(x, context) def benchmark_transformer_blocks(): device = "cuda" if torch.cuda.is_available() else "cpu" import torch.utils.benchmark as benchmark def benchmark_torch_function_in_microseconds(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} ) return t0.blocked_autorange().mean * 1e6 checkpoint = True compile = False batch_size = 32 h, w = 64, 64 context_len = 77 embed_dimension = 1024 context_dim = 1024 d_head = 64 transformer_depth = 4 n_heads = embed_dimension // d_head dtype = torch.float16 model_native = SpatialTransformer( embed_dimension, n_heads, d_head, context_dim=context_dim, use_linear=True, use_checkpoint=checkpoint, attn_type="softmax", depth=transformer_depth, sdp_backend=SDPBackend.FLASH_ATTENTION, ).to(device) model_efficient_attn = SpatialTransformer( embed_dimension, n_heads, d_head, context_dim=context_dim, use_linear=True, depth=transformer_depth, use_checkpoint=checkpoint, attn_type="softmax-xformers", ).to(device) if not checkpoint and compile: print("compiling models") model_native = torch.compile(model_native) model_efficient_attn = torch.compile(model_efficient_attn) x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype) c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype) from torch.profiler import ProfilerActivity, profile, record_function activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] with torch.autocast("cuda"): print( f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds" ) print( f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds" ) print(75 * "+") print("NATIVE") print(75 * "+") torch.cuda.reset_peak_memory_stats() with profile( activities=activities, record_shapes=False, profile_memory=True ) as prof: with record_function("NativeAttention stats"): for _ in range(25): model_native(x, c) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block") print(75 * "+") print("Xformers") print(75 * "+") torch.cuda.reset_peak_memory_stats() with profile( activities=activities, record_shapes=False, profile_memory=True ) as prof: with record_function("xformers stats"): for _ in range(25): model_efficient_attn(x, c) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block") def test01(): # conv1x1 vs linear from sgm.util import count_params conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda() print(count_params(conv)) linear = torch.nn.Linear(3, 32).cuda() print(count_params(linear)) print(conv.weight.shape) # use same initialization linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1)) linear.bias = torch.nn.Parameter(conv.bias) print(linear.weight.shape) x = torch.randn(11, 3, 64, 64).cuda() xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous() print(xr.shape) out_linear = linear(xr) print(out_linear.mean(), out_linear.shape) out_conv = conv(x) print(out_conv.mean(), out_conv.shape) print("done with test01.\n") def test02(): # try cosine flash attention import time torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True print("testing cosine flash attention...") DIM = 1024 SEQLEN = 4096 BS = 16 print(" softmax (vanilla) first...") model = BasicTransformerBlock( dim=DIM, n_heads=16, d_head=64, dropout=0.0, context_dim=None, attn_mode="softmax", ).cuda() try: x = torch.randn(BS, SEQLEN, DIM).cuda() tic = time.time() y = model(x) toc = time.time() print(y.shape, toc - tic) except RuntimeError as e: # likely oom print(str(e)) print("\n now flash-cosine...") model = BasicTransformerBlock( dim=DIM, n_heads=16, d_head=64, dropout=0.0, context_dim=None, attn_mode="flash-cosine", ).cuda() x = torch.randn(BS, SEQLEN, DIM).cuda() tic = time.time() y = model(x) toc = time.time() print(y.shape, toc - tic) print("done with test02.\n") if __name__ == "__main__": # test01() # test02() # test03() # benchmark_attn() benchmark_transformer_blocks() print("done.")