from ttts.vqvae.xtts_dvae import DiscreteVAE from ttts.diffusion.model import DiffusionTts from ttts.gpt.model import UnifiedVoice from ttts.classifier.model import AudioMiniEncoderWithClassifierHead from omegaconf import OmegaConf from ttts.diffusion.aa_model import AA_diffusion import json import torch import os def load_model(model_name, model_path, config_path, device): config_path = os.path.expanduser(config_path) model_path = os.path.expanduser(model_path) if config_path.endswith('.json'): config = json.load(open(config_path)) else: config = OmegaConf.load(config_path) if model_name=='vqvae': model = DiscreteVAE(**config['vqvae']) sd = torch.load(model_path, map_location=device)['model'] model.load_state_dict(sd, strict=True) model = model.to(device) elif model_name=='gpt': model = UnifiedVoice(**config['gpt']) gpt = torch.load(model_path, map_location=device)['model'] model.load_state_dict(gpt, strict=True) model = model.to(device) elif model_name=='diffusion': # model = DiffusionTts(**config['diffusion']) model = AA_diffusion(config) diffusion = torch.load(model_path, map_location=device)['model'] model.load_state_dict(diffusion, strict=True) model = model.to(device) elif model_name == 'classifier': model = AudioMiniEncoderWithClassifierHead(**config['classifier']) classifier = torch.load(model_path, map_location=device)['model'] model.load_state_dict(classifier, strict=True) model = model.to(device) # elif model_name=='clvp': return model.eval()