xVASynth / resources /app /python /models_manager.py
Pendrokar's picture
manual model download
19b56c2
raw
history blame contribute delete
No virus
6.84 kB
import os
import torch
import traceback
class ModelsManager(object):
def __init__(self, logger, PROD, device="cpu"):
super(ModelsManager, self).__init__()
self.models_bank = {}
self.logger = logger
self.PROD = PROD
self.device_label = device
self.device = torch.device(device)
def init_model (self, model_key, instance_index=0):
model_key = model_key.lower()
try:
if model_key in list(self.models_bank.keys()) and instance_index in self.models_bank[model_key].keys() and self.models_bank[model_key][instance_index].isReady:
return
self.logger.info(f'ModelsManager: Initializing model: {model_key}')
if model_key=="hifigan":
from python.hifigan.model import HiFi_GAN
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = HiFi_GAN(self.logger, self.PROD, self.device, self)
elif model_key=="big_waveglow":
from python.big_waveglow.model import BIG_WaveGlow
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = BIG_WaveGlow(self.logger, self.PROD, self.device, self)
elif model_key=="256_waveglow":
from python.waveglow.model import WaveGlow
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = WaveGlow(self.logger, self.PROD, self.device, self)
elif model_key=="fastpitch":
from python.fastpitch.model import FastPitch
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = FastPitch(self.logger, self.PROD, self.device, self)
elif model_key=="fastpitch1_1":
from python.fastpitch1_1.model import FastPitch1_1
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = FastPitch1_1(self.logger, self.PROD, self.device, self)
elif model_key=="xvapitch":
from python.xvapitch.model import xVAPitch
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = xVAPitch(self.logger, self.PROD, self.device, self)
elif model_key=="s2s_fastpitch1_1":
from python.fastpitch1_1.model import FastPitch1_1 as S2S_FastPitch1_1
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = S2S_FastPitch1_1(self.logger, self.PROD, self.device, self)
elif model_key=="wav2vec2":
from python.wav2vec2.model import Wav2Vec2
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = Wav2Vec2(self.logger, self.PROD, self.device, self)
elif model_key=="speaker_rep":
from python.xvapitch.speaker_rep.model import ResNetSpeakerEncoder
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = ResNetSpeakerEncoder(self.logger, self.PROD, self.device, self)
elif model_key=="nuwave2":
from python.nuwave2.model import Nuwave2Model
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = Nuwave2Model(self.logger, self.PROD, self.device, self)
elif model_key=="deepfilternet2":
from python.deepfilternet2.model import DeepFilter2Model
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = DeepFilter2Model(self.logger, self.PROD, self.device, self)
else:
raise(f'Model not recognized: {model_key}')
try:
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index].model = self.models_bank[model_key][instance_index].model.to(self.device)
except:
pass
try:
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = self.models_bank[model_key][instance_index].to(self.device)
except:
pass
except:
self.logger.info(traceback.format_exc())
def load_model (self, model_key, ckpt_path, instance_index=0, **kwargs):
if model_key not in self.models_bank.keys() or instance_index not in self.models_bank[model_key].keys():
self.init_model(model_key, instance_index)
if not os.path.exists(ckpt_path):
self.logger.error('Checkpoint not found!')
raise FileNotFoundError()
if self.models_bank[model_key][instance_index].ckpt_path != ckpt_path:
self.logger.info(f'ModelsManager: Loading model checkpoint: {model_key}, {ckpt_path}')
ckpt = torch.load(ckpt_path, map_location="cpu")
try:
self.models_bank[model_key][instance_index].load_checkpoint(ckpt_path, ckpt, **kwargs)
except:
self.models_bank[model_key][instance_index].load_state_dict(ckpt_path, ckpt, **kwargs)
def set_device (self, device, instance_index=0):
if device=="gpu":
device = "cuda:0"
if self.device_label==device:
return
self.device_label = device
self.device = torch.device(device)
self.logger.info(f'ModelsManager: Changing device to: {device}')
for model_key in list(self.models_bank.keys()):
self.models_bank[model_key][instance_index].set_device(self.device)
def models (self, key, instance_index=0):
if key.lower() not in self.models_bank.keys() or instance_index not in self.models_bank[key.lower()].keys():
self.init_model(key.lower(), instance_index=instance_index)
return self.models_bank[key.lower()][instance_index]