swimmiing's picture
Upload model files
b20af9f
raw
history blame
4.22 kB
import torch
from diffusers.loaders import AttnProcsLayers
from modules.BEATs.BEATs import BEATs, BEATsConfig
from modules.AudioToken.embedder import FGAEmbedder
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import LoRAAttnProcessor
class AudioTokenWrapper(torch.nn.Module):
"""Simple wrapper module for Stable Diffusion that holds all the models together"""
def __init__(
self,
args,
accelerator,
):
super().__init__()
# Load scheduler and models
from modules.clip_text_model.modeling_clip import CLIPTextModel
self.text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
self.unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
self.vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
)
checkpoint = torch.load(
'models/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')
cfg = BEATsConfig(checkpoint['cfg'])
self.aud_encoder = BEATs(cfg)
self.aud_encoder.load_state_dict(checkpoint['model'])
self.aud_encoder.predictor = None
input_size = 768 * 3
if args.pretrained_model_name_or_path == "CompVis/stable-diffusion-v1-4":
self.embedder = FGAEmbedder(input_size=input_size, output_size=768)
else:
self.embedder = FGAEmbedder(input_size=input_size, output_size=1024)
self.vae.eval()
self.unet.eval()
self.text_encoder.eval()
self.aud_encoder.eval()
if 'lora' in args and args.lora:
# Set correct lora layers
lora_attn_procs = {}
for name in self.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith(
"attn1.processor") else self.unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = self.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim)
self.unet.set_attn_processor(lora_attn_procs)
self.lora_layers = AttnProcsLayers(self.unet.attn_processors)
if args.data_set == 'train':
# Freeze vae, unet, text_enc and aud_encoder
self.vae.requires_grad_(False)
self.unet.requires_grad_(False)
self.text_encoder.requires_grad_(False)
self.aud_encoder.requires_grad_(False)
self.embedder.requires_grad_(True)
self.embedder.train()
if 'lora' in args and args.lora:
self.unet.train()
if args.data_set == 'test':
from transformers import CLIPTextModel
self.text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
self.embedder.eval()
embedder_learned_embeds = args.learned_embeds
self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=accelerator.device))
if 'lora' in args and args.lora:
self.lora_layers.eval()
lora_layers_learned_embeds = args.lora_learned_embeds
self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=accelerator.device))
self.unet.load_attn_procs(lora_layers_learned_embeds)