Spaces:
No application file
No application file
import argparse | |
import inspect | |
import os | |
import numpy as np | |
import torch | |
from torch.nn import functional as F | |
from transformers import CLIPConfig, CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer | |
from diffusers import DDPMScheduler, IFPipeline, IFSuperResolutionPipeline, UNet2DConditionModel | |
from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker | |
try: | |
from omegaconf import OmegaConf | |
except ImportError: | |
raise ImportError( | |
"OmegaConf is required to convert the IF checkpoints. Please install it with `pip install" " OmegaConf`." | |
) | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--dump_path", required=False, default=None, type=str) | |
parser.add_argument("--dump_path_stage_2", required=False, default=None, type=str) | |
parser.add_argument("--dump_path_stage_3", required=False, default=None, type=str) | |
parser.add_argument("--unet_config", required=False, default=None, type=str, help="Path to unet config file") | |
parser.add_argument( | |
"--unet_checkpoint_path", required=False, default=None, type=str, help="Path to unet checkpoint file" | |
) | |
parser.add_argument( | |
"--unet_checkpoint_path_stage_2", | |
required=False, | |
default=None, | |
type=str, | |
help="Path to stage 2 unet checkpoint file", | |
) | |
parser.add_argument( | |
"--unet_checkpoint_path_stage_3", | |
required=False, | |
default=None, | |
type=str, | |
help="Path to stage 3 unet checkpoint file", | |
) | |
parser.add_argument("--p_head_path", type=str, required=True) | |
parser.add_argument("--w_head_path", type=str, required=True) | |
args = parser.parse_args() | |
return args | |
def main(args): | |
tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl") | |
text_encoder = T5EncoderModel.from_pretrained("google/t5-v1_1-xxl") | |
feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
safety_checker = convert_safety_checker(p_head_path=args.p_head_path, w_head_path=args.w_head_path) | |
if args.unet_config is not None and args.unet_checkpoint_path is not None and args.dump_path is not None: | |
convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args) | |
if args.unet_checkpoint_path_stage_2 is not None and args.dump_path_stage_2 is not None: | |
convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=2) | |
if args.unet_checkpoint_path_stage_3 is not None and args.dump_path_stage_3 is not None: | |
convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=3) | |
def convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args): | |
unet = get_stage_1_unet(args.unet_config, args.unet_checkpoint_path) | |
scheduler = DDPMScheduler( | |
variance_type="learned_range", | |
beta_schedule="squaredcos_cap_v2", | |
prediction_type="epsilon", | |
thresholding=True, | |
dynamic_thresholding_ratio=0.95, | |
sample_max_value=1.5, | |
) | |
pipe = IFPipeline( | |
tokenizer=tokenizer, | |
text_encoder=text_encoder, | |
unet=unet, | |
scheduler=scheduler, | |
safety_checker=safety_checker, | |
feature_extractor=feature_extractor, | |
requires_safety_checker=True, | |
) | |
pipe.save_pretrained(args.dump_path) | |
def convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage): | |
if stage == 2: | |
unet_checkpoint_path = args.unet_checkpoint_path_stage_2 | |
sample_size = None | |
dump_path = args.dump_path_stage_2 | |
elif stage == 3: | |
unet_checkpoint_path = args.unet_checkpoint_path_stage_3 | |
sample_size = 1024 | |
dump_path = args.dump_path_stage_3 | |
else: | |
assert False | |
unet = get_super_res_unet(unet_checkpoint_path, verify_param_count=False, sample_size=sample_size) | |
image_noising_scheduler = DDPMScheduler( | |
beta_schedule="squaredcos_cap_v2", | |
) | |
scheduler = DDPMScheduler( | |
variance_type="learned_range", | |
beta_schedule="squaredcos_cap_v2", | |
prediction_type="epsilon", | |
thresholding=True, | |
dynamic_thresholding_ratio=0.95, | |
sample_max_value=1.0, | |
) | |
pipe = IFSuperResolutionPipeline( | |
tokenizer=tokenizer, | |
text_encoder=text_encoder, | |
unet=unet, | |
scheduler=scheduler, | |
image_noising_scheduler=image_noising_scheduler, | |
safety_checker=safety_checker, | |
feature_extractor=feature_extractor, | |
requires_safety_checker=True, | |
) | |
pipe.save_pretrained(dump_path) | |
def get_stage_1_unet(unet_config, unet_checkpoint_path): | |
original_unet_config = OmegaConf.load(unet_config) | |
original_unet_config = original_unet_config.params | |
unet_diffusers_config = create_unet_diffusers_config(original_unet_config) | |
unet = UNet2DConditionModel(**unet_diffusers_config) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
unet_checkpoint = torch.load(unet_checkpoint_path, map_location=device) | |
converted_unet_checkpoint = convert_ldm_unet_checkpoint( | |
unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path | |
) | |
unet.load_state_dict(converted_unet_checkpoint) | |
return unet | |
def convert_safety_checker(p_head_path, w_head_path): | |
state_dict = {} | |
# p head | |
p_head = np.load(p_head_path) | |
p_head_weights = p_head["weights"] | |
p_head_weights = torch.from_numpy(p_head_weights) | |
p_head_weights = p_head_weights.unsqueeze(0) | |
p_head_biases = p_head["biases"] | |
p_head_biases = torch.from_numpy(p_head_biases) | |
p_head_biases = p_head_biases.unsqueeze(0) | |
state_dict["p_head.weight"] = p_head_weights | |
state_dict["p_head.bias"] = p_head_biases | |
# w head | |
w_head = np.load(w_head_path) | |
w_head_weights = w_head["weights"] | |
w_head_weights = torch.from_numpy(w_head_weights) | |
w_head_weights = w_head_weights.unsqueeze(0) | |
w_head_biases = w_head["biases"] | |
w_head_biases = torch.from_numpy(w_head_biases) | |
w_head_biases = w_head_biases.unsqueeze(0) | |
state_dict["w_head.weight"] = w_head_weights | |
state_dict["w_head.bias"] = w_head_biases | |
# vision model | |
vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") | |
vision_model_state_dict = vision_model.state_dict() | |
for key, value in vision_model_state_dict.items(): | |
key = f"vision_model.{key}" | |
state_dict[key] = value | |
# full model | |
config = CLIPConfig.from_pretrained("openai/clip-vit-large-patch14") | |
safety_checker = IFSafetyChecker(config) | |
safety_checker.load_state_dict(state_dict) | |
return safety_checker | |
def create_unet_diffusers_config(original_unet_config, class_embed_type=None): | |
attention_resolutions = parse_list(original_unet_config.attention_resolutions) | |
attention_resolutions = [original_unet_config.image_size // int(res) for res in attention_resolutions] | |
channel_mult = parse_list(original_unet_config.channel_mult) | |
block_out_channels = [original_unet_config.model_channels * mult for mult in channel_mult] | |
down_block_types = [] | |
resolution = 1 | |
for i in range(len(block_out_channels)): | |
if resolution in attention_resolutions: | |
block_type = "SimpleCrossAttnDownBlock2D" | |
elif original_unet_config.resblock_updown: | |
block_type = "ResnetDownsampleBlock2D" | |
else: | |
block_type = "DownBlock2D" | |
down_block_types.append(block_type) | |
if i != len(block_out_channels) - 1: | |
resolution *= 2 | |
up_block_types = [] | |
for i in range(len(block_out_channels)): | |
if resolution in attention_resolutions: | |
block_type = "SimpleCrossAttnUpBlock2D" | |
elif original_unet_config.resblock_updown: | |
block_type = "ResnetUpsampleBlock2D" | |
else: | |
block_type = "UpBlock2D" | |
up_block_types.append(block_type) | |
resolution //= 2 | |
head_dim = original_unet_config.num_head_channels | |
use_linear_projection = ( | |
original_unet_config.use_linear_in_transformer | |
if "use_linear_in_transformer" in original_unet_config | |
else False | |
) | |
if use_linear_projection: | |
# stable diffusion 2-base-512 and 2-768 | |
if head_dim is None: | |
head_dim = [5, 10, 20, 20] | |
projection_class_embeddings_input_dim = None | |
if class_embed_type is None: | |
if "num_classes" in original_unet_config: | |
if original_unet_config.num_classes == "sequential": | |
class_embed_type = "projection" | |
assert "adm_in_channels" in original_unet_config | |
projection_class_embeddings_input_dim = original_unet_config.adm_in_channels | |
else: | |
raise NotImplementedError( | |
f"Unknown conditional unet num_classes config: {original_unet_config.num_classes}" | |
) | |
config = { | |
"sample_size": original_unet_config.image_size, | |
"in_channels": original_unet_config.in_channels, | |
"down_block_types": tuple(down_block_types), | |
"block_out_channels": tuple(block_out_channels), | |
"layers_per_block": original_unet_config.num_res_blocks, | |
"cross_attention_dim": original_unet_config.encoder_channels, | |
"attention_head_dim": head_dim, | |
"use_linear_projection": use_linear_projection, | |
"class_embed_type": class_embed_type, | |
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, | |
"out_channels": original_unet_config.out_channels, | |
"up_block_types": tuple(up_block_types), | |
"upcast_attention": False, # TODO: guessing | |
"cross_attention_norm": "group_norm", | |
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn", | |
"addition_embed_type": "text", | |
"act_fn": "gelu", | |
} | |
if original_unet_config.use_scale_shift_norm: | |
config["resnet_time_scale_shift"] = "scale_shift" | |
if "encoder_dim" in original_unet_config: | |
config["encoder_hid_dim"] = original_unet_config.encoder_dim | |
return config | |
def convert_ldm_unet_checkpoint(unet_state_dict, config, path=None): | |
""" | |
Takes a state dict and a config, and returns a converted checkpoint. | |
""" | |
new_checkpoint = {} | |
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] | |
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] | |
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] | |
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] | |
if config["class_embed_type"] in [None, "identity"]: | |
# No parameters to port | |
... | |
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": | |
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] | |
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] | |
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] | |
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] | |
else: | |
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") | |
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] | |
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] | |
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] | |
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] | |
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] | |
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] | |
# Retrieves the keys for the input blocks only | |
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) | |
input_blocks = { | |
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] | |
for layer_id in range(num_input_blocks) | |
} | |
# Retrieves the keys for the middle blocks only | |
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) | |
middle_blocks = { | |
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] | |
for layer_id in range(num_middle_blocks) | |
} | |
# Retrieves the keys for the output blocks only | |
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) | |
output_blocks = { | |
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] | |
for layer_id in range(num_output_blocks) | |
} | |
for i in range(1, num_input_blocks): | |
block_id = (i - 1) // (config["layers_per_block"] + 1) | |
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) | |
resnets = [ | |
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key | |
] | |
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] | |
if f"input_blocks.{i}.0.op.weight" in unet_state_dict: | |
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( | |
f"input_blocks.{i}.0.op.weight" | |
) | |
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( | |
f"input_blocks.{i}.0.op.bias" | |
) | |
paths = renew_resnet_paths(resnets) | |
# TODO need better check than i in [4, 8, 12, 16] | |
block_type = config["down_block_types"][block_id] | |
if (block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D") and i in [ | |
4, | |
8, | |
12, | |
16, | |
]: | |
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"} | |
else: | |
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} | |
assign_to_checkpoint( | |
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
) | |
if len(attentions): | |
old_path = f"input_blocks.{i}.1" | |
new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}" | |
assign_attention_to_checkpoint( | |
new_checkpoint=new_checkpoint, | |
unet_state_dict=unet_state_dict, | |
old_path=old_path, | |
new_path=new_path, | |
config=config, | |
) | |
paths = renew_attention_paths(attentions) | |
meta_path = {"old": old_path, "new": new_path} | |
assign_to_checkpoint( | |
paths, | |
new_checkpoint, | |
unet_state_dict, | |
additional_replacements=[meta_path], | |
config=config, | |
) | |
resnet_0 = middle_blocks[0] | |
attentions = middle_blocks[1] | |
resnet_1 = middle_blocks[2] | |
resnet_0_paths = renew_resnet_paths(resnet_0) | |
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) | |
resnet_1_paths = renew_resnet_paths(resnet_1) | |
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) | |
old_path = "middle_block.1" | |
new_path = "mid_block.attentions.0" | |
assign_attention_to_checkpoint( | |
new_checkpoint=new_checkpoint, | |
unet_state_dict=unet_state_dict, | |
old_path=old_path, | |
new_path=new_path, | |
config=config, | |
) | |
attentions_paths = renew_attention_paths(attentions) | |
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} | |
assign_to_checkpoint( | |
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
) | |
for i in range(num_output_blocks): | |
block_id = i // (config["layers_per_block"] + 1) | |
layer_in_block_id = i % (config["layers_per_block"] + 1) | |
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] | |
output_block_list = {} | |
for layer in output_block_layers: | |
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) | |
if layer_id in output_block_list: | |
output_block_list[layer_id].append(layer_name) | |
else: | |
output_block_list[layer_id] = [layer_name] | |
# len(output_block_list) == 1 -> resnet | |
# len(output_block_list) == 2 -> resnet, attention | |
# len(output_block_list) == 3 -> resnet, attention, upscale resnet | |
if len(output_block_list) > 1: | |
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] | |
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] | |
paths = renew_resnet_paths(resnets) | |
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} | |
assign_to_checkpoint( | |
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
) | |
output_block_list = {k: sorted(v) for k, v in output_block_list.items()} | |
if ["conv.bias", "conv.weight"] in output_block_list.values(): | |
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) | |
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ | |
f"output_blocks.{i}.{index}.conv.weight" | |
] | |
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ | |
f"output_blocks.{i}.{index}.conv.bias" | |
] | |
# Clear attentions as they have been attributed above. | |
if len(attentions) == 2: | |
attentions = [] | |
if len(attentions): | |
old_path = f"output_blocks.{i}.1" | |
new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}" | |
assign_attention_to_checkpoint( | |
new_checkpoint=new_checkpoint, | |
unet_state_dict=unet_state_dict, | |
old_path=old_path, | |
new_path=new_path, | |
config=config, | |
) | |
paths = renew_attention_paths(attentions) | |
meta_path = { | |
"old": old_path, | |
"new": new_path, | |
} | |
assign_to_checkpoint( | |
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
) | |
if len(output_block_list) == 3: | |
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key] | |
paths = renew_resnet_paths(resnets) | |
meta_path = {"old": f"output_blocks.{i}.2", "new": f"up_blocks.{block_id}.upsamplers.0"} | |
assign_to_checkpoint( | |
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
) | |
else: | |
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) | |
for path in resnet_0_paths: | |
old_path = ".".join(["output_blocks", str(i), path["old"]]) | |
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) | |
new_checkpoint[new_path] = unet_state_dict[old_path] | |
if "encoder_proj.weight" in unet_state_dict: | |
new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict.pop("encoder_proj.weight") | |
new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict.pop("encoder_proj.bias") | |
if "encoder_pooling.0.weight" in unet_state_dict: | |
new_checkpoint["add_embedding.norm1.weight"] = unet_state_dict.pop("encoder_pooling.0.weight") | |
new_checkpoint["add_embedding.norm1.bias"] = unet_state_dict.pop("encoder_pooling.0.bias") | |
new_checkpoint["add_embedding.pool.positional_embedding"] = unet_state_dict.pop( | |
"encoder_pooling.1.positional_embedding" | |
) | |
new_checkpoint["add_embedding.pool.k_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.k_proj.weight") | |
new_checkpoint["add_embedding.pool.k_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.k_proj.bias") | |
new_checkpoint["add_embedding.pool.q_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.q_proj.weight") | |
new_checkpoint["add_embedding.pool.q_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.q_proj.bias") | |
new_checkpoint["add_embedding.pool.v_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.v_proj.weight") | |
new_checkpoint["add_embedding.pool.v_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.v_proj.bias") | |
new_checkpoint["add_embedding.proj.weight"] = unet_state_dict.pop("encoder_pooling.2.weight") | |
new_checkpoint["add_embedding.proj.bias"] = unet_state_dict.pop("encoder_pooling.2.bias") | |
new_checkpoint["add_embedding.norm2.weight"] = unet_state_dict.pop("encoder_pooling.3.weight") | |
new_checkpoint["add_embedding.norm2.bias"] = unet_state_dict.pop("encoder_pooling.3.bias") | |
return new_checkpoint | |
def shave_segments(path, n_shave_prefix_segments=1): | |
""" | |
Removes segments. Positive values shave the first segments, negative shave the last segments. | |
""" | |
if n_shave_prefix_segments >= 0: | |
return ".".join(path.split(".")[n_shave_prefix_segments:]) | |
else: | |
return ".".join(path.split(".")[:n_shave_prefix_segments]) | |
def renew_resnet_paths(old_list, n_shave_prefix_segments=0): | |
""" | |
Updates paths inside resnets to the new naming scheme (local renaming) | |
""" | |
mapping = [] | |
for old_item in old_list: | |
new_item = old_item.replace("in_layers.0", "norm1") | |
new_item = new_item.replace("in_layers.2", "conv1") | |
new_item = new_item.replace("out_layers.0", "norm2") | |
new_item = new_item.replace("out_layers.3", "conv2") | |
new_item = new_item.replace("emb_layers.1", "time_emb_proj") | |
new_item = new_item.replace("skip_connection", "conv_shortcut") | |
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | |
mapping.append({"old": old_item, "new": new_item}) | |
return mapping | |
def renew_attention_paths(old_list, n_shave_prefix_segments=0): | |
""" | |
Updates paths inside attentions to the new naming scheme (local renaming) | |
""" | |
mapping = [] | |
for old_item in old_list: | |
new_item = old_item | |
if "qkv" in new_item: | |
continue | |
if "encoder_kv" in new_item: | |
continue | |
new_item = new_item.replace("norm.weight", "group_norm.weight") | |
new_item = new_item.replace("norm.bias", "group_norm.bias") | |
new_item = new_item.replace("proj_out.weight", "to_out.0.weight") | |
new_item = new_item.replace("proj_out.bias", "to_out.0.bias") | |
new_item = new_item.replace("norm_encoder.weight", "norm_cross.weight") | |
new_item = new_item.replace("norm_encoder.bias", "norm_cross.bias") | |
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | |
mapping.append({"old": old_item, "new": new_item}) | |
return mapping | |
def assign_attention_to_checkpoint(new_checkpoint, unet_state_dict, old_path, new_path, config): | |
qkv_weight = unet_state_dict.pop(f"{old_path}.qkv.weight") | |
qkv_weight = qkv_weight[:, :, 0] | |
qkv_bias = unet_state_dict.pop(f"{old_path}.qkv.bias") | |
is_cross_attn_only = "only_cross_attention" in config and config["only_cross_attention"] | |
split = 1 if is_cross_attn_only else 3 | |
weights, bias = split_attentions( | |
weight=qkv_weight, | |
bias=qkv_bias, | |
split=split, | |
chunk_size=config["attention_head_dim"], | |
) | |
if is_cross_attn_only: | |
query_weight, q_bias = weights, bias | |
new_checkpoint[f"{new_path}.to_q.weight"] = query_weight[0] | |
new_checkpoint[f"{new_path}.to_q.bias"] = q_bias[0] | |
else: | |
[query_weight, key_weight, value_weight], [q_bias, k_bias, v_bias] = weights, bias | |
new_checkpoint[f"{new_path}.to_q.weight"] = query_weight | |
new_checkpoint[f"{new_path}.to_q.bias"] = q_bias | |
new_checkpoint[f"{new_path}.to_k.weight"] = key_weight | |
new_checkpoint[f"{new_path}.to_k.bias"] = k_bias | |
new_checkpoint[f"{new_path}.to_v.weight"] = value_weight | |
new_checkpoint[f"{new_path}.to_v.bias"] = v_bias | |
encoder_kv_weight = unet_state_dict.pop(f"{old_path}.encoder_kv.weight") | |
encoder_kv_weight = encoder_kv_weight[:, :, 0] | |
encoder_kv_bias = unet_state_dict.pop(f"{old_path}.encoder_kv.bias") | |
[encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions( | |
weight=encoder_kv_weight, | |
bias=encoder_kv_bias, | |
split=2, | |
chunk_size=config["attention_head_dim"], | |
) | |
new_checkpoint[f"{new_path}.add_k_proj.weight"] = encoder_k_weight | |
new_checkpoint[f"{new_path}.add_k_proj.bias"] = encoder_k_bias | |
new_checkpoint[f"{new_path}.add_v_proj.weight"] = encoder_v_weight | |
new_checkpoint[f"{new_path}.add_v_proj.bias"] = encoder_v_bias | |
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, additional_replacements=None, config=None): | |
""" | |
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits | |
attention layers, and takes into account additional replacements that may arise. | |
Assigns the weights to the new checkpoint. | |
""" | |
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." | |
for path in paths: | |
new_path = path["new"] | |
# Global renaming happens here | |
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") | |
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") | |
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") | |
if additional_replacements is not None: | |
for replacement in additional_replacements: | |
new_path = new_path.replace(replacement["old"], replacement["new"]) | |
# proj_attn.weight has to be converted from conv 1D to linear | |
if "proj_attn.weight" in new_path or "to_out.0.weight" in new_path: | |
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] | |
else: | |
checkpoint[new_path] = old_checkpoint[path["old"]] | |
# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?) | |
def split_attentions(*, weight, bias, split, chunk_size): | |
weights = [None] * split | |
biases = [None] * split | |
weights_biases_idx = 0 | |
for starting_row_index in range(0, weight.shape[0], chunk_size): | |
row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size) | |
weight_rows = weight[row_indices, :] | |
bias_rows = bias[row_indices] | |
if weights[weights_biases_idx] is None: | |
weights[weights_biases_idx] = weight_rows | |
biases[weights_biases_idx] = bias_rows | |
else: | |
assert weights[weights_biases_idx] is not None | |
weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows]) | |
biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows]) | |
weights_biases_idx = (weights_biases_idx + 1) % split | |
return weights, biases | |
def parse_list(value): | |
if isinstance(value, str): | |
value = value.split(",") | |
value = [int(v) for v in value] | |
elif isinstance(value, list): | |
pass | |
else: | |
raise ValueError(f"Can't parse list for type: {type(value)}") | |
return value | |
# below is copy and pasted from original convert_if_stage_2.py script | |
def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None): | |
orig_path = unet_checkpoint_path | |
original_unet_config = OmegaConf.load(os.path.join(orig_path, "config.yml")) | |
original_unet_config = original_unet_config.params | |
unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config) | |
unet_diffusers_config["time_embedding_dim"] = original_unet_config.model_channels * int( | |
original_unet_config.channel_mult.split(",")[-1] | |
) | |
if original_unet_config.encoder_dim != original_unet_config.encoder_channels: | |
unet_diffusers_config["encoder_hid_dim"] = original_unet_config.encoder_dim | |
unet_diffusers_config["class_embed_type"] = "timestep" | |
unet_diffusers_config["addition_embed_type"] = "text" | |
unet_diffusers_config["time_embedding_act_fn"] = "gelu" | |
unet_diffusers_config["resnet_skip_time_act"] = True | |
unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071 | |
unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071 | |
unet_diffusers_config["only_cross_attention"] = ( | |
bool(original_unet_config.disable_self_attentions) | |
if ( | |
"disable_self_attentions" in original_unet_config | |
and isinstance(original_unet_config.disable_self_attentions, int) | |
) | |
else True | |
) | |
if sample_size is None: | |
unet_diffusers_config["sample_size"] = original_unet_config.image_size | |
else: | |
# The second upscaler unet's sample size is incorrectly specified | |
# in the config and is instead hardcoded in source | |
unet_diffusers_config["sample_size"] = sample_size | |
unet_checkpoint = torch.load(os.path.join(unet_checkpoint_path, "pytorch_model.bin"), map_location="cpu") | |
if verify_param_count: | |
# check that architecture matches - is a bit slow | |
verify_param_count(orig_path, unet_diffusers_config) | |
converted_unet_checkpoint = superres_convert_ldm_unet_checkpoint( | |
unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path | |
) | |
converted_keys = converted_unet_checkpoint.keys() | |
model = UNet2DConditionModel(**unet_diffusers_config) | |
expected_weights = model.state_dict().keys() | |
diff_c_e = set(converted_keys) - set(expected_weights) | |
diff_e_c = set(expected_weights) - set(converted_keys) | |
assert len(diff_e_c) == 0, f"Expected, but not converted: {diff_e_c}" | |
assert len(diff_c_e) == 0, f"Converted, but not expected: {diff_c_e}" | |
model.load_state_dict(converted_unet_checkpoint) | |
return model | |
def superres_create_unet_diffusers_config(original_unet_config): | |
attention_resolutions = parse_list(original_unet_config.attention_resolutions) | |
attention_resolutions = [original_unet_config.image_size // int(res) for res in attention_resolutions] | |
channel_mult = parse_list(original_unet_config.channel_mult) | |
block_out_channels = [original_unet_config.model_channels * mult for mult in channel_mult] | |
down_block_types = [] | |
resolution = 1 | |
for i in range(len(block_out_channels)): | |
if resolution in attention_resolutions: | |
block_type = "SimpleCrossAttnDownBlock2D" | |
elif original_unet_config.resblock_updown: | |
block_type = "ResnetDownsampleBlock2D" | |
else: | |
block_type = "DownBlock2D" | |
down_block_types.append(block_type) | |
if i != len(block_out_channels) - 1: | |
resolution *= 2 | |
up_block_types = [] | |
for i in range(len(block_out_channels)): | |
if resolution in attention_resolutions: | |
block_type = "SimpleCrossAttnUpBlock2D" | |
elif original_unet_config.resblock_updown: | |
block_type = "ResnetUpsampleBlock2D" | |
else: | |
block_type = "UpBlock2D" | |
up_block_types.append(block_type) | |
resolution //= 2 | |
head_dim = original_unet_config.num_head_channels | |
use_linear_projection = ( | |
original_unet_config.use_linear_in_transformer | |
if "use_linear_in_transformer" in original_unet_config | |
else False | |
) | |
if use_linear_projection: | |
# stable diffusion 2-base-512 and 2-768 | |
if head_dim is None: | |
head_dim = [5, 10, 20, 20] | |
class_embed_type = None | |
projection_class_embeddings_input_dim = None | |
if "num_classes" in original_unet_config: | |
if original_unet_config.num_classes == "sequential": | |
class_embed_type = "projection" | |
assert "adm_in_channels" in original_unet_config | |
projection_class_embeddings_input_dim = original_unet_config.adm_in_channels | |
else: | |
raise NotImplementedError( | |
f"Unknown conditional unet num_classes config: {original_unet_config.num_classes}" | |
) | |
config = { | |
"in_channels": original_unet_config.in_channels, | |
"down_block_types": tuple(down_block_types), | |
"block_out_channels": tuple(block_out_channels), | |
"layers_per_block": tuple(original_unet_config.num_res_blocks), | |
"cross_attention_dim": original_unet_config.encoder_channels, | |
"attention_head_dim": head_dim, | |
"use_linear_projection": use_linear_projection, | |
"class_embed_type": class_embed_type, | |
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, | |
"out_channels": original_unet_config.out_channels, | |
"up_block_types": tuple(up_block_types), | |
"upcast_attention": False, # TODO: guessing | |
"cross_attention_norm": "group_norm", | |
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn", | |
"act_fn": "gelu", | |
} | |
if original_unet_config.use_scale_shift_norm: | |
config["resnet_time_scale_shift"] = "scale_shift" | |
return config | |
def superres_convert_ldm_unet_checkpoint(unet_state_dict, config, path=None, extract_ema=False): | |
""" | |
Takes a state dict and a config, and returns a converted checkpoint. | |
""" | |
new_checkpoint = {} | |
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] | |
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] | |
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] | |
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] | |
if config["class_embed_type"] is None: | |
# No parameters to port | |
... | |
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": | |
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["aug_proj.0.weight"] | |
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["aug_proj.0.bias"] | |
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["aug_proj.2.weight"] | |
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["aug_proj.2.bias"] | |
else: | |
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") | |
if "encoder_proj.weight" in unet_state_dict: | |
new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict["encoder_proj.weight"] | |
new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict["encoder_proj.bias"] | |
if "encoder_pooling.0.weight" in unet_state_dict: | |
mapping = { | |
"encoder_pooling.0": "add_embedding.norm1", | |
"encoder_pooling.1": "add_embedding.pool", | |
"encoder_pooling.2": "add_embedding.proj", | |
"encoder_pooling.3": "add_embedding.norm2", | |
} | |
for key in unet_state_dict.keys(): | |
if key.startswith("encoder_pooling"): | |
prefix = key[: len("encoder_pooling.0")] | |
new_key = key.replace(prefix, mapping[prefix]) | |
new_checkpoint[new_key] = unet_state_dict[key] | |
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] | |
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] | |
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] | |
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] | |
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] | |
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] | |
# Retrieves the keys for the input blocks only | |
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) | |
input_blocks = { | |
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] | |
for layer_id in range(num_input_blocks) | |
} | |
# Retrieves the keys for the middle blocks only | |
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) | |
middle_blocks = { | |
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] | |
for layer_id in range(num_middle_blocks) | |
} | |
# Retrieves the keys for the output blocks only | |
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) | |
output_blocks = { | |
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] | |
for layer_id in range(num_output_blocks) | |
} | |
if not isinstance(config["layers_per_block"], int): | |
layers_per_block_list = [e + 1 for e in config["layers_per_block"]] | |
layers_per_block_cumsum = list(np.cumsum(layers_per_block_list)) | |
downsampler_ids = layers_per_block_cumsum | |
else: | |
# TODO need better check than i in [4, 8, 12, 16] | |
downsampler_ids = [4, 8, 12, 16] | |
for i in range(1, num_input_blocks): | |
if isinstance(config["layers_per_block"], int): | |
layers_per_block = config["layers_per_block"] | |
block_id = (i - 1) // (layers_per_block + 1) | |
layer_in_block_id = (i - 1) % (layers_per_block + 1) | |
else: | |
block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if (i - 1) < n) | |
passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0 | |
layer_in_block_id = (i - 1) - passed_blocks | |
resnets = [ | |
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key | |
] | |
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] | |
if f"input_blocks.{i}.0.op.weight" in unet_state_dict: | |
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( | |
f"input_blocks.{i}.0.op.weight" | |
) | |
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( | |
f"input_blocks.{i}.0.op.bias" | |
) | |
paths = renew_resnet_paths(resnets) | |
block_type = config["down_block_types"][block_id] | |
if ( | |
block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D" | |
) and i in downsampler_ids: | |
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"} | |
else: | |
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} | |
assign_to_checkpoint( | |
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
) | |
if len(attentions): | |
old_path = f"input_blocks.{i}.1" | |
new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}" | |
assign_attention_to_checkpoint( | |
new_checkpoint=new_checkpoint, | |
unet_state_dict=unet_state_dict, | |
old_path=old_path, | |
new_path=new_path, | |
config=config, | |
) | |
paths = renew_attention_paths(attentions) | |
meta_path = {"old": old_path, "new": new_path} | |
assign_to_checkpoint( | |
paths, | |
new_checkpoint, | |
unet_state_dict, | |
additional_replacements=[meta_path], | |
config=config, | |
) | |
resnet_0 = middle_blocks[0] | |
attentions = middle_blocks[1] | |
resnet_1 = middle_blocks[2] | |
resnet_0_paths = renew_resnet_paths(resnet_0) | |
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) | |
resnet_1_paths = renew_resnet_paths(resnet_1) | |
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) | |
old_path = "middle_block.1" | |
new_path = "mid_block.attentions.0" | |
assign_attention_to_checkpoint( | |
new_checkpoint=new_checkpoint, | |
unet_state_dict=unet_state_dict, | |
old_path=old_path, | |
new_path=new_path, | |
config=config, | |
) | |
attentions_paths = renew_attention_paths(attentions) | |
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} | |
assign_to_checkpoint( | |
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
) | |
if not isinstance(config["layers_per_block"], int): | |
layers_per_block_list = list(reversed([e + 1 for e in config["layers_per_block"]])) | |
layers_per_block_cumsum = list(np.cumsum(layers_per_block_list)) | |
for i in range(num_output_blocks): | |
if isinstance(config["layers_per_block"], int): | |
layers_per_block = config["layers_per_block"] | |
block_id = i // (layers_per_block + 1) | |
layer_in_block_id = i % (layers_per_block + 1) | |
else: | |
block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if i < n) | |
passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0 | |
layer_in_block_id = i - passed_blocks | |
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] | |
output_block_list = {} | |
for layer in output_block_layers: | |
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) | |
if layer_id in output_block_list: | |
output_block_list[layer_id].append(layer_name) | |
else: | |
output_block_list[layer_id] = [layer_name] | |
# len(output_block_list) == 1 -> resnet | |
# len(output_block_list) == 2 -> resnet, attention or resnet, upscale resnet | |
# len(output_block_list) == 3 -> resnet, attention, upscale resnet | |
if len(output_block_list) > 1: | |
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] | |
has_attention = True | |
if len(output_block_list) == 2 and any("in_layers" in k for k in output_block_list["1"]): | |
has_attention = False | |
maybe_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] | |
paths = renew_resnet_paths(resnets) | |
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} | |
assign_to_checkpoint( | |
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
) | |
output_block_list = {k: sorted(v) for k, v in output_block_list.items()} | |
if ["conv.bias", "conv.weight"] in output_block_list.values(): | |
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) | |
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ | |
f"output_blocks.{i}.{index}.conv.weight" | |
] | |
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ | |
f"output_blocks.{i}.{index}.conv.bias" | |
] | |
# this layer was no attention | |
has_attention = False | |
maybe_attentions = [] | |
if has_attention: | |
old_path = f"output_blocks.{i}.1" | |
new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}" | |
assign_attention_to_checkpoint( | |
new_checkpoint=new_checkpoint, | |
unet_state_dict=unet_state_dict, | |
old_path=old_path, | |
new_path=new_path, | |
config=config, | |
) | |
paths = renew_attention_paths(maybe_attentions) | |
meta_path = { | |
"old": old_path, | |
"new": new_path, | |
} | |
assign_to_checkpoint( | |
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
) | |
if len(output_block_list) == 3 or (not has_attention and len(maybe_attentions) > 0): | |
layer_id = len(output_block_list) - 1 | |
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.{layer_id}" in key] | |
paths = renew_resnet_paths(resnets) | |
meta_path = {"old": f"output_blocks.{i}.{layer_id}", "new": f"up_blocks.{block_id}.upsamplers.0"} | |
assign_to_checkpoint( | |
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | |
) | |
else: | |
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) | |
for path in resnet_0_paths: | |
old_path = ".".join(["output_blocks", str(i), path["old"]]) | |
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) | |
new_checkpoint[new_path] = unet_state_dict[old_path] | |
return new_checkpoint | |
def verify_param_count(orig_path, unet_diffusers_config): | |
if "-II-" in orig_path: | |
from deepfloyd_if.modules import IFStageII | |
if_II = IFStageII(device="cpu", dir_or_name=orig_path) | |
elif "-III-" in orig_path: | |
from deepfloyd_if.modules import IFStageIII | |
if_II = IFStageIII(device="cpu", dir_or_name=orig_path) | |
else: | |
assert f"Weird name. Should have -II- or -III- in path: {orig_path}" | |
unet = UNet2DConditionModel(**unet_diffusers_config) | |
# in params | |
assert_param_count(unet.time_embedding, if_II.model.time_embed) | |
assert_param_count(unet.conv_in, if_II.model.input_blocks[:1]) | |
# downblocks | |
assert_param_count(unet.down_blocks[0], if_II.model.input_blocks[1:4]) | |
assert_param_count(unet.down_blocks[1], if_II.model.input_blocks[4:7]) | |
assert_param_count(unet.down_blocks[2], if_II.model.input_blocks[7:11]) | |
if "-II-" in orig_path: | |
assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:17]) | |
assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[17:]) | |
if "-III-" in orig_path: | |
assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:15]) | |
assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[15:20]) | |
assert_param_count(unet.down_blocks[5], if_II.model.input_blocks[20:]) | |
# mid block | |
assert_param_count(unet.mid_block, if_II.model.middle_block) | |
# up block | |
if "-II-" in orig_path: | |
assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:6]) | |
assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[6:12]) | |
assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[12:16]) | |
assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[16:19]) | |
assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[19:]) | |
if "-III-" in orig_path: | |
assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:5]) | |
assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[5:10]) | |
assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[10:14]) | |
assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[14:18]) | |
assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[18:21]) | |
assert_param_count(unet.up_blocks[5], if_II.model.output_blocks[21:24]) | |
# out params | |
assert_param_count(unet.conv_norm_out, if_II.model.out[0]) | |
assert_param_count(unet.conv_out, if_II.model.out[2]) | |
# make sure all model architecture has same param count | |
assert_param_count(unet, if_II.model) | |
def assert_param_count(model_1, model_2): | |
count_1 = sum(p.numel() for p in model_1.parameters()) | |
count_2 = sum(p.numel() for p in model_2.parameters()) | |
assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}" | |
def superres_check_against_original(dump_path, unet_checkpoint_path): | |
model_path = dump_path | |
model = UNet2DConditionModel.from_pretrained(model_path) | |
model.to("cuda") | |
orig_path = unet_checkpoint_path | |
if "-II-" in orig_path: | |
from deepfloyd_if.modules import IFStageII | |
if_II_model = IFStageII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model | |
elif "-III-" in orig_path: | |
from deepfloyd_if.modules import IFStageIII | |
if_II_model = IFStageIII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model | |
batch_size = 1 | |
channels = model.in_channels // 2 | |
height = model.sample_size | |
width = model.sample_size | |
height = 1024 | |
width = 1024 | |
torch.manual_seed(0) | |
latents = torch.randn((batch_size, channels, height, width), device=model.device) | |
image_small = torch.randn((batch_size, channels, height // 4, width // 4), device=model.device) | |
interpolate_antialias = {} | |
if "antialias" in inspect.signature(F.interpolate).parameters: | |
interpolate_antialias["antialias"] = True | |
image_upscaled = F.interpolate( | |
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias | |
) | |
latent_model_input = torch.cat([latents, image_upscaled], dim=1).to(model.dtype) | |
t = torch.tensor([5], device=model.device).to(model.dtype) | |
seq_len = 64 | |
encoder_hidden_states = torch.randn((batch_size, seq_len, model.config.encoder_hid_dim), device=model.device).to( | |
model.dtype | |
) | |
fake_class_labels = torch.tensor([t], device=model.device).to(model.dtype) | |
with torch.no_grad(): | |
out = if_II_model(latent_model_input, t, aug_steps=fake_class_labels, text_emb=encoder_hidden_states) | |
if_II_model.to("cpu") | |
del if_II_model | |
import gc | |
torch.cuda.empty_cache() | |
gc.collect() | |
print(50 * "=") | |
with torch.no_grad(): | |
noise_pred = model( | |
sample=latent_model_input, | |
encoder_hidden_states=encoder_hidden_states, | |
class_labels=fake_class_labels, | |
timestep=t, | |
).sample | |
print("Out shape", noise_pred.shape) | |
print("Diff", (out - noise_pred).abs().sum()) | |
if __name__ == "__main__": | |
main(parse_args()) | |