sayakpaul HF staff commited on
Commit
6c04d23
1 Parent(s): 94913a9

improve runtime.

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. convert.py +16 -7
app.py CHANGED
@@ -27,7 +27,7 @@ def run(hf_token, text_encoder_weights, unet_weights, repo_prefix):
27
  text_encoder_weights = None
28
  if unet_weights == "":
29
  unet_weights = None
30
- print(f"unet_weights: {unet_weights}")
31
  pipeline = run_conversion(text_encoder_weights, unet_weights)
32
  output_path = "kerascv_sd_diffusers_pipeline"
33
  pipeline.save_pretrained(output_path)
 
27
  text_encoder_weights = None
28
  if unet_weights == "":
29
  unet_weights = None
30
+
31
  pipeline = run_conversion(text_encoder_weights, unet_weights)
32
  output_path = "kerascv_sd_diffusers_pipeline"
33
  pipeline.save_pretrained(output_path)
convert.py CHANGED
@@ -13,6 +13,7 @@ PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
13
  REVISION = None
14
  NON_EMA_REVISION = None
15
  IMG_HEIGHT = IMG_WIDTH = 512
 
16
 
17
 
18
  def initialize_pt_models():
@@ -34,17 +35,25 @@ def initialize_pt_models():
34
  return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker
35
 
36
 
37
- def initialize_tf_models():
38
  """Initializes the separate models of Stable Diffusion from KerasCV and downloads
39
  their pre-trained weights."""
40
  tf_sd_model = keras_cv.models.StableDiffusion(
41
  img_height=IMG_HEIGHT, img_width=IMG_WIDTH
42
  )
43
- _ = tf_sd_model.text_to_image("Cartoon") # To download the weights.
44
-
45
- tf_text_encoder = tf_sd_model.text_encoder
 
 
 
46
  tf_vae = tf_sd_model.image_encoder
47
- tf_unet = tf_sd_model.diffusion_model
 
 
 
 
 
48
  return tf_sd_model, tf_text_encoder, tf_vae, tf_unet
49
 
50
 
@@ -55,11 +64,11 @@ def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
55
 
56
  if text_encoder_weights is not None:
57
  print("Loading fine-tuned text encoder weights.")
58
- text_encoder_weights_path = tf.keras.utils.get_file(text_encoder_weights)
59
  tf_text_encoder.load_weights(text_encoder_weights_path)
60
  if unet_weights is not None:
61
  print("Loading fine-tuned UNet weights.")
62
- unet_weights_path = tf.keras.utils.get_file(unet_weights)
63
  tf_unet.load_weights(unet_weights_path)
64
 
65
  text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
 
13
  REVISION = None
14
  NON_EMA_REVISION = None
15
  IMG_HEIGHT = IMG_WIDTH = 512
16
+ MAX_SEQ_LENGTH = 77
17
 
18
 
19
  def initialize_pt_models():
 
35
  return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker
36
 
37
 
38
+ def initialize_tf_models(text_encoder_weights: str, unet_weights: str):
39
  """Initializes the separate models of Stable Diffusion from KerasCV and downloads
40
  their pre-trained weights."""
41
  tf_sd_model = keras_cv.models.StableDiffusion(
42
  img_height=IMG_HEIGHT, img_width=IMG_WIDTH
43
  )
44
+ if text_encoder_weights is None:
45
+ tf_text_encoder = tf_sd_model.text_encoder
46
+ else:
47
+ tf_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
48
+ MAX_SEQ_LENGTH, download_weights=False
49
+ )
50
  tf_vae = tf_sd_model.image_encoder
51
+ if unet_weights is None:
52
+ tf_unet = tf_sd_model.diffusion_model
53
+ else:
54
+ tf_unet = keras_cv.models.stable_diffusion.DiffusionModel(
55
+ IMG_HEIGHT, IMG_WIDTH, MAX_SEQ_LENGTH, download_weights=False
56
+ )
57
  return tf_sd_model, tf_text_encoder, tf_vae, tf_unet
58
 
59
 
 
64
 
65
  if text_encoder_weights is not None:
66
  print("Loading fine-tuned text encoder weights.")
67
+ text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
68
  tf_text_encoder.load_weights(text_encoder_weights_path)
69
  if unet_weights is not None:
70
  print("Loading fine-tuned UNet weights.")
71
+ unet_weights_path = tf.keras.utils.get_file(origin=unet_weights)
72
  tf_unet.load_weights(unet_weights_path)
73
 
74
  text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)