File size: 7,860 Bytes
ce3552b 3f6a115 94ca5b6 b05fbf1 94ca5b6 b05fbf1 527bf99 ce3552b 98c2a34 b05fbf1 94ca5b6 b05fbf1 94ca5b6 d75def4 94ca5b6 941bc2f ce3552b 94ca5b6 527bf99 fc8546c ce3552b cb48ec9 ce3552b d42756a ce3552b b05fbf1 9e3f56f b05fbf1 9e3f56f b05fbf1 ce3552b 527bf99 b05fbf1 527bf99 b05fbf1 527bf99 ce3552b 5f09b9a ce3552b d09bee2 e84f604 79488ea b05fbf1 2443dbf b05fbf1 da9710b fce9e32 94ca5b6 d1c3f6b 94ca5b6 d2336e9 1007b31 fce9e32 251a915 b05fbf1 2bd0d19 b05fbf1 fce9e32 da9710b b05fbf1 2104e5b 527bf99 b05fbf1 62ee62a b05fbf1 94ca5b6 527bf99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
import base64
import requests
import random
import os
from io import BytesIO
from region_control import MultiDiffusion, get_views, preprocess_mask, seed_everything
from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix
MAX_COLORS = 12
sd = MultiDiffusion("cuda", "2.1")
is_shared_ui = True if "weizmannscience/multidiffusion-region-based" in os.environ['SPACE_ID'] else False
is_gpu_associated = True if torch.cuda.is_available() else False
canvas_html = "<div id='canvas-root' style='max-width:400px; margin: 0 auto'></div>"
load_js = """
async () => {
const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/sketch-canvas.js"
fetch(url)
.then(res => res.text())
.then(text => {
const script = document.createElement('script');
script.type = "module"
script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' }));
document.head.appendChild(script);
});
}
"""
get_js_colors = """
async (canvasData) => {
const canvasEl = document.getElementById("canvas-root");
return [canvasEl._data]
}
"""
set_canvas_size ="""
async (aspect) => {
if(aspect ==='square'){
_updateCanvas(512,512)
}
if(aspect ==='horizontal'){
_updateCanvas(768,512)
}
if(aspect ==='vertical'){
_updateCanvas(512,768)
}
}
"""
def process_sketch(canvas_data, binary_matrixes):
binary_matrixes.clear()
base64_img = canvas_data['image']
image_data = base64.b64decode(base64_img.split(',')[1])
image = Image.open(BytesIO(image_data)).convert("RGB")
im2arr = np.array(image)
colors = [tuple(map(int, rgb[4:-1].split(','))) for rgb in canvas_data['colors']]
colors_fixed = []
for color in colors:
r, g, b = color
if any(c != 255 for c in (r, g, b)):
binary_matrix = create_binary_matrix(im2arr, (r,g,b))
binary_matrixes.append(binary_matrix)
colors_fixed.append(gr.update(value=f'<div style="display:flex;align-items: center;justify-content: center"><img width="20%" style="margin-right: 1em" src="file/{binary_matrix}" /><div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div></div>'))
visibilities = []
colors = []
for n in range(MAX_COLORS):
visibilities.append(gr.update(visible=False))
colors.append(gr.update(value=f'<div class="color-bg-item" style="background-color: black"></div>'))
for n in range(len(colors_fixed)):
visibilities[n] = gr.update(visible=True)
colors[n] = colors_fixed[n]
return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
def process_generation(model, binary_matrixes, boostrapping, aspect, steps, seed, master_prompt, negative_prompt, *prompts):
global sd
if(model != "stabilityai/stable-diffusion-2-1-base"):
sd = MultiDiffusion("cuda", model)
if(seed == -1):
seed = random.randint(1, 2147483647)
seed_everything(seed)
dimensions = {"square": (512, 512), "horizontal": (768, 512), "vertical": (512, 768)}
width, height = dimensions.get(aspect, dimensions["square"])
clipped_prompts = prompts[:len(binary_matrixes)]
prompts = [master_prompt] + list(clipped_prompts)
neg_prompts = [negative_prompt] * len(prompts)
fg_masks = torch.cat([preprocess_mask(mask_path, height // 8, width // 8, "cuda") for mask_path in binary_matrixes])
bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
bg_mask[bg_mask < 0] = 0
masks = torch.cat([bg_mask, fg_masks])
print(masks.size())
image = sd.generate(masks, prompts, neg_prompts, height, width, steps, bootstrapping=boostrapping)
return(image)
css = '''
#color-bg{display:flex;justify-content: center;align-items: center;}
.color-bg-item{width: 100%; height: 32px}
#main_button{width:100%}
<style>
'''
with gr.Blocks(css=css) as demo:
binary_matrixes = gr.State([])
gr.Markdown('''## Control your Stable Diffusion generation with Sketches (_beta_)
A beta version demo of [MultiDiffusion](https://arxiv.org/abs/2302.08113) region-based generation using Stable Diffusion 2.1 model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io).
''')
if(is_shared_ui):
gr.HTML(f'''
<div style="margin-top:-20px">To skip the queue or try the technique with custom models, you may duplicate the space and associate an A10 GPU to it <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></div>
''')
elif(not is_gpu_associated):
gr.HTML(f'''
<div>You have succesfully duplicated the Space 🎉, but it is running on CPU - which may break this application. Go to the <a href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}/settings">settings</a> page to associate a GPU to it</div>
''')
with gr.Row():
with gr.Box(elem_id="main-image"):
canvas_data = gr.JSON(value={}, visible=False)
model = gr.Textbox(label="The id of any Hugging Face model in the diffusers format", value="stabilityai/stable-diffusion-2-1-base", visible=False if is_shared_ui else True)
canvas = gr.HTML(canvas_html)
aspect = gr.Radio(["square", "horizontal", "vertical"], value="square", label="Aspect Ratio", visible=False if is_shared_ui else True)
button_run = gr.Button("I've finished my sketch",elem_id="main_button", interactive=True)
prompts = []
colors = []
color_row = [None] * MAX_COLORS
with gr.Column(visible=False) as post_sketch:
general_prompt = gr.Textbox(label="General Prompt")
for n in range(MAX_COLORS):
with gr.Row(visible=False) as color_row[n]:
with gr.Box(elem_id="color-bg"):
colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>'))
prompts.append(gr.Textbox(label="Prompt for this mask"))
with gr.Accordion("Advanced options", open=False):
negative_prompt = gr.Textbox(label="Global negative prompt for all prompts", value="low quality")
boostrapping = gr.Slider(label="Bootstrapping", minimum=1, maximum=100, value=10, step=1)
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=-1, step=1)
final_run_btn = gr.Button("Generate!")
out_image = gr.Image(label="Result", ).style(width=512,height=512)
gr.Markdown('''
![Examples](https://multidiffusion.github.io/pics/tight.jpg)
''')
#css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
aspect.change(None, inputs=[aspect], outputs=None, _js = set_canvas_size)
button_run.click(process_sketch, inputs=[canvas_data, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors], _js=get_js_colors, queue=False)
final_run_btn.click(process_generation, inputs=[model, binary_matrixes, boostrapping, aspect, steps, seed, general_prompt, negative_prompt, *prompts], outputs=out_image)
demo.load(None, None, None, _js=load_js)
demo.launch(debug=True) |