import os import base64 import gradio as gr from PIL import Image from src.util import * from io import BytesIO from src.pipelines import * from threading import Thread from dash import Dash, dcc, html, Input, Output, no_update, callback app = Dash(__name__) app.layout = html.Div( className="container", children=[ dcc.Graph( id="graph", figure=fig, clear_on_unhover=True, style={"height": "90vh"} ), dcc.Tooltip(id="tooltip"), html.Div(id="word-emb-txt", style={"background-color": "white"}), html.Div(id="word-emb-vis"), html.Div( [ html.Button(id="btn-download-image", hidden=True), dcc.Download(id="download-image"), ] ), ], ) @callback( Output("tooltip", "show"), Output("tooltip", "bbox"), Output("tooltip", "children"), Output("tooltip", "direction"), Output("word-emb-txt", "children"), Output("word-emb-vis", "children"), Input("graph", "hoverData"), ) def display_hover(hoverData): if hoverData is None: return False, no_update, no_update, no_update, no_update, no_update hover_data = hoverData["points"][0] bbox = hover_data["bbox"] direction = "left" index = hover_data["pointNumber"] children = [ html.Img( src=images[index], style={"width": "250px"}, ), html.P( hover_data["text"], style={ "color": "black", "font-size": "20px", "text-align": "center", "background-color": "white", "margin": "5px", }, ), ] emb_children = [ html.Img( src=generate_word_emb_vis(hover_data["text"]), style={"width": "100%", "height": "25px"}, ), ] return True, bbox, children, direction, hover_data["text"], emb_children @callback( Output("download-image", "data"), Input("graph", "clickData"), ) def download_image(clickData): if clickData is None: return no_update click_data = clickData["points"][0] index = click_data["pointNumber"] txt = click_data["text"] img_encoded = images[index] img_decoded = base64.b64decode(img_encoded.split(",")[1]) img = Image.open(BytesIO(img_decoded)) img.save(f"{txt}.png") return dcc.send_file(f"{txt}.png") with gr.Blocks() as demo: gr.Markdown("## Stable Diffusion Demo") with gr.Tab("Latent Space"): with gr.TabItem("Beginner"): gr.Markdown("Generate images from text.") with gr.Row(): with gr.Column(): prompt_beginner = gr.Textbox( lines=1, label="Prompt", value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", ) with gr.Row(): seed_beginner = gr.Slider( minimum=0, maximum=100, step=1, value=14, label="Seed" ) seed_vis_beginner = gr.Plot( value=generate_seed_vis(14), label="Seed" ) generate_images_button_beginner = gr.Button("Generate Image") with gr.Column(): images_output_beginner = gr.Image(label="Image") @generate_images_button_beginner.click( inputs=[prompt_beginner, seed_beginner], outputs=[images_output_beginner], ) def generate_images_wrapper( prompt, seed, progress=gr.Progress() ): images, _ = display_poke_images( prompt, seed, num_inference_steps=8, poke=False, intermediate=False ) return images seed_beginner.change( fn=generate_seed_vis, inputs=[seed_beginner], outputs=[seed_vis_beginner] ) with gr.TabItem("Denoising"): gr.Markdown("Observe the intermediate images during denoising.") gr.HTML(read_html("html/denoising.html")) with gr.Row(): with gr.Column(): prompt_denoise = gr.Textbox( lines=1, label="Prompt", value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", ) num_inference_steps_denoise = gr.Slider( minimum=2, maximum=100, step=1, value=8, label="Number of Inference Steps", ) with gr.Row(): seed_denoise = gr.Slider( minimum=0, maximum=100, step=1, value=14, label="Seed" ) seed_vis_denoise = gr.Plot( value=generate_seed_vis(14), label="Seed" ) generate_images_button_denoise = gr.Button("Generate Images") with gr.Column(): images_output_denoise = gr.Gallery(label="Images", selected_index=0, height=512) gif_denoise = gr.Image(label="GIF") zip_output_denoise = gr.File(label="Download ZIP") @generate_images_button_denoise.click( inputs=[prompt_denoise, seed_denoise, num_inference_steps_denoise], outputs=[images_output_denoise, gif_denoise, zip_output_denoise], ) def generate_images_wrapper( prompt, seed, num_inference_steps, progress=gr.Progress() ): images, _ = display_poke_images( prompt, seed, num_inference_steps, poke=False, intermediate=True ) fname = "denoising" tab_config = { "Tab": "Denoising", "Prompt": prompt, "Number of Inference Steps": num_inference_steps, "Seed": seed, } export_as_zip(images, fname, tab_config) progress(1, desc="Exporting as gif") export_as_gif(images, filename="denoising.gif") return images, "outputs/denoising.gif", f"outputs/{fname}.zip" seed_denoise.change( fn=generate_seed_vis, inputs=[seed_denoise], outputs=[seed_vis_denoise] ) with gr.TabItem("Seeds"): gr.Markdown( "Understand how different starting points in latent space can lead to different images." ) gr.HTML(read_html("html/seeds.html")) with gr.Row(): with gr.Column(): prompt_seed = gr.Textbox( lines=1, label="Prompt", value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", ) num_images_seed = gr.Slider( minimum=1, maximum=100, step=1, value=5, label="Number of Seeds" ) num_inference_steps_seed = gr.Slider( minimum=2, maximum=100, step=1, value=8, label="Number of Inference Steps per Image", ) generate_images_button_seed = gr.Button("Generate Images") with gr.Column(): images_output_seed = gr.Gallery(label="Images", selected_index=0, height=512) zip_output_seed = gr.File(label="Download ZIP") generate_images_button_seed.click( fn=display_seed_images, inputs=[prompt_seed, num_inference_steps_seed, num_images_seed], outputs=[images_output_seed, zip_output_seed], ) with gr.TabItem("Perturbations"): gr.Markdown("Explore different perturbations from a point in latent space.") gr.HTML(read_html("html/perturbations.html")) with gr.Row(): with gr.Column(): prompt_perturb = gr.Textbox( lines=1, label="Prompt", value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", ) num_images_perturb = gr.Slider( minimum=0, maximum=100, step=1, value=5, label="Number of Perturbations", ) perturbation_size_perturb = gr.Slider( minimum=0, maximum=1, step=0.1, value=0.1, label="Perturbation Size", ) num_inference_steps_perturb = gr.Slider( minimum=2, maximum=100, step=1, value=8, label="Number of Inference Steps per Image", ) with gr.Row(): seed_perturb = gr.Slider( minimum=0, maximum=100, step=1, value=14, label="Seed" ) seed_vis_perturb = gr.Plot( value=generate_seed_vis(14), label="Seed" ) generate_images_button_perturb = gr.Button("Generate Images") with gr.Column(): images_output_perturb = gr.Gallery(label="Image", selected_index=0, height=512) zip_output_perturb = gr.File(label="Download ZIP") generate_images_button_perturb.click( fn=display_perturb_images, inputs=[ prompt_perturb, seed_perturb, num_inference_steps_perturb, num_images_perturb, perturbation_size_perturb, ], outputs=[images_output_perturb, zip_output_perturb], ) seed_perturb.change( fn=generate_seed_vis, inputs=[seed_perturb], outputs=[seed_vis_perturb] ) with gr.TabItem("Circular"): gr.Markdown( "Generate a circular path in latent space and observe how the images vary along the path." ) gr.HTML(read_html("html/circular.html")) with gr.Row(): with gr.Column(): prompt_circular = gr.Textbox( lines=1, label="Prompt", value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", ) num_images_circular = gr.Slider( minimum=2, maximum=100, step=1, value=5, label="Number of Steps around the Circle", ) with gr.Row(): start_degree_circular = gr.Slider( minimum=0, maximum=360, step=1, value=0, label="Start Angle", info="Enter the value in degrees", ) end_degree_circular = gr.Slider( minimum=0, maximum=360, step=1, value=360, label="End Angle", info="Enter the value in degrees", ) step_size_circular = gr.Textbox( label="Step Size", value=180 / 4 ) num_inference_steps_circular = gr.Slider( minimum=2, maximum=100, step=1, value=8, label="Number of Inference Steps per Image", ) with gr.Row(): seed_circular = gr.Slider( minimum=0, maximum=100, step=1, value=14, label="Seed" ) seed_vis_circular = gr.Plot( value=generate_seed_vis(14), label="Seed" ) generate_images_button_circular = gr.Button("Generate Images") with gr.Column(): images_output_circular = gr.Gallery(label="Image", selected_index=0) gif_circular = gr.Image(label="GIF") zip_output_circular = gr.File(label="Download ZIP") num_images_circular.change( fn=calculate_step_size, inputs=[num_images_circular, start_degree_circular, end_degree_circular], outputs=[step_size_circular], ) start_degree_circular.change( fn=calculate_step_size, inputs=[num_images_circular, start_degree_circular, end_degree_circular], outputs=[step_size_circular], ) end_degree_circular.change( fn=calculate_step_size, inputs=[num_images_circular, start_degree_circular, end_degree_circular], outputs=[step_size_circular], ) generate_images_button_circular.click( fn=display_circular_images, inputs=[ prompt_circular, seed_circular, num_inference_steps_circular, num_images_circular, start_degree_circular, end_degree_circular, ], outputs=[images_output_circular, gif_circular, zip_output_circular], ) seed_circular.change( fn=generate_seed_vis, inputs=[seed_circular], outputs=[seed_vis_circular] ) with gr.TabItem("Poke"): gr.Markdown("Perturb a region in the image and observe the effect.") gr.HTML(read_html("html/poke.html")) with gr.Row(): with gr.Column(): prompt_poke = gr.Textbox( lines=1, label="Prompt", value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", ) num_inference_steps_poke = gr.Slider( minimum=2, maximum=100, step=1, value=8, label="Number of Inference Steps per Image", ) with gr.Row(): seed_poke = gr.Slider( minimum=0, maximum=100, step=1, value=14, label="Seed" ) seed_vis_poke = gr.Plot( value=generate_seed_vis(14), label="Seed" ) pokeX = gr.Slider( label="pokeX", minimum=0, maximum=64, step=1, value=32, info="X coordinate of poke center", ) pokeY = gr.Slider( label="pokeY", minimum=0, maximum=64, step=1, value=32, info="Y coordinate of poke center", ) pokeHeight = gr.Slider( label="pokeHeight", minimum=0, maximum=64, step=1, value=8, info="Height of the poke", ) pokeWidth = gr.Slider( label="pokeWidth", minimum=0, maximum=64, step=1, value=8, info="Width of the poke", ) generate_images_button_poke = gr.Button("Generate Images") with gr.Column(): original_images_output_poke = gr.Image( value=visualize_poke(32, 32, 8, 8)[0], label="Original Image" ) poked_images_output_poke = gr.Image( value=visualize_poke(32, 32, 8, 8)[1], label="Poked Image" ) zip_output_poke = gr.File(label="Download ZIP") pokeX.change( visualize_poke, inputs=[pokeX, pokeY, pokeHeight, pokeWidth], outputs=[original_images_output_poke, poked_images_output_poke], ) pokeY.change( visualize_poke, inputs=[pokeX, pokeY, pokeHeight, pokeWidth], outputs=[original_images_output_poke, poked_images_output_poke], ) pokeHeight.change( visualize_poke, inputs=[pokeX, pokeY, pokeHeight, pokeWidth], outputs=[original_images_output_poke, poked_images_output_poke], ) pokeWidth.change( visualize_poke, inputs=[pokeX, pokeY, pokeHeight, pokeWidth], outputs=[original_images_output_poke, poked_images_output_poke], ) seed_poke.change( fn=generate_seed_vis, inputs=[seed_poke], outputs=[seed_vis_poke] ) @generate_images_button_poke.click( inputs=[ prompt_poke, seed_poke, num_inference_steps_poke, pokeX, pokeY, pokeHeight, pokeWidth, ], outputs=[ original_images_output_poke, poked_images_output_poke, zip_output_poke, ], ) def generate_images_wrapper( prompt, seed, num_inference_steps, pokeX=pokeX, pokeY=pokeY, pokeHeight=pokeHeight, pokeWidth=pokeWidth, ): _, _ = display_poke_images( prompt, seed, num_inference_steps, poke=True, pokeX=pokeX, pokeY=pokeY, pokeHeight=pokeHeight, pokeWidth=pokeWidth, intermediate=False, ) images, modImages = visualize_poke(pokeX, pokeY, pokeHeight, pokeWidth) fname = "poke" tab_config = { "Tab": "Poke", "Prompt": prompt, "Number of Inference Steps per Image": num_inference_steps, "Seed": seed, "PokeX": pokeX, "PokeY": pokeY, "PokeHeight": pokeHeight, "PokeWidth": pokeWidth, } imgs_list = [] imgs_list.append((images, "Original Image")) imgs_list.append((modImages, "Poked Image")) export_as_zip(imgs_list, fname, tab_config) return images, modImages, f"outputs/{fname}.zip" with gr.TabItem("Guidance"): gr.Markdown("Observe the effect of different guidance scales.") gr.HTML(read_html("html/guidance.html")) with gr.Row(): with gr.Column(): prompt_guidance = gr.Textbox( lines=1, label="Prompt", value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", ) num_inference_steps_guidance = gr.Slider( minimum=2, maximum=100, step=1, value=8, label="Number of Inference Steps per Image", ) guidance_scale_values = gr.Textbox( lines=1, value="1, 8, 20, 30", label="Guidance Scale Values" ) with gr.Row(): seed_guidance = gr.Slider( minimum=0, maximum=100, step=1, value=14, label="Seed" ) seed_vis_guidance = gr.Plot( value=generate_seed_vis(14), label="Seed" ) generate_images_button_guidance = gr.Button("Generate Images") with gr.Column(): images_output_guidance = gr.Gallery( label="Images", selected_index=0, height=512, ) zip_output_guidance = gr.File(label="Download ZIP") generate_images_button_guidance.click( fn=display_guidance_images, inputs=[ prompt_guidance, seed_guidance, num_inference_steps_guidance, guidance_scale_values, ], outputs=[images_output_guidance, zip_output_guidance], ) seed_guidance.change( fn=generate_seed_vis, inputs=[seed_guidance], outputs=[seed_vis_guidance] ) with gr.TabItem("Inpainting"): gr.Markdown("Inpaint the image based on the prompt.") gr.HTML(read_html("html/inpainting.html")) with gr.Row(): with gr.Column(): uploaded_img_inpaint = gr.Sketchpad( sources="upload", brush=gr.Brush(colors=["#ffff00"]), type="pil", label="Upload" ) prompt_inpaint = gr.Textbox( lines=1, label="Prompt", value="sunglasses" ) num_inference_steps_inpaint = gr.Slider( minimum=2, maximum=100, step=1, value=8, label="Number of Inference Steps per Image", ) with gr.Row(): seed_inpaint = gr.Slider( minimum=0, maximum=100, step=1, value=14, label="Seed" ) seed_vis_inpaint = gr.Plot( value=generate_seed_vis(14), label="Seed" ) inpaint_button = gr.Button("Inpaint") with gr.Column(): images_output_inpaint = gr.Image(label="Output") zip_output_inpaint = gr.File(label="Download ZIP") inpaint_button.click( fn=inpaint, inputs=[ uploaded_img_inpaint, num_inference_steps_inpaint, seed_inpaint, prompt_inpaint, ], outputs=[images_output_inpaint, zip_output_inpaint], ) seed_inpaint.change( fn=generate_seed_vis, inputs=[seed_inpaint], outputs=[seed_vis_inpaint] ) with gr.Tab("CLIP Space"): with gr.TabItem("Embeddings"): gr.Markdown( "Visualize text embedding space in 3D with input texts and output images based on the chosen axis." ) gr.HTML(read_html("html/embeddings.html")) with gr.Row(): output = gr.HTML( f""" """ ) with gr.Row(): word2add_rem = gr.Textbox(lines=1, label="Add/Remove word") word2change = gr.Textbox(lines=1, label="Change image for word") clear_words_button = gr.Button(value="Clear words") with gr.Accordion("Custom Semantic Dimensions", open=False): with gr.Row(): axis_name_1 = gr.Textbox(label="Axis name", value="gender") which_axis_1 = gr.Dropdown( choices=["X - Axis", "Y - Axis", "Z - Axis", "---"], value=whichAxisMap["which_axis_1"], label="Axis direction", ) from_words_1 = gr.Textbox( lines=1, label="Positive", value="prince husband father son uncle", ) to_words_1 = gr.Textbox( lines=1, label="Negative", value="princess wife mother daughter aunt", ) submit_1 = gr.Button("Submit") with gr.Row(): axis_name_2 = gr.Textbox(label="Axis name", value="age") which_axis_2 = gr.Dropdown( choices=["X - Axis", "Y - Axis", "Z - Axis", "---"], value=whichAxisMap["which_axis_2"], label="Axis direction", ) from_words_2 = gr.Textbox( lines=1, label="Positive", value="man woman king queen father" ) to_words_2 = gr.Textbox( lines=1, label="Negative", value="boy girl prince princess son" ) submit_2 = gr.Button("Submit") with gr.Row(): axis_name_3 = gr.Textbox(label="Axis name", value="residual") which_axis_3 = gr.Dropdown( choices=["X - Axis", "Y - Axis", "Z - Axis", "---"], value=whichAxisMap["which_axis_3"], label="Axis direction", ) from_words_3 = gr.Textbox(lines=1, label="Positive") to_words_3 = gr.Textbox(lines=1, label="Negative") submit_3 = gr.Button("Submit") with gr.Row(): axis_name_4 = gr.Textbox(label="Axis name", value="number") which_axis_4 = gr.Dropdown( choices=["X - Axis", "Y - Axis", "Z - Axis", "---"], value=whichAxisMap["which_axis_4"], label="Axis direction", ) from_words_4 = gr.Textbox( lines=1, label="Positive", value="boys girls cats puppies computers", ) to_words_4 = gr.Textbox( lines=1, label="Negative", value="boy girl cat puppy computer" ) submit_4 = gr.Button("Submit") with gr.Row(): axis_name_5 = gr.Textbox(label="Axis name", value="royalty") which_axis_5 = gr.Dropdown( choices=["X - Axis", "Y - Axis", "Z - Axis", "---"], value=whichAxisMap["which_axis_5"], label="Axis direction", ) from_words_5 = gr.Textbox( lines=1, label="Positive", value="king queen prince princess duchess", ) to_words_5 = gr.Textbox( lines=1, label="Negative", value="man woman boy girl woman" ) submit_5 = gr.Button("Submit") with gr.Row(): axis_name_6 = gr.Textbox(label="Axis name") which_axis_6 = gr.Dropdown( choices=["X - Axis", "Y - Axis", "Z - Axis", "---"], value=whichAxisMap["which_axis_6"], label="Axis direction", ) from_words_6 = gr.Textbox(lines=1, label="Positive") to_words_6 = gr.Textbox(lines=1, label="Negative") submit_6 = gr.Button("Submit") @word2add_rem.submit(inputs=[word2add_rem], outputs=[output, word2add_rem]) def add_rem_word_and_clear(words): return add_rem_word(words), "" @word2change.submit(inputs=[word2change], outputs=[output, word2change]) def change_word_and_clear(word): return change_word(word), "" clear_words_button.click(fn=clear_words, outputs=[output]) @submit_1.click( inputs=[axis_name_1, which_axis_1, from_words_1, to_words_1], outputs=[ output, which_axis_2, which_axis_3, which_axis_4, which_axis_5, which_axis_6, ], ) def set_axis_wrapper(axis_name, which_axis, from_words, to_words): for ax in whichAxisMap: if whichAxisMap[ax] == which_axis: whichAxisMap[ax] = "---" whichAxisMap["which_axis_1"] = which_axis return ( set_axis(axis_name, which_axis, from_words, to_words), whichAxisMap["which_axis_2"], whichAxisMap["which_axis_3"], whichAxisMap["which_axis_4"], whichAxisMap["which_axis_5"], whichAxisMap["which_axis_6"], ) @submit_2.click( inputs=[axis_name_2, which_axis_2, from_words_2, to_words_2], outputs=[ output, which_axis_1, which_axis_3, which_axis_4, which_axis_5, which_axis_6, ], ) def set_axis_wrapper(axis_name, which_axis, from_words, to_words): for ax in whichAxisMap: if whichAxisMap[ax] == which_axis: whichAxisMap[ax] = "---" whichAxisMap["which_axis_2"] = which_axis return ( set_axis(axis_name, which_axis, from_words, to_words), whichAxisMap["which_axis_1"], whichAxisMap["which_axis_3"], whichAxisMap["which_axis_4"], whichAxisMap["which_axis_5"], whichAxisMap["which_axis_6"], ) @submit_3.click( inputs=[axis_name_3, which_axis_3, from_words_3, to_words_3], outputs=[ output, which_axis_1, which_axis_2, which_axis_4, which_axis_5, which_axis_6, ], ) def set_axis_wrapper(axis_name, which_axis, from_words, to_words): for ax in whichAxisMap: if whichAxisMap[ax] == which_axis: whichAxisMap[ax] = "---" whichAxisMap["which_axis_3"] = which_axis return ( set_axis(axis_name, which_axis, from_words, to_words), whichAxisMap["which_axis_1"], whichAxisMap["which_axis_2"], whichAxisMap["which_axis_4"], whichAxisMap["which_axis_5"], whichAxisMap["which_axis_6"], ) @submit_4.click( inputs=[axis_name_4, which_axis_4, from_words_4, to_words_4], outputs=[ output, which_axis_1, which_axis_2, which_axis_3, which_axis_5, which_axis_6, ], ) def set_axis_wrapper(axis_name, which_axis, from_words, to_words): for ax in whichAxisMap: if whichAxisMap[ax] == which_axis: whichAxisMap[ax] = "---" whichAxisMap["which_axis_4"] = which_axis return ( set_axis(axis_name, which_axis, from_words, to_words), whichAxisMap["which_axis_1"], whichAxisMap["which_axis_2"], whichAxisMap["which_axis_3"], whichAxisMap["which_axis_5"], whichAxisMap["which_axis_6"], ) @submit_5.click( inputs=[axis_name_5, which_axis_5, from_words_5, to_words_5], outputs=[ output, which_axis_1, which_axis_2, which_axis_3, which_axis_4, which_axis_6, ], ) def set_axis_wrapper(axis_name, which_axis, from_words, to_words): for ax in whichAxisMap: if whichAxisMap[ax] == which_axis: whichAxisMap[ax] = "---" whichAxisMap["which_axis_5"] = which_axis return ( set_axis(axis_name, which_axis, from_words, to_words), whichAxisMap["which_axis_1"], whichAxisMap["which_axis_2"], whichAxisMap["which_axis_3"], whichAxisMap["which_axis_4"], whichAxisMap["which_axis_6"], ) @submit_6.click( inputs=[axis_name_6, which_axis_6, from_words_6, to_words_6], outputs=[ output, which_axis_1, which_axis_2, which_axis_3, which_axis_4, which_axis_5, ], ) def set_axis_wrapper(axis_name, which_axis, from_words, to_words): for ax in whichAxisMap: if whichAxisMap[ax] == which_axis: whichAxisMap[ax] = "---" whichAxisMap["which_axis_6"] = which_axis return ( set_axis(axis_name, which_axis, from_words, to_words), whichAxisMap["which_axis_1"], whichAxisMap["which_axis_2"], whichAxisMap["which_axis_3"], whichAxisMap["which_axis_4"], whichAxisMap["which_axis_5"], ) with gr.TabItem("Interpolate"): gr.Markdown( "Interpolate between the first and the second prompt, and observe how the output changes." ) gr.HTML(read_html("html/interpolate.html")) with gr.Row(): with gr.Column(): promptA = gr.Textbox( lines=1, label="First Prompt", value="Self-portrait oil painting, a beautiful man with golden hair, 8k", ) promptB = gr.Textbox( lines=1, label="Second Prompt", value="Self-portrait oil painting, a beautiful woman with golden hair, 8k", ) num_images_interpolate = gr.Slider( minimum=0, maximum=100, step=1, value=5, label="Number of Interpolation Steps", ) num_inference_steps_interpolate = gr.Slider( minimum=2, maximum=100, step=1, value=8, label="Number of Inference Steps per Image", ) with gr.Row(): seed_interpolate = gr.Slider( minimum=0, maximum=100, step=1, value=14, label="Seed" ) seed_vis_interpolate = gr.Plot( value=generate_seed_vis(14), label="Seed" ) generate_images_button_interpolate = gr.Button("Generate Images") with gr.Column(): images_output_interpolate = gr.Gallery( label="Interpolated Images", selected_index=0, height=512, ) gif_interpolate = gr.Image(label="GIF") zip_output_interpolate = gr.File(label="Download ZIP") generate_images_button_interpolate.click( fn=display_interpolate_images, inputs=[ seed_interpolate, promptA, promptB, num_inference_steps_interpolate, num_images_interpolate, ], outputs=[ images_output_interpolate, gif_interpolate, zip_output_interpolate, ], ) seed_interpolate.change( fn=generate_seed_vis, inputs=[seed_interpolate], outputs=[seed_vis_interpolate], ) with gr.TabItem("Negative"): gr.Markdown("Observe the effect of negative prompts.") gr.HTML(read_html("html/negative.html")) with gr.Row(): with gr.Column(): prompt_negative = gr.Textbox( lines=1, label="Prompt", value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", ) neg_prompt = gr.Textbox( lines=1, label="Negative Prompt", value="Yellow" ) num_inference_steps_negative = gr.Slider( minimum=2, maximum=100, step=1, value=8, label="Number of Inference Steps per Image", ) with gr.Row(): seed_negative = gr.Slider( minimum=0, maximum=100, step=1, value=14, label="Seed" ) seed_vis_negative = gr.Plot( value=generate_seed_vis(14), label="Seed" ) generate_images_button_negative = gr.Button("Generate Images") with gr.Column(): images_output_negative = gr.Image( label="Image without Negative Prompt" ) images_neg_output_negative = gr.Image( label="Image with Negative Prompt" ) zip_output_negative = gr.File(label="Download ZIP") seed_negative.change( fn=generate_seed_vis, inputs=[seed_negative], outputs=[seed_vis_negative] ) generate_images_button_negative.click( fn=display_negative_images, inputs=[ prompt_negative, seed_negative, num_inference_steps_negative, neg_prompt, ], outputs=[ images_output_negative, images_neg_output_negative, zip_output_negative, ], ) with gr.Tab("Credits"): gr.Markdown(""" Author: Adithya Kameswara Rao, Carnegie Mellon University. Advisor: David S. Touretzky, Carnegie Mellon University. This work was funded by a grant from NEOM Company, and by National Science Foundation award IIS-2112633. """) def run_dash(): app.run(host="127.0.0.1", port="8000") # def run_gradio(): # demo.queue() # _, _, public_url = demo.launch(share=True) # return public_url if __name__ == "__main__": thread = Thread(target=run_dash) thread.daemon = True thread.start() try: os.makedirs("outputs", exist_ok=True) demo.queue().launch(share=True) except KeyboardInterrupt: print("Server closed")