| |
| |
| |
| |
| """Verify converted safetensors match original .pt checkpoints.""" |
|
|
| import json |
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, 'original') |
|
|
| import torch |
| from safetensors.torch import load_file |
|
|
| from models.anomaly_transformer import get_anomaly_transformer |
|
|
|
|
| def verify(dataset: str) -> bool: |
| """Verify a single converted checkpoint.""" |
| pt_path = Path(f'{dataset}_parameters.pt') |
| config_path = Path(dataset) / 'config.json' |
| safetensors_path = Path(dataset) / 'model.safetensors' |
|
|
| |
| original = torch.load(pt_path, map_location='cpu', weights_only=False) |
| original_sd = original.state_dict() |
|
|
| |
| with open(config_path) as f: |
| config = json.load(f) |
|
|
| model = get_anomaly_transformer( |
| input_d_data=config['input_d_data'], |
| output_d_data=config['output_d_data'], |
| patch_size=config['patch_size'], |
| d_embed=config['d_embed'], |
| hidden_dim_rate=config['hidden_dim_rate'], |
| max_seq_len=config['max_seq_len'], |
| positional_encoding=config['positional_encoding'], |
| relative_position_embedding=config['relative_position_embedding'], |
| transformer_n_layer=config['transformer_n_layer'], |
| transformer_n_head=config['transformer_n_head'], |
| dropout=config['dropout'], |
| ) |
|
|
| |
| saved_sd = load_file(str(safetensors_path)) |
| model.load_state_dict(saved_sd) |
| loaded_sd = model.state_dict() |
|
|
| |
| ok = True |
| for key in original_sd: |
| if key not in loaded_sd: |
| print(f' MISSING: {key}') |
| ok = False |
| continue |
| if not torch.equal(original_sd[key], loaded_sd[key]): |
| diff = (original_sd[key] - loaded_sd[key]).abs().max().item() |
| print(f' MISMATCH: {key} (max diff={diff})') |
| ok = False |
|
|
| extra = set(loaded_sd.keys()) - set(original_sd.keys()) |
| if extra: |
| print(f' EXTRA keys: {extra}') |
| ok = False |
|
|
| status = 'OK' if ok else 'FAIL' |
| print(f'{dataset}: {status}') |
| return ok |
|
|
|
|
| def main() -> None: |
| """Verify all converted checkpoints.""" |
| datasets = ['MSL', 'SMAP', 'SWaT', 'WADI'] |
| results = {d: verify(d) for d in datasets} |
| all_ok = all(results.values()) |
| print(f'\nAll passed: {all_ok}') |
| if not all_ok: |
| sys.exit(1) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|