Spaces:
Running
Running
import json | |
from collections import deque | |
from dataclasses import dataclass | |
from typing import Optional | |
import gradio as gr | |
import websockets | |
from gradio.processing_utils import decode_base64_to_image, encode_pil_to_base64 | |
from PIL import Image | |
from websockets.sync.client import connect | |
from constants import DESCRIPTION, WS_ADDRESS, LOGO | |
from utils import replace_background | |
from gradio_examples import EXAMPLES | |
MAX_QUEUE_SIZE = 4 | |
class GenerationState: | |
prompts: deque | |
responses: deque | |
def get_initial_state() -> GenerationState: | |
return GenerationState( | |
prompts=deque(maxlen=MAX_QUEUE_SIZE), | |
responses=deque(maxlen=MAX_QUEUE_SIZE), | |
) | |
def load_initial_state(request: gr.Request) -> GenerationState: | |
print("Loading initial state for", request.client.host) | |
return get_initial_state() | |
async def put_to_queue( | |
image: Optional[Image.Image], | |
prompt: str, | |
seed: int, | |
strength: float, | |
state: GenerationState, | |
): | |
prompts_queue = state.prompts | |
if prompt and image is not None: | |
prompts_queue.append((image, prompt, seed, strength)) | |
return state | |
def send_inference_request(state: GenerationState) -> Image.Image: | |
prompts_queue = state.prompts | |
response_queue = state.responses | |
if len(prompts_queue) == 0: | |
return state | |
image, prompt, seed, strength = prompts_queue.popleft() | |
original_image_size = image.size | |
image = replace_background(image.resize((512, 512))) | |
arguments = { | |
"prompt": prompt, | |
"image_url": encode_pil_to_base64(image), | |
"strength": strength, | |
"negative_prompt": "cartoon, illustration, animation. face. male, female", | |
"seed": seed, | |
"guidance_scale": 1, | |
"num_inference_steps": 4, | |
"sync_mode": 1, | |
"num_images": 1, | |
} | |
connection = connect(WS_ADDRESS) | |
connection.send(json.dumps(arguments)) | |
try: | |
response = json.loads(connection.recv()) | |
except websockets.exceptions.ConnectionClosedOK: | |
print("Connection closed, reconnecting...") | |
# TODO: This is a hacky way to reconnect, but it works for now | |
# Ideally, we should be able to reconnect to the same connection | |
# and not have to create a new one | |
connection = connect(WS_ADDRESS) | |
try: | |
response = json.loads(connection.recv()) | |
except websockets.exceptions.ConnectionClosedOK: | |
print("Connection closed again, aborting...") | |
return state | |
# TODO: If a new connection is created, the response do not contain the images. | |
if "images" in response: | |
response_queue.append((response, original_image_size)) | |
return state | |
def update_output_image(state: GenerationState): | |
image_update = gr.update() | |
inference_time_update = gr.update() | |
response_queue = state.responses | |
if len(response_queue) > 0: | |
response, original_image_size = response_queue.popleft() | |
generated_image = decode_base64_to_image(response["images"][0]["url"]) | |
inference_time = response["timings"]["inference"] | |
generated_image.resize(original_image_size).save("generated.png") | |
image_update = gr.update(value=generated_image.resize(original_image_size)) | |
inference_time_update = gr.update(value=round(inference_time, 4)) | |
return image_update, inference_time_update, state | |
with gr.Blocks(css="style.css", title=f"Realtime Latent Consistency Model") as demo: | |
generation_state = gr.State(get_initial_state()) | |
gr.HTML(f'<div style="width: 70px;">{LOGO}</div>') | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(variant="default"): | |
input_image = gr.Image( | |
tool="color-sketch", | |
source="canvas", | |
label="Initial Image", | |
type="pil", | |
height=512, | |
width=512, | |
brush_radius=40.0, | |
) | |
output_image = gr.Image( | |
label="Generated Image", | |
type="pil", | |
interactive=False, | |
elem_id="output_image", | |
) | |
with gr.Row(): | |
with gr.Column(scale=23): | |
prompt_box = gr.Textbox(label="Prompt") | |
with gr.Column(scale=1): | |
inference_time_box = gr.Number( | |
label="Inference Time (s)", interactive=False | |
) | |
with gr.Accordion(label="Advanced Options", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
strength = gr.Slider( | |
label="Strength", | |
minimum=0.1, | |
maximum=1.0, | |
step=0.05, | |
value=0.8, | |
info=""" | |
Strength of the initial image that will be applied during inference. | |
""", | |
) | |
with gr.Column(): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=2**31 - 1, | |
step=1, | |
randomize=True, | |
info=""" | |
Seed for the random number generator. | |
""", | |
) | |
demo.load( | |
load_initial_state, | |
outputs=[generation_state], | |
) | |
demo.load( | |
send_inference_request, | |
inputs=[generation_state], | |
outputs=[generation_state], | |
every=0.1, | |
) | |
demo.load( | |
update_output_image, | |
inputs=[generation_state], | |
outputs=[output_image, inference_time_box, generation_state], | |
every=0.1, | |
) | |
for event in [input_image.change, prompt_box.change, strength.change, seed.change]: | |
event( | |
put_to_queue, | |
[input_image, prompt_box, seed, strength, generation_state], | |
[generation_state], | |
show_progress=False, | |
queue=True, | |
) | |
gr.Markdown("## Example Prompts") | |
gr.Examples(examples=EXAMPLES, inputs=[prompt_box], label="Examples") | |
if __name__ == "__main__": | |
demo.queue(concurrency_count=4, api_open=False).launch() | |