Spaces:
Runtime error
Runtime error
File size: 6,679 Bytes
1b92e8f 2821e52 1b92e8f d833a17 1b92e8f d833a17 1b92e8f 56d047b 1b92e8f 5aeb32a 56d047b 5aeb32a 1b92e8f 56d047b bc16d21 1b92e8f 56d047b 1b92e8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
from diffusers.loaders import AttnProcsLayers
from transformers import CLIPTextModel, CLIPTokenizer
from modules.beats.BEATs import BEATs, BEATsConfig
from modules.AudioToken.embedder import FGAEmbedder
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers import StableDiffusionPipeline
import numpy as np
import gradio as gr
class AudioTokenWrapper(torch.nn.Module):
"""Simple wrapper module for Stable Diffusion that holds all the models together"""
def __init__(
self,
lora,
device,
):
super().__init__()
# Load scheduler and models
self.tokenizer = CLIPTokenizer.from_pretrained(
"CompVis/stable-diffusion-v1-4", subfolder="tokenizer"
)
self.text_encoder = CLIPTextModel.from_pretrained(
"CompVis/stable-diffusion-v1-4", subfolder="text_encoder", revision=None
)
self.unet = UNet2DConditionModel.from_pretrained(
"CompVis/stable-diffusion-v1-4", subfolder="unet", revision=None
)
self.vae = AutoencoderKL.from_pretrained(
"CompVis/stable-diffusion-v1-4", subfolder="vae", revision=None
)
checkpoint = torch.load(
'models/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')
cfg = BEATsConfig(checkpoint['cfg'])
self.aud_encoder = BEATs(cfg)
self.aud_encoder.load_state_dict(checkpoint['model'])
self.aud_encoder.predictor = None
input_size = 768 * 3
self.embedder = FGAEmbedder(input_size=input_size, output_size=768)
self.vae.eval()
self.unet.eval()
self.text_encoder.eval()
self.aud_encoder.eval()
if lora:
# Set correct lora layers
lora_attn_procs = {}
for name in self.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith(
"attn1.processor") else self.unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = self.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim)
self.unet.set_attn_processor(lora_attn_procs)
self.lora_layers = AttnProcsLayers(self.unet.attn_processors)
self.lora_layers.eval()
lora_layers_learned_embeds = 'models/lora_layers_learned_embeds.bin'
self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=device))
self.unet.load_attn_procs(lora_layers_learned_embeds)
self.embedder.eval()
embedder_learned_embeds = 'models/embedder_learned_embeds.bin'
self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=device))
self.placeholder_token = '<*>'
num_added_tokens = self.tokenizer.add_tokens(self.placeholder_token)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {self.placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
self.placeholder_token_id = self.tokenizer.convert_tokens_to_ids(self.placeholder_token)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
def greet(audio):
audio = audio[-1].astype(np.float32, order='C') / 32768.0
weight_dtype = torch.float32
prompt = 'a photo of <*>'
audio_values = torch.unsqueeze(torch.tensor(audio), dim=0).to(device).to(dtype=weight_dtype)
aud_features = model.aud_encoder.extract_features(audio_values)[1]
audio_token = model.embedder(aud_features)
token_embeds = model.text_encoder.get_input_embeddings().weight.data
token_embeds[model.placeholder_token_id] = audio_token.clone()
pipeline = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
tokenizer=model.tokenizer,
text_encoder=model.text_encoder,
vae=model.vae,
unet=model.unet,
).to(device)
image = pipeline(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
return image
if __name__ == "__main__":
lora = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AudioTokenWrapper(lora, device)
description = """<p>
This is a demo of <a href='https://pages.cs.huji.ac.il/adiyoss-lab/AudioToken' target='_blank'>AudioToken: Adaptation of Text-Conditioned Diffusion Models for Audio-to-Image Generation</a>.<br><br>
In recent years, image generation has shown a great leap in performance, where diffusion models play a central role. Although generating high-quality images, such models are mainly conditioned on textual descriptions. This begs the question: "how can we adopt such models to be conditioned on other modalities?". We propose a novel method utilizing latent diffusion models trained for text-to-image-generation to generate images conditioned on audio recordings. Using a pre-trained audio encoding model, the proposed method encodes audio into a new token, which can be considered as an adaptation layer between the audio and text representations. Such a modeling paradigm requires a small number of trainable parameters, making the proposed approach appealing for lightweight optimization.<br><br>
For more information, please see the original <a href='https://arxiv.org/abs/2305.13050' target='_blank'>paper</a> and <a href='https://github.com/guyyariv/AudioToken' target='_blank'>repo</a>.
</p>"""
examples = [
["assets/train.wav"],
["assets/dog barking.wav"],
["assets/airplane.wav"],
["assets/electric guitar.wav"],
["assets/female sings.wav"],
]
demo = gr.Interface(
fn=greet,
inputs="audio",
outputs="image",
title='AudioToken',
description=description,
examples=examples
)
demo.launch()
|