ernestchu commited on
Commit
32bac05
1 Parent(s): 19b2e5e
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ flagged
2
+ __pycache__
3
+ .DS_Store
4
+ *.swp
5
+ *.egg-info
6
+ build
7
+
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tsmnet import Stretcher
3
+ import gradio as gr
4
+ from gradio import processing_utils
5
+ import torch
6
+ import torchaudio
7
+
8
+ model_root = './weights'
9
+ available_models = ['general', 'pop-music', 'classical-music', 'speech']
10
+ working_sr = 22050
11
+
12
+ def prepare_models():
13
+ return {
14
+ weight: Stretcher(os.path.join(model_root, f'{weight}.pt'))
15
+ for weight in available_models
16
+ }
17
+
18
+ def prepare_audio_file(rec, audio_file, yt_url):
19
+ if rec is not None:
20
+ return rec
21
+ if audio_file is not None:
22
+ return audio_file
23
+ if yt_url != '':
24
+ pass
25
+ else:
26
+ raise gr.Error('No audio found!')
27
+
28
+
29
+ def run(rec, audio_file, yt_url, speed, model, start_time, end_time):
30
+ audio_file = prepare_audio_file(rec, audio_file, yt_url)
31
+ if speed == 1:
32
+ return processing_utils.audio_from_file(audio_file)
33
+
34
+ model = models[model]
35
+
36
+ x, sr = torchaudio.load(audio_file)
37
+ x = torchaudio.transforms.Resample(orig_freq=sr, new_freq=working_sr)(x)
38
+ sr = working_sr
39
+
40
+ x = model(x, speed).cpu()
41
+
42
+ torchaudio.save(audio_file, x, sr)
43
+
44
+ return processing_utils.audio_from_file(audio_file)
45
+
46
+
47
+ # @@@@@@@ Start of the program @@@@@@@@
48
+
49
+ models = prepare_models()
50
+
51
+ with gr.Blocks() as demo:
52
+ gr.Markdown('# TSM-Net')
53
+ gr.Markdown('---')
54
+ with gr.Row():
55
+ with gr.Column():
56
+ with gr.Tab('From microphone'):
57
+ rec_box = gr.Audio(label='Recording', source='microphone', type='filepath')
58
+ with gr.Tab('From file'):
59
+ audio_file_box = gr.Audio(label='Audio sample', type='filepath')
60
+ with gr.Tab('From YouTube'):
61
+ yt_url_box = gr.Textbox(label='YouTube URL', placeholder='Under Construction', interactive=False)
62
+
63
+ rec_box.change(lambda: [None] * 2, outputs=[audio_file_box, yt_url_box])
64
+ audio_file_box.change(lambda: [None] * 2, outputs=[rec_box, yt_url_box])
65
+ yt_url_box.input(lambda: [None] * 2, outputs=[rec_box, audio_file_box])
66
+
67
+ speed_box = gr.Slider(label='Playback speed', minimum=0, maximum=2, value=1)
68
+ with gr.Accordion('Fine-grained settings', open=False):
69
+ with gr.Row():
70
+ gr.Textbox(label='', value='Trim audio sample', interactive=False)
71
+ start_time_box = gr.Number(label='Start', value=0)
72
+ end_time_box = gr.Number(label='End', value=20)
73
+
74
+ model_box = gr.Dropdown(label='Model weight', choices=available_models, value=available_models[0])
75
+
76
+ submit_btn = gr.Button('Submit')
77
+
78
+ with gr.Column():
79
+ with gr.Accordion('Hint', open=False):
80
+ gr.Markdown('You can find more settings under the **Fine-grained settings**')
81
+ gr.Markdown('- Feeling slow? Try to adjust the start/end timestamp')
82
+ gr.Markdown('- Low audio quality? Try to switch to a proper model weight')
83
+ outputs=gr.Audio(label='Output')
84
+
85
+ submit_btn.click(fn=run, inputs=[
86
+ rec_box,
87
+ audio_file_box,
88
+ yt_url_box,
89
+ speed_box,
90
+ model_box,
91
+ start_time_box,
92
+ end_time_box,
93
+ ], outputs=outputs)
94
+
95
+ with gr.Accordion('Read more ...', open=False):
96
+ gr.Markdown('---')
97
+ gr.Markdown(
98
+ 'We proposed a novel approach in the field of time-scale modification '
99
+ 'on audio signals. While traditional methods use the framing technique, '
100
+ 'spectral approach uses the short-time Fourier transform to preserve '
101
+ 'the frequency during temporal stretching. TSM-Net, our neural-network '
102
+ 'model encodes the raw audio into a high-level latent representation. '
103
+ 'We call it Neuralgram, in which one vector represents 1024 audio samples. '
104
+ 'It is inspired by the framing technique but addresses the clipping '
105
+ 'artifacts. The Neuralgram is a two-dimensional matrix with real values, '
106
+ 'we can apply some existing image resizing techniques on the Neuralgram '
107
+ 'and decode it using our neural decoder to obtain the time-scaled audio. '
108
+ 'Both the encoder and decoder are trained with GANs, which shows fair '
109
+ 'generalization ability on the scaled Neuralgrams. Our method yields '
110
+ 'little artifacts and opens a new possibility in the research of modern '
111
+ 'time-scale modification. Please find more detail in our '
112
+ '<a href="https://arxiv.org/abs/2210.17152" target="_blank">paper</a>.'
113
+ )
114
+
115
+ demo.queue(4)
116
+ demo.launch(server_name='0.0.0.0')
117
+
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ./tsmnet
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+ yt-dlp
6
+ wget
7
+
tsmnet/setup.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ setup(
4
+ name='tsmnet',
5
+ version='1.0.0',
6
+ packages=['tsmnet'],
7
+ )
tsmnet/tsmnet/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from tsmnet.interface import load_model, Neuralgram, Stretcher
tsmnet/tsmnet/dataset.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+
6
+
7
+ from pathlib import Path
8
+ import numpy as np
9
+ import random
10
+
11
+
12
+ def files_to_list(filename):
13
+ """
14
+ Takes a text file of filenames and makes a list of filenames
15
+ """
16
+ with open(filename, encoding="utf-8") as f:
17
+ files = f.readlines()
18
+
19
+ files = [f.rstrip() for f in files]
20
+ return files
21
+
22
+
23
+ class AudioDataset(torch.utils.data.Dataset):
24
+ """
25
+ This is the main class that calculates the spectrogram and returns the
26
+ spectrogram, audio pair.
27
+ """
28
+
29
+ def __init__(self, training_files, segment_length, sampling_rate, augment=True):
30
+ self.sampling_rate = sampling_rate
31
+ self.segment_length = segment_length
32
+ self.audio_files = files_to_list(training_files)
33
+ self.audio_files = [Path(training_files).parent / x for x in self.audio_files]
34
+ random.seed(1234)
35
+ random.shuffle(self.audio_files)
36
+ self.augment = augment
37
+
38
+ def __getitem__(self, index):
39
+ # Read audio
40
+ filename = self.audio_files[index]
41
+ try:
42
+ audio, sampling_rate = self.load_wav_to_torch(filename)
43
+ except RuntimeError:
44
+ # there's lots of corrupted files in FMA
45
+ print(f'Found corrupted file: {filename}, use empty data instead')
46
+ audio = torch.tensor([])
47
+ # Take segment
48
+ if audio.size(0) >= self.segment_length:
49
+ max_audio_start = audio.size(0) - self.segment_length
50
+ audio_start = random.randint(0, max_audio_start)
51
+ audio = audio[audio_start : audio_start + self.segment_length]
52
+ else:
53
+ audio = F.pad(
54
+ audio, (0, self.segment_length - audio.size(0)), "constant"
55
+ ).data
56
+
57
+ # audio = audio / 32768.0
58
+ return audio.unsqueeze(0)
59
+
60
+ def __len__(self):
61
+ return len(self.audio_files)
62
+
63
+ def load_wav_to_torch(self, full_path):
64
+ """
65
+ Loads audio into torch array
66
+ """
67
+ data, sampling_rate = torchaudio.load(str(full_path))
68
+ data = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=self.sampling_rate)(data)
69
+ sampling_rate = self.sampling_rate
70
+
71
+ if len(data.shape) > 1:
72
+ # convert to mono
73
+ data = data[random.randint(0, data.shape[0]-1)]
74
+
75
+ if self.augment:
76
+ amplitude = np.random.uniform(low=0.3, high=1.0)
77
+ data = data * amplitude
78
+
79
+ return data.float(), sampling_rate
tsmnet/tsmnet/interface.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tsmnet.modules import Autoencoder
2
+
3
+ from torchvision.transforms.functional import resize
4
+ from torchvision.transforms import InterpolationMode
5
+ from pathlib import Path
6
+ import yaml
7
+ import torch
8
+ import os
9
+
10
+
11
+ def get_default_device():
12
+ if torch.cuda.is_available():
13
+ return "cuda"
14
+ else:
15
+ return "cpu"
16
+
17
+
18
+ def load_model(path, device=get_default_device()):
19
+ """
20
+ Args:
21
+ mel2wav_path (str or Path): path to the root folder of dumped text2mel
22
+ device (str or torch.device): device to load the model
23
+ """
24
+ root = Path(path)
25
+ with open(os.path.join(os.path.dirname(path), "args.yml"), "r") as f:
26
+ args = yaml.unsafe_load(f)
27
+ netA = Autoencoder([int(n) for n in args.compress_ratios], args.ngf, args.n_residual_layers).to(device)
28
+ netA.load_state_dict(torch.load(path, map_location=device))
29
+ return netA
30
+
31
+
32
+ class Neuralgram:
33
+ def __init__(
34
+ self,
35
+ path,
36
+ device=None,
37
+ ):
38
+ if device is None:
39
+ device = get_default_device()
40
+ self.device = device
41
+ self.netA = load_model(path, device)
42
+
43
+ def __call__(self, audio):
44
+ """
45
+ Performs audio to neuralgram conversion (See Autoencoder.encoder in tsmnet/modules.py)
46
+ Args:
47
+ audio (torch.tensor): PyTorch tensor containing audio (batch_size, timesteps)
48
+ Returns:
49
+ torch.tensor: neuralgram computed on input audio (batch_size, channels, timesteps)
50
+ """
51
+ with torch.no_grad():
52
+ return self.netA.encoder(torch.as_tensor(audio).unsqueeze(1).to(self.device))
53
+
54
+ def inverse(self, neu):
55
+ """
56
+ Performs neuralgram to audio conversion
57
+ Args:
58
+ neu (torch.tensor): PyTorch tensor containing neuralgram (batch_size, channels, timesteps)
59
+ Returns:
60
+ torch.tensor: Inverted raw audio (batch_size, timesteps)
61
+
62
+ """
63
+ with torch.no_grad():
64
+ return self.netA.decoder(neu.to(self.device)).squeeze(1)
65
+
66
+ class Stretcher:
67
+ def __init__(self, path, device=None):
68
+ self.neuralgram = Neuralgram(path, device)
69
+
70
+ @torch.no_grad()
71
+ def __call__(self, audio, rate , interpolation=InterpolationMode.NEAREST): # NEAREST | BILINEAR | BICUBIC
72
+ if rate == 1:
73
+ return audio.numpy() if isinstance(audio, torch.Tensor) else audio
74
+ neu = self.neuralgram(audio)
75
+ neu_resized = resize(
76
+ neu,
77
+ (*neu.shape[1:-1], int(neu.shape[-1] * (1/rate))),
78
+ interpolation
79
+ )
80
+ return self.neuralgram.inverse(neu_resized)
tsmnet/tsmnet/modules.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ from torch.nn.utils import weight_norm
5
+ import numpy as np
6
+
7
+
8
+ def weights_init(m):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ m.weight.data.normal_(0.0, 0.02)
12
+ elif classname.find("BatchNorm2d") != -1:
13
+ m.weight.data.normal_(1.0, 0.02)
14
+ m.bias.data.fill_(0)
15
+
16
+
17
+ def WNConv1d(*args, **kwargs):
18
+ return weight_norm(nn.Conv1d(*args, **kwargs))
19
+
20
+
21
+ def WNConvTranspose1d(*args, **kwargs):
22
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
23
+
24
+ class ResnetBlock(nn.Module):
25
+ def __init__(self, dim, dilation=1):
26
+ super().__init__()
27
+ self.block = nn.Sequential(
28
+ nn.Tanh(),
29
+ nn.ReflectionPad1d(dilation),
30
+ WNConv1d(dim, dim, kernel_size=3, dilation=dilation),
31
+ nn.Tanh(),
32
+ WNConv1d(dim, dim, kernel_size=1),
33
+ )
34
+ self.shortcut = WNConv1d(dim, dim, kernel_size=1)
35
+
36
+ def forward(self, x):
37
+ return self.shortcut(x) + self.block(x)
38
+
39
+ class Autoencoder(nn.Module):
40
+ def __init__(self, compress_ratios, ngf, n_residual_layers):
41
+ super().__init__()
42
+
43
+ self.encoder = self.makeEncoder(compress_ratios, ngf, n_residual_layers)
44
+ self.decoder = self.makeDecoder([r for r in reversed(compress_ratios)], ngf, n_residual_layers)
45
+
46
+ self.apply(weights_init)
47
+
48
+ def makeEncoder(self, ratios, ngf, n_residual_layers):
49
+ mult = 1
50
+
51
+ model = [
52
+ nn.ReflectionPad1d(3),
53
+ WNConv1d(1, ngf, kernel_size=7, padding=0),
54
+ nn.Tanh(),
55
+ ]
56
+
57
+ # Downsample to neuralgram scale
58
+ for i, r in enumerate(ratios):
59
+ mult *= 2
60
+
61
+ for j in range(n_residual_layers-1, -1, -1):
62
+ model += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)]
63
+
64
+ model += [
65
+ nn.Tanh(),
66
+ WNConv1d(
67
+ mult * ngf // 2,
68
+ mult * ngf,
69
+ kernel_size=r * 2,
70
+ stride=r,
71
+ padding=r // 2 + r % 2
72
+ ),
73
+ ]
74
+
75
+ model += [ nn.Tanh() ]
76
+
77
+ return nn.Sequential(*model)
78
+ def makeDecoder(self, ratios, ngf, n_residual_layers):
79
+ mult = int(2 ** len(ratios))
80
+
81
+ model = []
82
+
83
+ # Upsample to raw audio scale
84
+ for i, r in enumerate(ratios):
85
+ model += [
86
+ nn.Tanh(),
87
+ WNConvTranspose1d(
88
+ mult * ngf,
89
+ mult * ngf // 2,
90
+ kernel_size=r * 2,
91
+ stride=r,
92
+ padding=r // 2 + r % 2,
93
+ output_padding=r % 2
94
+ ),
95
+ ]
96
+
97
+ for j in range(n_residual_layers):
98
+ model += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)]
99
+
100
+ mult //= 2
101
+
102
+ model += [
103
+ nn.Tanh(),
104
+ nn.ReflectionPad1d(3),
105
+ WNConv1d(ngf, 1, kernel_size=7, padding=0),
106
+ nn.Tanh(),
107
+ ]
108
+
109
+ return nn.Sequential(*model)
110
+
111
+ def forward(self, x):
112
+ return self.decoder(self.encoder(x))
113
+
114
+ class NLayerDiscriminator(nn.Module):
115
+ def __init__(self, ndf, n_layers, downsampling_factor):
116
+ super().__init__()
117
+ model = nn.ModuleDict()
118
+
119
+ model["layer_0"] = nn.Sequential(
120
+ nn.ReflectionPad1d(7),
121
+ WNConv1d(1, ndf, kernel_size=15),
122
+ nn.Tanh(),
123
+ )
124
+
125
+ nf = ndf
126
+ stride = downsampling_factor
127
+ for n in range(1, n_layers + 1):
128
+ nf_prev = nf
129
+ nf = min(nf * stride, 1024)
130
+
131
+ model["layer_%d" % n] = nn.Sequential(
132
+ WNConv1d(
133
+ nf_prev,
134
+ nf,
135
+ kernel_size=stride * 10 + 1,
136
+ stride=stride,
137
+ padding=stride * 5,
138
+ groups=nf_prev // 4,
139
+ ),
140
+ nn.Tanh(),
141
+ )
142
+
143
+ nf = min(nf * 2, 1024)
144
+ model["layer_%d" % (n_layers + 1)] = nn.Sequential(
145
+ WNConv1d(nf_prev, nf, kernel_size=5, stride=1, padding=2),
146
+ nn.Tanh(),
147
+ )
148
+
149
+ model["layer_%d" % (n_layers + 2)] = WNConv1d(
150
+ nf, 1, kernel_size=3, stride=1, padding=1
151
+ )
152
+
153
+ self.model = model
154
+
155
+ def forward(self, x):
156
+ results = []
157
+ for key, layer in self.model.items():
158
+ x = layer(x)
159
+ results.append(x)
160
+ return results
161
+
162
+
163
+ class Discriminator(nn.Module):
164
+ def __init__(self, num_D, ndf, n_layers, downsampling_factor):
165
+ super().__init__()
166
+ self.model = nn.ModuleDict()
167
+ for i in range(num_D):
168
+ self.model[f"disc_{i}"] = NLayerDiscriminator(
169
+ ndf, n_layers, downsampling_factor
170
+ )
171
+
172
+ self.downsample = nn.AvgPool1d(4, stride=2, padding=1, count_include_pad=False)
173
+ self.apply(weights_init)
174
+
175
+ def forward(self, x):
176
+ results = []
177
+ for key, disc in self.model.items():
178
+ results.append(disc(x))
179
+ x = self.downsample(x)
180
+ return results
tsmnet/tsmnet/utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import scipy.io.wavfile
2
+
3
+
4
+ def save_sample(file_path, sampling_rate, audio):
5
+ """Helper function to save sample
6
+
7
+ Args:
8
+ file_path (str or pathlib.Path): save file path
9
+ sampling_rate (int): sampling rate of audio (usually 22050)
10
+ audio (torch.FloatTensor): torch array containing audio in [-1, 1]
11
+ """
12
+ audio = (audio.numpy() * 32768).astype("int16")
13
+ scipy.io.wavfile.write(file_path, sampling_rate, audio)
weights/args.yml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !!python/object:argparse.Namespace
2
+ batch_size: 2
3
+ compress_ratios: '22488'
4
+ cond_disc: false
5
+ data_path: !!python/object/apply:pathlib.PosixPath
6
+ - /
7
+ - home
8
+ - b073040018
9
+ - Datasets
10
+ downsamp_factor: 4
11
+ epochs: 3000
12
+ lambda_feat: 10
13
+ load_path: logs-all/weights
14
+ log_interval: 100
15
+ n_layers_D: 4
16
+ n_residual_layers: 1
17
+ n_test_samples: 8
18
+ ndf: 16
19
+ ngf: 32
20
+ num_D: 3
21
+ project: tsmnet-all
22
+ save_interval: 1000
23
+ save_path: logs-all2
24
+ seq_len: 8192
weights/classical-music.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c736e5c7414354ad2789b4d8dd6d3ab2d5813f52fb0982818a4fff8887d2eeba
3
+ size 100400811
weights/general.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e70b0ca672ab2008da3517ae3eb524135a1ef5685d59cc034084316a665f69f6
3
+ size 100400920
weights/pop-music.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3010d34e0d538ecb4c63c8bc89ad4023630dc36e2746bb71b799026d2b03ad4
3
+ size 100400898
weights/speech.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e29674ce2312e1ba8f9071348de84031e8afbb08412cbc8088b7365f2162f497
3
+ size 100400879