|
|
|
|
|
import queue |
|
import gradio as gr |
|
import random |
|
import torch |
|
from collections import defaultdict |
|
from diffusers import DiffusionPipeline |
|
from functools import partial |
|
from itertools import zip_longest |
|
from typing import List |
|
from PIL import Image |
|
|
|
SELECT_LABEL = "Select as seed" |
|
|
|
MODEL_ID = "CompVis/ldm-text2im-large-256" |
|
STEPS = 5 |
|
ETA = 0.3 |
|
GUIDANCE_SCALE = 6 |
|
|
|
ldm = DiffusionPipeline.from_pretrained(MODEL_ID) |
|
|
|
import torch |
|
print(f"cuda: {torch.cuda.is_available()}") |
|
|
|
with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo: |
|
state = gr.Variable({ |
|
'selected': -1, |
|
'seeds': [random.randint(0, 2 ** 32 - 1) for _ in range(6)] |
|
}) |
|
|
|
def infer_seeded_image(prompt, seed): |
|
print(f"Prompt: {prompt}, seed: {seed}") |
|
images, _ = infer_grid(prompt, n=1, seeds=[seed]) |
|
return images[0] |
|
|
|
def infer_grid(prompt, n=6, seeds=[]): |
|
|
|
|
|
result = defaultdict(list) |
|
for _, seed in zip_longest(range(n), seeds, fillvalue=None): |
|
seed = random.randint(0, 2**32 - 1) if seed is None else seed |
|
_ = torch.manual_seed(seed) |
|
with torch.autocast("cuda"): |
|
images = ldm( |
|
[prompt], |
|
num_inference_steps=STEPS, |
|
eta=ETA, |
|
guidance_scale=GUIDANCE_SCALE |
|
)["sample"] |
|
result["images"].append(images[0]) |
|
result["seeds"].append(seed) |
|
return result["images"], result["seeds"] |
|
|
|
def infer(prompt, state): |
|
""" |
|
Outputs: |
|
- Grid images (list) |
|
- Seeded Image (Image or None) |
|
- Grid Box with updated visibility |
|
- Seeded Box with updated visibility |
|
""" |
|
grid_images = [None] * 6 |
|
image_with_seed = None |
|
visible = (False, False) |
|
|
|
if (seed_index := state["selected"]) > -1: |
|
seed = state["seeds"][seed_index] |
|
image_with_seed = infer_seeded_image(prompt, seed) |
|
visible = (False, True) |
|
else: |
|
grid_images, seeds = infer_grid(prompt) |
|
state["seeds"] = seeds |
|
visible = (True, False) |
|
|
|
boxes = [gr.Box.update(visible=v) for v in visible] |
|
return grid_images + [image_with_seed] + boxes + [state] |
|
|
|
def update_state(selected_index: int, value, state): |
|
if value == '': |
|
others_value = gr.components._Keywords.NO_VALUE |
|
else: |
|
others_value = '' |
|
state["selected"] = selected_index |
|
return [gr.Radio.update(value=others_value) for _ in range(5)] + [state] |
|
|
|
def clear_seed(state): |
|
"""Update state of Radio buttons, grid, seeded_box""" |
|
state["selected"] = -1 |
|
return [''] * 6 + [gr.Box.update(visible=True), gr.Box.update(visible=False)] + [state] |
|
|
|
def image_block(): |
|
return gr.Image( |
|
interactive=False, show_label=False |
|
).style( |
|
|
|
rounded = (True, True, False, False), |
|
) |
|
|
|
def radio_block(): |
|
radio = gr.Radio( |
|
choices=[SELECT_LABEL], interactive=True, show_label=False, |
|
).style( |
|
|
|
|
|
container=False |
|
) |
|
return radio |
|
|
|
gr.Markdown( |
|
""" |
|
<h1><center>Latent Diffusion Demo</center></h1> |
|
<p>Type anything to generate a few images that represent your prompt. |
|
Select one of the results to use as a <b>seed</b> for the next generation: |
|
you can try variations of your prompt starting from the same state and see how it changes. |
|
For example, <i>Labrador in the style of Vermeer</i> could be tweaked to |
|
<i>Labrador in the style of Picasso</i> or <i>Lynx in the style of Van Gogh</i>. |
|
If your prompts are similar, the tweaked result should also have a similar structure |
|
but different details or style.</p> |
|
""" |
|
) |
|
with gr.Group(): |
|
with gr.Box(): |
|
with gr.Row().style(mobile_collapse=False, equal_height=True): |
|
text = gr.Textbox( |
|
label="Enter your prompt", show_label=False, max_lines=1 |
|
).style( |
|
border=(True, False, True, True), |
|
|
|
rounded=(True, False, False, True), |
|
container=False, |
|
) |
|
btn = gr.Button("Run").style( |
|
margin=False, |
|
rounded=(False, True, True, False), |
|
) |
|
|
|
|
|
with (grid := gr.Box()): |
|
with gr.Row(): |
|
with gr.Box().style(border=None): |
|
image1 = image_block() |
|
select1 = radio_block() |
|
with gr.Box().style(border=None): |
|
image2 = image_block() |
|
select2 = radio_block() |
|
with gr.Box().style(border=None): |
|
image3 = image_block() |
|
select3 = radio_block() |
|
with gr.Row(): |
|
with gr.Box().style(border=None): |
|
image4 = image_block() |
|
select4 = radio_block() |
|
with gr.Box().style(border=None): |
|
image5 = image_block() |
|
select5 = radio_block() |
|
with gr.Box().style(border=None): |
|
image6 = image_block() |
|
select6 = radio_block() |
|
|
|
images = [image1, image2, image3, image4, image5, image6] |
|
selectors = [select1, select2, select3, select4, select5, select6] |
|
|
|
for i, radio in enumerate(selectors): |
|
others = list(filter(lambda s: s != radio, selectors)) |
|
radio.change( |
|
partial(update_state, i), |
|
inputs=[radio, state], |
|
outputs=others + [state], |
|
queue=False |
|
) |
|
|
|
with (seeded_box := gr.Box()): |
|
seeded_image = image_block() |
|
clear_seed_button = gr.Button("Return to Grid") |
|
seeded_box.visible = False |
|
clear_seed_button.click( |
|
clear_seed, |
|
inputs=[state], |
|
outputs=selectors + [grid, seeded_box] + [state] |
|
) |
|
|
|
all_images = images + [seeded_image] |
|
boxes = [grid, seeded_box] |
|
infer_outputs = all_images + boxes + [state] |
|
|
|
text.submit( |
|
infer, |
|
inputs=[text, state], |
|
outputs=infer_outputs |
|
) |
|
btn.click( |
|
infer, |
|
inputs=[text, state], |
|
outputs=infer_outputs |
|
) |
|
|
|
demo.launch(enable_queue=True) |