Spaces:
Runtime error
Runtime error
File size: 4,688 Bytes
e7d5680 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import colossalai
import torch
import torch.distributed as dist
from colossalai.testing import spawn
from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward
from opensora.acceleration.parallel_states import set_sequence_parallel_group
from opensora.models.layers.blocks import (
Attention,
MultiHeadCrossAttention,
SeqParallelAttention,
SeqParallelMultiHeadCrossAttention,
)
def run_attention(rank, world_size):
# create model
torch.manual_seed(1024)
set_sequence_parallel_group(dist.group.WORLD)
seq_parallel_attention = SeqParallelAttention(dim=256, num_heads=4, qkv_bias=True, enable_flashattn=False).cuda()
torch.manual_seed(1024)
attention = Attention(
dim=256,
num_heads=4,
qkv_bias=True,
enable_flashattn=False,
).cuda()
# create inputs
torch.manual_seed(1024)
x = torch.randn(4, 64, 256).cuda()
seq_x = x.clone().detach()
x.requires_grad = True
x.retain_grad()
seq_x.requires_grad = True
seq_x.retain_grad()
sub_seq_x = split_forward_gather_backward(seq_x, dist.group.WORLD, dim=1, grad_scale="down")
# run model
out = attention(x)
sub_seq_out = seq_parallel_attention(sub_seq_x)
seq_out = gather_forward_split_backward(sub_seq_out, dist.group.WORLD, dim=1, grad_scale="up")
assert torch.allclose(seq_out, out, atol=1e-7), f"{seq_out}\nvs\n{out}"
# run backward
seq_out.mean().backward()
out.mean().backward()
# all reduce gradient for sp
for p in seq_parallel_attention.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, group=dist.group.WORLD)
p.grad.div_(world_size)
# check grad
for p1, p2 in zip(seq_parallel_attention.parameters(), attention.parameters()):
assert torch.allclose(p1.grad, p2.grad, atol=1e-7), f"{p1.grad}\nvs\n{p2.grad}"
# check input grad
assert torch.allclose(x.grad, seq_x.grad, atol=1e-7), f"{x.grad}\nvs\n{seq_x.grad}"
def run_cross_attention(rank, world_size):
# create model
torch.manual_seed(1024)
set_sequence_parallel_group(dist.group.WORLD)
seq_parallel_attention = SeqParallelMultiHeadCrossAttention(
d_model=256,
num_heads=4,
).cuda().to(torch.bfloat16)
torch.manual_seed(1024)
attention = MultiHeadCrossAttention(
d_model=256,
num_heads=4,
).cuda().to(torch.bfloat16)
# make sure the weights are the same
for p1, p2 in zip(seq_parallel_attention.parameters(), attention.parameters()):
p1.data.copy_(p2.data)
# create inputs
torch.manual_seed(1024)
x = torch.randn(4, 64, 256).cuda().to(torch.bfloat16)
y = torch.randn(4, 32, 256).cuda().to(torch.bfloat16)
mask = [2, 10, 8, 16]
mask = None
seq_x = x.clone().detach()
seq_y = y.clone().detach()
# set grad
x.requires_grad = True
x.retain_grad()
seq_x.requires_grad = True
seq_x.retain_grad()
y.requires_grad = True
y.retain_grad()
seq_y.requires_grad = True
seq_y.retain_grad()
# split by sequence
sub_seq_x = split_forward_gather_backward(seq_x, dist.group.WORLD, dim=1, grad_scale="down")
# run model
out = attention(x, y, mask)
sub_seq_out = seq_parallel_attention(sub_seq_x, seq_y, mask)
seq_out = gather_forward_split_backward(sub_seq_out, dist.group.WORLD, dim=1, grad_scale="up")
assert torch.allclose(seq_out, out, rtol=1e-5, atol=1e-6), f"\n{seq_out}\nvs\n{out}"
# run backward
seq_out.mean().backward()
out.mean().backward()
# all reduce gradient for sp
for name, p in seq_parallel_attention.named_parameters():
if p.grad is not None:
dist.all_reduce(p.grad, group=dist.group.WORLD)
p.grad.div_(world_size)
else:
print(f"grad of {name} is None")
# # check grad
for p1, p2 in zip(seq_parallel_attention.named_parameters(), attention.named_parameters()):
assert torch.allclose(p1[1].grad, p2[1].grad, rtol=1e-3, atol=1e-4), f"\n{p1[0]}\nvs\n{p2[0]}:\n{p1[1].grad}\nvs\n{p2[1].grad}"
# # check input grad
assert torch.allclose(x.grad, seq_x.grad, atol=1e-7), f"{x.grad}\nvs\n{seq_x.grad}"
assert torch.allclose(y.grad, seq_y.grad, atol=1e-7), f"{y.grad}\nvs\n{seq_y.grad}"
def run_dist(rank, world_size, port):
colossalai.launch({}, rank=rank, world_size=world_size, host="localhost", port=port)
# run_attention(rank, world_size)
run_cross_attention(rank, world_size)
def test_seq_parallel_attention():
spawn(run_dist, nprocs=2)
if __name__ == "__main__":
test_seq_parallel_attention()
|