Spaces:
Runtime error
Runtime error
import einops | |
import torch | |
import torch.nn.functional as F | |
import torch.utils.benchmark as benchmark | |
from torch.backends.cuda import SDPBackend | |
from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer | |
def benchmark_attn(): | |
# Lets define a helpful benchmarking function: | |
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def benchmark_torch_function_in_microseconds(f, *args, **kwargs): | |
t0 = benchmark.Timer( | |
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} | |
) | |
return t0.blocked_autorange().mean * 1e6 | |
# Lets define the hyper-parameters of our input | |
batch_size = 32 | |
max_sequence_len = 1024 | |
num_heads = 32 | |
embed_dimension = 32 | |
dtype = torch.float16 | |
query = torch.rand( | |
batch_size, | |
num_heads, | |
max_sequence_len, | |
embed_dimension, | |
device=device, | |
dtype=dtype, | |
) | |
key = torch.rand( | |
batch_size, | |
num_heads, | |
max_sequence_len, | |
embed_dimension, | |
device=device, | |
dtype=dtype, | |
) | |
value = torch.rand( | |
batch_size, | |
num_heads, | |
max_sequence_len, | |
embed_dimension, | |
device=device, | |
dtype=dtype, | |
) | |
print(f"q/k/v shape:", query.shape, key.shape, value.shape) | |
# Lets explore the speed of each of the 3 implementations | |
from torch.backends.cuda import SDPBackend, sdp_kernel | |
# Helpful arguments mapper | |
backend_map = { | |
SDPBackend.MATH: { | |
"enable_math": True, | |
"enable_flash": False, | |
"enable_mem_efficient": False, | |
}, | |
SDPBackend.FLASH_ATTENTION: { | |
"enable_math": False, | |
"enable_flash": True, | |
"enable_mem_efficient": False, | |
}, | |
SDPBackend.EFFICIENT_ATTENTION: { | |
"enable_math": False, | |
"enable_flash": False, | |
"enable_mem_efficient": True, | |
}, | |
} | |
from torch.profiler import ProfilerActivity, profile, record_function | |
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] | |
print( | |
f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" | |
) | |
with profile( | |
activities=activities, record_shapes=False, profile_memory=True | |
) as prof: | |
with record_function("Default detailed stats"): | |
for _ in range(25): | |
o = F.scaled_dot_product_attention(query, key, value) | |
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) | |
print( | |
f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" | |
) | |
with sdp_kernel(**backend_map[SDPBackend.MATH]): | |
with profile( | |
activities=activities, record_shapes=False, profile_memory=True | |
) as prof: | |
with record_function("Math implmentation stats"): | |
for _ in range(25): | |
o = F.scaled_dot_product_attention(query, key, value) | |
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) | |
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): | |
try: | |
print( | |
f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" | |
) | |
except RuntimeError: | |
print("FlashAttention is not supported. See warnings for reasons.") | |
with profile( | |
activities=activities, record_shapes=False, profile_memory=True | |
) as prof: | |
with record_function("FlashAttention stats"): | |
for _ in range(25): | |
o = F.scaled_dot_product_attention(query, key, value) | |
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) | |
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): | |
try: | |
print( | |
f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" | |
) | |
except RuntimeError: | |
print("EfficientAttention is not supported. See warnings for reasons.") | |
with profile( | |
activities=activities, record_shapes=False, profile_memory=True | |
) as prof: | |
with record_function("EfficientAttention stats"): | |
for _ in range(25): | |
o = F.scaled_dot_product_attention(query, key, value) | |
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) | |
def run_model(model, x, context): | |
return model(x, context) | |
def benchmark_transformer_blocks(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
import torch.utils.benchmark as benchmark | |
def benchmark_torch_function_in_microseconds(f, *args, **kwargs): | |
t0 = benchmark.Timer( | |
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} | |
) | |
return t0.blocked_autorange().mean * 1e6 | |
checkpoint = True | |
compile = False | |
batch_size = 32 | |
h, w = 64, 64 | |
context_len = 77 | |
embed_dimension = 1024 | |
context_dim = 1024 | |
d_head = 64 | |
transformer_depth = 4 | |
n_heads = embed_dimension // d_head | |
dtype = torch.float16 | |
model_native = SpatialTransformer( | |
embed_dimension, | |
n_heads, | |
d_head, | |
context_dim=context_dim, | |
use_linear=True, | |
use_checkpoint=checkpoint, | |
attn_type="softmax", | |
depth=transformer_depth, | |
sdp_backend=SDPBackend.FLASH_ATTENTION, | |
).to(device) | |
model_efficient_attn = SpatialTransformer( | |
embed_dimension, | |
n_heads, | |
d_head, | |
context_dim=context_dim, | |
use_linear=True, | |
depth=transformer_depth, | |
use_checkpoint=checkpoint, | |
attn_type="softmax-xformers", | |
).to(device) | |
if not checkpoint and compile: | |
print("compiling models") | |
model_native = torch.compile(model_native) | |
model_efficient_attn = torch.compile(model_efficient_attn) | |
x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype) | |
c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype) | |
from torch.profiler import ProfilerActivity, profile, record_function | |
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] | |
with torch.autocast("cuda"): | |
print( | |
f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds" | |
) | |
print( | |
f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds" | |
) | |
print(75 * "+") | |
print("NATIVE") | |
print(75 * "+") | |
torch.cuda.reset_peak_memory_stats() | |
with profile( | |
activities=activities, record_shapes=False, profile_memory=True | |
) as prof: | |
with record_function("NativeAttention stats"): | |
for _ in range(25): | |
model_native(x, c) | |
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) | |
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block") | |
print(75 * "+") | |
print("Xformers") | |
print(75 * "+") | |
torch.cuda.reset_peak_memory_stats() | |
with profile( | |
activities=activities, record_shapes=False, profile_memory=True | |
) as prof: | |
with record_function("xformers stats"): | |
for _ in range(25): | |
model_efficient_attn(x, c) | |
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) | |
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block") | |
def test01(): | |
# conv1x1 vs linear | |
from sgm.util import count_params | |
conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda() | |
print(count_params(conv)) | |
linear = torch.nn.Linear(3, 32).cuda() | |
print(count_params(linear)) | |
print(conv.weight.shape) | |
# use same initialization | |
linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1)) | |
linear.bias = torch.nn.Parameter(conv.bias) | |
print(linear.weight.shape) | |
x = torch.randn(11, 3, 64, 64).cuda() | |
xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous() | |
print(xr.shape) | |
out_linear = linear(xr) | |
print(out_linear.mean(), out_linear.shape) | |
out_conv = conv(x) | |
print(out_conv.mean(), out_conv.shape) | |
print("done with test01.\n") | |
def test02(): | |
# try cosine flash attention | |
import time | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
torch.backends.cudnn.benchmark = True | |
print("testing cosine flash attention...") | |
DIM = 1024 | |
SEQLEN = 4096 | |
BS = 16 | |
print(" softmax (vanilla) first...") | |
model = BasicTransformerBlock( | |
dim=DIM, | |
n_heads=16, | |
d_head=64, | |
dropout=0.0, | |
context_dim=None, | |
attn_mode="softmax", | |
).cuda() | |
try: | |
x = torch.randn(BS, SEQLEN, DIM).cuda() | |
tic = time.time() | |
y = model(x) | |
toc = time.time() | |
print(y.shape, toc - tic) | |
except RuntimeError as e: | |
# likely oom | |
print(str(e)) | |
print("\n now flash-cosine...") | |
model = BasicTransformerBlock( | |
dim=DIM, | |
n_heads=16, | |
d_head=64, | |
dropout=0.0, | |
context_dim=None, | |
attn_mode="flash-cosine", | |
).cuda() | |
x = torch.randn(BS, SEQLEN, DIM).cuda() | |
tic = time.time() | |
y = model(x) | |
toc = time.time() | |
print(y.shape, toc - tic) | |
print("done with test02.\n") | |
if __name__ == "__main__": | |
# test01() | |
# test02() | |
# test03() | |
# benchmark_attn() | |
benchmark_transformer_blocks() | |
print("done.") | |