Ahsen Khaliq commited on
Commit
7c0b7db
1 Parent(s): 8c22980

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("git clone https://github.com/v-iashin/SpecVQGAN")
3
+ os.system("pip install pytorch-lightning==1.2.10 omegaconf==2.0.6 streamlit==0.80 matplotlib==3.4.1 albumentations==0.5.2 SoundFile torch")
4
+
5
+ from pathlib import Path
6
+ import soundfile
7
+ import torch
8
+
9
+ os.chdir("SpecVQGAN")
10
+
11
+ from feature_extraction.demo_utils import (calculate_codebook_bitrate,
12
+ extract_melspectrogram,
13
+ get_audio_file_bitrate,
14
+ get_duration,
15
+ load_neural_audio_codec)
16
+ from sample_visualization import tensor_to_plt
17
+ from torch.utils.data.dataloader import default_collate
18
+
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+
21
+
22
+ model_name = '2021-05-19T22-16-54_vggsound_codebook'
23
+ log_dir = './logs'
24
+ # loading the models might take a few minutes
25
+ config, model, vocoder = load_neural_audio_codec(model_name, log_dir, device)
26
+
27
+ def inference(audio):
28
+ # Select an Audio
29
+ input_wav = audio.name
30
+
31
+ # Spectrogram Extraction
32
+ model_sr = config.data.params.sample_rate
33
+ duration = get_duration(input_wav)
34
+ spec = extract_melspectrogram(input_wav, sr=model_sr, duration=duration)
35
+ print(f'Audio Duration: {duration} seconds')
36
+ print('Original Spectrogram Shape:', spec.shape)
37
+
38
+ # Prepare Input
39
+ spectrogram = {'input': spec}
40
+ batch = default_collate([spectrogram])
41
+ batch['image'] = batch['input'].to(device)
42
+ x = model.get_input(batch, 'image')
43
+
44
+ with torch.no_grad():
45
+ quant_z, diff, info = model.encode(x)
46
+ xrec = model.decode(quant_z)
47
+
48
+ print('Compressed representation (it is all you need to recover the audio):')
49
+ F, T = quant_z.shape[-2:]
50
+ print(info[2].reshape(F, T))
51
+
52
+
53
+ # Calculate Bitrate
54
+ bitrate = calculate_codebook_bitrate(duration, quant_z, model.quantize.n_e)
55
+ orig_bitrate = get_audio_file_bitrate(input_wav)
56
+
57
+ # Save and Display
58
+ x = x.squeeze(0)
59
+ xrec = xrec.squeeze(0)
60
+ # specs are in [-1, 1], making them in [0, 1]
61
+ wav_x = vocoder((x + 1) / 2).squeeze().detach().cpu().numpy()
62
+ wav_xrec = vocoder((xrec + 1) / 2).squeeze().detach().cpu().numpy()
63
+ # Creating a temp folder which will hold the results
64
+ tmp_dir = os.path.join('./tmp/neural_audio_codec', Path(input_wav).parent.stem)
65
+ os.makedirs(tmp_dir, exist_ok=True)
66
+ # Save paths
67
+ x_save_path = Path(tmp_dir) / 'vocoded_orig_spec.wav'
68
+ xrec_save_path = Path(tmp_dir) / f'specvqgan_{bitrate:.2f}kbps.wav'
69
+ # Save
70
+ soundfile.write(x_save_path, wav_x, model_sr, 'PCM_16')
71
+ soundfile.write(xrec_save_path, wav_xrec, model_sr, 'PCM_16')
72
+ return './tmp/neural_audio_codec/vocoded_orig_spec.wav', "./tmp/neural_audio_codec/"+f'specvqgan_{bitrate:.2f}kbps.wav'
73
+
74
+ title = "Anime2Sketch"
75
+ description = "demo for Anime2Sketch. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
76
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.05703'>Adversarial Open Domain Adaption for Sketch-to-Photo Synthesis</a> | <a href='https://github.com/Mukosame/Anime2Sketch'>Github Repo</a></p>"
77
+
78
+ gr.Interface(
79
+ inference,
80
+ gr.inputs.Audio(type="file", label="Input Audio"),
81
+ [gr.outputs.Audio(type="file", label="Original audio"),gr.outputs.Audio(type="file", label="Reconstructed audio")],
82
+ title=title,
83
+ description=description,
84
+ article=article,
85
+ enable_queue=True
86
+ ).launch(debug=True)