File size: 3,634 Bytes
7c0b7db
 
a803368
7c0b7db
 
 
 
9868c23
7c0b7db
39540ac
 
7c0b7db
 
 
 
 
 
 
 
0df1e05
7c0b7db
 
a803368
7c0b7db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43b9586
 
7c0b7db
 
 
696e7d4
 
 
 
96a2826
7c0b7db
3fb4243
7c0b7db
 
f483f0f
 
7c0b7db
 
 
96a2826
f483f0f
 
7c0b7db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
os.system("git clone https://github.com/v-iashin/SpecVQGAN")
os.system("pip install pytorch-lightning==1.2.10 omegaconf==2.0.6 streamlit==0.80 matplotlib==3.4.1 albumentations==0.5.2 SoundFile torch torchvision librosa gdown")

from pathlib import Path
import soundfile
import torch
import gradio as gr

import sys
sys.path.append('./SpecVQGAN')
from feature_extraction.demo_utils import (calculate_codebook_bitrate,
                                           extract_melspectrogram,
                                           get_audio_file_bitrate,
                                           get_duration,
                                           load_neural_audio_codec)
from sample_visualization import tensor_to_plt
from torch.utils.data.dataloader import default_collate

os.chdir("SpecVQGAN")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

os.system("gdown https://drive.google.com/uc?id=1KGof44Sx4yIn4Hohpp9-VVTh2zGucKeY")

model_name = '2021-05-19T22-16-54_vggsound_codebook'
log_dir = './logs'
# loading the models might take a few minutes
config, model, vocoder = load_neural_audio_codec(model_name, log_dir, device)

def inference(audio):
  # Select an Audio
  input_wav = audio.name
  
  # Spectrogram Extraction
  model_sr = config.data.params.sample_rate
  duration = get_duration(input_wav)
  spec = extract_melspectrogram(input_wav, sr=model_sr, duration=duration)
  print(f'Audio Duration: {duration} seconds')
  print('Original Spectrogram Shape:', spec.shape)
  
  # Prepare Input
  spectrogram = {'input': spec}
  batch = default_collate([spectrogram])
  batch['image'] = batch['input'].to(device)
  x = model.get_input(batch, 'image')
  
  with torch.no_grad():
    quant_z, diff, info = model.encode(x)
    xrec = model.decode(quant_z)
  
  print('Compressed representation (it is all you need to recover the audio):')
  F, T = quant_z.shape[-2:]
  print(info[2].reshape(F, T))
  
  
    # Calculate Bitrate
  bitrate = calculate_codebook_bitrate(duration, quant_z, model.quantize.n_e)
  orig_bitrate = get_audio_file_bitrate(input_wav)
  
  # Save and Display
  x = x.squeeze(0)
  xrec = xrec.squeeze(0)
  # specs are in [-1, 1], making them in [0, 1]
  wav_x = vocoder((x + 1) / 2).squeeze().detach().cpu().numpy()
  wav_xrec = vocoder((xrec + 1) / 2).squeeze().detach().cpu().numpy()
  # Save paths
  x_save_path = 'vocoded_orig_spec.wav'
  xrec_save_path = f'specvqgan_{bitrate:.2f}kbps.wav'
  # Save
  soundfile.write(x_save_path, wav_x, model_sr, 'PCM_16')
  soundfile.write(xrec_save_path, wav_xrec, model_sr, 'PCM_16')
  return 'vocoded_orig_spec.wav', f'specvqgan_{bitrate:.2f}kbps.wav', tensor_to_plt(x, flip_dims=(2,)), tensor_to_plt(xrec, flip_dims=(2,))

title = "SpecVQGAN Neural Audio Codec"
description = "Gradio demo for Spectrogram VQGAN as a Neural Audio Codec. To use it, simply add your audio, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2110.08791' target='_blank'>Taming Visually Guided Sound Generation</a> | <a href='https://github.com/v-iashin/SpecVQGAN' target='_blank'>Github Repo</a></p>"

examples=[['example.wav']]
gr.Interface(
    inference, 
    gr.Audio(type="file", label="Input Audio"), 
    [gr.Audio(type="file", label="Original audio"),gr.Audio(type="file", label="Reconstructed audio"),gr.Plot(label="Original Spectrogram:"),gr.Plot(label="Reconstructed Spectrogram:")],
    title=title,
    description=description,
    article=article,
    enable_queue=True,
    examples=examples,
    cache_examples=True
    ).launch(debug=True)