|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  | from audiocraft.modules.rope import RotaryEmbedding | 
					
						
						|  | from audiocraft.modules.transformer import StreamingTransformer, set_efficient_attention_backend | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_rope(): | 
					
						
						|  | set_efficient_attention_backend('torch') | 
					
						
						|  | B, T, H, C = 8, 75, 16, 128 | 
					
						
						|  |  | 
					
						
						|  | rope = RotaryEmbedding(dim=C) | 
					
						
						|  | xq = torch.rand((B, T, H, C)) | 
					
						
						|  | xk = torch.rand((B, T, H, C)) | 
					
						
						|  | xq_out, xk_out = rope.rotate_qk(xq, xk, start=7) | 
					
						
						|  |  | 
					
						
						|  | assert list(xq_out.shape) == [B, T, H, C] | 
					
						
						|  | assert list(xk_out.shape) == [B, T, H, C] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_rope_io_dtypes(): | 
					
						
						|  | set_efficient_attention_backend('torch') | 
					
						
						|  | B, T, H, C = 8, 75, 16, 128 | 
					
						
						|  |  | 
					
						
						|  | rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32) | 
					
						
						|  | rope_64 = RotaryEmbedding(dim=C, dtype=torch.float64) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | xq_16 = torch.rand((B, T, H, C)).to(torch.bfloat16) | 
					
						
						|  | xk_16 = torch.rand((B, T, H, C)).to(torch.bfloat16) | 
					
						
						|  | xq_out, xk_out = rope_32.rotate_qk(xq_16, xk_16) | 
					
						
						|  | assert xq_out.dtype == torch.bfloat16 | 
					
						
						|  | xq_out, xk_out = rope_64.rotate_qk(xq_16, xk_16) | 
					
						
						|  | assert xq_out.dtype == torch.bfloat16 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | xq_32 = torch.rand((B, T, H, C)).to(torch.float32) | 
					
						
						|  | xk_32 = torch.rand((B, T, H, C)).to(torch.float32) | 
					
						
						|  | xq_out, xk_out = rope_32.rotate_qk(xq_32, xk_32) | 
					
						
						|  | assert xq_out.dtype == torch.float32 | 
					
						
						|  | xq_out, xk_out = rope_64.rotate_qk(xq_32, xk_32) | 
					
						
						|  | assert xq_out.dtype == torch.float32 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_transformer_with_rope(): | 
					
						
						|  | set_efficient_attention_backend('torch') | 
					
						
						|  | torch.manual_seed(1234) | 
					
						
						|  | for pos in ['rope', 'sin_rope']: | 
					
						
						|  | tr = StreamingTransformer( | 
					
						
						|  | 16, 4, 2, custom=True, dropout=0., layer_scale=0.1, | 
					
						
						|  | positional_embedding=pos) | 
					
						
						|  | tr.eval() | 
					
						
						|  | steps = 12 | 
					
						
						|  | x = torch.randn(3, steps, 16) | 
					
						
						|  |  | 
					
						
						|  | out = tr(x) | 
					
						
						|  | assert list(out.shape) == list(x.shape) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def test_rope_streaming(): | 
					
						
						|  | set_efficient_attention_backend('torch') | 
					
						
						|  | torch.manual_seed(1234) | 
					
						
						|  | tr = StreamingTransformer( | 
					
						
						|  | 16, 4, 2, causal=True, dropout=0., | 
					
						
						|  | custom=True, positional_embedding='rope') | 
					
						
						|  | tr.eval() | 
					
						
						|  | steps = 12 | 
					
						
						|  | x = torch.randn(3, steps, 16) | 
					
						
						|  |  | 
					
						
						|  | ref = tr(x) | 
					
						
						|  |  | 
					
						
						|  | with tr.streaming(): | 
					
						
						|  | outs = [] | 
					
						
						|  | frame_sizes = [1] * steps | 
					
						
						|  |  | 
					
						
						|  | for frame_size in frame_sizes: | 
					
						
						|  | frame = x[:, :frame_size] | 
					
						
						|  | x = x[:, frame_size:] | 
					
						
						|  | outs.append(tr(frame)) | 
					
						
						|  |  | 
					
						
						|  | out = torch.cat(outs, dim=1) | 
					
						
						|  | assert list(out.shape) == [3, steps, 16] | 
					
						
						|  | delta = torch.norm(out - ref) / torch.norm(out) | 
					
						
						|  | assert delta < 1e-6, delta | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def test_rope_streaming_past_context(): | 
					
						
						|  | set_efficient_attention_backend('torch') | 
					
						
						|  | torch.manual_seed(1234) | 
					
						
						|  |  | 
					
						
						|  | for context in [None, 10]: | 
					
						
						|  | tr = StreamingTransformer( | 
					
						
						|  | 16, 4, 1 if context else 2, | 
					
						
						|  | causal=True, past_context=context, custom=True, | 
					
						
						|  | dropout=0., positional_embedding='rope') | 
					
						
						|  | tr.eval() | 
					
						
						|  |  | 
					
						
						|  | steps = 20 | 
					
						
						|  | x = torch.randn(3, steps, 16) | 
					
						
						|  | ref = tr(x) | 
					
						
						|  |  | 
					
						
						|  | with tr.streaming(): | 
					
						
						|  | outs = [] | 
					
						
						|  | frame_sizes = [1] * steps | 
					
						
						|  |  | 
					
						
						|  | for frame_size in frame_sizes: | 
					
						
						|  | frame = x[:, :frame_size] | 
					
						
						|  | x = x[:, frame_size:] | 
					
						
						|  | outs.append(tr(frame)) | 
					
						
						|  |  | 
					
						
						|  | out = torch.cat(outs, dim=1) | 
					
						
						|  | assert list(out.shape) == [3, steps, 16] | 
					
						
						|  | delta = torch.norm(out - ref) / torch.norm(out) | 
					
						
						|  | assert delta < 1e-6, delta | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_rope_memory_efficient(): | 
					
						
						|  | set_efficient_attention_backend('torch') | 
					
						
						|  | torch.manual_seed(1234) | 
					
						
						|  | tr = StreamingTransformer( | 
					
						
						|  | 16, 4, 2, custom=True, dropout=0., layer_scale=0.1, | 
					
						
						|  | positional_embedding='rope') | 
					
						
						|  | tr_mem_efficient = StreamingTransformer( | 
					
						
						|  | 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1, | 
					
						
						|  | positional_embedding='rope') | 
					
						
						|  | tr_mem_efficient.load_state_dict(tr.state_dict()) | 
					
						
						|  | tr.eval() | 
					
						
						|  | steps = 12 | 
					
						
						|  | x = torch.randn(3, steps, 16) | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | y = tr(x) | 
					
						
						|  | y2 = tr_mem_efficient(x) | 
					
						
						|  |  | 
					
						
						|  | assert torch.allclose(y, y2, atol=1e-7), (y - y2).norm() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_rope_with_xpos(): | 
					
						
						|  | set_efficient_attention_backend('torch') | 
					
						
						|  | B, T, H, C = 8, 75, 16, 128 | 
					
						
						|  |  | 
					
						
						|  | rope = RotaryEmbedding(dim=C, xpos=True) | 
					
						
						|  | xq = torch.rand((B, T, H, C)) | 
					
						
						|  | xk = torch.rand((B, T, H, C)) | 
					
						
						|  | xq_out, xk_out = rope.rotate_qk(xq, xk, start=7) | 
					
						
						|  |  | 
					
						
						|  | assert list(xq_out.shape) == [B, T, H, C] | 
					
						
						|  | assert list(xk_out.shape) == [B, T, H, C] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_positional_scale(): | 
					
						
						|  | set_efficient_attention_backend('torch') | 
					
						
						|  | B, T, H, C = 8, 75, 16, 128 | 
					
						
						|  |  | 
					
						
						|  | rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0) | 
					
						
						|  | xq = torch.rand((B, T, H, C)) | 
					
						
						|  | xk = torch.rand((B, T, H, C)) | 
					
						
						|  | xq_out, xk_out = rope.rotate_qk(xq, xk, start=7) | 
					
						
						|  |  | 
					
						
						|  | assert torch.allclose(xq, xq_out) | 
					
						
						|  | assert torch.allclose(xk, xk_out) | 
					
						
						|  |  |