| | import torch |
| | import torch.nn as nn |
| | from omini.rotation.layer import Linear, Rotation |
| |
|
| | def test_rotation_merge(): |
| | """ |
| | Test that merging rotation adapter produces the same output as the unmerged version. |
| | """ |
| | print("="*60) |
| | print("Testing Rotation Layer Merge") |
| | print("="*60) |
| | |
| | |
| | torch.manual_seed(42) |
| | |
| | |
| | in_features = 512 |
| | out_features = 1024 |
| | r = 4 |
| | num_rotations = 4 |
| | T = 1.0 |
| | batch_size = 8 |
| | seq_len = 16 |
| | |
| | |
| | base_layer = nn.Linear(in_features, out_features, bias=True) |
| | |
| | |
| | rotation_layer = Linear( |
| | base_layer=base_layer, |
| | adapter_name="default", |
| | r=r, |
| | T=T, |
| | num_rotations=num_rotations |
| | ) |
| | |
| | |
| | x = torch.randn(batch_size, seq_len, in_features) |
| | |
| | |
| | print("\n" + "-"*60) |
| | print("Test 1: Computing output BEFORE merge") |
| | print("-"*60) |
| | rotation_layer.eval() |
| | with torch.no_grad(): |
| | output_before = rotation_layer(x) |
| | |
| | print(f"Output shape: {output_before.shape}") |
| | print(f"Output mean: {output_before.mean().item():.6f}") |
| | print(f"Output std: {output_before.std().item():.6f}") |
| | print(f"Output min: {output_before.min().item():.6f}") |
| | print(f"Output max: {output_before.max().item():.6f}") |
| | |
| | |
| | original_weight = base_layer.weight.data.clone() |
| | |
| | |
| | print("\n" + "-"*60) |
| | print("Test 2: Merging adapter") |
| | print("-"*60) |
| | rotation_layer.merge(safe_merge=True, adapter_names=["default"]) |
| | print(f"β Adapter merged successfully") |
| | print(f"β Merged adapters: {rotation_layer.merged_adapters}") |
| | |
| | |
| | weight_diff = (base_layer.weight.data - original_weight).abs().max().item() |
| | print(f"Max weight change: {weight_diff:.6e}") |
| | |
| | |
| | print("\n" + "-"*60) |
| | print("Test 3: Computing output AFTER merge") |
| | print("-"*60) |
| | with torch.no_grad(): |
| | output_after = rotation_layer(x) |
| | |
| | print(f"Output shape: {output_after.shape}") |
| | print(f"Output mean: {output_after.mean().item():.6f}") |
| | print(f"Output std: {output_after.std().item():.6f}") |
| | print(f"Output min: {output_after.min().item():.6f}") |
| | print(f"Output max: {output_after.max().item():.6f}") |
| | |
| | |
| | print("\n" + "-"*60) |
| | print("Test 4: Comparing outputs") |
| | print("-"*60) |
| | |
| | |
| | abs_diff = (output_after - output_before).abs() |
| | rel_diff = abs_diff / (output_before.abs() + 1e-8) |
| | |
| | max_abs_diff = abs_diff.max().item() |
| | mean_abs_diff = abs_diff.mean().item() |
| | max_rel_diff = rel_diff.max().item() |
| | mean_rel_diff = rel_diff.mean().item() |
| | |
| | print(f"Max absolute difference: {max_abs_diff:.6e}") |
| | print(f"Mean absolute difference: {mean_abs_diff:.6e}") |
| | print(f"Max relative difference: {max_rel_diff:.6e}") |
| | print(f"Mean relative difference: {mean_rel_diff:.6e}") |
| | |
| | |
| | atol = 1e-4 |
| | rtol = 1e-3 |
| | |
| | are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol) |
| | |
| | if are_close: |
| | print(f"\nβ
PASS: Outputs are identical (within atol={atol}, rtol={rtol})") |
| | else: |
| | print(f"\nβ FAIL: Outputs differ significantly") |
| | print(f" Expected: atol < {atol}, rtol < {rtol}") |
| | print(f" Got: max_abs_diff = {max_abs_diff:.6e}, max_rel_diff = {max_rel_diff:.6e}") |
| | |
| | |
| | print("\n" + "-"*60) |
| | print("Test 5: Testing unmerge") |
| | print("-"*60) |
| | rotation_layer.unmerge() |
| | print(f"β Adapter unmerged") |
| | print(f"β Merged adapters: {rotation_layer.merged_adapters}") |
| | |
| | with torch.no_grad(): |
| | output_unmerged = rotation_layer(x) |
| | |
| | unmerge_diff = (output_unmerged - output_before).abs().max().item() |
| | print(f"Max difference after unmerge: {unmerge_diff:.6e}") |
| | |
| | unmerge_close = torch.allclose(output_before, output_unmerged, atol=atol, rtol=rtol) |
| | if unmerge_close: |
| | print(f"β
PASS: Unmerge restored original behavior") |
| | else: |
| | print(f"β FAIL: Unmerge did not restore original behavior") |
| | |
| | |
| | weight_restored_diff = (base_layer.weight.data - original_weight).abs().max().item() |
| | print(f"Max weight difference after unmerge: {weight_restored_diff:.6e}") |
| | |
| | weight_restored = torch.allclose(base_layer.weight.data, original_weight, atol=1e-5) |
| | if weight_restored: |
| | print(f"β
PASS: Original weights restored") |
| | else: |
| | print(f"β FAIL: Original weights not fully restored") |
| | |
| | print("\n" + "="*60) |
| | print("Test Summary") |
| | print("="*60) |
| | return are_close and unmerge_close and weight_restored |
| |
|
| |
|
| | def test_multiple_merges(): |
| | """ |
| | Test merging and unmerging multiple times. |
| | """ |
| | print("\n" + "="*60) |
| | print("Testing Multiple Merge/Unmerge Cycles") |
| | print("="*60) |
| | |
| | torch.manual_seed(42) |
| | |
| | in_features = 256 |
| | out_features = 512 |
| | r = 4 |
| | num_rotations = 4 |
| | |
| | base_layer = nn.Linear(in_features, out_features, bias=True) |
| | rotation_layer = Linear( |
| | base_layer=base_layer, |
| | adapter_name="default", |
| | r=r, |
| | T=1.0, |
| | num_rotations=num_rotations |
| | ) |
| | |
| | x = torch.randn(4, 8, in_features) |
| | rotation_layer.eval() |
| | |
| | |
| | with torch.no_grad(): |
| | original_output = rotation_layer(x) |
| | |
| | |
| | all_passed = True |
| | for cycle in range(3): |
| | print(f"\nCycle {cycle + 1}:") |
| | |
| | |
| | rotation_layer.merge(safe_merge=True) |
| | with torch.no_grad(): |
| | merged_output = rotation_layer(x) |
| | |
| | merge_close = torch.allclose(original_output, merged_output, atol=1e-4, rtol=1e-3) |
| | print(f" Merge: {'β
PASS' if merge_close else 'β FAIL'}") |
| | |
| | |
| | rotation_layer.unmerge() |
| | with torch.no_grad(): |
| | unmerged_output = rotation_layer(x) |
| | |
| | unmerge_close = torch.allclose(original_output, unmerged_output, atol=1e-4, rtol=1e-3) |
| | print(f" Unmerge: {'β
PASS' if unmerge_close else 'β FAIL'}") |
| | |
| | all_passed = all_passed and merge_close and unmerge_close |
| | |
| | return all_passed |
| |
|
| |
|
| | def test_with_different_dtypes(): |
| | """ |
| | Test merging with different data types. |
| | """ |
| | print("\n" + "="*60) |
| | print("Testing Different Data Types") |
| | print("="*60) |
| | |
| | torch.manual_seed(42) |
| | |
| | dtypes = [torch.float32, torch.float16, torch.bfloat16] |
| | all_passed = True |
| | |
| | for dtype in dtypes: |
| | print(f"\nTesting with dtype: {dtype}") |
| | |
| | in_features = 256 |
| | out_features = 512 |
| | r = 4 |
| | num_rotations = 4 |
| | |
| | base_layer = nn.Linear(in_features, out_features, bias=True) |
| | base_layer = base_layer.to(dtype) |
| | |
| | rotation_layer = Linear( |
| | base_layer=base_layer, |
| | adapter_name="default", |
| | r=r, |
| | T=1.0, |
| | num_rotations=num_rotations |
| | ) |
| | rotation_layer = rotation_layer.to(dtype) |
| | |
| | x = torch.randn(4, 8, in_features, dtype=dtype) |
| | rotation_layer.eval() |
| | |
| | with torch.no_grad(): |
| | output_before = rotation_layer(x) |
| | rotation_layer.merge(safe_merge=True) |
| | output_after = rotation_layer(x) |
| | |
| | |
| | if dtype == torch.float32: |
| | atol, rtol = 1e-5, 1e-4 |
| | elif dtype == torch.float16: |
| | atol, rtol = 1e-2, 1e-2 |
| | else: |
| | atol, rtol = 1e-2, 1e-2 |
| | |
| | are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol) |
| | |
| | if are_close: |
| | print(f" β
PASS") |
| | else: |
| | max_diff = (output_after - output_before).abs().max().item() |
| | print(f" β FAIL (max diff: {max_diff:.6e})") |
| | |
| | all_passed = all_passed and are_close |
| | |
| | return all_passed |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("\n" + "="*60) |
| | print("ROTATION LAYER MERGE TEST SUITE") |
| | print("="*60) |
| | |
| | results = {} |
| | |
| | |
| | results["basic_merge"] = test_rotation_merge() |
| | results["multiple_cycles"] = test_multiple_merges() |
| | results["different_dtypes"] = test_with_different_dtypes() |
| | |
| | |
| | print("\n" + "="*60) |
| | print("FINAL SUMMARY") |
| | print("="*60) |
| | |
| | for test_name, passed in results.items(): |
| | status = "β
PASS" if passed else "β FAIL" |
| | print(f"{test_name}: {status}") |
| | |
| | all_passed = all(results.values()) |
| | print("\n" + "="*60) |
| | if all_passed: |
| | print("π ALL TESTS PASSED!") |
| | else: |
| | print("β οΈ SOME TESTS FAILED") |
| | print("="*60) |