Dimitre commited on
Commit
4ebc3ec
1 Parent(s): 943b9fb

Loading model with 'from_pretrained_keras'

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -5,20 +5,20 @@ import gradio as gr
5
  import keras_cv
6
  import numpy as np
7
  import tensorflow as tf
 
8
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__file__)
11
- prompt_token = os.environ.get("TOKEN", "<token>")
12
- text_encoder_path = os.environ.get(
13
- "TEXT_ENCODER", "./models/example/text_encoder/keras"
14
- )
15
 
16
  logger.info(f'Inversed token used: "{prompt_token}"')
17
- logger.info(f'Loading text encoder from: "{text_encoder_path}"')
18
 
19
  stable_diffusion = keras_cv.models.StableDiffusion()
20
  stable_diffusion.tokenizer.add_tokens(prompt_token)
21
- loaded_text_encoder_ = tf.keras.models.load_model(text_encoder_path)
 
22
  stable_diffusion._text_encoder = loaded_text_encoder_
23
  stable_diffusion._text_encoder.compile(jit_compile=True)
24
 
 
5
  import keras_cv
6
  import numpy as np
7
  import tensorflow as tf
8
+ from huggingface_hub import from_pretrained_keras
9
 
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__file__)
12
+ prompt_token = "<token>"
13
+ text_encoder_url = "Dimitre/stablediffusion-canarinho_pistola"
 
 
14
 
15
  logger.info(f'Inversed token used: "{prompt_token}"')
16
+ logger.info(f'Loading text encoder from: "{text_encoder_url}"')
17
 
18
  stable_diffusion = keras_cv.models.StableDiffusion()
19
  stable_diffusion.tokenizer.add_tokens(prompt_token)
20
+ # loaded_text_encoder_ = tf.keras.models.load_model(text_encoder_path)
21
+ loaded_text_encoder_ = from_pretrained_keras(text_encoder_url)
22
  stable_diffusion._text_encoder = loaded_text_encoder_
23
  stable_diffusion._text_encoder.compile(jit_compile=True)
24