import spaces import gradio as gr import json import torch import wavio from tqdm import tqdm from huggingface_hub import snapshot_download from models import AudioDiffusion, DDPMScheduler from audioldm.audio.stft import TacotronSTFT from audioldm.variational_autoencoder import AutoencoderKL from pydub import AudioSegment from gradio import Markdown import torch from diffusers import UNet2DConditionModel from diffusers import DiffusionPipeline,AudioPipelineOutput from transformers import CLIPTextModel, T5EncoderModel, AutoModel, T5Tokenizer, T5TokenizerFast from typing import Union from diffusers.utils.torch_utils import randn_tensor from tqdm import tqdm from transformers import pipeline translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") class Tango2Pipeline(DiffusionPipeline): def __init__( self, vae: AutoencoderKL, text_encoder: T5EncoderModel, tokenizer: Union[T5Tokenizer, T5TokenizerFast], unet: UNet2DConditionModel, scheduler: DDPMScheduler ): super().__init__() self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler ) def _encode_prompt(self, prompt): device = self.text_encoder.device batch = self.tokenizer( prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt" ) input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device) encoder_hidden_states = self.text_encoder( input_ids=input_ids, attention_mask=attention_mask )[0] boolean_encoder_mask = (attention_mask == 1).to(device) return encoder_hidden_states, boolean_encoder_mask def _encode_text_classifier_free(self, prompt, num_samples_per_prompt): device = self.text_encoder.device batch = self.tokenizer( prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt" ) input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device) with torch.no_grad(): prompt_embeds = self.text_encoder( input_ids=input_ids, attention_mask=attention_mask )[0] prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0) attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0) # get unconditional embeddings for classifier free guidance uncond_tokens = [""] * len(prompt) max_length = prompt_embeds.shape[1] uncond_batch = self.tokenizer( uncond_tokens, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", ) uncond_input_ids = uncond_batch.input_ids.to(device) uncond_attention_mask = uncond_batch.attention_mask.to(device) with torch.no_grad(): negative_prompt_embeds = self.text_encoder( input_ids=uncond_input_ids, attention_mask=uncond_attention_mask )[0] negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0) uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0) # For classifier free guidance, we need to do two forward passes. # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_mask = torch.cat([uncond_attention_mask, attention_mask]) boolean_prompt_mask = (prompt_mask == 1).to(device) return prompt_embeds, boolean_prompt_mask def prepare_latents(self, batch_size, inference_scheduler, num_channels_latents, dtype, device): shape = (batch_size, num_channels_latents, 256, 16) latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * inference_scheduler.init_noise_sigma return latents @torch.no_grad() def inference(self, prompt, inference_scheduler, num_steps=20, guidance_scale=3, num_samples_per_prompt=1, disable_progress=True): device = self.text_encoder.device classifier_free_guidance = guidance_scale > 1.0 batch_size = len(prompt) * num_samples_per_prompt if classifier_free_guidance: prompt_embeds, boolean_prompt_mask = self._encode_text_classifier_free(prompt, num_samples_per_prompt) else: prompt_embeds, boolean_prompt_mask = self._encode_text(prompt) prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0) boolean_prompt_mask = boolean_prompt_mask.repeat_interleave(num_samples_per_prompt, 0) inference_scheduler.set_timesteps(num_steps, device=device) timesteps = inference_scheduler.timesteps num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents(batch_size, inference_scheduler, num_channels_latents, prompt_embeds.dtype, device) num_warmup_steps = len(timesteps) - num_steps * inference_scheduler.order progress_bar = tqdm(range(num_steps), disable=disable_progress) for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if classifier_free_guidance else latents latent_model_input = inference_scheduler.scale_model_input(latent_model_input, t) noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, encoder_attention_mask=boolean_prompt_mask ).sample # perform guidance if classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = inference_scheduler.step(noise_pred, t, latents).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0): progress_bar.update(1) return latents @torch.no_grad() def __call__(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True): """ Genrate audio for a single prompt string. """ with torch.no_grad(): latents = self.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress) mel = self.vae.decode_first_stage(latents) wave = self.vae.decode_to_waveform(mel) return AudioPipelineOutput(audios=wave) # Automatic device detection if torch.cuda.is_available(): device_type = "cuda" device_selection = "cuda:0" else: device_type = "cpu" device_selection = "cpu" class Tango: def __init__(self, name="declare-lab/tango2", device=device_selection): path = snapshot_download(repo_id=name) vae_config = json.load(open("{}/vae_config.json".format(path))) stft_config = json.load(open("{}/stft_config.json".format(path))) main_config = json.load(open("{}/main_config.json".format(path))) self.vae = AutoencoderKL(**vae_config).to(device) self.stft = TacotronSTFT(**stft_config).to(device) self.model = AudioDiffusion(**main_config).to(device) vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location=device) stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location=device) main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location=device) self.vae.load_state_dict(vae_weights) self.stft.load_state_dict(stft_weights) self.model.load_state_dict(main_weights) print ("Successfully loaded checkpoint from:", name) self.vae.eval() self.stft.eval() self.model.eval() self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder="scheduler") def chunks(self, lst, n): """ Yield successive n-sized chunks from a list. """ for i in range(0, len(lst), n): yield lst[i:i + n] def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True): """ Genrate audio for a single prompt string. """ with torch.no_grad(): latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress) mel = self.vae.decode_first_stage(latents) wave = self.vae.decode_to_waveform(mel) return wave[0] def generate_for_batch(self, prompts, steps=200, guidance=3, samples=1, batch_size=8, disable_progress=True): """ Genrate audio for a list of prompt strings. """ outputs = [] for k in tqdm(range(0, len(prompts), batch_size)): batch = prompts[k: k+batch_size] with torch.no_grad(): latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress=disable_progress) mel = self.vae.decode_first_stage(latents) wave = self.vae.decode_to_waveform(mel) outputs += [item for item in wave] if samples == 1: return outputs else: return list(self.chunks(outputs, samples)) # Initialize TANGO tango = Tango(device="cpu") tango.vae.to(device_type) tango.stft.to(device_type) tango.model.to(device_type) pipe = Tango2Pipeline(vae=tango.vae, text_encoder=tango.model.text_encoder, tokenizer=tango.model.tokenizer, unet=tango.model.unet, scheduler=tango.scheduler ) @spaces.GPU(duration=60) def gradio_generate(prompt, output_format, steps, guidance): # 한글이 포함되어 있는지 확인 if any(ord('가') <= ord(char) <= ord('힣') for char in prompt): # 한글을 영어로 번역 translation = translator(prompt)[0]['translation_text'] prompt = translation print(f"Translated prompt: {prompt}") output_wave = pipe(prompt,steps,guidance) output_wave = output_wave.audios[0] output_filename = "temp.wav" wavio.write(output_filename, output_wave, rate=16000, sampwidth=2) if (output_format == "mp3"): AudioSegment.from_wav("temp.wav").export("temp.mp3", format = "mp3") output_filename = "temp.mp3" return output_filename input_text = gr.Textbox(lines=2, label="Prompt") output_format = gr.Radio(label = "Output format", info = "The file you can dowload", choices = ["mp3", "wav"], value = "wav") output_audio = gr.Audio(label="Generated Audio", type="filepath") denoising_steps = gr.Slider(minimum=100, maximum=200, value=200, step=1, label="Steps", interactive=True) guidance_scale = gr.Slider(minimum=1, maximum=10, value=8, step=0.1, label="Guidance Scale", interactive=True) css = """ footer { visibility: hidden; } """ gr_interface = gr.Interface( fn=gradio_generate, inputs=[input_text, output_format, denoising_steps, guidance_scale], outputs=[output_audio], title="SoundAI by tango", theme="Yntec/HaleyCH_Theme_Orange", css=css, allow_flagging=False, examples=[ ["Quiet whispered conversation gradually fading into distant jet engine roar diminishing into silence"], ["Clear sound of bicycle tires crunching on loose gravel and dirt, followed by deep male laughter echoing"], ["Multiple ducks quacking loudly with splashing water and piercing wild animal shriek in background"], ["Powerful ocean waves crashing and receding on sandy beach with distant seagulls"], ["기관총 발사 소음"], ["Gentle female voice cooing and baby responding with happy gurgles and giggles"], ["Clear male voice speaking, sharp popping sound, followed by genuine group laughter"], ["Stream of water hitting empty ceramic cup, pitch rising as cup fills up"], ["Massive crowd erupting in thunderous applause and excited cheering"], ["Deep rolling thunder with bright lightning strikes crackling through sky"], ["Aggressive dog barking and distressed cat meowing as racing car roars past at high speed"], ["Peaceful stream bubbling and birds singing, interrupted by sudden explosive gunshot"], ["Man speaking outdoors, goat bleating loudly, metal gate scraping closed, ducks quacking frantically, wind howling into microphone"], ["Series of loud aggressive dog barks echoing"], ["Multiple distinct cat meows at different pitches"], ["Rhythmic wooden table tapping overlaid with steady water pouring sound"], ["Sustained crowd applause with camera clicks and amplified male announcer voice"], ["Two sharp gunshots followed by panicked birds taking flight with rapid wing flaps"], ["Melodic human whistling harmonizing with natural birdsong"], ["Deep rhythmic snoring with clear breathing patterns"], ["Multiple racing engines revving and accelerating with sharp whistle piercing through"], ["Massive stadium crowd cheering as thunder crashes and lightning strikes"], ["Heavy helicopter blades chopping through air with engine and wind noise"], ["Dog barking excitedly and man shouting as race car engine roars past"] ], cache_examples="lazy", # Turn on to cache. ) gr_interface.queue(10).launch()