mwitiderrick commited on
Commit
ec34a12
1 Parent(s): ebf1213

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import from_pretrained_keras
2
+ import keras_cv
3
+ import gradio as gr
4
+ from tensorflow import keras
5
+
6
+ keras.mixed_precision.set_global_policy("mixed_float16")
7
+ # load keras model
8
+ resolution = 512
9
+ dreambooth_model = keras_cv.models.StableDiffusion(
10
+ img_width=resolution, img_height=resolution, jit_compile=True,
11
+ )
12
+ loaded_diffusion_model = from_pretrained_keras("keras-dreambooth/living_room_dreambooth_diffusion_model")
13
+ dreambooth_model._diffusion_model = loaded_diffusion_model
14
+
15
+
16
+ def generate_images(prompt: str, negative_prompt:str, num_imgs_to_gen: int, num_steps: int):
17
+ """
18
+ This function is used to generate images using our fine-tuned keras dreambooth stable diffusion model.
19
+ Args:
20
+ prompt (str): The text input given by the user based on which images will be generated.
21
+ num_imgs_to_gen (int): The number of images to be generated using given prompt.
22
+ num_steps (int): The number of denoising steps
23
+ Returns:
24
+ generated_img (List): List of images that were generated using the model
25
+ """
26
+ generated_img = dreambooth_model.text_to_image(
27
+ prompt,
28
+ negative_prompt=negative_prompt,
29
+ batch_size=num_imgs_to_gen,
30
+ num_steps=num_steps,
31
+ )
32
+
33
+ return generated_img
34
+
35
+ with gr.Blocks() as demo:
36
+ gr.HTML("<h2 style=\"font-size: 2em; font-weight: bold\" align=\"center\">Keras Dreambooth - Dreambooth Fantasy Demo</h2>")
37
+ with gr.Row():
38
+ with gr.Column():
39
+ prompt = gr.Textbox(lines=1, value="sks living_room with maroon sofas", label="Base Prompt")
40
+ negative_prompt = gr.Textbox(lines=1, value="", label="Negative Prompt")
41
+ samples = gr.Slider(minimum=1, maximum=10, default=1, step=1, label="Number of Image")
42
+ num_steps = gr.Slider(label="Inference Steps",value=50)
43
+ run = gr.Button(value="Run")
44
+ with gr.Column():
45
+ gallery = gr.Gallery(label="Outputs").style(grid=(1,2))
46
+
47
+ run.click(generate_images, inputs=[prompt,negative_prompt, samples, num_steps], outputs=gallery)
48
+
49
+ gr.Examples([["A phot of sks living_room with maroon sofas","", 3, 75]],
50
+ [prompt,negative_prompt, samples,num_steps], gallery, generate_images)
51
+ gr.Markdown('\n Demo created by: <a href=\"https://huggingface.co/mwitiderrick/\">Derrick Mwiti</a>')
52
+
53
+ demo.launch(debug=True)