Spaces:
Sleeping
Sleeping
import gradio as gr | |
from data_generation.nsynth import get_nsynth_dataloader | |
from webUI.natural_language_guided_STFT.utils import encodeBatch2GradioOutput_STFT, InputBatch2Encode_STFT, \ | |
latent_representation_to_Gradio_image | |
def get_recSTFT_module(gradioWebUI, reconstruction_state): | |
# Load configurations | |
uNet = gradioWebUI.uNet | |
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution | |
VAE_scale = gradioWebUI.VAE_scale | |
height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels | |
timesteps = gradioWebUI.timesteps | |
VAE_quantizer = gradioWebUI.VAE_quantizer | |
VAE_encoder = gradioWebUI.VAE_encoder | |
VAE_decoder = gradioWebUI.VAE_decoder | |
CLAP = gradioWebUI.CLAP | |
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer | |
device = gradioWebUI.device | |
squared = gradioWebUI.squared | |
sample_rate = gradioWebUI.sample_rate | |
noise_strategy = gradioWebUI.noise_strategy | |
def generate_reconstruction_samples(sample_source, batchsize_slider, encodeCache, | |
reconstruction_samples): | |
vae_batchsize = int(batchsize_slider) | |
if sample_source == "text2sound_trainSTFT": | |
training_dataset_path = f'data/NSynth/nsynth-STFT-train-52.hdf5' # Make sure to use your actual path | |
iterator = get_nsynth_dataloader(training_dataset_path, batch_size=vae_batchsize, shuffle=True, | |
get_latent_representation=False, with_meta_data=False, | |
task="STFT") | |
elif sample_source == "text2sound_validSTFT": | |
training_dataset_path = f'data/NSynth/nsynth-STFT-valid-52.hdf5' # Make sure to use your actual path | |
iterator = get_nsynth_dataloader(training_dataset_path, batch_size=vae_batchsize, shuffle=True, | |
get_latent_representation=False, with_meta_data=False, | |
task="STFT") | |
elif sample_source == "text2sound_testSTFT": | |
training_dataset_path = f'data/NSynth/nsynth-STFT-test-52.hdf5' # Make sure to use your actual path | |
iterator = get_nsynth_dataloader(training_dataset_path, batch_size=vae_batchsize, shuffle=True, | |
get_latent_representation=False, with_meta_data=False, | |
task="STFT") | |
else: | |
raise NotImplementedError() | |
spectrogram_batch = next(iter(iterator)) | |
origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, latent_representations, quantized_latent_representations = InputBatch2Encode_STFT( | |
VAE_encoder, spectrogram_batch, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, squared=squared) | |
latent_representation_gradio_images, quantized_latent_representation_gradio_images = [], [] | |
for i in range(vae_batchsize): | |
latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i])) | |
quantized_latent_representation_gradio_images.append( | |
latent_representation_to_Gradio_image(quantized_latent_representations[i])) | |
if quantized_latent_representations is None: | |
quantized_latent_representations = latent_representations | |
reconstruction_flipped_log_spectrums, reconstruction_flipped_phases, reconstruction_signals, reconstruction_flipped_log_spectrums_WOA, reconstruction_flipped_phases_WOA, reconstruction_signals_WOA = encodeBatch2GradioOutput_STFT(VAE_decoder, | |
quantized_latent_representations, | |
resolution=( | |
512, | |
width * VAE_scale), | |
original_STFT_batch=spectrogram_batch | |
) | |
reconstruction_samples["origin_flipped_log_spectrums"] = origin_flipped_log_spectrums | |
reconstruction_samples["origin_flipped_phases"] = origin_flipped_phases | |
reconstruction_samples["origin_signals"] = origin_signals | |
reconstruction_samples["latent_representation_gradio_images"] = latent_representation_gradio_images | |
reconstruction_samples[ | |
"quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images | |
reconstruction_samples[ | |
"reconstruction_flipped_log_spectrums"] = reconstruction_flipped_log_spectrums | |
reconstruction_samples[ | |
"reconstruction_flipped_phases"] = reconstruction_flipped_phases | |
reconstruction_samples["reconstruction_signals"] = reconstruction_signals | |
reconstruction_samples[ | |
"reconstruction_flipped_log_spectrums_WOA"] = reconstruction_flipped_log_spectrums_WOA | |
reconstruction_samples[ | |
"reconstruction_flipped_phases_WOA"] = reconstruction_flipped_phases_WOA | |
reconstruction_samples["reconstruction_signals_WOA"] = reconstruction_signals_WOA | |
reconstruction_samples["sampleRate"] = sample_rate | |
latent_representation_gradio_image = reconstruction_samples["latent_representation_gradio_images"][0] | |
quantized_latent_representation_gradio_image = \ | |
reconstruction_samples["quantized_latent_representation_gradio_images"][0] | |
origin_flipped_log_spectrum = reconstruction_samples["origin_flipped_log_spectrums"][0] | |
origin_flipped_phase = reconstruction_samples["origin_flipped_phases"][0] | |
origin_signal = reconstruction_samples["origin_signals"][0] | |
reconstruction_flipped_log_spectrum = reconstruction_samples["reconstruction_flipped_log_spectrums"][0] | |
reconstruction_flipped_phase = reconstruction_samples["reconstruction_flipped_phases"][0] | |
reconstruction_signal = reconstruction_samples["reconstruction_signals"][0] | |
reconstruction_flipped_log_spectrum_WOA = reconstruction_samples["reconstruction_flipped_log_spectrums_WOA"][0] | |
reconstruction_flipped_phase_WOA = reconstruction_samples["reconstruction_flipped_phases_WOA"][0] | |
reconstruction_signal_WOA = reconstruction_samples["reconstruction_signals_WOA"][0] | |
return {origin_amplitude_image_output: origin_flipped_log_spectrum, | |
origin_phase_image_output: origin_flipped_phase, | |
origin_audio_output: (sample_rate, origin_signal), | |
latent_representation_image_output: latent_representation_gradio_image, | |
quantized_latent_representation_image_output: quantized_latent_representation_gradio_image, | |
reconstruction_amplitude_image_output: reconstruction_flipped_log_spectrum, | |
reconstruction_phase_image_output: reconstruction_flipped_phase, | |
reconstruction_audio_output: (sample_rate, reconstruction_signal), | |
reconstruction_amplitude_image_output_WOA: reconstruction_flipped_log_spectrum_WOA, | |
reconstruction_phase_image_output_WOA: reconstruction_flipped_phase_WOA, | |
reconstruction_audio_output_WOA: (sample_rate, reconstruction_signal_WOA), | |
sample_index_slider: gr.update(minimum=0, maximum=vae_batchsize - 1, value=0, step=1.0, | |
label="Sample index.", | |
info="Slide to view other samples", scale=1, visible=True), | |
reconstruction_state: encodeCache, | |
reconstruction_samples_state: reconstruction_samples} | |
def show_reconstruction_sample(sample_index, encodeCache_state, reconstruction_samples_state): | |
sample_index = int(sample_index) | |
sampleRate = reconstruction_samples_state["sampleRate"] | |
latent_representation_gradio_image = reconstruction_samples_state["latent_representation_gradio_images"][ | |
sample_index] | |
quantized_latent_representation_gradio_image = \ | |
reconstruction_samples_state["quantized_latent_representation_gradio_images"][sample_index] | |
origin_flipped_log_spectrum = reconstruction_samples_state["origin_flipped_log_spectrums"][sample_index] | |
origin_flipped_phase = reconstruction_samples_state["origin_flipped_phases"][sample_index] | |
origin_signal = reconstruction_samples_state["origin_signals"][sample_index] | |
reconstruction_flipped_log_spectrum = reconstruction_samples_state["reconstruction_flipped_log_spectrums"][ | |
sample_index] | |
reconstruction_flipped_phase = reconstruction_samples_state["reconstruction_flipped_phases"][ | |
sample_index] | |
reconstruction_signal = reconstruction_samples_state["reconstruction_signals"][sample_index] | |
reconstruction_flipped_log_spectrum_WOA = reconstruction_samples_state["reconstruction_flipped_log_spectrums_WOA"][ | |
sample_index] | |
reconstruction_flipped_phase_WOA = reconstruction_samples_state["reconstruction_flipped_phases_WOA"][ | |
sample_index] | |
reconstruction_signal_WOA = reconstruction_samples_state["reconstruction_signals_WOA"][sample_index] | |
return origin_flipped_log_spectrum, origin_flipped_phase, (sampleRate, origin_signal), \ | |
latent_representation_gradio_image, quantized_latent_representation_gradio_image, \ | |
reconstruction_flipped_log_spectrum, reconstruction_flipped_phase, (sampleRate, reconstruction_signal), \ | |
reconstruction_flipped_log_spectrum_WOA, reconstruction_flipped_phase_WOA, (sampleRate, reconstruction_signal_WOA), \ | |
encodeCache_state, reconstruction_samples_state | |
with gr.Tab("Reconstruction"): | |
reconstruction_samples_state = gr.State(value={}) | |
gr.Markdown("Test reconstruction.") | |
with gr.Row(variant="panel"): | |
with gr.Column(): | |
sample_source_radio = gr.Radio( | |
choices=["synthetic", "external", "text2sound_trainSTFT", "text2sound_testSTFT", "text2sound_validSTFT"], | |
value="text2sound_trainf", info="Info placeholder", scale=2) | |
batchsize_slider = gr.Slider(minimum=1., maximum=16., value=4., step=1., | |
label="batchsize") | |
with gr.Column(): | |
generate_button = gr.Button(variant="primary", value="Generate reconstruction samples", scale=1) | |
sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, label="Sample index.", | |
info="Slide to view other samples", scale=1, visible=False) | |
with gr.Row(variant="panel"): | |
with gr.Column(): | |
origin_amplitude_image_output = gr.Image(label="Spectrogram", type="numpy", height=300, width=100, scale=1) | |
origin_phase_image_output = gr.Image(label="Phase", type="numpy", height=300, width=100, scale=1) | |
origin_audio_output = gr.Audio(type="numpy", label="Play the example!") | |
with gr.Column(): | |
reconstruction_amplitude_image_output = gr.Image(label="Spectrogram", type="numpy", height=300, width=100, scale=1) | |
reconstruction_phase_image_output = gr.Image(label="Phase", type="numpy", height=300, width=100, scale=1) | |
reconstruction_audio_output = gr.Audio(type="numpy", label="Play the example!") | |
with gr.Column(): | |
reconstruction_amplitude_image_output_WOA = gr.Image(label="Spectrogram", type="numpy", height=300, width=100, scale=1) | |
reconstruction_phase_image_output_WOA = gr.Image(label="Phase", type="numpy", height=300, width=100, scale=1) | |
reconstruction_audio_output_WOA = gr.Audio(type="numpy", label="Play the example!") | |
with gr.Row(variant="panel", equal_height=True): | |
latent_representation_image_output = gr.Image(label="latent_representation", type="numpy", height=300, width=100) | |
quantized_latent_representation_image_output = gr.Image(label="quantized", type="numpy", height=300, width=100) | |
generate_button.click(generate_reconstruction_samples, | |
inputs=[sample_source_radio, batchsize_slider, reconstruction_state, | |
reconstruction_samples_state], | |
outputs=[origin_amplitude_image_output, origin_phase_image_output, origin_audio_output, | |
latent_representation_image_output, quantized_latent_representation_image_output, | |
reconstruction_amplitude_image_output, reconstruction_phase_image_output, reconstruction_audio_output, | |
reconstruction_amplitude_image_output_WOA, reconstruction_phase_image_output_WOA, reconstruction_audio_output_WOA, | |
sample_index_slider, reconstruction_state, reconstruction_samples_state]) | |
sample_index_slider.change(show_reconstruction_sample, | |
inputs=[sample_index_slider, reconstruction_state, reconstruction_samples_state], | |
outputs=[origin_amplitude_image_output, origin_phase_image_output, origin_audio_output, | |
latent_representation_image_output, quantized_latent_representation_image_output, | |
reconstruction_amplitude_image_output, reconstruction_phase_image_output, reconstruction_audio_output, | |
reconstruction_amplitude_image_output_WOA, reconstruction_phase_image_output_WOA, reconstruction_audio_output_WOA, | |
reconstruction_state, reconstruction_samples_state]) |