File size: 3,486 Bytes
08a7b0a
 
 
 
 
 
 
 
 
 
924289b
08a7b0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adfe545
08a7b0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from huggingface_hub import from_pretrained_keras
from keras_cv import models
from tensorflow import keras
import tensorflow as tf
import gradio as gr


keras.mixed_precision.set_global_policy("mixed_float16")

stable_prompt_list = [
        "a photo of sshh shoe",
    ]

stable_negative_prompt_list = [
        "bad, ugly",
        "deformed"
    ]

def keras_stable_diffusion(
    model_path:str,
    prompt:str,
    negative_prompt:str,
    guidance_scale:int,
    num_inference_step:int,
    height:int,
    width:int,
    ):
        
    sd_dreambooth_model = models.StableDiffusion(
        img_width=height, 
        img_height=width
        )
    
    db_diffusion_model = from_pretrained_keras(model_path)
    sd_dreambooth_model._diffusion_model = db_diffusion_model

    generated_images = sd_dreambooth_model.text_to_image(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_steps=num_inference_step,
        unconditional_guidance_scale=guidance_scale
    )
    tf.keras.backend.clear_session()


    return generated_images

def keras_stable_diffusion_app():
    with gr.Blocks():
        with gr.Row():
            with gr.Column():
                keras_text2image_model_path = "ashishtanwer/shoe"

                keras_text2image_prompt = gr.Textbox(
                    lines=1, 
                    value=stable_prompt_list[0], 
                    label='Prompt'
                )

                keras_text2image_negative_prompt = gr.Textbox(
                    lines=1, 
                    value=stable_negative_prompt_list[0], 
                    label='Negative Prompt'
                )

                with gr.Accordion("Advanced Options", open=False):
                    keras_text2image_guidance_scale = gr.Slider(
                        minimum=0.1, 
                        maximum=15, 
                        step=0.1, 
                        value=7.5, 
                        label='Guidance Scale'
                    )

                    keras_text2image_num_inference_step = gr.Slider(
                        minimum=1, 
                        maximum=100, 
                        step=1, 
                        value=50, 
                        label='Num Inference Step'
                    )

                    keras_text2image_height = gr.Slider(
                        minimum=128, 
                        maximum=1280, 
                        step=32, 
                        value=512, 
                        label='Image Height'
                    )

                    keras_text2image_width = gr.Slider(
                        minimum=128, 
                        maximum=1280, 
                        step=32, 
                        value=512, 
                        label='Image Height'
                    )

                keras_text2image_predict = gr.Button(value='Generator')
    
            with gr.Column():
                output_image = gr.Gallery(label='Output')
                        
        keras_text2image_predict.click(
            fn=keras_stable_diffusion,
            inputs=[
                keras_text2image_model_path,
                keras_text2image_prompt,
                keras_text2image_negative_prompt,
                keras_text2image_guidance_scale,
                keras_text2image_num_inference_step,
                keras_text2image_height,
                keras_text2image_width
            ],
            outputs=output_image
        )