Spaces:
Runtime error
Runtime error
import hydra | |
from omegaconf import DictConfig | |
import torch | |
from remfx.models import RemFXChainInference | |
import torchaudio | |
def main(cfg: DictConfig): | |
print("Loading models...") | |
models = {} | |
for effect in cfg.ckpts: | |
model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial") | |
ckpt_path = cfg.ckpts[effect].ckpt_path | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"] | |
model.load_state_dict(state_dict) | |
model.to(device) | |
models[effect] = model | |
classifier = hydra.utils.instantiate(cfg.classifier, _convert_="partial") | |
ckpt_path = cfg.classifier_ckpt | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"] | |
classifier.load_state_dict(state_dict) | |
classifier.to(device) | |
inference_model = RemFXChainInference( | |
models, | |
sample_rate=cfg.sample_rate, | |
num_bins=cfg.num_bins, | |
effect_order=cfg.inference_effects_ordering, | |
classifier=classifier, | |
shuffle_effect_order=cfg.inference_effects_shuffle, | |
use_all_effect_models=cfg.inference_use_all_effect_models, | |
) | |
audio_file = cfg.audio_input | |
print("Loading", audio_file) | |
audio, sr = torchaudio.load(audio_file) | |
# Resample | |
audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio) | |
# Convert to mono | |
audio = audio.mean(0, keepdim=True) | |
# Add dimension for batch | |
audio = audio.unsqueeze(0) | |
audio = audio.to(device) | |
batch = [audio, audio, None, None] | |
_, y = inference_model(batch, 0, verbose=True) | |
y = y.cpu() | |
if "output_path" in cfg: | |
output_path = cfg.output_path | |
else: | |
output_path = "./output.wav" | |
print("Saving output to", output_path) | |
torchaudio.save(output_path, y[0], sample_rate=cfg.sample_rate) | |
if __name__ == "__main__": | |
main() | |