RemFx / scripts /remfx_detect.py
mattricesound's picture
Fix gpu issues in remfx_detect script
6c005f0
raw
history blame
No virus
2.11 kB
import hydra
from omegaconf import DictConfig
import torch
from remfx.models import RemFXChainInference
import torchaudio
@hydra.main(
version_base=None,
config_path="../cfg",
config_name="config.yaml",
)
def main(cfg: DictConfig):
print("Loading models...")
models = {}
for effect in cfg.ckpts:
model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial")
ckpt_path = cfg.ckpts[effect].ckpt_path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
model.load_state_dict(state_dict)
model.to(device)
models[effect] = model
classifier = hydra.utils.instantiate(cfg.classifier, _convert_="partial")
ckpt_path = cfg.classifier_ckpt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
classifier.load_state_dict(state_dict)
classifier.to(device)
inference_model = RemFXChainInference(
models,
sample_rate=cfg.sample_rate,
num_bins=cfg.num_bins,
effect_order=cfg.inference_effects_ordering,
classifier=classifier,
shuffle_effect_order=cfg.inference_effects_shuffle,
use_all_effect_models=cfg.inference_use_all_effect_models,
)
audio_file = cfg.audio_input
print("Loading", audio_file)
audio, sr = torchaudio.load(audio_file)
# Resample
audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio)
# Convert to mono
audio = audio.mean(0, keepdim=True)
# Add dimension for batch
audio = audio.unsqueeze(0)
audio = audio.to(device)
batch = [audio, audio, None, None]
_, y = inference_model(batch, 0, verbose=True)
y = y.cpu()
if "output_path" in cfg:
output_path = cfg.output_path
else:
output_path = "./output.wav"
print("Saving output to", output_path)
torchaudio.save(output_path, y[0], sample_rate=cfg.sample_rate)
if __name__ == "__main__":
main()