import torch import numpy as np import gradio as gr from scipy import signal from diffusers.utils import logging logging.set_verbosity_error() from diffusers.loaders import AttnProcsLayers from transformers import CLIPTextModel, CLIPTokenizer from modules.beats.BEATs import BEATs, BEATsConfig from modules.AudioToken.embedder import FGAEmbedder from diffusers import AutoencoderKL, UNet2DConditionModel from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers import StableDiffusionPipeline from diffusers import ( DDPMScheduler, DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, DEISMultistepScheduler, UniPCMultistepScheduler, HeunDiscreteScheduler, KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, ) class AudioTokenWrapper(torch.nn.Module): """Simple wrapper module for Stable Diffusion that holds all the models together""" def __init__( self, lora, device, ): super().__init__() self.repo_id = repo_id # Load scheduler and models self.ddpm = DDPMScheduler.from_pretrained(self.repo_id, subfolder="scheduler") self.ddim = DDIMScheduler.from_pretrained(self.repo_id, subfolder="scheduler") self.pndm = PNDMScheduler.from_pretrained(self.repo_id, subfolder="scheduler") self.lms = LMSDiscreteScheduler.from_pretrained( self.repo_id, subfolder="scheduler" ) self.euler_anc = EulerAncestralDiscreteScheduler.from_pretrained( self.repo_id, subfolder="scheduler" ) self.euler = EulerDiscreteScheduler.from_pretrained( self.repo_id, subfolder="scheduler" ) self.dpm = DPMSolverMultistepScheduler.from_pretrained( self.repo_id, subfolder="scheduler" ) self.dpms = DPMSolverSinglestepScheduler.from_pretrained( self.repo_id, subfolder="scheduler" ) self.deis = DEISMultistepScheduler.from_pretrained( self.repo_id, subfolder="scheduler" ) self.unipc = UniPCMultistepScheduler.from_pretrained( self.repo_id, subfolder="scheduler" ) self.heun = HeunDiscreteScheduler.from_pretrained( self.repo_id, subfolder="scheduler" ) self.kdpm2_anc = KDPM2AncestralDiscreteScheduler.from_pretrained( self.repo_id, subfolder="scheduler" ) self.kdpm2 = KDPM2DiscreteScheduler.from_pretrained( self.repo_id, subfolder="scheduler" ) self.tokenizer = CLIPTokenizer.from_pretrained( self.repo_id, subfolder="tokenizer" ) self.text_encoder = CLIPTextModel.from_pretrained( self.repo_id, subfolder="text_encoder", revision=None ) self.unet = UNet2DConditionModel.from_pretrained( self.repo_id, subfolder="unet", revision=None ) self.vae = AutoencoderKL.from_pretrained( self.repo_id, subfolder="vae", revision=None ) checkpoint = torch.load( "models/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt" ) cfg = BEATsConfig(checkpoint["cfg"]) self.aud_encoder = BEATs(cfg) self.aud_encoder.load_state_dict(checkpoint["model"]) self.aud_encoder.predictor = None input_size = 768 * 3 self.embedder = FGAEmbedder(input_size=input_size, output_size=768) self.vae.eval() self.unet.eval() self.text_encoder.eval() self.aud_encoder.eval() if lora: # Set correct lora layers lora_attn_procs = {} for name in self.unet.attn_processors.keys(): cross_attention_dim = ( None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim ) if name.startswith("mid_block"): hidden_size = self.unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(self.unet.config.block_out_channels))[ block_id ] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = self.unet.config.block_out_channels[block_id] lora_attn_procs[name] = LoRAAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim ) self.unet.set_attn_processor(lora_attn_procs) self.lora_layers = AttnProcsLayers(self.unet.attn_processors) self.lora_layers.eval() lora_layers_learned_embeds = "models/lora_layers_learned_embeds.bin" self.lora_layers.load_state_dict( torch.load(lora_layers_learned_embeds, map_location=device) ) self.unet.load_attn_procs(lora_layers_learned_embeds) self.embedder.eval() embedder_learned_embeds = "models/embedder_learned_embeds.bin" self.embedder.load_state_dict( torch.load(embedder_learned_embeds, map_location=device) ) self.placeholder_token = "<*>" num_added_tokens = self.tokenizer.add_tokens(self.placeholder_token) if num_added_tokens == 0: raise ValueError( f"The tokenizer already contains the token {self.placeholder_token}. Please pass a different" " `placeholder_token` that is not already in the tokenizer." ) self.placeholder_token_id = self.tokenizer.convert_tokens_to_ids( self.placeholder_token ) # Resize the token embeddings as we are adding new special tokens to the tokenizer self.text_encoder.resize_token_embeddings(len(self.tokenizer)) def greet(audio, steps=25, scheduler="ddpm"): sample_rate, audio = audio audio = audio.astype(np.float32, order="C") / 32768.0 desired_sample_rate = 16000 match scheduler: case "ddpm": use_sched = model.ddpm case "ddim": use_sched = model.ddim case "pndm": use_sched = model.pndm case "lms": use_sched = model.lms case "euler_anc": use_sched = model.euler_anc case "euler": use_sched = model.euler case "dpm": use_sched = model.dpm case "dpms": use_sched = model.dpms case "deis": use_sched = model.deis case "unipc": use_sched = model.unipc case "heun": use_sched = model.heun case "kdpm2_anc": use_sched = model.kdpm2_anc case "kdpm2": use_sched = model.kdpm2 if audio.ndim == 2: audio = audio.sum(axis=1) / 2 if sample_rate != desired_sample_rate: # Calculate the resampling ratio resample_ratio = desired_sample_rate / sample_rate # Determine the new length of the audio data after downsampling new_length = int(len(audio) * resample_ratio) # Downsample the audio data using resample audio = signal.resample(audio, new_length) weight_dtype = torch.float32 prompt = "a photo of <*>" audio_values = ( torch.unsqueeze(torch.tensor(audio), dim=0).to(device).to(dtype=weight_dtype) ) if audio_values.ndim == 1: audio_values = torch.unsqueeze(audio_values, dim=0) # i dont know why but this seems mandatory for deterministic results with torch.no_grad(): aud_features = model.aud_encoder.extract_features(audio_values)[1] audio_token = model.embedder(aud_features) token_embeds = model.text_encoder.get_input_embeddings().weight.data token_embeds[model.placeholder_token_id] = audio_token.clone() generator = torch.Generator(device=device) generator.manual_seed(23229249375547) # no reason this can't be input by the user! pipeline = StableDiffusionPipeline.from_pretrained( pretrained_model_name_or_path=model.repo_id, tokenizer=model.tokenizer, text_encoder=model.text_encoder, vae=model.vae, unet=model.unet, scheduler=use_sched, safety_checker=None, ).to(device) pipeline.enable_xformers_memory_efficient_attention() # print(f"taking {steps} steps using the {scheduler} scheduler") image = pipeline( prompt, num_inference_steps=steps, guidance_scale=8.5, generator=generator ).images[0] return image lora = False repo_id = "philz1337/reliberate" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = AudioTokenWrapper(lora, device) model = model.to(device) description = """

This is a demo of AudioToken: Adaptation of Text-Conditioned Diffusion Models for Audio-to-Image Generation.

A novel method utilizing latent diffusion models trained for text-to-image-generation to generate images conditioned on audio recordings. Using a pre-trained audio encoding model, the proposed method encodes audio into a new token, which can be considered as an adaptation layer between the audio and text representations.

For more information, please see the original paper and repo.

""" examples = [ # ["assets/train.wav"], # ["assets/dog barking.wav"], # ["assets/airplane taking off.wav"], # ["assets/electric guitar.wav"], # ["assets/female sings.wav"], ] my_demo = gr.Interface( fn=greet, inputs=[ "audio", gr.Slider(value=25, step=1, label="diffusion steps"), gr.Dropdown( choices=[ "ddim", "ddpm", "pndm", "lms", "euler_anc", "euler", "dpm", "dpms", "deis", "unipc", "heun", "kdpm2_anc", "kdpm2", ], value="unipc", ), ], outputs="image", title="AudioToken", description=description, examples=examples, ) my_demo.launch()