Spaces:
Sleeping
Sleeping
import librosa | |
import numpy as np | |
import torch | |
import gradio as gr | |
from scipy.ndimage import zoom | |
from model.DiffSynthSampler import DiffSynthSampler | |
from tools import adjust_audio_length, safe_int, pad_STFT, encode_stft | |
from webUI.natural_language_guided_4.utils import latent_representation_to_Gradio_image, InputBatch2Encode_STFT, \ | |
encodeBatch2GradioOutput_STFT, add_instrument, average_np_arrays | |
def get_triangle_mask(height, width): | |
mask = np.zeros((height, width)) | |
slope = 8 / 3 | |
for i in range(height): | |
for j in range(width): | |
if i < slope * j: | |
mask[i, j] = 1 | |
return mask | |
def get_inpaint_with_text_module(gradioWebUI, inpaintWithText_state, virtual_instruments_state): | |
# Load configurations | |
uNet = gradioWebUI.uNet | |
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution | |
VAE_scale = gradioWebUI.VAE_scale | |
height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels | |
timesteps = gradioWebUI.timesteps | |
VAE_encoder = gradioWebUI.VAE_encoder | |
VAE_quantizer = gradioWebUI.VAE_quantizer | |
VAE_decoder = gradioWebUI.VAE_decoder | |
CLAP = gradioWebUI.CLAP | |
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer | |
device = gradioWebUI.device | |
squared = gradioWebUI.squared | |
sample_rate = gradioWebUI.sample_rate | |
noise_strategy = gradioWebUI.noise_strategy | |
def receive_upload_origin_audio(sound2sound_duration, sound2sound_origin, inpaintWithText_dict): | |
origin_sr, origin_audio = sound2sound_origin | |
origin_audio = origin_audio / np.max(np.abs(origin_audio)) | |
width = int(time_resolution * ((sound2sound_duration + 1) / 4) / VAE_scale) | |
audio_length = 256 * (VAE_scale * width - 1) | |
origin_audio = adjust_audio_length(origin_audio, audio_length, origin_sr, sample_rate) | |
D = librosa.stft(origin_audio, n_fft=1024, hop_length=256, win_length=1024) | |
padded_D = pad_STFT(D) | |
encoded_D = encode_stft(padded_D) | |
# Todo: justify batchsize to 1 | |
origin_spectrogram_batch_tensor = torch.from_numpy( | |
np.repeat(encoded_D[np.newaxis, :, :, :], 1, axis=0)).float().to(device) | |
# Todo: remove hard-coding | |
origin_flipped_log_spectrums, origin_flipped_phases, origin_signals, origin_latent_representations, quantized_origin_latent_representations = InputBatch2Encode_STFT( | |
VAE_encoder, origin_spectrogram_batch_tensor, resolution=(512, width * VAE_scale), quantizer=VAE_quantizer, | |
squared=squared) | |
inpaintWithText_dict["origin_upload_latent_representations"] = origin_latent_representations.tolist() | |
inpaintWithText_dict[ | |
"sound2sound_origin_upload_latent_representation_image"] = latent_representation_to_Gradio_image( | |
origin_latent_representations[0]).tolist() | |
inpaintWithText_dict[ | |
"sound2sound_origin_upload_quantized_latent_representation_image"] = latent_representation_to_Gradio_image( | |
quantized_origin_latent_representations[0]).tolist() | |
return {sound2sound_origin_spectrogram_image: origin_flipped_log_spectrums[0], | |
sound2sound_origin_phase_image: origin_flipped_phases[0], | |
sound2sound_origin_upload_latent_representation_image: latent_representation_to_Gradio_image( | |
origin_latent_representations[0]), | |
sound2sound_origin_upload_quantized_latent_representation_image: latent_representation_to_Gradio_image( | |
quantized_origin_latent_representations[0]), | |
sound2sound_origin_microphone_latent_representation_image: gr.update(), | |
sound2sound_origin_microphone_quantized_latent_representation_image: gr.update(), | |
inpaintWithText_state: inpaintWithText_dict} | |
def sound2sound_sample(sound2sound_origin_spectrogram, | |
text2sound_prompts, text2sound_negative_prompts, sound2sound_batchsize, | |
sound2sound_guidance_scale, sound2sound_sampler, | |
sound2sound_sample_steps, | |
sound2sound_noising_strength, sound2sound_seed, sound2sound_inpaint_area, | |
mask_time_begin, mask_time_end, mask_frequency_begin, mask_frequency_end, | |
inpaintWithText_dict | |
): | |
# input preprocessing | |
sound2sound_seed = safe_int(sound2sound_seed, 12345678) | |
sound2sound_batchsize = int(sound2sound_batchsize) | |
noising_strength = sound2sound_noising_strength | |
sound2sound_sample_steps = int(sound2sound_sample_steps) | |
CFG = int(sound2sound_guidance_scale) | |
text2sound_embedding = \ | |
CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to( | |
device) | |
averaged_transparency = average_np_arrays(sound2sound_origin_spectrogram["layers"]) | |
# print(f"averaged_transparency: {averaged_transparency}") | |
averaged_transparency = averaged_transparency[:, :, -1] | |
# print(f"averaged_transparency: {averaged_transparency}") | |
# print(f"np.shape(averaged_transparency): {np.shape(averaged_transparency)}") | |
# print(f"np.mean(averaged_transparency): {np.mean(averaged_transparency)}") | |
origin_latent_representations = torch.tensor( | |
inpaintWithText_dict["origin_upload_latent_representations"]).repeat(sound2sound_batchsize, 1, 1, 1).to( | |
device) | |
merged_mask = np.where(averaged_transparency > 0, 1, 0) | |
latent_mask = zoom(merged_mask, (1 / VAE_scale, 1 / VAE_scale)) | |
latent_mask = np.clip(latent_mask, 0, 1) | |
# print(f"latent_mask.avg = {np.mean(latent_mask)}") | |
latent_mask[int(mask_frequency_begin):int(mask_frequency_end), | |
int(mask_time_begin * time_resolution / (VAE_scale * 4)):int( | |
mask_time_end * time_resolution / (VAE_scale * 4))] = 1 | |
if sound2sound_inpaint_area == "masked": | |
latent_mask = 1 - latent_mask | |
latent_mask = torch.from_numpy(latent_mask).unsqueeze(0).unsqueeze(1).repeat(sound2sound_batchsize, channels, 1, | |
1).float().to(device) | |
latent_mask = torch.flip(latent_mask, [2]) | |
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy) | |
unconditional_condition = \ | |
CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[ | |
0] | |
mySampler.activate_classifier_free_guidance(CFG, unconditional_condition.to(device)) | |
normalized_sample_steps = int(sound2sound_sample_steps / noising_strength) | |
mySampler.respace(list(np.linspace(0, timesteps - 1, normalized_sample_steps, dtype=np.int32))) | |
# Todo: remove hard-coding | |
width = origin_latent_representations.shape[-1] | |
condition = text2sound_embedding.repeat(sound2sound_batchsize, 1) | |
new_sound_latent_representations, initial_noise = \ | |
mySampler.inpaint_sample(model=uNet, shape=(sound2sound_batchsize, channels, height, width), | |
seed=sound2sound_seed, | |
noising_strength=noising_strength, | |
guide_img=origin_latent_representations, mask=latent_mask, return_tensor=True, | |
condition=condition, sampler=sound2sound_sampler) | |
new_sound_latent_representations = new_sound_latent_representations[-1] | |
# Quantize new sound latent representations | |
quantized_new_sound_latent_representations, loss, (_, _, _) = VAE_quantizer(new_sound_latent_representations) | |
new_sound_flipped_log_spectrums, new_sound_flipped_phases, new_sound_signals, _, _, _ = encodeBatch2GradioOutput_STFT( | |
VAE_decoder, | |
quantized_new_sound_latent_representations, | |
resolution=( | |
512, | |
width * VAE_scale), | |
original_STFT_batch=None | |
) | |
new_sound_latent_representation_gradio_images = [] | |
new_sound_quantized_latent_representation_gradio_images = [] | |
new_sound_spectrogram_gradio_images = [] | |
new_sound_phase_gradio_images = [] | |
new_sound_rec_signals_gradio = [] | |
for i in range(sound2sound_batchsize): | |
new_sound_latent_representation_gradio_images.append( | |
latent_representation_to_Gradio_image(new_sound_latent_representations[i])) | |
new_sound_quantized_latent_representation_gradio_images.append( | |
latent_representation_to_Gradio_image(quantized_new_sound_latent_representations[i])) | |
new_sound_spectrogram_gradio_images.append(new_sound_flipped_log_spectrums[i]) | |
new_sound_phase_gradio_images.append(new_sound_flipped_phases[i]) | |
new_sound_rec_signals_gradio.append((sample_rate, new_sound_signals[i])) | |
inpaintWithText_dict[ | |
"new_sound_latent_representation_gradio_images"] = new_sound_latent_representation_gradio_images | |
inpaintWithText_dict[ | |
"new_sound_quantized_latent_representation_gradio_images"] = new_sound_quantized_latent_representation_gradio_images | |
inpaintWithText_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images | |
inpaintWithText_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images | |
inpaintWithText_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio | |
inpaintWithText_dict["latent_representations"] = new_sound_latent_representations.to("cpu").detach().numpy() | |
inpaintWithText_dict["quantized_latent_representations"] = quantized_new_sound_latent_representations.to( | |
"cpu").detach().numpy() | |
inpaintWithText_dict["sampler"] = sound2sound_sampler | |
return {sound2sound_new_sound_latent_representation_image: latent_representation_to_Gradio_image( | |
new_sound_latent_representations[0]), | |
sound2sound_new_sound_quantized_latent_representation_image: latent_representation_to_Gradio_image( | |
quantized_new_sound_latent_representations[0]), | |
sound2sound_new_sound_spectrogram_image: new_sound_flipped_log_spectrums[0], | |
sound2sound_new_sound_phase_image: new_sound_flipped_phases[0], | |
sound2sound_new_sound_audio: (sample_rate, new_sound_signals[0]), | |
sound2sound_sample_index_slider: gr.update(minimum=0, maximum=sound2sound_batchsize - 1, value=0, | |
step=1.0, | |
visible=True, | |
label="Sample index", | |
info="Swipe to view other samples"), | |
sound2sound_seed_textbox: sound2sound_seed, | |
inpaintWithText_state: inpaintWithText_dict} | |
def show_sound2sound_sample(sound2sound_sample_index, inpaintWithText_dict): | |
sample_index = int(sound2sound_sample_index) | |
return {sound2sound_new_sound_latent_representation_image: | |
inpaintWithText_dict["new_sound_latent_representation_gradio_images"][sample_index], | |
sound2sound_new_sound_quantized_latent_representation_image: | |
inpaintWithText_dict["new_sound_quantized_latent_representation_gradio_images"][sample_index], | |
sound2sound_new_sound_spectrogram_image: inpaintWithText_dict["new_sound_spectrogram_gradio_images"][ | |
sample_index], | |
sound2sound_new_sound_phase_image: inpaintWithText_dict["new_sound_phase_gradio_images"][ | |
sample_index], | |
sound2sound_new_sound_audio: inpaintWithText_dict["new_sound_rec_signals_gradio"][sample_index]} | |
def save_virtual_instrument(sample_index, virtual_instrument_name, sound2sound_dict, virtual_instruments_dict): | |
virtual_instruments_dict = add_instrument(sound2sound_dict, virtual_instruments_dict, virtual_instrument_name, | |
sample_index) | |
return {virtual_instruments_state: virtual_instruments_dict, | |
sound2sound_instrument_name_textbox: gr.Textbox(label="Instrument name", lines=1, | |
placeholder=f"Saved as {virtual_instrument_name}!")} | |
with gr.Tab("Inpaint"): | |
gr.Markdown("Upload a musical note and select the area by drawing on \"Input spectrogram\" for inpainting!") | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=3): | |
text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="organ") | |
text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="") | |
with gr.Column(scale=1): | |
sound2sound_sample_button = gr.Button(variant="primary", value="Generate", scale=1) | |
sound2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False, | |
label="Sample index", | |
info="Swipe to view other samples") | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
sound2sound_duration_slider = gradioWebUI.get_duration_slider() | |
sound2sound_origin_audio = gr.Audio( | |
sources=["microphone", "upload"], label="Upload/Record source sound", | |
waveform_options=gr.WaveformOptions( | |
waveform_color="#01C6FF", | |
waveform_progress_color="#0066B4", | |
skip_length=1, | |
show_controls=False, | |
), | |
) | |
with gr.Row(variant="panel"): | |
with gr.Tab("Sound2sound settings"): | |
sound2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider() | |
sound2sound_sampler_radio = gradioWebUI.get_sampler_radio() | |
sound2sound_batchsize_slider = gradioWebUI.get_batchsize_slider() | |
sound2sound_noising_strength_slider = gradioWebUI.get_noising_strength_slider() | |
sound2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider() | |
sound2sound_seed_textbox = gradioWebUI.get_seed_textbox() | |
with gr.Tab("Mask prototypes"): | |
with gr.Tab("Mask along time axis"): | |
mask_time_begin_slider = gr.Slider(minimum=0.0, maximum=4.00, value=0.0, step=0.01, | |
label="Begin time") | |
mask_time_end_slider = gr.Slider(minimum=0.0, maximum=4.00, value=0.0, step=0.01, | |
label="End time") | |
with gr.Tab("Mask along frequency axis"): | |
mask_frequency_begin_slider = gr.Slider(minimum=0, maximum=127, value=0, step=1, | |
label="Begin freq pixel") | |
mask_frequency_end_slider = gr.Slider(minimum=0, maximum=127, value=0, step=1, | |
label="End freq pixel") | |
with gr.Column(scale=1): | |
with gr.Row(variant="panel"): | |
sound2sound_origin_spectrogram_image = gr.ImageEditor(label="Input spectrogram (draw here!)", | |
type="numpy", | |
visible=True, height=600, scale=1) | |
sound2sound_new_sound_spectrogram_image = gr.Image(label="New sound spectrogram", type="numpy", | |
height=600, scale=1) | |
with gr.Row(variant="panel"): | |
sound2sound_inpaint_area_radio = gr.Radio(label="Inpainting area", choices=["masked", "unmasked"], | |
value="masked", scale=1) | |
sound2sound_new_sound_audio = gr.Audio(type="numpy", label="Play new sound", interactive=False, | |
waveform_options=gr.WaveformOptions( | |
waveform_color="#FFB6C1", | |
waveform_progress_color="#FF0000", | |
skip_length=1, | |
show_controls=False, | |
), scale=1 ) | |
with gr.Row(variant="panel"): | |
sound2sound_instrument_name_textbox = gr.Textbox(label="Instrument name", lines=1, | |
placeholder="Name of your instrument") | |
sound2sound_save_instrument_button = gr.Button(variant="primary", | |
value="Save instrument", | |
scale=1) | |
with gr.Row(variant="panel"): | |
sound2sound_origin_upload_latent_representation_image = gr.Image(label="Original latent representation", | |
type="numpy", height=800, | |
visible=False) | |
sound2sound_origin_upload_quantized_latent_representation_image = gr.Image( | |
label="Original quantized latent representation", type="numpy", height=800, visible=False) | |
sound2sound_origin_microphone_latent_representation_image = gr.Image(label="Original latent representation", | |
type="numpy", height=800, | |
visible=False) | |
sound2sound_origin_microphone_quantized_latent_representation_image = gr.Image( | |
label="Original quantized latent representation", type="numpy", height=800, visible=False) | |
sound2sound_new_sound_latent_representation_image = gr.Image(label="New latent representation", | |
type="numpy", height=800, visible=False) | |
sound2sound_new_sound_quantized_latent_representation_image = gr.Image( | |
label="New sound quantized latent representation", type="numpy", height=800, visible=False) | |
sound2sound_origin_phase_image = gr.Image(label="Original upload phase", | |
type="numpy", visible=False) | |
sound2sound_new_sound_phase_image = gr.Image(label="New sound phase", type="numpy", | |
height=600, scale=1, visible=False) | |
sound2sound_origin_audio.change(receive_upload_origin_audio, | |
inputs=[sound2sound_duration_slider, sound2sound_origin_audio, | |
inpaintWithText_state], | |
outputs=[sound2sound_origin_spectrogram_image, | |
sound2sound_origin_phase_image, | |
sound2sound_origin_upload_latent_representation_image, | |
sound2sound_origin_upload_quantized_latent_representation_image, | |
sound2sound_origin_microphone_latent_representation_image, | |
sound2sound_origin_microphone_quantized_latent_representation_image, | |
inpaintWithText_state]) | |
sound2sound_sample_button.click(sound2sound_sample, | |
inputs=[sound2sound_origin_spectrogram_image, | |
text2sound_prompts_textbox, | |
text2sound_negative_prompts_textbox, | |
sound2sound_batchsize_slider, | |
sound2sound_guidance_scale_slider, | |
sound2sound_sampler_radio, | |
sound2sound_sample_steps_slider, | |
sound2sound_noising_strength_slider, | |
sound2sound_seed_textbox, | |
sound2sound_inpaint_area_radio, | |
mask_time_begin_slider, | |
mask_time_end_slider, | |
mask_frequency_begin_slider, | |
mask_frequency_end_slider, | |
inpaintWithText_state], | |
outputs=[sound2sound_new_sound_latent_representation_image, | |
sound2sound_new_sound_quantized_latent_representation_image, | |
sound2sound_new_sound_spectrogram_image, | |
sound2sound_new_sound_phase_image, | |
sound2sound_new_sound_audio, | |
sound2sound_sample_index_slider, | |
sound2sound_seed_textbox, | |
inpaintWithText_state]) | |
sound2sound_sample_index_slider.change(show_sound2sound_sample, | |
inputs=[sound2sound_sample_index_slider, inpaintWithText_state], | |
outputs=[sound2sound_new_sound_latent_representation_image, | |
sound2sound_new_sound_quantized_latent_representation_image, | |
sound2sound_new_sound_spectrogram_image, | |
sound2sound_new_sound_phase_image, | |
sound2sound_new_sound_audio]) | |
sound2sound_save_instrument_button.click(save_virtual_instrument, | |
inputs=[sound2sound_sample_index_slider, | |
sound2sound_instrument_name_textbox, | |
inpaintWithText_state, | |
virtual_instruments_state], | |
outputs=[virtual_instruments_state, | |
sound2sound_instrument_name_textbox]) | |