hysts's picture
hysts HF staff
Migrate from yapf to black
11cddad
raw
history blame
6.47 kB
#!/usr/bin/env python
from __future__ import annotations
import os
import gradio as gr
import PIL.Image
from model import Model
DESCRIPTION = """\
# Attend-and-Excite
This is a demo for [Attend-and-Excite](https://arxiv.org/abs/2301.13826).
Attend-and-Excite performs attention-based generative semantic guidance to mitigate subject neglect in Stable Diffusion.
Select a prompt and a set of indices matching the subjects you wish to strengthen (the `Check token indices` cell can help map between a word and its index).
"""
model = Model()
def process_example(
prompt: str,
indices_to_alter_str: str,
seed: int,
apply_attend_and_excite: bool,
) -> tuple[list[tuple[int, str]], PIL.Image.Image]:
num_steps = 50
guidance_scale = 7.5
token_table = model.get_token_table(prompt)
result = model.run(prompt, indices_to_alter_str, seed, apply_attend_and_excite, num_steps, guidance_scale)
return token_table, result
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Row():
with gr.Column():
prompt = gr.Text(
label="Prompt",
max_lines=1,
placeholder="A pod of dolphins leaping out of the water in an ocean with a ship on the background",
)
with gr.Accordion(label="Check token indices", open=False):
show_token_indices_button = gr.Button("Show token indices")
token_indices_table = gr.Dataframe(label="Token indices", headers=["Index", "Token"], col_count=2)
token_indices_str = gr.Text(
label="Token indices (a comma-separated list indices of the tokens you wish to alter)",
max_lines=1,
placeholder="4,16",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=100000,
step=1,
value=0,
)
apply_attend_and_excite = gr.Checkbox(label="Apply Attend-and-Excite", value=True)
num_steps = gr.Slider(
label="Number of steps",
minimum=0,
maximum=100,
step=1,
value=50,
)
guidance_scale = gr.Slider(
label="CFG scale",
minimum=0,
maximum=50,
step=0.1,
value=7.5,
)
run_button = gr.Button("Generate")
with gr.Column():
result = gr.Image(label="Result")
with gr.Row():
examples = [
[
"A mouse and a red car",
"2,6",
2098,
True,
],
[
"A mouse and a red car",
"2,6",
2098,
False,
],
[
"A horse and a dog",
"2,5",
123,
True,
],
[
"A horse and a dog",
"2,5",
123,
False,
],
[
"A painting of an elephant with glasses",
"5,7",
123,
True,
],
[
"A painting of an elephant with glasses",
"5,7",
123,
False,
],
[
"A playful kitten chasing a butterfly in a wildflower meadow",
"3,6,10",
123,
True,
],
[
"A playful kitten chasing a butterfly in a wildflower meadow",
"3,6,10",
123,
False,
],
[
"A grizzly bear catching a salmon in a crystal clear river surrounded by a forest",
"2,6,15",
123,
True,
],
[
"A grizzly bear catching a salmon in a crystal clear river surrounded by a forest",
"2,6,15",
123,
False,
],
[
"A pod of dolphins leaping out of the water in an ocean with a ship on the background",
"4,16",
123,
True,
],
[
"A pod of dolphins leaping out of the water in an ocean with a ship on the background",
"4,16",
123,
False,
],
]
gr.Examples(
examples=examples,
inputs=[
prompt,
token_indices_str,
seed,
apply_attend_and_excite,
],
outputs=[
token_indices_table,
result,
],
fn=process_example,
cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
examples_per_page=20,
)
show_token_indices_button.click(
fn=model.get_token_table,
inputs=prompt,
outputs=token_indices_table,
queue=False,
api_name=False,
)
inputs = [
prompt,
token_indices_str,
seed,
apply_attend_and_excite,
num_steps,
guidance_scale,
]
prompt.submit(
fn=model.get_token_table,
inputs=prompt,
outputs=token_indices_table,
queue=False,
api_name=False,
).then(
fn=model.run,
inputs=inputs,
outputs=result,
api_name=False,
)
token_indices_str.submit(
fn=model.get_token_table,
inputs=prompt,
outputs=token_indices_table,
queue=False,
api_name=False,
).then(
fn=model.run,
inputs=inputs,
outputs=result,
api_name=False,
)
run_button.click(
fn=model.get_token_table,
inputs=prompt,
outputs=token_indices_table,
queue=False,
api_name=False,
).then(
fn=model.run,
inputs=inputs,
outputs=result,
api_name="run",
)
if __name__ == "__main__":
demo.queue(max_size=10).launch()