Spaces:
Build error
Build error
fix: PT text encoder init.
Browse files- convert.py +12 -10
convert.py
CHANGED
@@ -15,7 +15,7 @@ IMG_HEIGHT = IMG_WIDTH = 512
|
|
15 |
MAX_SEQ_LENGTH = 77
|
16 |
|
17 |
|
18 |
-
def initialize_pt_models():
|
19 |
"""Initializes the separate models of Stable Diffusion from diffusers and downloads
|
20 |
their pre-trained weights."""
|
21 |
pt_text_encoder = CLIPTextModel.from_pretrained(
|
@@ -32,6 +32,16 @@ def initialize_pt_models():
|
|
32 |
PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
|
33 |
)
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
return pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker
|
36 |
|
37 |
|
@@ -93,7 +103,7 @@ def run_conversion(
|
|
93 |
pt_vae,
|
94 |
pt_unet,
|
95 |
pt_safety_checker,
|
96 |
-
) = initialize_pt_models()
|
97 |
tf_text_encoder, tf_unet, tf_tokenizer = initialize_tf_models(
|
98 |
text_encoder_weights, unet_weights, placeholder_token
|
99 |
)
|
@@ -103,14 +113,6 @@ def run_conversion(
|
|
103 |
print("Initializing a new text encoder with the placeholder token...")
|
104 |
tf_text_encoder = create_new_text_encoder(tf_text_encoder, tf_tokenizer)
|
105 |
|
106 |
-
print("Adding the placeholder token to PT CLIPTokenizer...")
|
107 |
-
num_added_tokens = pt_tokenizer.add_tokens(placeholder_token)
|
108 |
-
if num_added_tokens == 0:
|
109 |
-
raise ValueError(
|
110 |
-
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
|
111 |
-
" `placeholder_token` that is not already in the tokenizer."
|
112 |
-
)
|
113 |
-
|
114 |
if text_encoder_weights is not None:
|
115 |
print("Loading fine-tuned text encoder weights.")
|
116 |
text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
|
|
|
15 |
MAX_SEQ_LENGTH = 77
|
16 |
|
17 |
|
18 |
+
def initialize_pt_models(placeholder_token: str):
|
19 |
"""Initializes the separate models of Stable Diffusion from diffusers and downloads
|
20 |
their pre-trained weights."""
|
21 |
pt_text_encoder = CLIPTextModel.from_pretrained(
|
|
|
32 |
PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
|
33 |
)
|
34 |
|
35 |
+
if placeholder_token is not None:
|
36 |
+
num_added_tokens = pt_tokenizer.add_tokens(placeholder_token)
|
37 |
+
if num_added_tokens == 0:
|
38 |
+
raise ValueError(
|
39 |
+
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
|
40 |
+
" `placeholder_token` that is not already in the tokenizer."
|
41 |
+
)
|
42 |
+
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
43 |
+
pt_text_encoder.resize_token_embeddings(len(pt_tokenizer))
|
44 |
+
|
45 |
return pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker
|
46 |
|
47 |
|
|
|
103 |
pt_vae,
|
104 |
pt_unet,
|
105 |
pt_safety_checker,
|
106 |
+
) = initialize_pt_models(populate_text_encoder)
|
107 |
tf_text_encoder, tf_unet, tf_tokenizer = initialize_tf_models(
|
108 |
text_encoder_weights, unet_weights, placeholder_token
|
109 |
)
|
|
|
113 |
print("Initializing a new text encoder with the placeholder token...")
|
114 |
tf_text_encoder = create_new_text_encoder(tf_text_encoder, tf_tokenizer)
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
if text_encoder_weights is not None:
|
117 |
print("Loading fine-tuned text encoder weights.")
|
118 |
text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
|