from huggingface_hub import from_pretrained_keras from keras_cv import models import gradio as gr from tensorflow import keras keras.mixed_precision.set_global_policy("mixed_float16") # prepare model resolution = 512 sd_dreambooth_model = models.StableDiffusion( img_width=resolution, img_height=resolution ) db_diffusion_model = from_pretrained_keras("keras-dreambooth/nuthatch-bird-demo") sd_dreambooth_model._diffusion_model = db_diffusion_model # generate images def infer(prompt): generated_images = sd_dreambooth_model.text_to_image( prompt, batch_size=2 ) return generated_images output = gr.Gallery(label="Outputs").style(grid=(1,2)) # customize interface title = "Dreambooth Demo on Nuthatch bird Images" description = "This is a dreambooth model fine-tuned on nuthatch bird images. To try it, input the concept with {sks bird}." examples=[["sks bird flying"]] gr.Interface(infer, inputs=["text"], outputs=[output], title=title, description=description, examples=examples).queue().launch()