Fast-GeCo / app.py
OpenSound's picture
11
b0c1420
raw
history blame
6.93 kB
import gradio as gr
import spaces
import numpy as np
import torch
from fastgeco.model import ScoreModel
from geco.util.other import pad_spec
import os
import torchaudio
from speechbrain.lobes.models.dual_path import Encoder, SBTransformerBlock, SBTransformerBlock, Dual_Path_Model, Decoder
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sample_rate = 8000
num_spks = 2
ckpt_path = 'ckpts/'
def load_sepformer(ckpt_path):
encoder = Encoder(
kernel_size=160,
out_channels=256,
in_channels=1
)
SBtfintra = SBTransformerBlock(
num_layers=8,
d_model=256,
nhead=8,
d_ffn=1024,
dropout=0,
use_positional_encoding=True,
norm_before=True,
)
SBtfinter = SBTransformerBlock(
num_layers=8,
d_model=256,
nhead=8,
d_ffn=1024,
dropout=0,
use_positional_encoding=True,
norm_before=True,
)
masknet = Dual_Path_Model(
num_spks=num_spks,
in_channels=256,
out_channels=256,
num_layers=2,
K=250,
intra_model=SBtfintra,
inter_model=SBtfinter,
norm='ln',
linear_layer_after_inter_intra=False,
skip_around_intra=True,
)
decoder = Decoder(
in_channels=256,
out_channels=1,
kernel_size=160,
stride=80,
bias=False,
)
encoder_weights = torch.load(os.path.join(ckpt_path, 'encoder.ckpt'))
encoder.load_state_dict(encoder_weights)
masknet_weights = torch.load(os.path.join(ckpt_path, 'masknet.ckpt'))
masknet.load_state_dict(masknet_weights)
decoder_weights = torch.load(os.path.join(ckpt_path, 'decoder.ckpt'))
decoder.load_state_dict(decoder_weights)
encoder = encoder.eval().to(device)
masknet = masknet.eval().to(device)
decoder = decoder.eval().to(device)
return encoder, masknet, decoder
def load_fastgeco(ckpt_path):
checkpoint_file = os.path.join(ckpt_path, 'fastgeco.ckpt')
model = ScoreModel.load_from_checkpoint(
checkpoint_file,
batch_size=1, num_workers=0, kwargs=dict(gpu=False)
)
model.eval(no_ema=False)
model.to(device)
return model
encoder, masknet, decoder = load_sepformer(ckpt_path)
fastgeco_model = load_fastgeco(ckpt_path)
@spaces.GPU
def separate(test_file, encoder, masknet, decoder):
with torch.no_grad():
print('Process SepFormer...')
mix, fs_file = torchaudio.load(test_file)
mix = mix.to(device)
fs_model = sample_rate
# resample the data if needed
if fs_file != fs_model:
print(
"Resampling the audio from {} Hz to {} Hz".format(
fs_file, fs_model
)
)
tf = torchaudio.transforms.Resample(
orig_freq=fs_file, new_freq=fs_model
).to(device)
mix = mix.mean(dim=0, keepdim=True)
mix = tf(mix)
mix = mix.to(device)
# Separation
mix_w = encoder(mix)
est_mask = masknet(mix_w)
mix_w = torch.stack([mix_w] * num_spks)
sep_h = mix_w * est_mask
# Decoding
est_sources = torch.cat(
[
decoder(sep_h[i]).unsqueeze(-1)
for i in range(num_spks)
],
dim=-1,
)
est_sources = (
est_sources / est_sources.abs().max(dim=1, keepdim=True)[0]
).squeeze()
return est_sources, mix
@spaces.GPU
def correct(model, est_sources, mix):
with torch.no_grad():
print('Process Fast-Geco...')
N = 1
reverse_starting_point = 0.5
output = []
for idx in range(num_spks):
y = est_sources[:, idx].unsqueeze(0) # noisy
m = mix
min_leng = min(y.shape[-1],m.shape[-1])
y = y[...,:min_leng]
m = m[...,:min_leng]
T_orig = y.size(1)
norm_factor = y.abs().max()
y = y / norm_factor
m = m / norm_factor
Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(device))), 0)
Y = pad_spec(Y)
M = torch.unsqueeze(model._forward_transform(model._stft(m.to(device))), 0)
M = pad_spec(M)
timesteps = torch.linspace(reverse_starting_point, 0.03, N, device=Y.device)
std = model.sde._std(reverse_starting_point*torch.ones((Y.shape[0],), device=Y.device))
z = torch.randn_like(Y)
X_t = Y + z * std[:, None, None, None]
t = timesteps[0]
dt = timesteps[-1]
f, g = model.sde.sde(X_t, t, Y)
vec_t = torch.ones(Y.shape[0], device=Y.device) * t
mean_x_tm1 = X_t - (f - g**2*model.forward(X_t, vec_t, Y, M, vec_t[:,None,None,None]))*dt #mean of x t minus 1 = mu(x_{t-1})
sample = mean_x_tm1
sample = sample.squeeze()
x_hat = model.to_audio(sample.squeeze(), T_orig)
x_hat = x_hat * norm_factor
new_norm_factor = x_hat.abs().max()
x_hat = x_hat / new_norm_factor
x_hat = x_hat.squeeze().cpu().numpy()
output.append(x_hat)
return (output[0], sample_rate), (output[1], sample_rate)
@spaces.GPU
def process_audio(test_file):
result, mix = separate(test_file, encoder, masknet, decoder)
audio1, audio2 = correct(fastgeco_model, result, mix)
return audio1, audio2
# CSS styling (optional)
css = """
#col-container {
margin: 0 auto;
max-width: 1280px;
}
"""
# Gradio Blocks layout
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("""
# Fast-GeCo: Noise-robust Speech Separation with Fast Generative Correction
Separate the noisy mixture speech with a generative correction method, only support 2 speakers now.
Learn more about 🟣**Fast-GeCo** on the [Fast-GeCo Repo](https://github.com/WangHelin1997/Fast-GeCo/).
""")
with gr.Tab("Speech Separation"):
# Input: Upload audio file
with gr.Row():
gt_file_input = gr.Audio(label="Upload Audio to Separate", type="filepath", value="demo/item0_mix.wav")
button = gr.Button("Generate", scale=1)
# Output Component for edited audio
with gr.Row():
result1 = gr.Audio(label="Separated Audio 1", type="numpy")
result2 = gr.Audio(label="Separated Audio 2", type="numpy")
# Define the trigger and input-output linking
button.click(
fn=process_audio,
inputs=[
gt_file_input,
],
outputs=[result1, result2]
)
# Launch the Gradio demo
demo.launch()