sayakpaul HF staff commited on
Commit
214abe6
·
1 Parent(s): 3bd4a93

fix: PT text encoder init.

Browse files
Files changed (1) hide show
  1. convert.py +12 -10
convert.py CHANGED
@@ -15,7 +15,7 @@ IMG_HEIGHT = IMG_WIDTH = 512
15
  MAX_SEQ_LENGTH = 77
16
 
17
 
18
- def initialize_pt_models():
19
  """Initializes the separate models of Stable Diffusion from diffusers and downloads
20
  their pre-trained weights."""
21
  pt_text_encoder = CLIPTextModel.from_pretrained(
@@ -32,6 +32,16 @@ def initialize_pt_models():
32
  PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
33
  )
34
 
 
 
 
 
 
 
 
 
 
 
35
  return pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker
36
 
37
 
@@ -93,7 +103,7 @@ def run_conversion(
93
  pt_vae,
94
  pt_unet,
95
  pt_safety_checker,
96
- ) = initialize_pt_models()
97
  tf_text_encoder, tf_unet, tf_tokenizer = initialize_tf_models(
98
  text_encoder_weights, unet_weights, placeholder_token
99
  )
@@ -103,14 +113,6 @@ def run_conversion(
103
  print("Initializing a new text encoder with the placeholder token...")
104
  tf_text_encoder = create_new_text_encoder(tf_text_encoder, tf_tokenizer)
105
 
106
- print("Adding the placeholder token to PT CLIPTokenizer...")
107
- num_added_tokens = pt_tokenizer.add_tokens(placeholder_token)
108
- if num_added_tokens == 0:
109
- raise ValueError(
110
- f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
111
- " `placeholder_token` that is not already in the tokenizer."
112
- )
113
-
114
  if text_encoder_weights is not None:
115
  print("Loading fine-tuned text encoder weights.")
116
  text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
 
15
  MAX_SEQ_LENGTH = 77
16
 
17
 
18
+ def initialize_pt_models(placeholder_token: str):
19
  """Initializes the separate models of Stable Diffusion from diffusers and downloads
20
  their pre-trained weights."""
21
  pt_text_encoder = CLIPTextModel.from_pretrained(
 
32
  PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
33
  )
34
 
35
+ if placeholder_token is not None:
36
+ num_added_tokens = pt_tokenizer.add_tokens(placeholder_token)
37
+ if num_added_tokens == 0:
38
+ raise ValueError(
39
+ f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
40
+ " `placeholder_token` that is not already in the tokenizer."
41
+ )
42
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
43
+ pt_text_encoder.resize_token_embeddings(len(pt_tokenizer))
44
+
45
  return pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker
46
 
47
 
 
103
  pt_vae,
104
  pt_unet,
105
  pt_safety_checker,
106
+ ) = initialize_pt_models(populate_text_encoder)
107
  tf_text_encoder, tf_unet, tf_tokenizer = initialize_tf_models(
108
  text_encoder_weights, unet_weights, placeholder_token
109
  )
 
113
  print("Initializing a new text encoder with the placeholder token...")
114
  tf_text_encoder = create_new_text_encoder(tf_text_encoder, tf_tokenizer)
115
 
 
 
 
 
 
 
 
 
116
  if text_encoder_weights is not None:
117
  print("Loading fine-tuned text encoder weights.")
118
  text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)