File size: 3,667 Bytes
ddc8a59
 
3304f7d
 
 
 
 
ddc8a59
3304f7d
 
ddc8a59
 
 
 
 
 
3304f7d
ddc8a59
 
 
 
3304f7d
ddc8a59
 
 
 
 
 
 
 
 
 
 
 
 
3304f7d
ddc8a59
 
 
3304f7d
 
 
 
ddc8a59
3304f7d
 
ddc8a59
 
 
 
 
 
 
 
3304f7d
ddc8a59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3304f7d
 
 
ddc8a59
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import keras_cv
import tensorflow as tf
from diffusers import (AutoencoderKL, StableDiffusionPipeline,
                       UNet2DConditionModel)
from diffusers.pipelines.stable_diffusion.safety_checker import \
    StableDiffusionSafetyChecker
from transformers import CLIPTextModel

from conversion_utils import (populate_text_encoder, populate_unet,
                              run_assertion)

PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
REVISION = None
NON_EMA_REVISION = None
IMG_HEIGHT = IMG_WIDTH = 512


def initialize_pt_models():
    """Initializes the separate models of Stable Diffusion from diffusers and downloads
    their pre-trained weights."""
    pt_text_encoder = CLIPTextModel.from_pretrained(
        PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION
    )
    pt_vae = AutoencoderKL.from_pretrained(
        PRETRAINED_CKPT, subfolder="vae", revision=REVISION
    )
    pt_unet = UNet2DConditionModel.from_pretrained(
        PRETRAINED_CKPT, subfolder="unet", revision=NON_EMA_REVISION
    )
    pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
    )

    return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker


def initialize_tf_models():
    """Initializes the separate models of Stable Diffusion from KerasCV and downloads
    their pre-trained weights."""
    tf_sd_model = keras_cv.models.StableDiffusion(
        img_height=IMG_HEIGHT, img_width=IMG_WIDTH
    )
    _ = tf_sd_model.text_to_image("Cartoon")  # To download the weights.

    tf_text_encoder = tf_sd_model.text_encoder
    tf_vae = tf_sd_model.image_encoder
    tf_unet = tf_sd_model.diffusion_model
    return tf_sd_model, tf_text_encoder, tf_vae, tf_unet


def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
    pt_text_encoder, pt_vae, pt_unet, pt_safety_checker = initialize_pt_models()
    tf_sd_model, tf_text_encoder, tf_vae, tf_unet = initialize_tf_models()
    print("Pre-trained model weights downloaded.")

    if text_encoder_weights is not None:
        print("Loading fine-tuned text encoder weights.")
        text_encoder_weights_path = tf.keras.utils.get_file(text_encoder_weights)
        tf_text_encoder.load_weights(text_encoder_weights_path)
    if unet_weights is not None:
        print("Loading fine-tuned UNet weights.")
        unet_weights_path = tf.keras.utils.get_file(unet_weights)
        tf_unet.load_weights(unet_weights_path)

    text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
    unet_state_dict_from_tf = populate_unet(tf_unet)
    print("Conversion done, now running assertions...")

    # Since we cannot compare the fine-tuned weights.
    if text_encoder_weights is None:
        text_encoder_state_dict_from_pt = pt_text_encoder.state_dict()
        run_assertion(text_encoder_state_dict_from_pt, text_encoder_state_dict_from_tf)
    if unet_weights is None:
        unet_state_dict_from_pt = pt_text_encoder.state_dict()
        run_assertion(unet_state_dict_from_pt, unet_state_dict_from_tf)

    print(
        "Assertions successful, populating the converted parameters into the diffusers models..."
    )
    pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
    pt_unet.load_state_dict(unet_state_dict_from_tf)

    print("Parameters ported, preparing StabelDiffusionPipeline...")
    pipeline = StableDiffusionPipeline.from_pretrained(
        PRETRAINED_CKPT,
        unet=pt_unet,
        text_encoder=pt_text_encoder,
        vae=pt_vae,
        safety_checker=pt_safety_checker,
        revision=None,
    )
    return pipeline