cwitkowitz commited on
Commit
c1112a0
1 Parent(s): 106218a

Updated app in accordance with timbre-trap updates and chunk-based processing.

Browse files
Files changed (3) hide show
  1. app.py +22 -16
  2. tt-demo.pt → models/tt-orig.pt +2 -2
  3. requirements.txt +1 -1
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from pyharp import ModelCard, build_endpoint
2
 
3
  import gradio as gr
@@ -5,7 +6,19 @@ import torchaudio
5
  import torch
6
  import os
7
 
8
- timbre_trap = torch.load('tt-demo.pt', map_location='cpu')
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  card = ModelCard(
11
  name='Timbre-Trap',
@@ -26,28 +39,20 @@ def process_fn(audio_path, de_timbre):
26
  audio = audio.unsqueeze(0)
27
  # Determine original number of samples
28
  n_samples = audio.size(-1)
29
- # Pad audio to next multiple of block length
30
- audio = timbre_trap.sliCQ.pad_to_block_length(audio)
31
 
32
- # Encode raw audio into latent vectors
33
- latents, embeddings, _ = timbre_trap.encode(audio)
34
- # Apply skip connections if they are turned on
35
- embeddings = timbre_trap.apply_skip_connections(embeddings)
36
  # Obtain transcription or reconstructed spectral coefficients
37
- coefficients = timbre_trap.decode(latents, embeddings, de_timbre)
 
38
 
39
- # Invert reconstructed spectral coefficients
40
- audio = timbre_trap.sliCQ.decode(coefficients)
41
  # Trim to original number of samples
42
  audio = audio[..., :n_samples]
43
  # Remove batch dimension
44
  audio = audio.squeeze(0)
45
 
46
- if de_timbre and audio.abs().max():
47
- # Low-pass filter the audio to remove ringing
48
- audio = torchaudio.functional.lowpass_biquad(audio, 22050, 8000)
49
- # Normalize audio to [-1, 1]
50
- audio /= audio.abs().max()
51
 
52
  # Resample audio back to the original sampling rate
53
  audio = torchaudio.functional.resample(audio, 22050, fs)
@@ -62,6 +67,7 @@ def process_fn(audio_path, de_timbre):
62
  return save_path
63
 
64
 
 
65
  with gr.Blocks() as demo:
66
  inputs = [
67
  gr.Audio(
@@ -81,8 +87,8 @@ with gr.Blocks() as demo:
81
  )
82
  ]
83
 
 
84
  output = gr.Audio(label='Audio Output', type='filepath')
85
-
86
  widgets = build_endpoint(inputs, output, process_fn, card)
87
 
88
  demo.queue()
 
1
+ from timbre_trap.framework.modules import TimbreTrap
2
  from pyharp import ModelCard, build_endpoint
3
 
4
  import gradio as gr
 
6
  import torch
7
  import os
8
 
9
+
10
+ model = TimbreTrap(sample_rate=22050,
11
+ n_octaves=9,
12
+ bins_per_octave=60,
13
+ secs_per_block=3,
14
+ latent_size=128,
15
+ model_complexity=2,
16
+ skip_connections=False)
17
+ model.eval()
18
+
19
+ model_path_orig = os.path.join('models', 'tt-orig.pt')
20
+ tt_weights_orig = torch.load(model_path_orig, map_location='cpu')
21
+ model.load_state_dict(tt_weights_orig)
22
 
23
  card = ModelCard(
24
  name='Timbre-Trap',
 
39
  audio = audio.unsqueeze(0)
40
  # Determine original number of samples
41
  n_samples = audio.size(-1)
 
 
42
 
 
 
 
 
43
  # Obtain transcription or reconstructed spectral coefficients
44
+ coefficients = model.chunked_inference(audio, de_timbre)
45
+ #coefficients = model.inference(audio, de_timbre)
46
 
47
+ # Invert coefficients to produce audio
48
+ audio = model.sliCQ.decode(coefficients)
49
  # Trim to original number of samples
50
  audio = audio[..., :n_samples]
51
  # Remove batch dimension
52
  audio = audio.squeeze(0)
53
 
54
+ # Low-pass filter the audio in attempt to remove artifacts
55
+ audio = torchaudio.functional.lowpass_biquad(audio, 22050, 8000)
 
 
 
56
 
57
  # Resample audio back to the original sampling rate
58
  audio = torchaudio.functional.resample(audio, 22050, fs)
 
67
  return save_path
68
 
69
 
70
+ # Build Gradio endpoint
71
  with gr.Blocks() as demo:
72
  inputs = [
73
  gr.Audio(
 
87
  )
88
  ]
89
 
90
+ # Build endpoint
91
  output = gr.Audio(label='Audio Output', type='filepath')
 
92
  widgets = build_endpoint(inputs, output, process_fn, card)
93
 
94
  demo.queue()
tt-demo.pt → models/tt-orig.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2f4575c6642348eda3d2e7ff280eece5036e5922e0dacfd25e8dfeb10fd52842
3
- size 11399295
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c3bafd33a831d61e8ee9051d6c5b4c5d483e6a7669ca9df85ac6ab304cb9fe3
3
+ size 11353410
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  -e git+https://github.com/audacitorch/pyharp.git#egg=pyharp
2
- -e git+https://github.com/sony/timbre-trap.git@release#egg=timbre-trap
3
  torchaudio
4
  torch
5
  cqt_pytorch
 
1
  -e git+https://github.com/audacitorch/pyharp.git#egg=pyharp
2
+ -e git+https://github.com/sony/timbre-trap.git@updates#egg=timbre-trap
3
  torchaudio
4
  torch
5
  cqt_pytorch