Spaces:
Running
Running
| import json | |
| from collections import deque | |
| from dataclasses import dataclass | |
| import threading | |
| from typing import Optional | |
| import gradio as gr | |
| import websockets | |
| from gradio.processing_utils import decode_base64_to_image, encode_pil_to_base64 | |
| from PIL import Image | |
| from websockets.sync.client import connect | |
| from constants import DESCRIPTION, WS_ADDRESS, LOGO | |
| from utils import replace_background | |
| from gradio_examples import EXAMPLES | |
| MAX_QUEUE_SIZE = 4 | |
| class GenerationState: | |
| prompts: deque | |
| responses: deque | |
| def get_initial_state() -> GenerationState: | |
| return GenerationState( | |
| prompts=deque(maxlen=MAX_QUEUE_SIZE), | |
| responses=deque(maxlen=MAX_QUEUE_SIZE), | |
| ) | |
| def load_initial_state(request: gr.Request) -> GenerationState: | |
| print("Loading initial state for", request.client.host) | |
| print("Total number of active threads", threading.active_count()) | |
| return get_initial_state() | |
| async def put_to_queue( | |
| image: Optional[Image.Image], | |
| prompt: str, | |
| seed: int, | |
| strength: float, | |
| state: GenerationState, | |
| ): | |
| prompts_queue = state.prompts | |
| if prompt and image is not None: | |
| prompts_queue.append((image, prompt, seed, strength)) | |
| return state | |
| def send_inference_request(state: GenerationState) -> Image.Image: | |
| prompts_queue = state.prompts | |
| response_queue = state.responses | |
| if len(prompts_queue) == 0: | |
| return state | |
| image, prompt, seed, strength = prompts_queue.popleft() | |
| original_image_size = image.size | |
| image = replace_background(image.resize((512, 512))) | |
| arguments = { | |
| "prompt": prompt, | |
| "image_url": encode_pil_to_base64(image), | |
| "strength": strength, | |
| "negative_prompt": "cartoon, illustration, animation. face. male, female", | |
| "seed": seed, | |
| "guidance_scale": 1, | |
| "num_inference_steps": 4, | |
| "sync_mode": 1, | |
| "num_images": 1, | |
| } | |
| connection = connect(WS_ADDRESS) | |
| connection.send(json.dumps(arguments)) | |
| try: | |
| response = json.loads(connection.recv()) | |
| except websockets.exceptions.ConnectionClosedOK: | |
| print("Connection closed, reconnecting...") | |
| # TODO: This is a hacky way to reconnect, but it works for now | |
| # Ideally, we should be able to reconnect to the same connection | |
| # and not have to create a new one | |
| connection = connect(WS_ADDRESS) | |
| try: | |
| response = json.loads(connection.recv()) | |
| except websockets.exceptions.ConnectionClosedOK: | |
| print("Connection closed again, aborting...") | |
| return state | |
| # TODO: If a new connection is created, the response do not contain the images. | |
| if "images" in response: | |
| response_queue.append((response, original_image_size)) | |
| return state | |
| def update_output_image(state: GenerationState): | |
| image_update = gr.update() | |
| inference_time_update = gr.update() | |
| response_queue = state.responses | |
| if len(response_queue) > 0: | |
| response, original_image_size = response_queue.popleft() | |
| generated_image = decode_base64_to_image(response["images"][0]["url"]) | |
| inference_time = response["timings"]["inference"] | |
| image_update = gr.update(value=generated_image.resize(original_image_size)) | |
| inference_time_update = gr.update(value=round(inference_time, 4)) | |
| return image_update, inference_time_update, state | |
| with gr.Blocks(css="style.css", title=f"Realtime Latent Consistency Model") as demo: | |
| generation_state = gr.State(get_initial_state()) | |
| gr.HTML(f'<div style="width: 70px;">{LOGO}</div>') | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(variant="default"): | |
| input_image = gr.Image( | |
| tool="color-sketch", | |
| source="canvas", | |
| label="Initial Image", | |
| type="pil", | |
| height=512, | |
| width=512, | |
| brush_radius=40.0, | |
| ) | |
| output_image = gr.Image( | |
| label="Generated Image", | |
| type="pil", | |
| interactive=False, | |
| elem_id="output_image", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=23): | |
| prompt_box = gr.Textbox(label="Prompt", value=EXAMPLES[0]) | |
| with gr.Column(scale=1): | |
| inference_time_box = gr.Number( | |
| label="Inference Time (s)", interactive=False | |
| ) | |
| with gr.Accordion(label="Advanced Options", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| strength = gr.Slider( | |
| label="Strength", | |
| minimum=0.1, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.8, | |
| info=""" | |
| Strength of the initial image that will be applied during inference. | |
| """, | |
| ) | |
| with gr.Column(): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=2**31 - 1, | |
| step=1, | |
| randomize=True, | |
| info=""" | |
| Seed for the random number generator. | |
| """, | |
| ) | |
| demo.load( | |
| load_initial_state, | |
| outputs=[generation_state], | |
| ) | |
| demo.load( | |
| send_inference_request, | |
| inputs=[generation_state], | |
| outputs=[generation_state], | |
| every=0.1, | |
| ) | |
| demo.load( | |
| update_output_image, | |
| inputs=[generation_state], | |
| outputs=[output_image, inference_time_box, generation_state], | |
| every=0.1, | |
| ) | |
| for event in [input_image.change, prompt_box.change, strength.change, seed.change]: | |
| event( | |
| put_to_queue, | |
| [input_image, prompt_box, seed, strength, generation_state], | |
| [generation_state], | |
| show_progress=False, | |
| queue=True, | |
| ) | |
| gr.Markdown("## Example Prompts") | |
| gr.Examples(examples=EXAMPLES, inputs=[prompt_box], label="Examples") | |
| if __name__ == "__main__": | |
| demo.queue(concurrency_count=20, api_open=False).launch(max_threads=8192) | |