Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 6,843 Bytes
19c8b95 19b56c2 19c8b95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import torch
import traceback
class ModelsManager(object):
def __init__(self, logger, PROD, device="cpu"):
super(ModelsManager, self).__init__()
self.models_bank = {}
self.logger = logger
self.PROD = PROD
self.device_label = device
self.device = torch.device(device)
def init_model (self, model_key, instance_index=0):
model_key = model_key.lower()
try:
if model_key in list(self.models_bank.keys()) and instance_index in self.models_bank[model_key].keys() and self.models_bank[model_key][instance_index].isReady:
return
self.logger.info(f'ModelsManager: Initializing model: {model_key}')
if model_key=="hifigan":
from python.hifigan.model import HiFi_GAN
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = HiFi_GAN(self.logger, self.PROD, self.device, self)
elif model_key=="big_waveglow":
from python.big_waveglow.model import BIG_WaveGlow
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = BIG_WaveGlow(self.logger, self.PROD, self.device, self)
elif model_key=="256_waveglow":
from python.waveglow.model import WaveGlow
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = WaveGlow(self.logger, self.PROD, self.device, self)
elif model_key=="fastpitch":
from python.fastpitch.model import FastPitch
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = FastPitch(self.logger, self.PROD, self.device, self)
elif model_key=="fastpitch1_1":
from python.fastpitch1_1.model import FastPitch1_1
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = FastPitch1_1(self.logger, self.PROD, self.device, self)
elif model_key=="xvapitch":
from python.xvapitch.model import xVAPitch
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = xVAPitch(self.logger, self.PROD, self.device, self)
elif model_key=="s2s_fastpitch1_1":
from python.fastpitch1_1.model import FastPitch1_1 as S2S_FastPitch1_1
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = S2S_FastPitch1_1(self.logger, self.PROD, self.device, self)
elif model_key=="wav2vec2":
from python.wav2vec2.model import Wav2Vec2
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = Wav2Vec2(self.logger, self.PROD, self.device, self)
elif model_key=="speaker_rep":
from python.xvapitch.speaker_rep.model import ResNetSpeakerEncoder
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = ResNetSpeakerEncoder(self.logger, self.PROD, self.device, self)
elif model_key=="nuwave2":
from python.nuwave2.model import Nuwave2Model
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = Nuwave2Model(self.logger, self.PROD, self.device, self)
elif model_key=="deepfilternet2":
from python.deepfilternet2.model import DeepFilter2Model
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = DeepFilter2Model(self.logger, self.PROD, self.device, self)
else:
raise(f'Model not recognized: {model_key}')
try:
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index].model = self.models_bank[model_key][instance_index].model.to(self.device)
except:
pass
try:
if model_key not in self.models_bank.keys():
self.models_bank[model_key] = {}
self.models_bank[model_key][instance_index] = self.models_bank[model_key][instance_index].to(self.device)
except:
pass
except:
self.logger.info(traceback.format_exc())
def load_model (self, model_key, ckpt_path, instance_index=0, **kwargs):
if model_key not in self.models_bank.keys() or instance_index not in self.models_bank[model_key].keys():
self.init_model(model_key, instance_index)
if not os.path.exists(ckpt_path):
self.logger.error('Checkpoint not found!')
raise FileNotFoundError()
if self.models_bank[model_key][instance_index].ckpt_path != ckpt_path:
self.logger.info(f'ModelsManager: Loading model checkpoint: {model_key}, {ckpt_path}')
ckpt = torch.load(ckpt_path, map_location="cpu")
try:
self.models_bank[model_key][instance_index].load_checkpoint(ckpt_path, ckpt, **kwargs)
except:
self.models_bank[model_key][instance_index].load_state_dict(ckpt_path, ckpt, **kwargs)
def set_device (self, device, instance_index=0):
if device=="gpu":
device = "cuda:0"
if self.device_label==device:
return
self.device_label = device
self.device = torch.device(device)
self.logger.info(f'ModelsManager: Changing device to: {device}')
for model_key in list(self.models_bank.keys()):
self.models_bank[model_key][instance_index].set_device(self.device)
def models (self, key, instance_index=0):
if key.lower() not in self.models_bank.keys() or instance_index not in self.models_bank[key.lower()].keys():
self.init_model(key.lower(), instance_index=instance_index)
return self.models_bank[key.lower()][instance_index]
|