aet_demo / app.py
saefro991's picture
Update app.py
2d8f24b
import pathlib
import yaml
import torch
import torchaudio
import numpy as np
from lightning_module import SSLDualLightningModule
import gradio as gr
import subprocess
import requests
def normalize_waveform(wav, sr, db=-3):
wav, _ = torchaudio.sox_effects.apply_effects_tensor(
wav.unsqueeze(0),
sr,
[["norm", "{}".format(db)]],
)
return wav.squeeze(0)
def download_file_from_google_drive(id, destination):
URL = "https://docs.google.com/uc?export=download"
session = requests.Session()
response = session.get(URL, params = { 'id' : id }, stream = True)
token = get_confirm_token(response)
if token:
params = { 'id' : id, 'confirm' : token }
response = session.get(URL, params = params, stream = True)
save_response_content(response, destination)
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination):
CHUNK_SIZE = 32768
with open(destination, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
def calc_spectrogram(wav, config):
spec_module = torchaudio.transforms.MelSpectrogram(
sample_rate=config["preprocess"]["sampling_rate"],
n_fft=config["preprocess"]["fft_length"],
win_length=config["preprocess"]["frame_length"],
hop_length=config["preprocess"]["frame_shift"],
f_min=config["preprocess"]["fmin"],
f_max=config["preprocess"]["fmax"],
n_mels=config["preprocess"]["n_mels"],
power=1,
center=True,
norm="slaney",
mel_scale="slaney",
)
specs = spec_module(wav)
log_spec = torch.log(
torch.clamp_min(specs, config["preprocess"]["min_magnitude"])
* config["preprocess"]["comp_factor"]
).to(torch.float32)
return log_spec
def transfer(audio):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wp_src = pathlib.Path("aet_sample/src.wav")
wav_src, sr = torchaudio.load(wp_src)
sr_inp, wav_tar = audio
wav_tar = wav_tar / (np.max(np.abs(wav_tar)) * 1.1)
wav_tar = torch.from_numpy(wav_tar.astype(np.float32))
resampler = torchaudio.transforms.Resample(
orig_freq=sr_inp,
new_freq=sr,
)
wav_tar = resampler(wav_tar)
config_path = pathlib.Path("configs/test/melspec/ssl_tono.yaml")
config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
melspec_src = calc_spectrogram(
normalize_waveform(wav_src.squeeze(0), sr), config
)
wav_tar = normalize_waveform(wav_tar.squeeze(0), sr)
ckpt_path = pathlib.Path("tono_aet_melspec.ckpt").resolve()
src_model = SSLDualLightningModule(config).load_from_checkpoint(
checkpoint_path=ckpt_path,
config=config,
strict=False
).eval()
encoder_src = src_model.encoder.to(device)
channelfeats_src = src_model.channelfeats.to(device)
channel_src = src_model.channel.to(device)
with torch.no_grad():
_, enc_hidden_src = encoder_src(
melspec_src.unsqueeze(0).unsqueeze(1).transpose(2, 3).to(device)
)
chfeats_src = channelfeats_src(enc_hidden_src)
wav_transfer = channel_src(wav_tar.unsqueeze(0), chfeats_src)
wav_transfer = wav_transfer.cpu().detach().numpy()[0, :]
return sr, wav_transfer
if __name__ == "__main__":
subprocess.run(["curl", "-OL", "https://sarulab.sakura.ne.jp/saeki/selfremaster/pretrained/tono_aet_melspec.ckpt"])
download_file_from_google_drive("10OJ2iznutxzp8MEIS6lBVaIS_g5c_70V", "hifigan/hifigan_melspec_universal")
iface = gr.Interface(
transfer,
"audio",
gr.outputs.Audio(type="numpy"),
examples=[
["aet_sample/tar.wav"]
],
layout="horizontal",
title='Audio effect transfer with SelfRemaster',
description='Extracting the channel feature of a historical audio recording with a pretrained SelfRemaster and adding it to any high-quality audio. (Source audio is aet_sample/src.wav)'
)
iface.launch()