File size: 10,483 Bytes
e6346a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7854be
97b6f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98eb218
e6346a3
 
 
 
 
8505dc9
01188ff
 
98eb218
97b6f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98eb218
 
 
 
 
97b6f36
98eb218
97b6f36
 
 
98eb218
 
97b6f36
 
 
98eb218
97b6f36
f7854be
8505dc9
98eb218
 
97b6f36
a16f46b
98eb218
97b6f36
98eb218
97b6f36
 
 
 
 
 
98eb218
 
97b6f36
a16f46b
98eb218
 
97b6f36
 
 
98eb218
 
 
 
 
 
 
 
 
 
 
 
e6346a3
98eb218
 
 
 
e6346a3
 
98eb218
 
97b6f36
8505dc9
98eb218
97b6f36
 
 
 
 
 
 
 
 
 
 
 
98eb218
97b6f36
98eb218
a16f46b
98eb218
 
 
 
 
 
 
 
 
 
 
 
97b6f36
 
a16f46b
98eb218
 
 
 
 
 
 
97b6f36
 
d607f42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97b6f36
 
 
 
d607f42
 
97b6f36
 
d607f42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98eb218
97b6f36
 
 
a16f46b
97b6f36
 
e6346a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import sys
import os
# Check if running in debug mode
debug_mode = '--debug' in sys.argv or os.environ.get('DEBUG') == 'True'

if debug_mode:
    # Path to the local version of the package
    local_package_path = "../../GaMaDHaNi"
    
    # Add the local package path to sys.path
    sys.path.insert(0, local_package_path)
    
    print(f"Running in debug mode. Using package from: {local_package_path}")
else:
    print("Running in normal mode. Using package from site-packages.")

import spaces
import gradio as gr
import numpy as np
import torch
import librosa
import matplotlib.pyplot as plt
import pandas as pd
from functools import partial
import gin
import torchaudio
from absl import app
from torch.nn.functional import interpolate
import logging
import crepe
from hmmlearn import hmm
import soundfile as sf
import pdb
from gamadhani.utils.generate_utils import load_pitch_fns, load_audio_fns
import gamadhani.utils.pitch_to_audio_utils as p2a
from gamadhani.utils.utils import get_device


logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True)
pitch_path = 'models/diffusion_pitch/'
audio_path = 'models/pitch_to_audio/'
device = get_device()

def predict_voicing(confidence):
    # https://github.com/marl/crepe/pull/26
    """
    Find the Viterbi path for voiced versus unvoiced frames.
    Parameters
    ----------
    confidence : np.ndarray [shape=(N,)]
        voicing confidence array, i.e. the confidence in the presence of
        a pitch
    Returns
    -------
    voicing_states : np.ndarray [shape=(N,)]
        HMM predictions for each frames state, 0 if unvoiced, 1 if
        voiced
    """
    # uniform prior on the voicing confidence
    starting = np.array([0.5, 0.5])

    # transition probabilities inducing continuous voicing state
    transition = np.array([[0.99, 0.01], [0.01, 0.99]])

    # mean and variance for unvoiced and voiced states
    means = np.array([[0.0], [1.0]])
    variances = np.array([[0.25], [0.25]])

    # fix the model parameters because we are not optimizing the model
    model = hmm.GaussianHMM(n_components=2)
    model.startprob_, model.covars_, model.transmat_, model.means_, \
    model.n_features = starting, variances, transition, means, 1

    # find the Viterbi path
    voicing_states = model.predict(confidence.reshape(-1, 1), [len(confidence)])

    return np.array(voicing_states)

def extract_pitch(audio, unvoice=True, sr=16000, frame_shift_ms=10, log=True):
    time, frequency, confidence, _ = crepe.predict(
      audio, sr=sr,
      viterbi=True,
      step_size=frame_shift_ms,
      verbose=0 if not log else 1)
    f0 = frequency
    if unvoice:
      is_voiced = predict_voicing(confidence)
      frequency_unvoiced = frequency * is_voiced
      f0 = frequency_unvoiced

    return time, f0, confidence

