Spaces:
Running
Running
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
MAX_IMAGES = 1 | |
def generate_images( | |
type1: str, | |
type2: str, | |
hp_num: int, | |
attack_num: int, | |
defense_num: int, | |
sp_attack_num: int, | |
sp_defense_num: int, | |
speed_num: int, | |
) -> list: | |
"""Generates a sprite based on the input stats. | |
Parameters | |
---------- | |
Returns | |
------- | |
list | |
List of PIL images. | |
""" | |
# Initalize the images list | |
images_list = [] | |
# Calculate the base total | |
base_total = ( | |
hp_num + attack_num + defense_num + sp_attack_num + sp_defense_num + speed_num | |
) | |
# Create the text prompt | |
prompt = f"type1: {type1}, type2: {type2}, base_total: {base_total}, hp: {hp_num}, attack: {attack_num}, defense: {defense_num}, sp_attack: {sp_attack_num}, sp_defense: {sp_defense_num}, speed: {speed_num}" | |
# Generate the images | |
for _ in range(MAX_IMAGES): | |
image = pipe( | |
prompt, | |
height=288, | |
width=288, | |
num_inference_steps=10, | |
guidance_scale=7.5, | |
cross_attention_kwargs={"scale": 1.0}, | |
).images[0] | |
images_list.append(Image.fromarray(np.array(image))) | |
return images_list | |
# Create the demo interface | |
demo = gr.Blocks() | |
# Set the models to load | |
model_base = "stabilityai/stable-diffusion-2-base" | |
lora_model_path = "michaelriedl/MonsterForgeFusion-sd-2-base" | |
# Create the pipeline | |
pipe = StableDiffusionPipeline.from_pretrained( | |
model_base, torch_dtype=torch.float32, use_safetensors=False, local_files_only=False | |
) | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe.unet.load_attn_procs(lora_model_path) | |
# Create the interface | |
with demo: | |
gr.HTML( | |
""" | |
<div style="text-align: center; margin: 0 auto;"> | |
<p style="margin-bottom: 14px; line-height: 23px;"> | |
Gradio demo for MonsterForgeFusion models. This was built with LoRA fine-tuning of Stable Diffusion models. | |
</p> | |
</div> | |
""" | |
) | |
with gr.Column(): | |
with gr.Row(): | |
gallery = gr.Gallery( | |
columns=MAX_IMAGES, preview=True, object_fit="scale-down" | |
) | |
with gr.Row(): | |
type1 = gr.Dropdown( | |
[ | |
"bug", | |
"dark", | |
"dragon", | |
"electric", | |
"fairy", | |
"fighting", | |
"fire", | |
"flying", | |
"ghost", | |
"grass", | |
"ground", | |
"ice", | |
"normal", | |
"poison", | |
"psychic", | |
"rock", | |
"steel", | |
"water", | |
], | |
value="steel", | |
label="Type 1", | |
) | |
type2 = gr.Dropdown( | |
[ | |
"bug", | |
"dark", | |
"dragon", | |
"electric", | |
"fairy", | |
"fighting", | |
"fire", | |
"flying", | |
"ghost", | |
"grass", | |
"ground", | |
"ice", | |
"normal", | |
"poison", | |
"psychic", | |
"rock", | |
"steel", | |
"water", | |
], | |
value="fire", | |
label="Type 2", | |
) | |
with gr.Row(): | |
hp_num = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=1, | |
label="HP", | |
) | |
attack_num = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=1, | |
label="Attack", | |
) | |
with gr.Row(): | |
defense_num = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=1, | |
label="Defense", | |
) | |
sp_attack_num = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=1, | |
label="Special Attack", | |
) | |
with gr.Row(): | |
sp_defense_num = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=1, | |
label="Special Defense", | |
) | |
speed_num = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=1, | |
label="Speed", | |
) | |
gen_btn = gr.Button("Generate") | |
gen_btn.click( | |
fn=generate_images, | |
inputs=[ | |
type1, | |
type2, | |
hp_num, | |
attack_num, | |
defense_num, | |
sp_attack_num, | |
sp_defense_num, | |
speed_num, | |
], | |
outputs=gallery, | |
) | |
gr.HTML( | |
""" | |
<div class="footer"> | |
<div style='text-align: center;'>MonsterForgeFusion by <a href='https://michaelriedl.com/' target='_blank'>Michael Riedl</a></div> | |
</div> | |
""" | |
) | |
# Launch the interface | |
demo.launch() | |