import gradio as gr import torch import numpy as np from PIL import Image from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler MAX_IMAGES = 1 def generate_images( type1: str, type2: str, hp_num: int, attack_num: int, defense_num: int, sp_attack_num: int, sp_defense_num: int, speed_num: int, ) -> list: """Generates a sprite based on the input stats. Parameters ---------- Returns ------- list List of PIL images. """ # Initalize the images list images_list = [] # Calculate the base total base_total = ( hp_num + attack_num + defense_num + sp_attack_num + sp_defense_num + speed_num ) # Create the text prompt prompt = f"type1: {type1}, type2: {type2}, base_total: {base_total}, hp: {hp_num}, attack: {attack_num}, defense: {defense_num}, sp_attack: {sp_attack_num}, sp_defense: {sp_defense_num}, speed: {speed_num}" # Generate the images for _ in range(MAX_IMAGES): image = pipe( prompt, height=288, width=288, num_inference_steps=10, guidance_scale=7.5, cross_attention_kwargs={"scale": 1.0}, ).images[0] images_list.append(Image.fromarray(np.array(image))) return images_list # Create the demo interface demo = gr.Blocks() # Set the models to load model_base = "stabilityai/stable-diffusion-2-base" lora_model_path = "michaelriedl/MonsterForgeFusion-sd-2-base" # Create the pipeline pipe = StableDiffusionPipeline.from_pretrained( model_base, torch_dtype=torch.float32, use_safetensors=False, local_files_only=False ) pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe.unet.load_attn_procs(lora_model_path) # Create the interface with demo: gr.HTML( """

Gradio demo for MonsterForgeFusion models. This was built with LoRA fine-tuning of Stable Diffusion models.

""" ) with gr.Column(): with gr.Row(): gallery = gr.Gallery( columns=MAX_IMAGES, preview=True, object_fit="scale-down" ) with gr.Row(): type1 = gr.Dropdown( [ "bug", "dark", "dragon", "electric", "fairy", "fighting", "fire", "flying", "ghost", "grass", "ground", "ice", "normal", "poison", "psychic", "rock", "steel", "water", ], value="steel", label="Type 1", ) type2 = gr.Dropdown( [ "bug", "dark", "dragon", "electric", "fairy", "fighting", "fire", "flying", "ghost", "grass", "ground", "ice", "normal", "poison", "psychic", "rock", "steel", "water", ], value="fire", label="Type 2", ) with gr.Row(): hp_num = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="HP", ) attack_num = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Attack", ) with gr.Row(): defense_num = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Defense", ) sp_attack_num = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Special Attack", ) with gr.Row(): sp_defense_num = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Special Defense", ) speed_num = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Speed", ) gen_btn = gr.Button("Generate") gen_btn.click( fn=generate_images, inputs=[ type1, type2, hp_num, attack_num, defense_num, sp_attack_num, sp_defense_num, speed_num, ], outputs=gallery, ) gr.HTML( """ """ ) # Launch the interface demo.launch()