def generate_pitch_reinterp(pitch, pitch_model, invert_pitch_fn, num_samples, num_steps, noise_std=0.4):
    '''Generate pitch values for the melodic reinterpretation task'''
    # hardcoding the amount of noise to be added 
    noisy_pitch = torch.Tensor(pitch[:, :, -1200:]).to(pitch_model.device) + (torch.normal(mean=0.0, std=noise_std*torch.ones((1200)))).to(pitch_model.device)
    noisy_pitch = torch.clamp(noisy_pitch, -5.19, 5.19)     # clipping the pitch values to be within the range of the model
    samples = pitch_model.sample_sdedit(noisy_pitch, num_samples, num_steps)
    inverted_pitches = [invert_pitch_fn(f0=samples.detach().cpu().numpy()[0])[0]]   # pitch values in Hz

    return samples, inverted_pitches

def generate_audio(audio_model, f0s, invert_audio_fn, singers=[3], num_steps=100):
    '''Generate audio given pitch values'''
    singer_tensor = torch.tensor(np.repeat(singers, repeats=f0s.shape[0])).to(audio_model.device)
    samples, _, singers = audio_model.sample_cfg(f0s.shape[0], f0=f0s, num_steps=num_steps, singer=singer_tensor, strength=3)
    audio = invert_audio_fn(samples)

    return audio
    
@spaces.GPU(duration=150)
def generate(pitch, num_samples=1, num_steps=100, singers=[3], outfolder='temp', audio_seq_len=750, pitch_qt=None ):
    
    logging.log(logging.INFO, 'Generate function')
    logging.log(logging.INFO, 'Generating pitch')
    pitch, inverted_pitch = generate_pitch_reinterp(pitch, pitch_model, invert_pitch_fn, num_samples=num_samples, num_steps=100)
    if pitch_qt is not None:
        # if there is not pitch quantile transformer, undo the default quantile transformation that occurs
        def undo_qt(x, min_clip=200):
            pitch= pitch_qt.inverse_transform(x.reshape(-1, 1)).reshape(1, -1)
            pitch = np.around(pitch) # round to nearest integer, done in preprocessing of pitch contour fed into model
            pitch[pitch < 200] = np.nan
            return pitch
        pitch = torch.tensor(np.array([undo_qt(x) for x in pitch.detach().cpu().numpy()])).to(pitch_model.device)
    interpolated_pitch = p2a.interpolate_pitch(pitch=pitch, audio_seq_len=audio_seq_len)    # interpolate pitch values to match the audio model's input size
    interpolated_pitch = torch.nan_to_num(interpolated_pitch, nan=196)  # replace nan values with silent token
    interpolated_pitch = interpolated_pitch.squeeze(1) # to match input size by removing the extra dimension
    logging.log(logging.INFO, 'Generating audio')
    audio = generate_audio(audio_model, interpolated_pitch, invert_audio_fn, singers=singers, num_steps=100)
    audio = audio.detach().cpu().numpy()
    pitch = pitch.detach().cpu().numpy()
    pitch_vals = np.where(pitch[0][:, 0] == 0, np.nan, pitch[0].flatten())

    # generate plot of model output to display on interface
    model_output_plot = plt.figure()
    plt.plot(pitch_vals, figure=model_output_plot, label='Model Output')
    plt.close(model_output_plot)
    return (16000, audio[0]), model_output_plot, pitch_vals

# pdb.set_trace()
pitch_model, pitch_qt, pitch_task_fn, invert_pitch_fn, _ = load_pitch_fns(
    os.path.join(pitch_path, 'last.ckpt'), \
    model_type = 'diffusion', \
    config_path = os.path.join(pitch_path, 'config.gin'), \
    qt_path = os.path.join(pitch_path, 'qt.joblib'), \
    device = device
)
audio_model, audio_qt, audio_seq_len, invert_audio_fn = load_audio_fns(
    os.path.join(audio_path, 'last.ckpt'),
    qt_path = os.path.join(audio_path, 'qt.joblib'),
    config_path = os.path.join(audio_path, 'config.gin'),
    device = device
)
partial_generate = partial(generate, num_samples=1, num_steps=100, singers=[3], outfolder=None, pitch_qt=pitch_qt)  # generate function with default arguments

