WeixuanYuan's picture
Update webUI/natural_language_guided_4/text2sound.py
8ab6976 verified
import gradio as gr
import numpy as np
from model.DiffSynthSampler import DiffSynthSampler
from tools import safe_int
from webUI.natural_language_guided_4.utils import latent_representation_to_Gradio_image, \
encodeBatch2GradioOutput_STFT, add_instrument, resize_image_to_aspect_ratio
def get_text2sound_module(gradioWebUI, text2sound_state, virtual_instruments_state):
# Load configurations
uNet = gradioWebUI.uNet
freq_resolution, time_resolution = gradioWebUI.freq_resolution, gradioWebUI.time_resolution
VAE_scale = gradioWebUI.VAE_scale
height, width, channels = int(freq_resolution / VAE_scale), int(time_resolution / VAE_scale), gradioWebUI.channels
timesteps = gradioWebUI.timesteps
VAE_quantizer = gradioWebUI.VAE_quantizer
VAE_decoder = gradioWebUI.VAE_decoder
CLAP = gradioWebUI.CLAP
CLAP_tokenizer = gradioWebUI.CLAP_tokenizer
device = gradioWebUI.device
squared = gradioWebUI.squared
sample_rate = gradioWebUI.sample_rate
noise_strategy = gradioWebUI.noise_strategy
def diffusion_random_sample(text2sound_prompts, text2sound_negative_prompts, text2sound_batchsize,
text2sound_duration,
text2sound_guidance_scale, text2sound_sampler,
text2sound_sample_steps, text2sound_seed,
text2sound_dict):
text2sound_sample_steps = int(text2sound_sample_steps)
text2sound_seed = safe_int(text2sound_seed, 12345678)
width = int(time_resolution * ((text2sound_duration + 1) / 4) / VAE_scale)
text2sound_batchsize = int(text2sound_batchsize)
text2sound_embedding = \
CLAP.get_text_features(**CLAP_tokenizer([text2sound_prompts], padding=True, return_tensors="pt"))[0].to(
device)
CFG = int(text2sound_guidance_scale)
mySampler = DiffSynthSampler(timesteps, height=height, channels=channels, noise_strategy=noise_strategy)
negative_condition = \
CLAP.get_text_features(**CLAP_tokenizer([text2sound_negative_prompts], padding=True, return_tensors="pt"))[
0]
mySampler.activate_classifier_free_guidance(CFG, negative_condition.to(device))
mySampler.respace(list(np.linspace(0, timesteps - 1, text2sound_sample_steps, dtype=np.int32)))
condition = text2sound_embedding.repeat(text2sound_batchsize, 1)
latent_representations, initial_noise = \
mySampler.sample(model=uNet, shape=(text2sound_batchsize, channels, height, width), seed=text2sound_seed,
return_tensor=True, condition=condition, sampler=text2sound_sampler)
latent_representations = latent_representations[-1]
latent_representation_gradio_images = []
quantized_latent_representation_gradio_images = []
new_sound_spectrogram_gradio_images = []
new_sound_phase_gradio_images = []
new_sound_rec_signals_gradio = []
quantized_latent_representations, loss, (_, _, _) = VAE_quantizer(latent_representations)
# Todo: remove hard-coding
flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(VAE_decoder,
quantized_latent_representations,
resolution=(
512,
width * VAE_scale),
original_STFT_batch=None
)
for i in range(text2sound_batchsize):
latent_representation_gradio_images.append(latent_representation_to_Gradio_image(latent_representations[i]))
quantized_latent_representation_gradio_images.append(
latent_representation_to_Gradio_image(quantized_latent_representations[i]))
new_sound_spectrogram_gradio_images.append(flipped_log_spectrums[i])
new_sound_phase_gradio_images.append(flipped_phases[i])
new_sound_rec_signals_gradio.append((sample_rate, rec_signals[i]))
text2sound_dict["latent_representation_gradio_images"] = latent_representation_gradio_images
text2sound_dict["quantized_latent_representation_gradio_images"] = quantized_latent_representation_gradio_images
text2sound_dict["new_sound_spectrogram_gradio_images"] = new_sound_spectrogram_gradio_images
text2sound_dict["new_sound_phase_gradio_images"] = new_sound_phase_gradio_images
text2sound_dict["new_sound_rec_signals_gradio"] = new_sound_rec_signals_gradio
# save instrument
text2sound_dict["latent_representations"] = latent_representations.to("cpu").detach().numpy()
text2sound_dict["quantized_latent_representations"] = quantized_latent_representations.to(
"cpu").detach().numpy()
text2sound_dict["condition"] = condition.to("cpu").detach().numpy()
text2sound_dict["negative_condition"] = negative_condition.to("cpu").detach().numpy()
text2sound_dict["guidance_scale"] = CFG
text2sound_dict["sampler"] = text2sound_sampler
return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][0],
text2sound_quantized_latent_representation_image:
text2sound_dict["quantized_latent_representation_gradio_images"][0],
text2sound_sampled_spectrogram_image: resize_image_to_aspect_ratio(
text2sound_dict["new_sound_spectrogram_gradio_images"][0],
1.55,
1),
text2sound_sampled_phase_image: resize_image_to_aspect_ratio(
text2sound_dict["new_sound_phase_gradio_images"][0],
1.55,
1),
text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][0],
text2sound_seed_textbox: text2sound_seed,
text2sound_state: text2sound_dict,
text2sound_sample_index_slider: gr.update(minimum=0, maximum=text2sound_batchsize - 1, value=0, step=1,
visible=True,
label="Sample index.",
info="Swipe to view other samples")}
def show_random_sample(sample_index, text2sound_dict):
sample_index = int(sample_index)
text2sound_dict["sample_index"] = sample_index
print(text2sound_dict["new_sound_rec_signals_gradio"][sample_index])
return {text2sound_latent_representation_image: text2sound_dict["latent_representation_gradio_images"][
sample_index],
text2sound_quantized_latent_representation_image:
text2sound_dict["quantized_latent_representation_gradio_images"][sample_index],
text2sound_sampled_spectrogram_image: resize_image_to_aspect_ratio(
text2sound_dict["new_sound_spectrogram_gradio_images"][sample_index], 1.55, 1),
text2sound_sampled_phase_image: resize_image_to_aspect_ratio(text2sound_dict["new_sound_phase_gradio_images"][
sample_index], 1.55, 1),
text2sound_sampled_audio: text2sound_dict["new_sound_rec_signals_gradio"][sample_index]}
def save_virtual_instrument(sample_index, virtual_instrument_name, text2sound_dict, virtual_instruments_dict):
virtual_instruments_dict = add_instrument(text2sound_dict, virtual_instruments_dict, virtual_instrument_name,
sample_index)
return {virtual_instruments_state: virtual_instruments_dict,
text2sound_instrument_name_textbox: gr.Textbox(label="Instrument name", lines=1,
placeholder=f"Saved as {virtual_instrument_name}!")}
with gr.Tab("Text2sound"):
gr.Markdown("Use neural networks to select random sounds using your favorite instrument!")
with gr.Row(variant="panel"):
with gr.Column(scale=3):
text2sound_prompts_textbox = gr.Textbox(label="Positive prompt", lines=2, value="string")
text2sound_negative_prompts_textbox = gr.Textbox(label="Negative prompt", lines=2, value="")
with gr.Column(scale=1):
text2sound_sampling_button = gr.Button(variant="primary",
value="Generate a batch of samples and show "
"the first one",
scale=1)
text2sound_sample_index_slider = gr.Slider(minimum=0, maximum=3, value=0, step=1.0, visible=False,
label="Sample index",
info="Swipe to view other samples")
with gr.Row(variant="panel"):
with gr.Column(variant="panel", scale=1):
text2sound_sample_steps_slider = gradioWebUI.get_sample_steps_slider()
text2sound_sampler_radio = gradioWebUI.get_sampler_radio()
text2sound_batchsize_slider = gradioWebUI.get_batchsize_slider()
text2sound_duration_slider = gradioWebUI.get_duration_slider()
text2sound_guidance_scale_slider = gradioWebUI.get_guidance_scale_slider()
text2sound_seed_textbox = gradioWebUI.get_seed_textbox()
with gr.Column(variant="panel", scale=1):
with gr.Row(variant="panel", ):
text2sound_sampled_spectrogram_image = gr.Image(label="Sampled spectrogram", type="numpy", )
text2sound_sampled_phase_image = gr.Image(label="Sampled phase", type="numpy")
text2sound_sampled_audio = gr.Audio(type="numpy", label="Play",
scale=1)
with gr.Row(variant="panel", ):
text2sound_instrument_name_textbox = gr.Textbox(label="Instrument name", lines=2,
placeholder="Name of your instrument",
scale=1)
text2sound_save_instrument_button = gr.Button(variant="primary",
value="Save instrument",
scale=1)
with gr.Row(variant="panel"):
text2sound_latent_representation_image = gr.Image(label="Sampled latent representation", type="numpy",
height=200, width=100, visible=False)
text2sound_quantized_latent_representation_image = gr.Image(label="Quantized latent representation",
type="numpy", height=200, width=100,
visible=False)
text2sound_sampling_button.click(diffusion_random_sample,
inputs=[text2sound_prompts_textbox,
text2sound_negative_prompts_textbox,
text2sound_batchsize_slider,
text2sound_duration_slider,
text2sound_guidance_scale_slider, text2sound_sampler_radio,
text2sound_sample_steps_slider,
text2sound_seed_textbox,
text2sound_state],
outputs=[text2sound_latent_representation_image,
text2sound_quantized_latent_representation_image,
text2sound_sampled_spectrogram_image,
text2sound_sampled_phase_image,
text2sound_sampled_audio,
text2sound_seed_textbox,
text2sound_state,
text2sound_sample_index_slider])
text2sound_save_instrument_button.click(save_virtual_instrument,
inputs=[text2sound_sample_index_slider,
text2sound_instrument_name_textbox,
text2sound_state,
virtual_instruments_state],
outputs=[virtual_instruments_state,
text2sound_instrument_name_textbox])
text2sound_sample_index_slider.change(show_random_sample,
inputs=[text2sound_sample_index_slider, text2sound_state],
outputs=[text2sound_latent_representation_image,
text2sound_quantized_latent_representation_image,
text2sound_sampled_spectrogram_image,
text2sound_sampled_phase_image,
text2sound_sampled_audio])