# Run inside root directory of official source code: https://github.com/dome272/wuerstchen/ import os import torch from transformers import AutoTokenizer, CLIPTextModel from vqgan import VQModel from diffusers import ( DDPMWuerstchenScheduler, WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior model_path = "models/" device = "cpu" paella_vqmodel = VQModel() state_dict = torch.load(os.path.join(model_path, "vqgan_f4_v1_500k.pt"), map_location=device)["state_dict"] paella_vqmodel.load_state_dict(state_dict) state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"] state_dict.pop("vquantizer.codebook.weight") vqmodel = PaellaVQModel(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent) vqmodel.load_state_dict(state_dict) # Clip Text encoder and tokenizer text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") # Generator gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") gen_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") orig_state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["state_dict"] state_dict = {} for key in orig_state_dict.keys(): if key.endswith("in_proj_weight"): weights = orig_state_dict[key].chunk(3, 0) state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] elif key.endswith("in_proj_bias"): weights = orig_state_dict[key].chunk(3, 0) state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] elif key.endswith("out_proj.weight"): weights = orig_state_dict[key] state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights elif key.endswith("out_proj.bias"): weights = orig_state_dict[key] state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights else: state_dict[key] = orig_state_dict[key] deocder = WuerstchenDiffNeXt() deocder.load_state_dict(state_dict) # Prior orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"] state_dict = {} for key in orig_state_dict.keys(): if key.endswith("in_proj_weight"): weights = orig_state_dict[key].chunk(3, 0) state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] elif key.endswith("in_proj_bias"): weights = orig_state_dict[key].chunk(3, 0) state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] elif key.endswith("out_proj.weight"): weights = orig_state_dict[key] state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights elif key.endswith("out_proj.bias"): weights = orig_state_dict[key] state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights else: state_dict[key] = orig_state_dict[key] prior_model = WuerstchenPrior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) prior_model.load_state_dict(state_dict) # scheduler scheduler = DDPMWuerstchenScheduler() # Prior pipeline prior_pipeline = WuerstchenPriorPipeline( prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler ) prior_pipeline.save_pretrained("warp-ai/wuerstchen-prior") decoder_pipeline = WuerstchenDecoderPipeline( text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler ) decoder_pipeline.save_pretrained("warp-ai/wuerstchen") # Wuerstchen pipeline wuerstchen_pipeline = WuerstchenCombinedPipeline( # Decoder text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, decoder=deocder, scheduler=scheduler, vqgan=vqmodel, # Prior prior_tokenizer=tokenizer, prior_text_encoder=text_encoder, prior=prior_model, prior_scheduler=scheduler, ) wuerstchen_pipeline.save_pretrained("warp-ai/WuerstchenCombinedPipeline")