@spaces.GPU(duration=150)
def set_guide_and_generate(audio):
    global selected_prime, pitch_task_fn
    
    if audio is None:
        return None, None
    sr, audio = audio
    if len(audio) < 12*sr:
        audio = np.pad(audio, (0, 12*sr - len(audio)), mode='constant')
    
    audio = audio.astype(np.float32)
    audio /= np.max(np.abs(audio))
    audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) # convert only last 4 s
    mic_audio = audio.copy()
    audio = audio[-12*16000:] # consider only last 12 s
    _, f0, _ = extract_pitch(audio)
    mic_f0 = f0.copy() # save the user input pitch values
    logging.log(logging.INFO, 'Pitch extracted')
    f0 = pitch_task_fn(**{
        'inputs': {
            'pitch': {
                'data': torch.Tensor(f0), # task function expects a tensor
                'sampling_rate': 100
                }
        }, 
        'qt_transform': pitch_qt,
        'time_downsample': 1, # pitch will be extracted at 100 Hz, thus no downsampling
        'seq_len': None,
    })['sampled_sequence']
    # pdb.set_trace()
    f0 = f0.reshape(1, 1, -1)
    f0 = torch.tensor(f0).to(pitch_model.device).float()
    logging.log(logging.INFO, 'Calling generate function')
    audio, pitch, _ = partial_generate(f0)
    mic_f0 = np.where(mic_f0 == 0, np.nan, mic_f0)
    # plot user input
    user_input_plot = plt.figure()
    plt.plot(np.arange(0, len(mic_f0)), mic_f0, label='User Input', figure=user_input_plot)
    plt.close(user_input_plot)
    return audio, user_input_plot, pitch

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("""
                    # GaMaDHaNi: HIERARCHICAL GENERATIVE MODELING OF MELODIC VOCAL CONTOURS IN HINDUSTANI CLASSICAL MUSIC
                    :book: Read more about the project [here](https://arxiv.org/pdf/2408.12658) <br>
                    :samples: Listen to the samples [here](https://snnithya.github.io/gamadhani-samples) <br>
                    # """)
        gr.Markdown("""
                    ## Instructions
                    In this demo you can interact with the model in two ways: 
                    1. **Call and response**: The model will try to continue the idea that you input. This is similar to `primed generation' discussed in the paper.
                    2. **Melodic reinterpretation**: Akin to the idea of `coarse pitch conditioning' presented in the paper, you can input a pitch contour and the model will generate audio that is similar to but not exactly the same. <br><br>
                    **Upload an audio file or record your voice to get started!**
                    """)
        gr.Markdown("""
                    This is still a work in progress, so please feel free to share any weird or interesting examples, we would love to hear them! Contact us at [snnithya.mit.edu](mailto:snnithya.mit.edu).
                    """)

    with gr.Row():
        with gr.Column():    
            audio = gr.Audio(label="Input")
            sbmt = gr.Button()
            with gr.Accordion("View Pitch Plot"):
                user_input = gr.Plot(label="User Input")
        with gr.Column():
            generated_audio = gr.Audio(label="Generated Audio")
            with gr.Accordion("View Pitch Plot"):
                generated_pitch = gr.Plot(label="Generated Pitch")
    example_description = gr.Textbox(label="Example Description", interactive=False)
    examples = gr.Examples(
        examples=[
            ["examples/ex1.wav"],
            ["examples/ex2.wav"],
            ["examples/ex3.wav"],
            ["examples/ex4.wav"],
            ["examples/ex5.wav"]
            # Add more examples as needed
        ],
        inputs=audio
    )

    sbmt.click(set_guide_and_generate, inputs=[audio], outputs=[generated_audio, user_input, generated_pitch])

def main(argv):
    
    demo.launch()

if __name__ == '__main__':
    main(sys.argv)