import keras_cv import tensorflow as tf from diffusers import (AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel) from diffusers.pipelines.stable_diffusion.safety_checker import \ StableDiffusionSafetyChecker from transformers import CLIPTextModel from conversion_utils import populate_text_encoder, populate_unet PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4" REVISION = None NON_EMA_REVISION = None IMG_HEIGHT = IMG_WIDTH = 512 MAX_SEQ_LENGTH = 77 def initialize_pt_models(): """Initializes the separate models of Stable Diffusion from diffusers and downloads their pre-trained weights.""" pt_text_encoder = CLIPTextModel.from_pretrained( PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION ) pt_vae = AutoencoderKL.from_pretrained( PRETRAINED_CKPT, subfolder="vae", revision=REVISION ) pt_unet = UNet2DConditionModel.from_pretrained( PRETRAINED_CKPT, subfolder="unet", revision=NON_EMA_REVISION ) pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained( PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION ) return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker def initialize_tf_models(text_encoder_weights: str, unet_weights: str): """Initializes the separate models of Stable Diffusion from KerasCV and downloads their pre-trained weights.""" tf_sd_model = keras_cv.models.StableDiffusion( img_height=IMG_HEIGHT, img_width=IMG_WIDTH ) if text_encoder_weights is None: tf_text_encoder = tf_sd_model.text_encoder else: tf_text_encoder = keras_cv.models.stable_diffusion.TextEncoder( MAX_SEQ_LENGTH, download_weights=False ) tf_vae = tf_sd_model.image_encoder if unet_weights is None: tf_unet = tf_sd_model.diffusion_model else: tf_unet = keras_cv.models.stable_diffusion.DiffusionModel( IMG_HEIGHT, IMG_WIDTH, MAX_SEQ_LENGTH, download_weights=False ) return tf_sd_model, tf_text_encoder, tf_vae, tf_unet def run_conversion(text_encoder_weights: str = None, unet_weights: str = None): pt_text_encoder, pt_vae, pt_unet, pt_safety_checker = initialize_pt_models() tf_sd_model, tf_text_encoder, tf_vae, tf_unet = initialize_tf_models( text_encoder_weights, unet_weights ) print("Pre-trained model weights downloaded.") if text_encoder_weights is not None: print("Loading fine-tuned text encoder weights.") text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights) tf_text_encoder.load_weights(text_encoder_weights_path) text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder) pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf) print("Populated PT text encoder from TF weights.") if unet_weights is not None: print("Loading fine-tuned UNet weights.") unet_weights_path = tf.keras.utils.get_file(origin=unet_weights) tf_unet.load_weights(unet_weights_path) unet_state_dict_from_tf = populate_unet(tf_unet) pt_unet.load_state_dict(unet_state_dict_from_tf) print("Populated PT UNet from TF weights.") print("Weights ported, preparing StabelDiffusionPipeline...") pipeline = StableDiffusionPipeline.from_pretrained( PRETRAINED_CKPT, unet=pt_unet, text_encoder=pt_text_encoder, vae=pt_vae, safety_checker=pt_safety_checker, revision=None, ) return pipeline