Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModel | |
import torch | |
import numpy as np | |
from PIL import Image | |
MAX_IMAGES = 4 | |
def generate_small(color_indexed: bool, color_num: int) -> list: | |
"""Generates a small sprite. | |
Parameters | |
---------- | |
color_indexed : bool | |
Whether to use color indexing. | |
color_num : int | |
Number of colors in the palette. | |
Returns | |
------- | |
list | |
List of PIL images. | |
""" | |
# Get the latent dimension | |
latent_dim = model_small.model.latent_dim | |
# Initialize the list of images | |
images_list = [] | |
# Generate MAX_IMAGES images | |
for _ in range(MAX_IMAGES): | |
# Generate a random latent vector | |
latents = torch.randn((1, latent_dim)) | |
# Generate the image | |
with torch.no_grad(): | |
generated_image = model_small(latents) | |
# Clamp the image to [0, 1] | |
generated_image = generated_image.clamp_(0.0, 1.0).cpu().numpy() | |
# Convert the generated image to PIL image | |
color_image = Image.fromarray( | |
np.uint8(generated_image[0] * 255).transpose(1, 2, 0), "RGB" | |
) | |
# Convert to color indexed image if needed | |
if color_indexed: | |
# Convert using adaptive palette of given color depth | |
color_image_indexed = color_image.convert( | |
"P", palette=Image.ADAPTIVE, colors=color_num | |
) | |
# Add the color indexed image to the list | |
images_list.append(color_image_indexed) | |
# Add the image to the list | |
images_list.append(color_image) | |
return images_list | |
def generate_med(color_indexed: bool, color_num: int) -> list: | |
"""Generates a medium sprite. | |
Parameters | |
---------- | |
color_indexed : bool | |
Whether to use color indexing. | |
color_num : int | |
Number of colors in the palette. | |
Returns | |
------- | |
list | |
List of PIL images. | |
""" | |
# Get the latent dimension | |
latent_dim = model_med.model.latent_dim | |
# Initialize the list of images | |
images_list = [] | |
# Generate MAX_IMAGES images | |
for _ in range(MAX_IMAGES): | |
# Generate a random latent vector | |
latents = torch.randn((1, latent_dim)) | |
# Generate the image | |
with torch.no_grad(): | |
generated_image = model_med(latents) | |
# Clamp the image to [0, 1] | |
generated_image = generated_image.clamp_(0.0, 1.0).cpu().numpy() | |
# Convert the generated image to PIL image | |
color_image = Image.fromarray( | |
np.uint8(generated_image[0] * 255).transpose(1, 2, 0), "RGBA" | |
) | |
# Convert to color indexed image if needed | |
if color_indexed: | |
# Convert using adaptive palette of given color depth | |
color_image_indexed = color_image.convert( | |
"P", palette=Image.ADAPTIVE, colors=color_num | |
) | |
# Add the color indexed image to the list | |
images_list.append(color_image_indexed) | |
# Add the image to the list | |
images_list.append(color_image) | |
return images_list | |
# Create the demo interface | |
demo = gr.Blocks() | |
# Create the small model | |
model_small = AutoModel.from_pretrained( | |
"michaelriedl/MonsterForge-small", trust_remote_code=True | |
) | |
model_small.eval() | |
# Create the medium model | |
model_med = AutoModel.from_pretrained( | |
"michaelriedl/MonsterForge-medium", trust_remote_code=True | |
) | |
model_med.eval() | |
# Create the interface | |
with demo: | |
gr.HTML( | |
""" | |
<div style="text-align: center; margin: 0 auto;"> | |
<p style="margin-bottom: 14px; line-height: 23px;"> | |
Gradio demo for MonsterForge models. This was built with Lightweight GAN using the implementation from <a href='https://github.com/lucidrains/lightweight-gan' target='_blank'>lucidrains</a>. | |
</p> | |
</div> | |
""" | |
) | |
with gr.Tabs(): | |
with gr.TabItem("Small Sprite"): | |
with gr.Column(): | |
with gr.Row(): | |
gallery_small = gr.Gallery( | |
columns=4, | |
object_fit="scale-down", | |
) | |
with gr.Row(): | |
color_index_small = gr.Checkbox(label="Color indexed", value=False) | |
color_num_small = gr.Slider( | |
minimum=8, | |
maximum=32, | |
value=32, | |
step=4, | |
label="Number of colors in the palette", | |
) | |
gen_btn_small = gr.Button("Generate") | |
gen_btn_small.click( | |
fn=generate_small, | |
inputs=[color_index_small, color_num_small], | |
outputs=gallery_small, | |
) | |
with gr.TabItem("Medium Sprite"): | |
with gr.Column(): | |
with gr.Row(): | |
gallery_med = gr.Gallery( | |
columns=4, | |
object_fit="scale-down", | |
) | |
with gr.Row(): | |
color_index_med = gr.Checkbox(label="Color indexed", value=False) | |
color_num_med = gr.Slider( | |
minimum=8, | |
maximum=32, | |
value=32, | |
step=4, | |
label="Number of colors in the palette", | |
) | |
gen_btn_med = gr.Button("Generate") | |
gen_btn_med.click( | |
fn=generate_med, | |
inputs=[color_index_med, color_num_med], | |
outputs=gallery_med, | |
) | |
gr.HTML( | |
""" | |
<div class="footer"> | |
<div style='text-align: center;'>MonsterForge by <a href='https://michaelriedl.com/' target='_blank'>Michael Riedl</a></div> | |
</div> | |
""" | |
) | |
# Launch the interface | |
demo.launch() | |