Dimitre commited on
Commit
943b9fb
1 Parent(s): dfe1993

App and requirement files

Browse files
Files changed (2) hide show
  1. app.py +57 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import gradio as gr
5
+ import keras_cv
6
+ import numpy as np
7
+ import tensorflow as tf
8
+
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__file__)
11
+ prompt_token = os.environ.get("TOKEN", "<token>")
12
+ text_encoder_path = os.environ.get(
13
+ "TEXT_ENCODER", "./models/example/text_encoder/keras"
14
+ )
15
+
16
+ logger.info(f'Inversed token used: "{prompt_token}"')
17
+ logger.info(f'Loading text encoder from: "{text_encoder_path}"')
18
+
19
+ stable_diffusion = keras_cv.models.StableDiffusion()
20
+ stable_diffusion.tokenizer.add_tokens(prompt_token)
21
+ loaded_text_encoder_ = tf.keras.models.load_model(text_encoder_path)
22
+ stable_diffusion._text_encoder = loaded_text_encoder_
23
+ stable_diffusion._text_encoder.compile(jit_compile=True)
24
+
25
+
26
+ def generate_fn(input_prompt: str) -> np.ndarray:
27
+ """Generates images from a text prompt
28
+
29
+ Args:
30
+ input_prompt (str): Text input prompt
31
+
32
+ Returns:
33
+ np.ndarray: Generated image
34
+ """
35
+ generated = stable_diffusion.text_to_image(
36
+ prompt=input_prompt, batch_size=1, num_steps=1
37
+ )
38
+ return generated[0]
39
+
40
+
41
+ iface = gr.Interface(
42
+ fn=generate_fn,
43
+ title="Textual Inversion",
44
+ description="Textual Inversion Demo",
45
+ article="Note: Keras-cv uses lazy intialization, so the first use will be slower while the model is initialized.",
46
+ inputs=gr.Textbox(
47
+ label="Prompt",
48
+ show_label=False,
49
+ max_lines=2,
50
+ placeholder="Enter your prompt",
51
+ elem_id="input-prompt",
52
+ ),
53
+ outputs=gr.Image(),
54
+ )
55
+
56
+ if __name__ == "__main__":
57
+ app, local_url, share_url = iface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy
2
+ gradio
3
+ keras-cv==0.4.2
4
+ tensorflow-datasets
5
+ tensorflow==2.11.0