from PIL import Image import numpy as np import gradio as gr import spaces import torch from tqdm import tqdm from controlnet import QRControlNet from game_of_life import GameOfLife from utils import resize_image, generate_image_from_grid @spaces.GPU(duration=80) def generate_all_images( gol_grids: list[np.array], source_image: Image, num_inference_steps: int, controlnet_conditioning_scale: float, strength: float, prompt: str, negative_prompt: str, seed: int, guidance_scale: float, img_size: int, ): # device = "mps" # device = "cpu" device = "cuda" print(f"Using {device=}") # Initialize the controlnet (this can take a while the first time it's run) controlnet = QRControlNet(device=device) controlnet_conditioning_scale = float(controlnet_conditioning_scale) source_image = resize_image(source_image, resolution=img_size) images = [] for grid in tqdm(gol_grids): grid_inverse = 1 - grid # invert the grid for controlnet grid_inverse_image = generate_image_from_grid(grid_inverse, img_size=img_size) image = controlnet.generate_image( source_image=source_image, control_image=grid_inverse_image, num_inference_steps=num_inference_steps, controlnet_conditioning_scale=controlnet_conditioning_scale, strength=strength, prompt=prompt, negative_prompt=negative_prompt, seed=seed, guidance_scale=guidance_scale, img_size=img_size, ) images.append(image) return images def make_gif(images: list[Image.Image], gif_path): images[0].save( gif_path, save_all=True, append_images=images[1:], duration=200, # Duration between frames in milliseconds loop=0, ) # Loop forever return gif_path def generate( source_image, prompt, negative_prompt, seed, num_inference_steps, num_gol_steps, gol_grid_dim, img_size, controlnet_conditioning_scale, strength, guidance_scale, ): # Compute the Game of Life first gol = GameOfLife() gol.set_random_state(dim=(gol_grid_dim, gol_grid_dim), p=0.5, seed=seed) gol.generate_n_steps(n=num_gol_steps) gol_grids = gol.game_history # Generate the gif for the original Game of Life gol_images = [ generate_image_from_grid(grid, img_size=img_size) for grid in gol_grids ] path_gol_gif = make_gif(gol_images, "gol_original.gif") # Generate the gif for the ControlNet Game of Life controlnet_images = generate_all_images( gol_grids=gol_grids, source_image=source_image, num_inference_steps=num_inference_steps, controlnet_conditioning_scale=controlnet_conditioning_scale, strength=strength, prompt=prompt, negative_prompt=negative_prompt, seed=seed, guidance_scale=guidance_scale, img_size=img_size, ) path_gol_controlnet = make_gif(controlnet_images, "gol_controlnet.gif") return path_gol_controlnet, path_gol_gif source_image = gr.Image(label="Source Image", type="pil", value="sky-gol-image.jpeg") output_controlnet = gr.Image(label="ControlNet Game of Life") output_gol = gr.Image(label="Original Game of Life") prompt = gr.Textbox( label="Prompt", value="clear sky with clouds, high quality, background 4k" ) negative_prompt = gr.Textbox( label="Negative Prompt", value="ugly, disfigured, low quality, blurry, nsfw, qr code", ) seed = gr.Number(label="Seed", value=42) num_inference_steps = gr.Number(label="Controlnet Inference Steps per frame", value=30) num_gol_steps = gr.Slider( label="Number of Game of Life Steps to Generate", minimum=2, maximum=100, step=1, value=6, ) gol_grid_dim = gr.Number( label="Game of Life Grid Dimension", value=10, ) img_size = gr.Number(label="Image Size (pixels)", value=512) controlnet_conditioning_scale = gr.Slider( label="Controlnet Conditioning Scale", minimum=0.1, maximum=10.0, value=2.0 ) strength = gr.Slider(label="Strength", minimum=0.1, maximum=1.0, value=0.9) guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=100, value=20) demo = gr.Interface( fn=generate, inputs=[ source_image, prompt, negative_prompt, seed, num_inference_steps, num_gol_steps, gol_grid_dim, img_size, controlnet_conditioning_scale, strength, guidance_scale, ], outputs=[output_controlnet, output_gol], title="ControlNet Game of Life", description="""Generate a Game of Life grid and then use ControlNet to enhance the image based on the grid, a reference image and a prompt. For more information, check out this [blog post](https://www.jerpint.io/blog/diffusion-gol/). Generating frames can be slow and eat up GPU usage, for longer runtimes, you can checkout the [colab](https://colab.research.google.com/github/jerpint/jerpint.github.io/blob/master/colabs/gol_diffusion.ipynb) implementation. """, ) demo.launch(debug=True)