| import torch |
| import os |
| from safetensors.torch import load_file |
|
|
| def count_layers(state_dict, exclude_prefixes=None): |
| """ |
| Counts unique layers in a state dict. |
| Groups parameters by their module prefix (everything before the last dot). |
| """ |
| if exclude_prefixes is None: |
| exclude_prefixes = [] |
| |
| total_layers = set() |
| custom_layers = set() |
| |
| for key in state_dict.keys(): |
| parts = key.split('.') |
| if len(parts) > 1: |
| module_name = '.'.join(parts[:-1]) |
| else: |
| module_name = key |
| |
| total_layers.add(module_name) |
| |
| |
| is_pretrained = any(key.startswith(p + '.') or key == p for p in exclude_prefixes) |
| if not is_pretrained: |
| custom_layers.add(module_name) |
| |
| return len(total_layers), len(custom_layers) |
|
|
| def count_parameters(state_dict, exclude_prefixes=None): |
| if exclude_prefixes is None: |
| exclude_prefixes = [] |
| |
| total_params = 0 |
| custom_params = 0 |
| |
| for name, param in state_dict.items(): |
| p_count = param.numel() |
| total_params += p_count |
| |
| is_pretrained = any(name.startswith(p + '.') or name == p for p in exclude_prefixes) |
| if not is_pretrained: |
| custom_params += p_count |
| |
| return total_params, custom_params |
|
|
| |
| models_config = { |
| "Bass": { |
| "file": "bass_sota.pth", |
| "exclude": ["audio_encoder"] |
| }, |
| "Drums": { |
| "file": "drums.safetensors", |
| "exclude": ["wavlm"] |
| }, |
| "Vocals": { |
| "file": "vocals.pt", |
| "exclude": [] |
| } |
| } |
|
|
| print(f"{'='*100}") |
| print(f"{'MODEL':<10} | {'TOTAL LAYERS':<15} | {'CUSTOM LAYERS':<15} | {'CUSTOM PARAMS':<15} | {'FILE'}") |
| print(f"{'='*100}") |
|
|
| for model_name, cfg in models_config.items(): |
| filename = cfg["file"] |
| exclude = cfg["exclude"] |
| |
| if not os.path.exists(filename): |
| print(f"{model_name:<10} | {'MISSING':<15} | {'N/A':<15} | {'N/A':<15} | {filename}") |
| continue |
|
|
| try: |
| if filename.endswith(".safetensors"): |
| data = load_file(filename, device='cpu') |
| else: |
| data = torch.load(filename, map_location='cpu', weights_only=False) |
| |
| |
| if isinstance(data, dict): |
| if "model" in data: |
| data = data["model"] |
| elif "model_state_dict" in data: |
| data = data["model_state_dict"] |
| elif "state_dict" in data: |
| data = data["state_dict"] |
|
|
| total_l, custom_l = count_layers(data, exclude) |
| total_p, custom_p = count_parameters(data, exclude) |
| |
| print(f"{model_name:<10} | {total_l:<15} | {custom_l:<15} | {custom_p:<15,} | {filename}") |
| |
| except Exception as e: |
| print(f"{model_name:<10} | {'ERROR':<15} | {'N/A':<15} | {'N/A':<15} | {filename} - {str(e)[:30]}...") |
|
|
| print(f"{'='*100}") |
|
|