|
""" |
|
This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main. |
|
""" |
|
|
|
import argparse |
|
import os |
|
import tempfile |
|
|
|
import torch |
|
from lavis.models import load_model_and_preprocess |
|
from transformers import CLIPTokenizer |
|
from transformers.models.blip_2.configuration_blip_2 import Blip2Config |
|
|
|
from diffusers import ( |
|
AutoencoderKL, |
|
PNDMScheduler, |
|
UNet2DConditionModel, |
|
) |
|
from diffusers.pipelines import BlipDiffusionPipeline |
|
from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor |
|
from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel |
|
from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel |
|
|
|
|
|
BLIP2_CONFIG = { |
|
"vision_config": { |
|
"hidden_size": 1024, |
|
"num_hidden_layers": 23, |
|
"num_attention_heads": 16, |
|
"image_size": 224, |
|
"patch_size": 14, |
|
"intermediate_size": 4096, |
|
"hidden_act": "quick_gelu", |
|
}, |
|
"qformer_config": { |
|
"cross_attention_frequency": 1, |
|
"encoder_hidden_size": 1024, |
|
"vocab_size": 30523, |
|
}, |
|
"num_query_tokens": 16, |
|
} |
|
blip2config = Blip2Config(**BLIP2_CONFIG) |
|
|
|
|
|
def qformer_model_from_original_config(): |
|
qformer = Blip2QFormerModel(blip2config) |
|
return qformer |
|
|
|
|
|
def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix): |
|
embeddings = {} |
|
embeddings.update( |
|
{ |
|
f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[ |
|
f"{original_embeddings_prefix}.word_embeddings.weight" |
|
] |
|
} |
|
) |
|
embeddings.update( |
|
{ |
|
f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[ |
|
f"{original_embeddings_prefix}.position_embeddings.weight" |
|
] |
|
} |
|
) |
|
embeddings.update( |
|
{f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]} |
|
) |
|
embeddings.update( |
|
{f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]} |
|
) |
|
return embeddings |
|
|
|
|
|
def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix): |
|
proj_layer = {} |
|
proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]}) |
|
proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]}) |
|
proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]}) |
|
proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]}) |
|
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]}) |
|
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]}) |
|
return proj_layer |
|
|
|
|
|
def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix): |
|
attention = {} |
|
attention.update( |
|
{ |
|
f"{diffuser_attention_prefix}.attention.query.weight": model[ |
|
f"{original_attention_prefix}.self.query.weight" |
|
] |
|
} |
|
) |
|
attention.update( |
|
{f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]} |
|
) |
|
attention.update( |
|
{f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]} |
|
) |
|
attention.update( |
|
{f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]} |
|
) |
|
attention.update( |
|
{ |
|
f"{diffuser_attention_prefix}.attention.value.weight": model[ |
|
f"{original_attention_prefix}.self.value.weight" |
|
] |
|
} |
|
) |
|
attention.update( |
|
{f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]} |
|
) |
|
attention.update( |
|
{f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]} |
|
) |
|
attention.update( |
|
{f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]} |
|
) |
|
attention.update( |
|
{ |
|
f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[ |
|
f"{original_attention_prefix}.output.LayerNorm.weight" |
|
] |
|
} |
|
) |
|
attention.update( |
|
{ |
|
f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[ |
|
f"{original_attention_prefix}.output.LayerNorm.bias" |
|
] |
|
} |
|
) |
|
return attention |
|
|
|
|
|
def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix): |
|
output_layers = {} |
|
output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]}) |
|
output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]}) |
|
output_layers.update( |
|
{f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]} |
|
) |
|
output_layers.update( |
|
{f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]} |
|
) |
|
return output_layers |
|
|
|
|
|
def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix): |
|
encoder = {} |
|
for i in range(blip2config.qformer_config.num_hidden_layers): |
|
encoder.update( |
|
attention_from_original_checkpoint( |
|
model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention" |
|
) |
|
) |
|
encoder.update( |
|
attention_from_original_checkpoint( |
|
model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention" |
|
) |
|
) |
|
|
|
encoder.update( |
|
{ |
|
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[ |
|
f"{original_encoder_prefix}.{i}.intermediate.dense.weight" |
|
] |
|
} |
|
) |
|
encoder.update( |
|
{ |
|
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[ |
|
f"{original_encoder_prefix}.{i}.intermediate.dense.bias" |
|
] |
|
} |
|
) |
|
encoder.update( |
|
{ |
|
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[ |
|
f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight" |
|
] |
|
} |
|
) |
|
encoder.update( |
|
{ |
|
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[ |
|
f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias" |
|
] |
|
} |
|
) |
|
|
|
encoder.update( |
|
output_layers_from_original_checkpoint( |
|
model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output" |
|
) |
|
) |
|
encoder.update( |
|
output_layers_from_original_checkpoint( |
|
model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query" |
|
) |
|
) |
|
return encoder |
|
|
|
|
|
def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix): |
|
visual_encoder_layer = {} |
|
|
|
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]}) |
|
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]}) |
|
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]}) |
|
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]}) |
|
visual_encoder_layer.update( |
|
{f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]} |
|
) |
|
visual_encoder_layer.update( |
|
{f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]} |
|
) |
|
visual_encoder_layer.update( |
|
{f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]} |
|
) |
|
visual_encoder_layer.update( |
|
{f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]} |
|
) |
|
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]}) |
|
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]}) |
|
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]}) |
|
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]}) |
|
|
|
return visual_encoder_layer |
|
|
|
|
|
def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix): |
|
visual_encoder = {} |
|
|
|
visual_encoder.update( |
|
{ |
|
f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"] |
|
.unsqueeze(0) |
|
.unsqueeze(0) |
|
} |
|
) |
|
visual_encoder.update( |
|
{ |
|
f"{diffuser_prefix}.embeddings.position_embedding": model[ |
|
f"{original_prefix}.positional_embedding" |
|
].unsqueeze(0) |
|
} |
|
) |
|
visual_encoder.update( |
|
{f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]} |
|
) |
|
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]}) |
|
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]}) |
|
|
|
for i in range(blip2config.vision_config.num_hidden_layers): |
|
visual_encoder.update( |
|
visual_encoder_layer_from_original_checkpoint( |
|
model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}" |
|
) |
|
) |
|
|
|
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]}) |
|
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]}) |
|
|
|
return visual_encoder |
|
|
|
|
|
def qformer_original_checkpoint_to_diffusers_checkpoint(model): |
|
qformer_checkpoint = {} |
|
qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings")) |
|
qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]}) |
|
qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer")) |
|
qformer_checkpoint.update( |
|
encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer") |
|
) |
|
qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder")) |
|
return qformer_checkpoint |
|
|
|
|
|
def get_qformer(model): |
|
print("loading qformer") |
|
|
|
qformer = qformer_model_from_original_config() |
|
qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model) |
|
|
|
load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer) |
|
|
|
print("done loading qformer") |
|
return qformer |
|
|
|
|
|
def load_checkpoint_to_model(checkpoint, model): |
|
with tempfile.NamedTemporaryFile(delete=False) as file: |
|
torch.save(checkpoint, file.name) |
|
del checkpoint |
|
model.load_state_dict(torch.load(file.name), strict=False) |
|
|
|
os.remove(file.name) |
|
|
|
|
|
def save_blip_diffusion_model(model, args): |
|
qformer = get_qformer(model) |
|
qformer.eval() |
|
|
|
text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder") |
|
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") |
|
|
|
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") |
|
vae.eval() |
|
text_encoder.eval() |
|
scheduler = PNDMScheduler( |
|
beta_start=0.00085, |
|
beta_end=0.012, |
|
beta_schedule="scaled_linear", |
|
set_alpha_to_one=False, |
|
skip_prk_steps=True, |
|
) |
|
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer") |
|
image_processor = BlipImageProcessor() |
|
blip_diffusion = BlipDiffusionPipeline( |
|
tokenizer=tokenizer, |
|
text_encoder=text_encoder, |
|
vae=vae, |
|
unet=unet, |
|
scheduler=scheduler, |
|
qformer=qformer, |
|
image_processor=image_processor, |
|
) |
|
blip_diffusion.save_pretrained(args.checkpoint_path) |
|
|
|
|
|
def main(args): |
|
model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True) |
|
save_blip_diffusion_model(model.state_dict(), args) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") |
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|