Spaces:
Running on Zero
Running on Zero
| import torch | |
| # from vocos import Vocos | |
| from singer.model import Singer | |
| def load_model(model_cls, model_cfg, ckpt_path, vocab_char_map, device="cuda"): | |
| model_arc = model_cfg.model.arch | |
| mel_spec_kwargs = model_cfg.model.mel_spec | |
| vocab_size = len(vocab_char_map) | |
| backbone = model_cls( | |
| **model_arc, text_num_embeds=vocab_size, mel_dim=mel_spec_kwargs.n_mel_channels | |
| ) | |
| model = Singer( | |
| transformer=backbone, | |
| mel_spec_kwargs=mel_spec_kwargs, | |
| vocab_char_map=vocab_char_map, | |
| ) | |
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| if "ema_model_state_dict" in checkpoint: | |
| state_dict = checkpoint["ema_model_state_dict"] | |
| elif "model_state_dict" in checkpoint: | |
| state_dict = checkpoint["model_state_dict"] | |
| else: | |
| state_dict = checkpoint | |
| # Handle module prefix | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith("module."): | |
| new_state_dict[k[7:]] = v | |
| else: | |
| new_state_dict[k] = v | |
| model.load_state_dict(new_state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def load_vocoder(vocoder_name, is_local, local_path, device="cuda"): | |
| if vocoder_name == "vocos": | |
| if is_local: | |
| vocoder = Vocos.from_hparams(local_path).to(device) | |
| else: | |
| vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) | |
| elif vocoder_name == "bigvgan": | |
| # Placeholder for bigvgan | |
| # You might need to import bigvgan here | |
| raise NotImplementedError("BigVGAN loading not implemented yet") | |
| else: | |
| # Fallback or error | |
| print( | |
| f"Warning: Unknown vocoder {vocoder_name}, trying to load from local path if provided" | |
| ) | |
| if is_local: | |
| # Try loading as vocos or similar if generic | |
| vocoder = Vocos.from_hparams(local_path).to(device) | |
| else: | |
| raise ValueError(f"Unknown vocoder: {vocoder_name}") | |
| return vocoder | |