File size: 1,033 Bytes
7dee6cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5a259d
7dee6cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from huggingface_hub import from_pretrained_keras
from keras_cv import models
import gradio as gr

from tensorflow import keras

keras.mixed_precision.set_global_policy("mixed_float16")

# prepare model
resolution = 512
sd_dreambooth_model = models.StableDiffusion(
    img_width=resolution, img_height=resolution
)
db_diffusion_model = from_pretrained_keras("keras-dreambooth/nuthatch-bird-demo")
sd_dreambooth_model._diffusion_model = db_diffusion_model

# generate images
def infer(prompt):
    generated_images = sd_dreambooth_model.text_to_image(
        prompt, batch_size=2
    )
    return generated_images 
    
output = gr.Gallery(label="Outputs").style(grid=(1,2))

# customize interface
title = "Dreambooth Demo on Nuthatch bird Images"
description = "This is a dreambooth model fine-tuned on nuthatch bird images. To try it, input the concept with {sks bird}."
examples=[["sks bird flying"]]
gr.Interface(infer, inputs=["text"], outputs=[output], title=title, description=description, examples=examples).queue().launch()