michaelriedl's picture
Added types for fusion
bba17cc
import gradio as gr
import torch
import numpy as np
from PIL import Image
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
MAX_IMAGES = 1
def generate_images(
type1: str,
type2: str,
hp_num: int,
attack_num: int,
defense_num: int,
sp_attack_num: int,
sp_defense_num: int,
speed_num: int,
) -> list:
"""Generates a sprite based on the input stats.
Parameters
----------
Returns
-------
list
List of PIL images.
"""
# Initalize the images list
images_list = []
# Calculate the base total
base_total = (
hp_num + attack_num + defense_num + sp_attack_num + sp_defense_num + speed_num
)
# Create the text prompt
prompt = f"type1: {type1}, type2: {type2}, base_total: {base_total}, hp: {hp_num}, attack: {attack_num}, defense: {defense_num}, sp_attack: {sp_attack_num}, sp_defense: {sp_defense_num}, speed: {speed_num}"
# Generate the images
for _ in range(MAX_IMAGES):
image = pipe(
prompt,
height=288,
width=288,
num_inference_steps=10,
guidance_scale=7.5,
cross_attention_kwargs={"scale": 1.0},
).images[0]
images_list.append(Image.fromarray(np.array(image)))
return images_list
# Create the demo interface
demo = gr.Blocks()
# Set the models to load
model_base = "stabilityai/stable-diffusion-2-base"
lora_model_path = "michaelriedl/MonsterForgeFusion-sd-2-base"
# Create the pipeline
pipe = StableDiffusionPipeline.from_pretrained(
model_base, torch_dtype=torch.float32, use_safetensors=False, local_files_only=False
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.unet.load_attn_procs(lora_model_path)
# Create the interface
with demo:
gr.HTML(
"""
<div style="text-align: center; margin: 0 auto;">
<p style="margin-bottom: 14px; line-height: 23px;">
Gradio demo for MonsterForgeFusion models. This was built with LoRA fine-tuning of Stable Diffusion models.
</p>
</div>
"""
)
with gr.Column():
with gr.Row():
gallery = gr.Gallery(
columns=MAX_IMAGES, preview=True, object_fit="scale-down"
)
with gr.Row():
type1 = gr.Dropdown(
[
"bug",
"dark",
"dragon",
"electric",
"fairy",
"fighting",
"fire",
"flying",
"ghost",
"grass",
"ground",
"ice",
"normal",
"poison",
"psychic",
"rock",
"steel",
"water",
],
value="steel",
label="Type 1",
)
type2 = gr.Dropdown(
[
"bug",
"dark",
"dragon",
"electric",
"fairy",
"fighting",
"fire",
"flying",
"ghost",
"grass",
"ground",
"ice",
"normal",
"poison",
"psychic",
"rock",
"steel",
"water",
],
value="fire",
label="Type 2",
)
with gr.Row():
hp_num = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="HP",
)
attack_num = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Attack",
)
with gr.Row():
defense_num = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Defense",
)
sp_attack_num = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Special Attack",
)
with gr.Row():
sp_defense_num = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Special Defense",
)
speed_num = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Speed",
)
gen_btn = gr.Button("Generate")
gen_btn.click(
fn=generate_images,
inputs=[
type1,
type2,
hp_num,
attack_num,
defense_num,
sp_attack_num,
sp_defense_num,
speed_num,
],
outputs=gallery,
)
gr.HTML(
"""
<div class="footer">
<div style='text-align: center;'>MonsterForgeFusion by <a href='https://michaelriedl.com/' target='_blank'>Michael Riedl</a></div>
</div>
"""
)
# Launch the interface
demo.launch()