| | |
| | """ |
| | Convert original WavTokenizer checkpoint to HuggingFace format. |
| | |
| | Usage: |
| | python convert_wavtokenizer.py \ |
| | --config_path configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml \ |
| | --checkpoint_path checkpoints/wavtokenizer_small_320_24k_4096.ckpt \ |
| | --output_dir ./wavtokenizer_hf_converted |
| | |
| | This will create a HuggingFace-compatible model directory that can be loaded with: |
| | model = AutoModel.from_pretrained("./wavtokenizer_hf_converted", trust_remote_code=True) |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import os |
| | import shutil |
| | from pathlib import Path |
| |
|
| | import torch |
| | import yaml |
| |
|
| |
|
| | def convert_wavtokenizer(config_path: str, checkpoint_path: str, output_dir: str): |
| | """Convert WavTokenizer checkpoint to HuggingFace format.""" |
| | |
| | print(f"Loading config from: {config_path}") |
| | print(f"Loading checkpoint from: {checkpoint_path}") |
| | |
| | |
| | with open(config_path, 'r') as f: |
| | yaml_cfg = yaml.safe_load(f) |
| | |
| | |
| | model_args = yaml_cfg.get('model', {}).get('init_args', {}) |
| | |
| | |
| | head_args = model_args.get('head', {}).get('init_args', {}) |
| | backbone_args = model_args.get('backbone', {}).get('init_args', {}) |
| | quantizer_args = model_args.get('quantizer', {}).get('init_args', {}) |
| | feature_extractor_args = model_args.get('feature_extractor', {}).get('init_args', {}) |
| | |
| | |
| | hf_config = { |
| | "_name_or_path": "WavTokenizerSmall", |
| | "architectures": ["WavTokenizer"], |
| | "auto_map": { |
| | "AutoConfig": "configuration_wavtokenizer.WavTokenizerConfig", |
| | "AutoModel": "modeling_wavtokenizer.WavTokenizer" |
| | }, |
| | "model_type": "wavtokenizer", |
| | |
| | |
| | "sample_rate": feature_extractor_args.get('sample_rate', 24000), |
| | "n_fft": head_args.get('n_fft', 1280), |
| | "hop_length": head_args.get('hop_length', 320), |
| | "n_mels": feature_extractor_args.get('n_mels', 128), |
| | "padding": head_args.get('padding', 'center'), |
| | |
| | |
| | "feature_dim": backbone_args.get('dim', 512), |
| | "encoder_dim": 64, |
| | "encoder_rates": [8, 5, 4, 2], |
| | "latent_dim": backbone_args.get('input_channels', 512), |
| | |
| | |
| | "codebook_size": quantizer_args.get('codebook_size', 4096), |
| | "codebook_dim": quantizer_args.get('codebook_dim', 8), |
| | "num_quantizers": quantizer_args.get('num_quantizers', 1), |
| | |
| | |
| | "backbone_type": "vocos", |
| | "backbone_dim": backbone_args.get('dim', 512), |
| | "backbone_num_blocks": backbone_args.get('num_layers', 8), |
| | "backbone_intermediate_dim": backbone_args.get('intermediate_dim', 1536), |
| | "backbone_kernel_size": 7, |
| | "backbone_layer_scale_init_value": 1e-6, |
| | |
| | |
| | "head_type": "istft", |
| | "head_dim": head_args.get('n_fft', 1280) // 2 + 1, |
| | |
| | |
| | "use_attention": True, |
| | "attention_dim": backbone_args.get('dim', 512), |
| | "attention_heads": 8, |
| | "attention_layers": 1, |
| | |
| | "torch_dtype": "float32", |
| | "transformers_version": "4.40.0" |
| | } |
| | |
| | |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | config_out_path = os.path.join(output_dir, "config.json") |
| | with open(config_out_path, 'w') as f: |
| | json.dump(hf_config, f, indent=2) |
| | print(f"Saved config to: {config_out_path}") |
| | |
| | |
| | print("Loading checkpoint...") |
| | ckpt = torch.load(checkpoint_path, map_location='cpu') |
| | state_dict = ckpt.get('state_dict', ckpt) |
| | |
| | |
| | new_state_dict = {} |
| | for k, v in state_dict.items(): |
| | |
| | if k.startswith('model.'): |
| | k = k[6:] |
| | new_state_dict[k] = v |
| | |
| | |
| | model_out_path = os.path.join(output_dir, "pytorch_model.bin") |
| | torch.save(new_state_dict, model_out_path) |
| | print(f"Saved model weights to: {model_out_path}") |
| | |
| | |
| | script_dir = Path(__file__).parent |
| | |
| | |
| | config_py = script_dir / "configuration_wavtokenizer.py" |
| | if config_py.exists(): |
| | shutil.copy(config_py, output_dir) |
| | print(f"Copied: configuration_wavtokenizer.py") |
| | |
| | |
| | modeling_py = script_dir / "modeling_wavtokenizer.py" |
| | if modeling_py.exists(): |
| | shutil.copy(modeling_py, output_dir) |
| | print(f"Copied: modeling_wavtokenizer.py") |
| | |
| | |
| | readme = script_dir / "README.md" |
| | if readme.exists(): |
| | shutil.copy(readme, output_dir) |
| | print(f"Copied: README.md") |
| | |
| | print(f"\nConversion complete! Model saved to: {output_dir}") |
| | print("\nTo load the model:") |
| | print(f' model = AutoModel.from_pretrained("{output_dir}", trust_remote_code=True)') |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Convert WavTokenizer checkpoint to HuggingFace format") |
| | parser.add_argument( |
| | "--config_path", |
| | type=str, |
| | required=True, |
| | help="Path to WavTokenizer YAML config file" |
| | ) |
| | parser.add_argument( |
| | "--checkpoint_path", |
| | type=str, |
| | required=True, |
| | help="Path to WavTokenizer .ckpt checkpoint file" |
| | ) |
| | parser.add_argument( |
| | "--output_dir", |
| | type=str, |
| | default="./wavtokenizer_hf_converted", |
| | help="Output directory for HuggingFace model" |
| | ) |
| | |
| | args = parser.parse_args() |
| | convert_wavtokenizer(args.config_path, args.checkpoint_path, args.output_dir) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |