michaelriedl's picture
Added types for fusion
bba17cc
raw
history blame
5.76 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
MAX_IMAGES = 1
def generate_images(
type1: str,
type2: str,
hp_num: int,
attack_num: int,
defense_num: int,
sp_attack_num: int,
sp_defense_num: int,
speed_num: int,
) -> list:
"""Generates a sprite based on the input stats.
Parameters
----------
Returns
-------
list
List of PIL images.
"""
# Initalize the images list
images_list = []
# Calculate the base total
base_total = (
hp_num + attack_num + defense_num + sp_attack_num + sp_defense_num + speed_num
)
# Create the text prompt
prompt = f"type1: {type1}, type2: {type2}, base_total: {base_total}, hp: {hp_num}, attack: {attack_num}, defense: {defense_num}, sp_attack: {sp_attack_num}, sp_defense: {sp_defense_num}, speed: {speed_num}"
# Generate the images
for _ in range(MAX_IMAGES):
image = pipe(
prompt,
height=288,
width=288,
num_inference_steps=10,
guidance_scale=7.5,
cross_attention_kwargs={"scale": 1.0},
).images[0]
images_list.append(Image.fromarray(np.array(image)))
return images_list
# Create the demo interface
demo = gr.Blocks()
# Set the models to load
model_base = "stabilityai/stable-diffusion-2-base"
lora_model_path = "michaelriedl/MonsterForgeFusion-sd-2-base"
# Create the pipeline
pipe = StableDiffusionPipeline.from_pretrained(
model_base, torch_dtype=torch.float32, use_safetensors=False, local_files_only=False
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.unet.load_attn_procs(lora_model_path)
# Create the interface
with demo:
gr.HTML(
"""
<div style="text-align: center; margin: 0 auto;">
<p style="margin-bottom: 14px; line-height: 23px;">
Gradio demo for MonsterForgeFusion models. This was built with LoRA fine-tuning of Stable Diffusion models.
</p>
</div>
"""
)
with gr.Column():
with gr.Row():
gallery = gr.Gallery(
columns=MAX_IMAGES, preview=True, object_fit="scale-down"
)
with gr.Row():
type1 = gr.Dropdown(
[
"bug",
"dark",
"dragon",
"electric",
"fairy",
"fighting",
"fire",
"flying",
"ghost",
"grass",
"ground",
"ice",
"normal",
"poison",
"psychic",
"rock",
"steel",
"water",
],
value="steel",
label="Type 1",
)
type2 = gr.Dropdown(
[
"bug",
"dark",
"dragon",
"electric",
"fairy",
"fighting",
"fire",
"flying",
"ghost",
"grass",
"ground",
"ice",
"normal",
"poison",
"psychic",
"rock",
"steel",
"water",
],
value="fire",
label="Type 2",
)
with gr.Row():
hp_num = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="HP",
)
attack_num = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Attack",
)
with gr.Row():
defense_num = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Defense",
)
sp_attack_num = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Special Attack",
)
with gr.Row():
sp_defense_num = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Special Defense",
)
speed_num = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Speed",
)
gen_btn = gr.Button("Generate")
gen_btn.click(
fn=generate_images,
inputs=[
type1,
type2,
hp_num,
attack_num,
defense_num,
sp_attack_num,
sp_defense_num,
speed_num,
],
outputs=gallery,
)
gr.HTML(
"""
<div class="footer">
<div style='text-align: center;'>MonsterForgeFusion by <a href='https://michaelriedl.com/' target='_blank'>Michael Riedl</a></div>
</div>
"""
)
# Launch the interface
demo.launch()