import os import sys import pdb import random import numpy as np from PIL import Image import base64 from io import BytesIO import torch from torchvision import transforms import torchvision.transforms.functional as TF import gradio as gr from src.model import make_1step_sched from src.pix2pix_turbo import Pix2Pix_Turbo model = Pix2Pix_Turbo("sketch_to_image_stochastic") style_list = [ { "name": "No Style", "prompt": "{prompt}", }, { "name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", }, { "name": "3D Model", "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", }, { "name": "Anime", "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed", }, { "name": "Digital Art", "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed", }, { "name": "Photographic", "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", }, { "name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics", }, { "name": "Fantasy art", "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", }, { "name": "Neonpunk", "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", }, { "name": "Manga", "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style", }, ] styles = {k["name"]: k["prompt"] for k in style_list} STYLE_NAMES = list(styles.keys()) DEFAULT_STYLE_NAME = "Fantasy art" MAX_SEED = np.iinfo(np.int32).max def pil_image_to_data_uri(img, format='PNG'): buffered = BytesIO() img.save(buffered, format=format) img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/{format.lower()};base64,{img_str}" def run(image, prompt, prompt_template, style_name): print("sketch updated") if image is None: ones = Image.new("L", (512, 512), 255) temp_uri = pil_image_to_data_uri(ones) return ones, gr.update(link=temp_uri), gr.update(link=temp_uri) prompt = prompt_template.replace("{prompt}", prompt) image = image.convert("RGB") image_t = TF.to_tensor(image) > 0.5 image_pil = TF.to_pil_image(image_t.to(torch.float32)) with torch.no_grad(): c_t = image_t.unsqueeze(0).cuda().float() torch.manual_seed(42) B,C,H,W = c_t.shape noise = torch.randn((1,4,H//8, W//8), device=c_t.device) output_image = model(c_t, prompt, deterministic=False, r=0.5, noise_map=noise) output_pil = TF.to_pil_image(output_image[0].cpu()*0.5+0.5) input_sketch_uri = pil_image_to_data_uri(Image.fromarray(255-np.array(image))) output_image_uri = pil_image_to_data_uri(output_pil) return output_pil, gr.update(link=input_sketch_uri), gr.update(link=output_image_uri) def update_canvas(use_line, use_eraser): if use_eraser: _color = "#ffffff" brush_size = 20 if use_line: _color = "#000000" brush_size = 4 return gr.update(brush_radius=brush_size, brush_color=_color, interactive=True) def upload_sketch(file): _img = Image.open(file.name) _img = _img.convert("L") return gr.update(value=_img, source="upload", interactive=True) scripts = """ async () => { globalThis.theSketchDownloadFunction = () => { console.log("test") var link = document.createElement("a"); dataUri = document.getElementById('download_sketch').href link.setAttribute("href", dataUri) link.setAttribute("download", "sketch.png") document.body.appendChild(link); // Required for Firefox link.click(); document.body.removeChild(link); // Clean up // also call the output download function theOutputDownloadFunction(); return false } globalThis.theOutputDownloadFunction = () => { console.log("test output download function") var link = document.createElement("a"); dataUri = document.getElementById('download_output').href link.setAttribute("href", dataUri); link.setAttribute("download", "output.png"); document.body.appendChild(link); // Required for Firefox link.click(); document.body.removeChild(link); // Clean up return false } globalThis.UNDO_SKETCH_FUNCTION = () => { console.log("undo sketch function") var button_undo = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(1)'); // Create a new 'click' event var event = new MouseEvent('click', { 'view': window, 'bubbles': true, 'cancelable': true }); button_undo.dispatchEvent(event); } globalThis.DELETE_SKETCH_FUNCTION = () => { console.log("delete sketch function") var button_del = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(2)'); // Create a new 'click' event var event = new MouseEvent('click', { 'view': window, 'bubbles': true, 'cancelable': true }); button_del.dispatchEvent(event); } globalThis.togglePencil = () => { el_pencil = document.getElementById('my-toggle-pencil'); el_pencil.classList.toggle('clicked'); // simulate a click on the gradio button btn_gradio = document.querySelector("#cb-line > label > input"); var event = new MouseEvent('click', { 'view': window, 'bubbles': true, 'cancelable': true }); btn_gradio.dispatchEvent(event); if (el_pencil.classList.contains('clicked')) { document.getElementById('my-toggle-eraser').classList.remove('clicked'); document.getElementById('my-div-pencil').style.backgroundColor = "gray"; document.getElementById('my-div-eraser').style.backgroundColor = "white"; } else { document.getElementById('my-toggle-eraser').classList.add('clicked'); document.getElementById('my-div-pencil').style.backgroundColor = "white"; document.getElementById('my-div-eraser').style.backgroundColor = "gray"; } } globalThis.toggleEraser = () => { element = document.getElementById('my-toggle-eraser'); element.classList.toggle('clicked'); // simulate a click on the gradio button btn_gradio = document.querySelector("#cb-eraser > label > input"); var event = new MouseEvent('click', { 'view': window, 'bubbles': true, 'cancelable': true }); btn_gradio.dispatchEvent(event); if (element.classList.contains('clicked')) { document.getElementById('my-toggle-pencil').classList.remove('clicked'); document.getElementById('my-div-pencil').style.backgroundColor = "white"; document.getElementById('my-div-eraser').style.backgroundColor = "gray"; } else { document.getElementById('my-toggle-pencil').classList.add('clicked'); document.getElementById('my-div-pencil').style.backgroundColor = "gray"; document.getElementById('my-div-eraser').style.backgroundColor = "white"; } } } """ with gr.Blocks(css="style.css") as demo: # these are hidden buttons that are used to trigger the canvas changes line = gr.Checkbox(label="line", value=False, elem_id="cb-line") eraser = gr.Checkbox(label="eraser", value=False, elem_id="cb-eraser") with gr.Row(elem_id="main_row"): with gr.Column(elem_id="column_input"): gr.Markdown("## INPUT", elem_id="input_header") image = gr.Image( source="canvas", tool="color-sketch", type="pil", image_mode="L", invert_colors=True, shape=(512, 512), brush_radius=4, height=440, width=440, brush_color="#000000", interactive=True, show_download_button=True, elem_id="input_image", show_label=False) download_sketch = gr.Button("Download sketch", scale=1, elem_id="download_sketch") gr.HTML("""
""") # gr.Markdown("## Prompt", elem_id="tools_header") prompt = gr.Textbox(label="Prompt", value="", show_label=True) with gr.Row(): style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1) prompt_temp = gr.Textbox(label="Prompt Style Template", value=styles[DEFAULT_STYLE_NAME], scale=2, max_lines=1) with gr.Column(elem_id="column_output"): gr.Markdown("## OUTPUT", elem_id="output_header") result = gr.Image(label="Result", height=440, width=440, elem_id="output_image", show_label=False, show_download_button=True) download_output = gr.Button("Download output", elem_id="download_output") eraser.change(fn=lambda x: gr.update(value=not x), inputs=[eraser], outputs=[line]).then(update_canvas, [line, eraser], [image]) line.change(fn=lambda x: gr.update(value=not x), inputs=[line], outputs=[eraser]).then(update_canvas, [line, eraser], [image]) demo.load(None,None,None,_js=scripts) inputs = [image, prompt, prompt_temp, style] outputs = [result, download_sketch, download_output] prompt.submit(fn=run, inputs=inputs, outputs=outputs) style.change(lambda x: styles[x], inputs=[style], outputs=[prompt_temp]).then( fn=run, inputs=inputs, outputs=outputs,) image.change(run, inputs=inputs, outputs=outputs,) if __name__ == "__main__": demo.queue().launch(debug=True)