import gradio as gr from transformers import AutoModel import torch import numpy as np from PIL import Image MAX_IMAGES = 4 def generate_small(color_indexed: bool, color_num: int) -> list: """Generates a small sprite. Parameters ---------- color_indexed : bool Whether to use color indexing. color_num : int Number of colors in the palette. Returns ------- list List of PIL images. """ # Get the latent dimension latent_dim = model_small.model.latent_dim # Initialize the list of images images_list = [] # Generate MAX_IMAGES images for _ in range(MAX_IMAGES): # Generate a random latent vector latents = torch.randn((1, latent_dim)) # Generate the image with torch.no_grad(): generated_image = model_small(latents) # Clamp the image to [0, 1] generated_image = generated_image.clamp_(0.0, 1.0).cpu().numpy() # Convert the generated image to PIL image color_image = Image.fromarray( np.uint8(generated_image[0] * 255).transpose(1, 2, 0), "RGB" ) # Convert to color indexed image if needed if color_indexed: # Convert using adaptive palette of given color depth color_image_indexed = color_image.convert( "P", palette=Image.ADAPTIVE, colors=color_num ) # Add the color indexed image to the list images_list.append(color_image_indexed) # Add the image to the list images_list.append(color_image) return images_list def generate_med(color_indexed: bool, color_num: int) -> list: """Generates a medium sprite. Parameters ---------- color_indexed : bool Whether to use color indexing. color_num : int Number of colors in the palette. Returns ------- list List of PIL images. """ # Get the latent dimension latent_dim = model_med.model.latent_dim # Initialize the list of images images_list = [] # Generate MAX_IMAGES images for _ in range(MAX_IMAGES): # Generate a random latent vector latents = torch.randn((1, latent_dim)) # Generate the image with torch.no_grad(): generated_image = model_med(latents) # Clamp the image to [0, 1] generated_image = generated_image.clamp_(0.0, 1.0).cpu().numpy() # Convert the generated image to PIL image color_image = Image.fromarray( np.uint8(generated_image[0] * 255).transpose(1, 2, 0), "RGBA" ) # Convert to color indexed image if needed if color_indexed: # Convert using adaptive palette of given color depth color_image_indexed = color_image.convert( "P", palette=Image.ADAPTIVE, colors=color_num ) # Add the color indexed image to the list images_list.append(color_image_indexed) # Add the image to the list images_list.append(color_image) return images_list # Create the demo interface demo = gr.Blocks() # Create the small model model_small = AutoModel.from_pretrained( "michaelriedl/MonsterForge-small", trust_remote_code=True ) model_small.eval() # Create the medium model model_med = AutoModel.from_pretrained( "michaelriedl/MonsterForge-medium", trust_remote_code=True ) model_med.eval() # Create the interface with demo: gr.HTML( """
Gradio demo for MonsterForge models. This was built with Lightweight GAN using the implementation from lucidrains.