MonsterForge / app.py
michaelriedl's picture
Added medium model
4e39abd
import gradio as gr
from transformers import AutoModel
import torch
import numpy as np
from PIL import Image
MAX_IMAGES = 4
def generate_small(color_indexed: bool, color_num: int) -> list:
"""Generates a small sprite.
Parameters
----------
color_indexed : bool
Whether to use color indexing.
color_num : int
Number of colors in the palette.
Returns
-------
list
List of PIL images.
"""
# Get the latent dimension
latent_dim = model_small.model.latent_dim
# Initialize the list of images
images_list = []
# Generate MAX_IMAGES images
for _ in range(MAX_IMAGES):
# Generate a random latent vector
latents = torch.randn((1, latent_dim))
# Generate the image
with torch.no_grad():
generated_image = model_small(latents)
# Clamp the image to [0, 1]
generated_image = generated_image.clamp_(0.0, 1.0).cpu().numpy()
# Convert the generated image to PIL image
color_image = Image.fromarray(
np.uint8(generated_image[0] * 255).transpose(1, 2, 0), "RGB"
)
# Convert to color indexed image if needed
if color_indexed:
# Convert using adaptive palette of given color depth
color_image_indexed = color_image.convert(
"P", palette=Image.ADAPTIVE, colors=color_num
)
# Add the color indexed image to the list
images_list.append(color_image_indexed)
# Add the image to the list
images_list.append(color_image)
return images_list
def generate_med(color_indexed: bool, color_num: int) -> list:
"""Generates a medium sprite.
Parameters
----------
color_indexed : bool
Whether to use color indexing.
color_num : int
Number of colors in the palette.
Returns
-------
list
List of PIL images.
"""
# Get the latent dimension
latent_dim = model_med.model.latent_dim
# Initialize the list of images
images_list = []
# Generate MAX_IMAGES images
for _ in range(MAX_IMAGES):
# Generate a random latent vector
latents = torch.randn((1, latent_dim))
# Generate the image
with torch.no_grad():
generated_image = model_med(latents)
# Clamp the image to [0, 1]
generated_image = generated_image.clamp_(0.0, 1.0).cpu().numpy()
# Convert the generated image to PIL image
color_image = Image.fromarray(
np.uint8(generated_image[0] * 255).transpose(1, 2, 0), "RGBA"
)
# Convert to color indexed image if needed
if color_indexed:
# Convert using adaptive palette of given color depth
color_image_indexed = color_image.convert(
"P", palette=Image.ADAPTIVE, colors=color_num
)
# Add the color indexed image to the list
images_list.append(color_image_indexed)
# Add the image to the list
images_list.append(color_image)
return images_list
# Create the demo interface
demo = gr.Blocks()
# Create the small model
model_small = AutoModel.from_pretrained(
"michaelriedl/MonsterForge-small", trust_remote_code=True
)
model_small.eval()
# Create the medium model
model_med = AutoModel.from_pretrained(
"michaelriedl/MonsterForge-medium", trust_remote_code=True
)
model_med.eval()
# Create the interface
with demo:
gr.HTML(
"""
<div style="text-align: center; margin: 0 auto;">
<p style="margin-bottom: 14px; line-height: 23px;">
Gradio demo for MonsterForge models. This was built with Lightweight GAN using the implementation from <a href='https://github.com/lucidrains/lightweight-gan' target='_blank'>lucidrains</a>.
</p>
</div>
"""
)
with gr.Tabs():
with gr.TabItem("Small Sprite"):
with gr.Column():
with gr.Row():
gallery_small = gr.Gallery(
columns=4,
object_fit="scale-down",
)
with gr.Row():
color_index_small = gr.Checkbox(label="Color indexed", value=False)
color_num_small = gr.Slider(
minimum=8,
maximum=32,
value=32,
step=4,
label="Number of colors in the palette",
)
gen_btn_small = gr.Button("Generate")
gen_btn_small.click(
fn=generate_small,
inputs=[color_index_small, color_num_small],
outputs=gallery_small,
)
with gr.TabItem("Medium Sprite"):
with gr.Column():
with gr.Row():
gallery_med = gr.Gallery(
columns=4,
object_fit="scale-down",
)
with gr.Row():
color_index_med = gr.Checkbox(label="Color indexed", value=False)
color_num_med = gr.Slider(
minimum=8,
maximum=32,
value=32,
step=4,
label="Number of colors in the palette",
)
gen_btn_med = gr.Button("Generate")
gen_btn_med.click(
fn=generate_med,
inputs=[color_index_med, color_num_med],
outputs=gallery_med,
)
gr.HTML(
"""
<div class="footer">
<div style='text-align: center;'>MonsterForge by <a href='https://michaelriedl.com/' target='_blank'>Michael Riedl</a></div>
</div>
"""
)
# Launch the interface
demo.launch()