import json from collections import deque from dataclasses import dataclass import threading from typing import Optional import gradio as gr import websockets from gradio.processing_utils import decode_base64_to_image, encode_pil_to_base64 from PIL import Image from websockets.sync.client import connect from constants import DESCRIPTION, WS_ADDRESS, LOGO from utils import replace_background from gradio_examples import EXAMPLES MAX_QUEUE_SIZE = 4 @dataclass class GenerationState: prompts: deque responses: deque def get_initial_state() -> GenerationState: return GenerationState( prompts=deque(maxlen=MAX_QUEUE_SIZE), responses=deque(maxlen=MAX_QUEUE_SIZE), ) def load_initial_state(request: gr.Request) -> GenerationState: print("Loading initial state for", request.client.host) print("Total number of active threads", threading.active_count()) return get_initial_state() async def put_to_queue( image: Optional[Image.Image], prompt: str, seed: int, strength: float, state: GenerationState, ): prompts_queue = state.prompts if prompt and image is not None: prompts_queue.append((image, prompt, seed, strength)) return state def send_inference_request(state: GenerationState) -> Image.Image: prompts_queue = state.prompts response_queue = state.responses if len(prompts_queue) == 0: return state image, prompt, seed, strength = prompts_queue.popleft() original_image_size = image.size image = replace_background(image.resize((512, 512))) arguments = { "prompt": prompt, "image_url": encode_pil_to_base64(image), "strength": strength, "negative_prompt": "cartoon, illustration, animation. face. male, female", "seed": seed, "guidance_scale": 1, "num_inference_steps": 4, "sync_mode": 1, "num_images": 1, } connection = connect(WS_ADDRESS) connection.send(json.dumps(arguments)) try: response = json.loads(connection.recv()) except websockets.exceptions.ConnectionClosedOK: print("Connection closed, reconnecting...") # TODO: This is a hacky way to reconnect, but it works for now # Ideally, we should be able to reconnect to the same connection # and not have to create a new one connection = connect(WS_ADDRESS) try: response = json.loads(connection.recv()) except websockets.exceptions.ConnectionClosedOK: print("Connection closed again, aborting...") return state # TODO: If a new connection is created, the response do not contain the images. if "images" in response: response_queue.append((response, original_image_size)) return state def update_output_image(state: GenerationState): image_update = gr.update() inference_time_update = gr.update() response_queue = state.responses if len(response_queue) > 0: response, original_image_size = response_queue.popleft() generated_image = decode_base64_to_image(response["images"][0]["url"]) inference_time = response["timings"]["inference"] image_update = gr.update(value=generated_image.resize(original_image_size)) inference_time_update = gr.update(value=round(inference_time, 4)) return image_update, inference_time_update, state with gr.Blocks(css="style.css", title=f"Realtime Latent Consistency Model") as demo: generation_state = gr.State(get_initial_state()) gr.HTML(f'
{LOGO}
') gr.Markdown(DESCRIPTION) with gr.Row(variant="default"): input_image = gr.Image( tool="color-sketch", source="canvas", label="Initial Image", type="pil", height=512, width=512, brush_radius=40.0, ) output_image = gr.Image( label="Generated Image", type="pil", interactive=False, elem_id="output_image", ) with gr.Row(): with gr.Column(scale=23): prompt_box = gr.Textbox(label="Prompt", value=EXAMPLES[0]) with gr.Column(scale=1): inference_time_box = gr.Number( label="Inference Time (s)", interactive=False ) with gr.Accordion(label="Advanced Options", open=False): with gr.Row(): with gr.Column(): strength = gr.Slider( label="Strength", minimum=0.1, maximum=1.0, step=0.05, value=0.8, info=""" Strength of the initial image that will be applied during inference. """, ) with gr.Column(): seed = gr.Slider( label="Seed", minimum=0, maximum=2**31 - 1, step=1, randomize=True, info=""" Seed for the random number generator. """, ) demo.load( load_initial_state, outputs=[generation_state], ) demo.load( send_inference_request, inputs=[generation_state], outputs=[generation_state], every=0.1, ) demo.load( update_output_image, inputs=[generation_state], outputs=[output_image, inference_time_box, generation_state], every=0.1, ) for event in [input_image.change, prompt_box.change, strength.change, seed.change]: event( put_to_queue, [input_image, prompt_box, seed, strength, generation_state], [generation_state], show_progress=False, queue=True, ) gr.Markdown("## Example Prompts") gr.Examples(examples=EXAMPLES, inputs=[prompt_box], label="Examples") if __name__ == "__main__": demo.queue(concurrency_count=20, api_open=False).launch(max_threads=8192)