DiffusionDemo / app.py
akameswa's picture
Update app.py
ed8a0f3 verified
import os
import base64
import gradio as gr
from PIL import Image
from src.util import *
from io import BytesIO
from src.pipelines import *
from threading import Thread
from dash import Dash, dcc, html, Input, Output, no_update, callback
app = Dash(__name__)
app.layout = html.Div(
className="container",
children=[
dcc.Graph(
id="graph", figure=fig, clear_on_unhover=True, style={"height": "90vh"}
),
dcc.Tooltip(id="tooltip"),
html.Div(id="word-emb-txt", style={"background-color": "white"}),
html.Div(id="word-emb-vis"),
html.Div(
[
html.Button(id="btn-download-image", hidden=True),
dcc.Download(id="download-image"),
]
),
],
)
@callback(
Output("tooltip", "show"),
Output("tooltip", "bbox"),
Output("tooltip", "children"),
Output("tooltip", "direction"),
Output("word-emb-txt", "children"),
Output("word-emb-vis", "children"),
Input("graph", "hoverData"),
)
def display_hover(hoverData):
if hoverData is None:
return False, no_update, no_update, no_update, no_update, no_update
hover_data = hoverData["points"][0]
bbox = hover_data["bbox"]
direction = "left"
index = hover_data["pointNumber"]
children = [
html.Img(
src=images[index],
style={"width": "250px"},
),
html.P(
hover_data["text"],
style={
"color": "black",
"font-size": "20px",
"text-align": "center",
"background-color": "white",
"margin": "5px",
},
),
]
emb_children = [
html.Img(
src=generate_word_emb_vis(hover_data["text"]),
style={"width": "100%", "height": "25px"},
),
]
return True, bbox, children, direction, hover_data["text"], emb_children
@callback(
Output("download-image", "data"),
Input("graph", "clickData"),
)
def download_image(clickData):
if clickData is None:
return no_update
click_data = clickData["points"][0]
index = click_data["pointNumber"]
txt = click_data["text"]
img_encoded = images[index]
img_decoded = base64.b64decode(img_encoded.split(",")[1])
img = Image.open(BytesIO(img_decoded))
img.save(f"{txt}.png")
return dcc.send_file(f"{txt}.png")
with gr.Blocks() as demo:
gr.Markdown("## Stable Diffusion Demo")
with gr.Tab("Latent Space"):
with gr.TabItem("Beginner"):
gr.Markdown("Generate images from text.")
with gr.Row():
with gr.Column():
prompt_beginner = gr.Textbox(
lines=1,
label="Prompt",
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
)
with gr.Row():
seed_beginner = gr.Slider(
minimum=0, maximum=100, step=1, value=14, label="Seed"
)
seed_vis_beginner = gr.Plot(
value=generate_seed_vis(14), label="Seed"
)
generate_images_button_beginner = gr.Button("Generate Image")
with gr.Column():
images_output_beginner = gr.Image(label="Image")
@generate_images_button_beginner.click(
inputs=[prompt_beginner, seed_beginner],
outputs=[images_output_beginner],
)
def generate_images_wrapper(
prompt, seed, progress=gr.Progress()
):
images, _ = display_poke_images(
prompt, seed, num_inference_steps=8, poke=False, intermediate=False
)
return images
seed_beginner.change(
fn=generate_seed_vis, inputs=[seed_beginner], outputs=[seed_vis_beginner]
)
with gr.TabItem("Denoising"):
gr.Markdown("Observe the intermediate images during denoising.")
gr.HTML(read_html("html/denoising.html"))
with gr.Row():
with gr.Column():
prompt_denoise = gr.Textbox(
lines=1,
label="Prompt",
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
)
num_inference_steps_denoise = gr.Slider(
minimum=2,
maximum=100,
step=1,
value=8,
label="Number of Inference Steps",
)
with gr.Row():
seed_denoise = gr.Slider(
minimum=0, maximum=100, step=1, value=14, label="Seed"
)
seed_vis_denoise = gr.Plot(
value=generate_seed_vis(14), label="Seed"
)
generate_images_button_denoise = gr.Button("Generate Images")
with gr.Column():
images_output_denoise = gr.Gallery(label="Images", selected_index=0, height=512)
gif_denoise = gr.Image(label="GIF")
zip_output_denoise = gr.File(label="Download ZIP")
@generate_images_button_denoise.click(
inputs=[prompt_denoise, seed_denoise, num_inference_steps_denoise],
outputs=[images_output_denoise, gif_denoise, zip_output_denoise],
)
def generate_images_wrapper(
prompt, seed, num_inference_steps, progress=gr.Progress()
):
images, _ = display_poke_images(
prompt, seed, num_inference_steps, poke=False, intermediate=True
)
fname = "denoising"
tab_config = {
"Tab": "Denoising",
"Prompt": prompt,
"Number of Inference Steps": num_inference_steps,
"Seed": seed,
}
export_as_zip(images, fname, tab_config)
progress(1, desc="Exporting as gif")
export_as_gif(images, filename="denoising.gif")
return images, "outputs/denoising.gif", f"outputs/{fname}.zip"
seed_denoise.change(
fn=generate_seed_vis, inputs=[seed_denoise], outputs=[seed_vis_denoise]
)
with gr.TabItem("Seeds"):
gr.Markdown(
"Understand how different starting points in latent space can lead to different images."
)
gr.HTML(read_html("html/seeds.html"))
with gr.Row():
with gr.Column():
prompt_seed = gr.Textbox(
lines=1,
label="Prompt",
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
)
num_images_seed = gr.Slider(
minimum=1, maximum=100, step=1, value=5, label="Number of Seeds"
)
num_inference_steps_seed = gr.Slider(
minimum=2,
maximum=100,
step=1,
value=8,
label="Number of Inference Steps per Image",
)
generate_images_button_seed = gr.Button("Generate Images")
with gr.Column():
images_output_seed = gr.Gallery(label="Images", selected_index=0, height=512)
zip_output_seed = gr.File(label="Download ZIP")
generate_images_button_seed.click(
fn=display_seed_images,
inputs=[prompt_seed, num_inference_steps_seed, num_images_seed],
outputs=[images_output_seed, zip_output_seed],
)
with gr.TabItem("Perturbations"):
gr.Markdown("Explore different perturbations from a point in latent space.")
gr.HTML(read_html("html/perturbations.html"))
with gr.Row():
with gr.Column():
prompt_perturb = gr.Textbox(
lines=1,
label="Prompt",
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
)
num_images_perturb = gr.Slider(
minimum=0,
maximum=100,
step=1,
value=5,
label="Number of Perturbations",
)
perturbation_size_perturb = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.1,
label="Perturbation Size",
)
num_inference_steps_perturb = gr.Slider(
minimum=2,
maximum=100,
step=1,
value=8,
label="Number of Inference Steps per Image",
)
with gr.Row():
seed_perturb = gr.Slider(
minimum=0, maximum=100, step=1, value=14, label="Seed"
)
seed_vis_perturb = gr.Plot(
value=generate_seed_vis(14), label="Seed"
)
generate_images_button_perturb = gr.Button("Generate Images")
with gr.Column():
images_output_perturb = gr.Gallery(label="Image", selected_index=0, height=512)
zip_output_perturb = gr.File(label="Download ZIP")
generate_images_button_perturb.click(
fn=display_perturb_images,
inputs=[
prompt_perturb,
seed_perturb,
num_inference_steps_perturb,
num_images_perturb,
perturbation_size_perturb,
],
outputs=[images_output_perturb, zip_output_perturb],
)
seed_perturb.change(
fn=generate_seed_vis, inputs=[seed_perturb], outputs=[seed_vis_perturb]
)
with gr.TabItem("Circular"):
gr.Markdown(
"Generate a circular path in latent space and observe how the images vary along the path."
)
gr.HTML(read_html("html/circular.html"))
with gr.Row():
with gr.Column():
prompt_circular = gr.Textbox(
lines=1,
label="Prompt",
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
)
num_images_circular = gr.Slider(
minimum=2,
maximum=100,
step=1,
value=5,
label="Number of Steps around the Circle",
)
with gr.Row():
start_degree_circular = gr.Slider(
minimum=0,
maximum=360,
step=1,
value=0,
label="Start Angle",
info="Enter the value in degrees",
)
end_degree_circular = gr.Slider(
minimum=0,
maximum=360,
step=1,
value=360,
label="End Angle",
info="Enter the value in degrees",
)
step_size_circular = gr.Textbox(
label="Step Size", value=180 / 4
)
num_inference_steps_circular = gr.Slider(
minimum=2,
maximum=100,
step=1,
value=8,
label="Number of Inference Steps per Image",
)
with gr.Row():
seed_circular = gr.Slider(
minimum=0, maximum=100, step=1, value=14, label="Seed"
)
seed_vis_circular = gr.Plot(
value=generate_seed_vis(14), label="Seed"
)
generate_images_button_circular = gr.Button("Generate Images")
with gr.Column():
images_output_circular = gr.Gallery(label="Image", selected_index=0)
gif_circular = gr.Image(label="GIF")
zip_output_circular = gr.File(label="Download ZIP")
num_images_circular.change(
fn=calculate_step_size,
inputs=[num_images_circular, start_degree_circular, end_degree_circular],
outputs=[step_size_circular],
)
start_degree_circular.change(
fn=calculate_step_size,
inputs=[num_images_circular, start_degree_circular, end_degree_circular],
outputs=[step_size_circular],
)
end_degree_circular.change(
fn=calculate_step_size,
inputs=[num_images_circular, start_degree_circular, end_degree_circular],
outputs=[step_size_circular],
)
generate_images_button_circular.click(
fn=display_circular_images,
inputs=[
prompt_circular,
seed_circular,
num_inference_steps_circular,
num_images_circular,
start_degree_circular,
end_degree_circular,
],
outputs=[images_output_circular, gif_circular, zip_output_circular],
)
seed_circular.change(
fn=generate_seed_vis, inputs=[seed_circular], outputs=[seed_vis_circular]
)
with gr.TabItem("Poke"):
gr.Markdown("Perturb a region in the image and observe the effect.")
gr.HTML(read_html("html/poke.html"))
with gr.Row():
with gr.Column():
prompt_poke = gr.Textbox(
lines=1,
label="Prompt",
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
)
num_inference_steps_poke = gr.Slider(
minimum=2,
maximum=100,
step=1,
value=8,
label="Number of Inference Steps per Image",
)
with gr.Row():
seed_poke = gr.Slider(
minimum=0, maximum=100, step=1, value=14, label="Seed"
)
seed_vis_poke = gr.Plot(
value=generate_seed_vis(14), label="Seed"
)
pokeX = gr.Slider(
label="pokeX",
minimum=0,
maximum=64,
step=1,
value=32,
info="X coordinate of poke center",
)
pokeY = gr.Slider(
label="pokeY",
minimum=0,
maximum=64,
step=1,
value=32,
info="Y coordinate of poke center",
)
pokeHeight = gr.Slider(
label="pokeHeight",
minimum=0,
maximum=64,
step=1,
value=8,
info="Height of the poke",
)
pokeWidth = gr.Slider(
label="pokeWidth",
minimum=0,
maximum=64,
step=1,
value=8,
info="Width of the poke",
)
generate_images_button_poke = gr.Button("Generate Images")
with gr.Column():
original_images_output_poke = gr.Image(
value=visualize_poke(32, 32, 8, 8)[0], label="Original Image"
)
poked_images_output_poke = gr.Image(
value=visualize_poke(32, 32, 8, 8)[1], label="Poked Image"
)
zip_output_poke = gr.File(label="Download ZIP")
pokeX.change(
visualize_poke,
inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
outputs=[original_images_output_poke, poked_images_output_poke],
)
pokeY.change(
visualize_poke,
inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
outputs=[original_images_output_poke, poked_images_output_poke],
)
pokeHeight.change(
visualize_poke,
inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
outputs=[original_images_output_poke, poked_images_output_poke],
)
pokeWidth.change(
visualize_poke,
inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
outputs=[original_images_output_poke, poked_images_output_poke],
)
seed_poke.change(
fn=generate_seed_vis, inputs=[seed_poke], outputs=[seed_vis_poke]
)
@generate_images_button_poke.click(
inputs=[
prompt_poke,
seed_poke,
num_inference_steps_poke,
pokeX,
pokeY,
pokeHeight,
pokeWidth,
],
outputs=[
original_images_output_poke,
poked_images_output_poke,
zip_output_poke,
],
)
def generate_images_wrapper(
prompt,
seed,
num_inference_steps,
pokeX=pokeX,
pokeY=pokeY,
pokeHeight=pokeHeight,
pokeWidth=pokeWidth,
):
_, _ = display_poke_images(
prompt,
seed,
num_inference_steps,
poke=True,
pokeX=pokeX,
pokeY=pokeY,
pokeHeight=pokeHeight,
pokeWidth=pokeWidth,
intermediate=False,
)
images, modImages = visualize_poke(pokeX, pokeY, pokeHeight, pokeWidth)
fname = "poke"
tab_config = {
"Tab": "Poke",
"Prompt": prompt,
"Number of Inference Steps per Image": num_inference_steps,
"Seed": seed,
"PokeX": pokeX,
"PokeY": pokeY,
"PokeHeight": pokeHeight,
"PokeWidth": pokeWidth,
}
imgs_list = []
imgs_list.append((images, "Original Image"))
imgs_list.append((modImages, "Poked Image"))
export_as_zip(imgs_list, fname, tab_config)
return images, modImages, f"outputs/{fname}.zip"
with gr.TabItem("Guidance"):
gr.Markdown("Observe the effect of different guidance scales.")
gr.HTML(read_html("html/guidance.html"))
with gr.Row():
with gr.Column():
prompt_guidance = gr.Textbox(
lines=1,
label="Prompt",
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
)
num_inference_steps_guidance = gr.Slider(
minimum=2,
maximum=100,
step=1,
value=8,
label="Number of Inference Steps per Image",
)
guidance_scale_values = gr.Textbox(
lines=1, value="1, 8, 20, 30", label="Guidance Scale Values"
)
with gr.Row():
seed_guidance = gr.Slider(
minimum=0, maximum=100, step=1, value=14, label="Seed"
)
seed_vis_guidance = gr.Plot(
value=generate_seed_vis(14), label="Seed"
)
generate_images_button_guidance = gr.Button("Generate Images")
with gr.Column():
images_output_guidance = gr.Gallery(
label="Images", selected_index=0,
height=512,
)
zip_output_guidance = gr.File(label="Download ZIP")
generate_images_button_guidance.click(
fn=display_guidance_images,
inputs=[
prompt_guidance,
seed_guidance,
num_inference_steps_guidance,
guidance_scale_values,
],
outputs=[images_output_guidance, zip_output_guidance],
)
seed_guidance.change(
fn=generate_seed_vis, inputs=[seed_guidance], outputs=[seed_vis_guidance]
)
with gr.TabItem("Inpainting"):
gr.Markdown("Inpaint the image based on the prompt.")
gr.HTML(read_html("html/inpainting.html"))
with gr.Row():
with gr.Column():
uploaded_img_inpaint = gr.Sketchpad(
sources="upload", brush=gr.Brush(colors=["#ffff00"]), type="pil", label="Upload"
)
prompt_inpaint = gr.Textbox(
lines=1, label="Prompt", value="sunglasses"
)
num_inference_steps_inpaint = gr.Slider(
minimum=2,
maximum=100,
step=1,
value=8,
label="Number of Inference Steps per Image",
)
with gr.Row():
seed_inpaint = gr.Slider(
minimum=0, maximum=100, step=1, value=14, label="Seed"
)
seed_vis_inpaint = gr.Plot(
value=generate_seed_vis(14), label="Seed"
)
inpaint_button = gr.Button("Inpaint")
with gr.Column():
images_output_inpaint = gr.Image(label="Output")
zip_output_inpaint = gr.File(label="Download ZIP")
inpaint_button.click(
fn=inpaint,
inputs=[
uploaded_img_inpaint,
num_inference_steps_inpaint,
seed_inpaint,
prompt_inpaint,
],
outputs=[images_output_inpaint, zip_output_inpaint],
)
seed_inpaint.change(
fn=generate_seed_vis, inputs=[seed_inpaint], outputs=[seed_vis_inpaint]
)
with gr.Tab("CLIP Space"):
with gr.TabItem("Embeddings"):
gr.Markdown(
"Visualize text embedding space in 3D with input texts and output images based on the chosen axis."
)
gr.HTML(read_html("html/embeddings.html"))
with gr.Row():
output = gr.HTML(
f"""
<iframe id="html" src="{dash_tunnel}" style="width:100%; height:700px;"></iframe>
"""
)
with gr.Row():
word2add_rem = gr.Textbox(lines=1, label="Add/Remove word")
word2change = gr.Textbox(lines=1, label="Change image for word")
clear_words_button = gr.Button(value="Clear words")
with gr.Accordion("Custom Semantic Dimensions", open=False):
with gr.Row():
axis_name_1 = gr.Textbox(label="Axis name", value="gender")
which_axis_1 = gr.Dropdown(
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
value=whichAxisMap["which_axis_1"],
label="Axis direction",
)
from_words_1 = gr.Textbox(
lines=1,
label="Positive",
value="prince husband father son uncle",
)
to_words_1 = gr.Textbox(
lines=1,
label="Negative",
value="princess wife mother daughter aunt",
)
submit_1 = gr.Button("Submit")
with gr.Row():
axis_name_2 = gr.Textbox(label="Axis name", value="age")
which_axis_2 = gr.Dropdown(
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
value=whichAxisMap["which_axis_2"],
label="Axis direction",
)
from_words_2 = gr.Textbox(
lines=1, label="Positive", value="man woman king queen father"
)
to_words_2 = gr.Textbox(
lines=1, label="Negative", value="boy girl prince princess son"
)
submit_2 = gr.Button("Submit")
with gr.Row():
axis_name_3 = gr.Textbox(label="Axis name", value="residual")
which_axis_3 = gr.Dropdown(
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
value=whichAxisMap["which_axis_3"],
label="Axis direction",
)
from_words_3 = gr.Textbox(lines=1, label="Positive")
to_words_3 = gr.Textbox(lines=1, label="Negative")
submit_3 = gr.Button("Submit")
with gr.Row():
axis_name_4 = gr.Textbox(label="Axis name", value="number")
which_axis_4 = gr.Dropdown(
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
value=whichAxisMap["which_axis_4"],
label="Axis direction",
)
from_words_4 = gr.Textbox(
lines=1,
label="Positive",
value="boys girls cats puppies computers",
)
to_words_4 = gr.Textbox(
lines=1, label="Negative", value="boy girl cat puppy computer"
)
submit_4 = gr.Button("Submit")
with gr.Row():
axis_name_5 = gr.Textbox(label="Axis name", value="royalty")
which_axis_5 = gr.Dropdown(
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
value=whichAxisMap["which_axis_5"],
label="Axis direction",
)
from_words_5 = gr.Textbox(
lines=1,
label="Positive",
value="king queen prince princess duchess",
)
to_words_5 = gr.Textbox(
lines=1, label="Negative", value="man woman boy girl woman"
)
submit_5 = gr.Button("Submit")
with gr.Row():
axis_name_6 = gr.Textbox(label="Axis name")
which_axis_6 = gr.Dropdown(
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
value=whichAxisMap["which_axis_6"],
label="Axis direction",
)
from_words_6 = gr.Textbox(lines=1, label="Positive")
to_words_6 = gr.Textbox(lines=1, label="Negative")
submit_6 = gr.Button("Submit")
@word2add_rem.submit(inputs=[word2add_rem], outputs=[output, word2add_rem])
def add_rem_word_and_clear(words):
return add_rem_word(words), ""
@word2change.submit(inputs=[word2change], outputs=[output, word2change])
def change_word_and_clear(word):
return change_word(word), ""
clear_words_button.click(fn=clear_words, outputs=[output])
@submit_1.click(
inputs=[axis_name_1, which_axis_1, from_words_1, to_words_1],
outputs=[
output,
which_axis_2,
which_axis_3,
which_axis_4,
which_axis_5,
which_axis_6,
],
)
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
for ax in whichAxisMap:
if whichAxisMap[ax] == which_axis:
whichAxisMap[ax] = "---"
whichAxisMap["which_axis_1"] = which_axis
return (
set_axis(axis_name, which_axis, from_words, to_words),
whichAxisMap["which_axis_2"],
whichAxisMap["which_axis_3"],
whichAxisMap["which_axis_4"],
whichAxisMap["which_axis_5"],
whichAxisMap["which_axis_6"],
)
@submit_2.click(
inputs=[axis_name_2, which_axis_2, from_words_2, to_words_2],
outputs=[
output,
which_axis_1,
which_axis_3,
which_axis_4,
which_axis_5,
which_axis_6,
],
)
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
for ax in whichAxisMap:
if whichAxisMap[ax] == which_axis:
whichAxisMap[ax] = "---"
whichAxisMap["which_axis_2"] = which_axis
return (
set_axis(axis_name, which_axis, from_words, to_words),
whichAxisMap["which_axis_1"],
whichAxisMap["which_axis_3"],
whichAxisMap["which_axis_4"],
whichAxisMap["which_axis_5"],
whichAxisMap["which_axis_6"],
)
@submit_3.click(
inputs=[axis_name_3, which_axis_3, from_words_3, to_words_3],
outputs=[
output,
which_axis_1,
which_axis_2,
which_axis_4,
which_axis_5,
which_axis_6,
],
)
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
for ax in whichAxisMap:
if whichAxisMap[ax] == which_axis:
whichAxisMap[ax] = "---"
whichAxisMap["which_axis_3"] = which_axis
return (
set_axis(axis_name, which_axis, from_words, to_words),
whichAxisMap["which_axis_1"],
whichAxisMap["which_axis_2"],
whichAxisMap["which_axis_4"],
whichAxisMap["which_axis_5"],
whichAxisMap["which_axis_6"],
)
@submit_4.click(
inputs=[axis_name_4, which_axis_4, from_words_4, to_words_4],
outputs=[
output,
which_axis_1,
which_axis_2,
which_axis_3,
which_axis_5,
which_axis_6,
],
)
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
for ax in whichAxisMap:
if whichAxisMap[ax] == which_axis:
whichAxisMap[ax] = "---"
whichAxisMap["which_axis_4"] = which_axis
return (
set_axis(axis_name, which_axis, from_words, to_words),
whichAxisMap["which_axis_1"],
whichAxisMap["which_axis_2"],
whichAxisMap["which_axis_3"],
whichAxisMap["which_axis_5"],
whichAxisMap["which_axis_6"],
)
@submit_5.click(
inputs=[axis_name_5, which_axis_5, from_words_5, to_words_5],
outputs=[
output,
which_axis_1,
which_axis_2,
which_axis_3,
which_axis_4,
which_axis_6,
],
)
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
for ax in whichAxisMap:
if whichAxisMap[ax] == which_axis:
whichAxisMap[ax] = "---"
whichAxisMap["which_axis_5"] = which_axis
return (
set_axis(axis_name, which_axis, from_words, to_words),
whichAxisMap["which_axis_1"],
whichAxisMap["which_axis_2"],
whichAxisMap["which_axis_3"],
whichAxisMap["which_axis_4"],
whichAxisMap["which_axis_6"],
)
@submit_6.click(
inputs=[axis_name_6, which_axis_6, from_words_6, to_words_6],
outputs=[
output,
which_axis_1,
which_axis_2,
which_axis_3,
which_axis_4,
which_axis_5,
],
)
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
for ax in whichAxisMap:
if whichAxisMap[ax] == which_axis:
whichAxisMap[ax] = "---"
whichAxisMap["which_axis_6"] = which_axis
return (
set_axis(axis_name, which_axis, from_words, to_words),
whichAxisMap["which_axis_1"],
whichAxisMap["which_axis_2"],
whichAxisMap["which_axis_3"],
whichAxisMap["which_axis_4"],
whichAxisMap["which_axis_5"],
)
with gr.TabItem("Interpolate"):
gr.Markdown(
"Interpolate between the first and the second prompt, and observe how the output changes."
)
gr.HTML(read_html("html/interpolate.html"))
with gr.Row():
with gr.Column():
promptA = gr.Textbox(
lines=1,
label="First Prompt",
value="Self-portrait oil painting, a beautiful man with golden hair, 8k",
)
promptB = gr.Textbox(
lines=1,
label="Second Prompt",
value="Self-portrait oil painting, a beautiful woman with golden hair, 8k",
)
num_images_interpolate = gr.Slider(
minimum=0,
maximum=100,
step=1,
value=5,
label="Number of Interpolation Steps",
)
num_inference_steps_interpolate = gr.Slider(
minimum=2,
maximum=100,
step=1,
value=8,
label="Number of Inference Steps per Image",
)
with gr.Row():
seed_interpolate = gr.Slider(
minimum=0, maximum=100, step=1, value=14, label="Seed"
)
seed_vis_interpolate = gr.Plot(
value=generate_seed_vis(14), label="Seed"
)
generate_images_button_interpolate = gr.Button("Generate Images")
with gr.Column():
images_output_interpolate = gr.Gallery(
label="Interpolated Images", selected_index=0,
height=512,
)
gif_interpolate = gr.Image(label="GIF")
zip_output_interpolate = gr.File(label="Download ZIP")
generate_images_button_interpolate.click(
fn=display_interpolate_images,
inputs=[
seed_interpolate,
promptA,
promptB,
num_inference_steps_interpolate,
num_images_interpolate,
],
outputs=[
images_output_interpolate,
gif_interpolate,
zip_output_interpolate,
],
)
seed_interpolate.change(
fn=generate_seed_vis,
inputs=[seed_interpolate],
outputs=[seed_vis_interpolate],
)
with gr.TabItem("Negative"):
gr.Markdown("Observe the effect of negative prompts.")
gr.HTML(read_html("html/negative.html"))
with gr.Row():
with gr.Column():
prompt_negative = gr.Textbox(
lines=1,
label="Prompt",
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
)
neg_prompt = gr.Textbox(
lines=1, label="Negative Prompt", value="Yellow"
)
num_inference_steps_negative = gr.Slider(
minimum=2,
maximum=100,
step=1,
value=8,
label="Number of Inference Steps per Image",
)
with gr.Row():
seed_negative = gr.Slider(
minimum=0, maximum=100, step=1, value=14, label="Seed"
)
seed_vis_negative = gr.Plot(
value=generate_seed_vis(14), label="Seed"
)
generate_images_button_negative = gr.Button("Generate Images")
with gr.Column():
images_output_negative = gr.Image(
label="Image without Negative Prompt"
)
images_neg_output_negative = gr.Image(
label="Image with Negative Prompt"
)
zip_output_negative = gr.File(label="Download ZIP")
seed_negative.change(
fn=generate_seed_vis, inputs=[seed_negative], outputs=[seed_vis_negative]
)
generate_images_button_negative.click(
fn=display_negative_images,
inputs=[
prompt_negative,
seed_negative,
num_inference_steps_negative,
neg_prompt,
],
outputs=[
images_output_negative,
images_neg_output_negative,
zip_output_negative,
],
)
with gr.Tab("Credits"):
gr.Markdown("""
Author: Adithya Kameswara Rao, Carnegie Mellon University.
Advisor: David S. Touretzky, Carnegie Mellon University.
This work was funded by a grant from NEOM Company, and by National Science Foundation award IIS-2112633.
""")
def run_dash():
app.run(host="127.0.0.1", port="8000")
# def run_gradio():
# demo.queue()
# _, _, public_url = demo.launch(share=True)
# return public_url
if __name__ == "__main__":
thread = Thread(target=run_dash)
thread.daemon = True
thread.start()
try:
os.makedirs("outputs", exist_ok=True)
demo.queue().launch(share=True)
except KeyboardInterrupt:
print("Server closed")