Spaces:
Running
on
Zero
Running
on
Zero
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline. | |
import argparse | |
from contextlib import nullcontext | |
import torch | |
from safetensors.torch import load_file | |
from transformers import ( | |
AutoTokenizer, | |
CLIPConfig, | |
CLIPImageProcessor, | |
CLIPTextModelWithProjection, | |
CLIPVisionModelWithProjection, | |
) | |
from diffusers import ( | |
DDPMWuerstchenScheduler, | |
StableCascadeCombinedPipeline, | |
StableCascadeDecoderPipeline, | |
StableCascadePriorPipeline, | |
) | |
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers | |
from diffusers.models import StableCascadeUNet | |
from diffusers.models.modeling_utils import load_model_dict_into_meta | |
from diffusers.pipelines.wuerstchen import PaellaVQModel | |
from diffusers.utils import is_accelerate_available | |
if is_accelerate_available(): | |
from accelerate import init_empty_weights | |
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline") | |
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights") | |
parser.add_argument( | |
"--stage_c_name", type=str, default="stage_c_lite.safetensors", help="Name of stage c checkpoint file" | |
) | |
parser.add_argument( | |
"--stage_b_name", type=str, default="stage_b_lite.safetensors", help="Name of stage b checkpoint file" | |
) | |
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c") | |
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b") | |
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion") | |
parser.add_argument( | |
"--prior_output_path", | |
default="stable-cascade-prior-lite", | |
type=str, | |
help="Hub organization to save the pipelines to", | |
) | |
parser.add_argument( | |
"--decoder_output_path", | |
type=str, | |
default="stable-cascade-decoder-lite", | |
help="Hub organization to save the pipelines to", | |
) | |
parser.add_argument( | |
"--combined_output_path", | |
type=str, | |
default="stable-cascade-combined-lite", | |
help="Hub organization to save the pipelines to", | |
) | |
parser.add_argument("--save_combined", action="store_true") | |
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub") | |
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights") | |
args = parser.parse_args() | |
if args.skip_stage_b and args.skip_stage_c: | |
raise ValueError("At least one stage should be converted") | |
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined: | |
raise ValueError("Cannot skip stages when creating a combined pipeline") | |
model_path = args.model_path | |
device = "cpu" | |
if args.variant == "bf16": | |
dtype = torch.bfloat16 | |
else: | |
dtype = torch.float32 | |
# set paths to model weights | |
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}" | |
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}" | |
# Clip Text encoder and tokenizer | |
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") | |
config.text_config.projection_dim = config.projection_dim | |
text_encoder = CLIPTextModelWithProjection.from_pretrained( | |
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config | |
) | |
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") | |
# image processor | |
feature_extractor = CLIPImageProcessor() | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") | |
# scheduler for prior and decoder | |
scheduler = DDPMWuerstchenScheduler() | |
ctx = init_empty_weights if is_accelerate_available() else nullcontext | |
if not args.skip_stage_c: | |
# Prior | |
if args.use_safetensors: | |
prior_orig_state_dict = load_file(prior_checkpoint_path, device=device) | |
else: | |
prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device) | |
prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict) | |
with ctx(): | |
prior_model = StableCascadeUNet( | |
in_channels=16, | |
out_channels=16, | |
timestep_ratio_embedding_dim=64, | |
patch_size=1, | |
conditioning_dim=1536, | |
block_out_channels=[1536, 1536], | |
num_attention_heads=[24, 24], | |
down_num_layers_per_block=[4, 12], | |
up_num_layers_per_block=[12, 4], | |
down_blocks_repeat_mappers=[1, 1], | |
up_blocks_repeat_mappers=[1, 1], | |
block_types_per_layer=[ | |
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], | |
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], | |
], | |
clip_text_in_channels=1280, | |
clip_text_pooled_in_channels=1280, | |
clip_image_in_channels=768, | |
clip_seq=4, | |
kernel_size=3, | |
dropout=[0.1, 0.1], | |
self_attn=True, | |
timestep_conditioning_type=["sca", "crp"], | |
switch_level=[False], | |
) | |
if is_accelerate_available(): | |
load_model_dict_into_meta(prior_model, prior_state_dict) | |
else: | |
prior_model.load_state_dict(prior_state_dict) | |
# Prior pipeline | |
prior_pipeline = StableCascadePriorPipeline( | |
prior=prior_model, | |
tokenizer=tokenizer, | |
text_encoder=text_encoder, | |
image_encoder=image_encoder, | |
scheduler=scheduler, | |
feature_extractor=feature_extractor, | |
) | |
prior_pipeline.to(dtype).save_pretrained( | |
args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant | |
) | |
if not args.skip_stage_b: | |
# Decoder | |
if args.use_safetensors: | |
decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device) | |
else: | |
decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device) | |
decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict) | |
with ctx(): | |
decoder = StableCascadeUNet( | |
in_channels=4, | |
out_channels=4, | |
timestep_ratio_embedding_dim=64, | |
patch_size=2, | |
conditioning_dim=1280, | |
block_out_channels=[320, 576, 1152, 1152], | |
down_num_layers_per_block=[2, 4, 14, 4], | |
up_num_layers_per_block=[4, 14, 4, 2], | |
down_blocks_repeat_mappers=[1, 1, 1, 1], | |
up_blocks_repeat_mappers=[2, 2, 2, 2], | |
num_attention_heads=[0, 9, 18, 18], | |
block_types_per_layer=[ | |
["SDCascadeResBlock", "SDCascadeTimestepBlock"], | |
["SDCascadeResBlock", "SDCascadeTimestepBlock"], | |
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], | |
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], | |
], | |
clip_text_pooled_in_channels=1280, | |
clip_seq=4, | |
effnet_in_channels=16, | |
pixel_mapper_in_channels=3, | |
kernel_size=3, | |
dropout=[0, 0, 0.1, 0.1], | |
self_attn=True, | |
timestep_conditioning_type=["sca"], | |
) | |
if is_accelerate_available(): | |
load_model_dict_into_meta(decoder, decoder_state_dict) | |
else: | |
decoder.load_state_dict(decoder_state_dict) | |
# VQGAN from Wuerstchen-V2 | |
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan") | |
# Decoder pipeline | |
decoder_pipeline = StableCascadeDecoderPipeline( | |
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler | |
) | |
decoder_pipeline.to(dtype).save_pretrained( | |
args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant | |
) | |
if args.save_combined: | |
# Stable Cascade combined pipeline | |
stable_cascade_pipeline = StableCascadeCombinedPipeline( | |
# Decoder | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
decoder=decoder, | |
scheduler=scheduler, | |
vqgan=vqmodel, | |
# Prior | |
prior_text_encoder=text_encoder, | |
prior_tokenizer=tokenizer, | |
prior_prior=prior_model, | |
prior_scheduler=scheduler, | |
prior_image_encoder=image_encoder, | |
prior_feature_extractor=feature_extractor, | |
) | |
stable_cascade_pipeline.to(dtype).save_pretrained( | |
args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant | |
) | |