import gradio as gr from transformers import AutoModel import torch import numpy as np from PIL import Image MAX_IMAGES = 4 def generate_small(color_indexed: bool, color_num: int) -> list: """Generates a small sprite. Parameters ---------- color_indexed : bool Whether to use color indexing. color_num : int Number of colors in the palette. Returns ------- list List of PIL images. """ # Get the latent dimension latent_dim = model_small.model.latent_dim # Initialize the list of images images_list = [] # Generate MAX_IMAGES images for _ in range(MAX_IMAGES): # Generate a random latent vector latents = torch.randn((1, latent_dim)) # Generate the image with torch.no_grad(): generated_image = model_small(latents) # Clamp the image to [0, 1] generated_image = generated_image.clamp_(0.0, 1.0).cpu().numpy() # Convert the generated image to PIL image color_image = Image.fromarray( np.uint8(generated_image[0] * 255).transpose(1, 2, 0), "RGB" ) # Convert to color indexed image if needed if color_indexed: # Convert using adaptive palette of given color depth color_image_indexed = color_image.convert( "P", palette=Image.ADAPTIVE, colors=color_num ) # Add the color indexed image to the list images_list.append(color_image_indexed) # Add the image to the list images_list.append(color_image) return images_list def generate_med(color_indexed: bool, color_num: int) -> list: """Generates a medium sprite. Parameters ---------- color_indexed : bool Whether to use color indexing. color_num : int Number of colors in the palette. Returns ------- list List of PIL images. """ # Get the latent dimension latent_dim = model_med.model.latent_dim # Initialize the list of images images_list = [] # Generate MAX_IMAGES images for _ in range(MAX_IMAGES): # Generate a random latent vector latents = torch.randn((1, latent_dim)) # Generate the image with torch.no_grad(): generated_image = model_med(latents) # Clamp the image to [0, 1] generated_image = generated_image.clamp_(0.0, 1.0).cpu().numpy() # Convert the generated image to PIL image color_image = Image.fromarray( np.uint8(generated_image[0] * 255).transpose(1, 2, 0), "RGBA" ) # Convert to color indexed image if needed if color_indexed: # Convert using adaptive palette of given color depth color_image_indexed = color_image.convert( "P", palette=Image.ADAPTIVE, colors=color_num ) # Add the color indexed image to the list images_list.append(color_image_indexed) # Add the image to the list images_list.append(color_image) return images_list # Create the demo interface demo = gr.Blocks() # Create the small model model_small = AutoModel.from_pretrained( "michaelriedl/MonsterForge-small", trust_remote_code=True ) model_small.eval() # Create the medium model model_med = AutoModel.from_pretrained( "michaelriedl/MonsterForge-medium", trust_remote_code=True ) model_med.eval() # Create the interface with demo: gr.HTML( """

Gradio demo for MonsterForge models. This was built with Lightweight GAN using the implementation from lucidrains.

""" ) with gr.Tabs(): with gr.TabItem("Small Sprite"): with gr.Column(): with gr.Row(): gallery_small = gr.Gallery( columns=4, object_fit="scale-down", ) with gr.Row(): color_index_small = gr.Checkbox(label="Color indexed", value=False) color_num_small = gr.Slider( minimum=8, maximum=32, value=32, step=4, label="Number of colors in the palette", ) gen_btn_small = gr.Button("Generate") gen_btn_small.click( fn=generate_small, inputs=[color_index_small, color_num_small], outputs=gallery_small, ) with gr.TabItem("Medium Sprite"): with gr.Column(): with gr.Row(): gallery_med = gr.Gallery( columns=4, object_fit="scale-down", ) with gr.Row(): color_index_med = gr.Checkbox(label="Color indexed", value=False) color_num_med = gr.Slider( minimum=8, maximum=32, value=32, step=4, label="Number of colors in the palette", ) gen_btn_med = gr.Button("Generate") gen_btn_med.click( fn=generate_med, inputs=[color_index_med, color_num_med], outputs=gallery_med, ) gr.HTML( """ """ ) # Launch the interface demo.launch()