import spaces import os import random import math import torch torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = True import numpy as np from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( StableDiffusionXLPipeline, ) from diffusers.schedulers.scheduling_euler_ancestral_discrete import ( EulerAncestralDiscreteScheduler, ) from diffusers.models.attention_processor import AttnProcessor2_0 from transformers import AutoModelForCausalLM, AutoTokenizer import gradio as gr try: from dotenv import load_dotenv load_dotenv() except: print("failed to import dotenv (this is not a problem on the production)") HF_TOKEN = os.environ.get("HF_TOKEN") assert HF_TOKEN is not None IMAGE_MODEL_REPO_ID = os.environ.get( "IMAGE_MODEL_REPO_ID", "OnomaAIResearch/Illustrious-xl-early-release-v0" ) DART_V3_REPO_ID = os.environ.get("DART_V3_REPO_ID", None) assert DART_V3_REPO_ID is not None CPU_OFFLOAD = os.environ.get("CPU_OFFLOAD", "False").lower() == "true" MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 TEMPLATE = ( "<|bos|>" # "<|rating:general|>" "{aspect_ratio}" "<|length:medium|>" # "" # "" # "{subject}" ) QUALITY_TAGS = "original style" NEGATIVE_PROMPT = "lowres, blurry, watermark, signature, copyright, logo, artistic error, bad anatomy, bad hands, retro, 2000s, 2010s, 2011s, 2012s, 2013s" BAN_TAGS = [ "photoshop (medium)", "clip studio paint (medium)", "2005", # year tags "2006", "2007", "2008", "2009", "2010", "2011", "2012", "2013", "2014", "2015", "2016", "2017", "2018", "2019", "2020", ] device = "cuda" if torch.cuda.is_available() else "cpu" dart = AutoModelForCausalLM.from_pretrained( DART_V3_REPO_ID, torch_dtype=torch.bfloat16, token=HF_TOKEN, use_cache=True, device_map="cpu", ) dart = dart.eval() dart = dart.requires_grad_(False) dart = torch.compile(dart) tokenizer = AutoTokenizer.from_pretrained(DART_V3_REPO_ID) BAN_TOKENS = [tokenizer.convert_tokens_to_ids([tag]) for tag in BAN_TAGS] def load_pipeline(): vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, ) pipe = StableDiffusionXLPipeline.from_pretrained( IMAGE_MODEL_REPO_ID, vae=vae, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False, custom_pipeline="lpw_stable_diffusion_xl", ) pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) if CPU_OFFLOAD: # local pipe.enable_sequential_cpu_offload(gpu_id=0, device=device) else: pipe.to(device) # for spaces return pipe if torch.cuda.is_available(): pipe = load_pipeline() print("Loaded pipeline") else: pipe = None def get_aspect_ratio(width: int, height: int) -> str: ar = width / height if ar <= 1 / math.sqrt(3): return "<|aspect_ratio:ultra_tall|>" elif ar <= 8 / 9: return "<|aspect_ratio:tall|>" elif ar < 9 / 8: return "<|aspect_ratio:square|>" elif ar < math.sqrt(3): return "<|aspect_ratio:wide|>" else: return "<|aspect_ratio:ultra_wide|>" @torch.inference_mode def generate_prompt(subject: str, aspect_ratio: str): input_ids = tokenizer.encode_plus( TEMPLATE.format(aspect_ratio=aspect_ratio, subject=subject), return_tensors="pt", ).input_ids print("input_ids:", input_ids) output_ids = dart.generate( input_ids, max_new_tokens=256, do_sample=True, temperature=1.0, top_p=1.0, top_k=100, num_beams=1, bad_words_ids=BAN_TOKENS, )[0] generated = output_ids[len(input_ids) :] decoded = ", ".join( [ token for token in tokenizer.batch_decode(generated, skip_special_tokens=True) if token.strip() != "" ] ) print("decoded:", decoded) return decoded def format_prompt(prompt: str, prompt_suffix: str): return f"{prompt}, {prompt_suffix}" @spaces.GPU(duration=20) @torch.inference_mode def generate_image( prompt: str, negative_prompt: str, generator, width: int, height: int, guidance_scale: float, num_inference_steps: int, ): image = pipe( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, ).images[0] return image def on_generate( subject: str, suffix: str, negative_prompt: str, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) ar = get_aspect_ratio(width, height) print("ar:", ar) prompt = generate_prompt(subject, ar) prompt = format_prompt(prompt, suffix) print(prompt) image = generate_image( prompt, negative_prompt, generator, width, height, guidance_scale, num_inference_steps, ) return image, prompt, seed def on_retry( prompt: str, negative_prompt: str, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) print(prompt) image = generate_image( prompt, negative_prompt, generator, width, height, guidance_scale, num_inference_steps, ) return image, prompt, seed css = """ #col-container { margin: 0 auto; max-width: 640px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(""" # IllustriousXL Random Gacha Image model: [IllustriousXL v0.1](https://huggingface.co/OnomaAIResearch/Illustrious-xl-early-release-v0) """) with gr.Row(): subject_radio = gr.Dropdown( label="Subject", choices=["1girl", "2girls", "1boy", "no humans"], value="1girl", ) run_button = gr.Button("Pull gacha", variant="primary", scale=0) result = gr.Image(label="Gacha result", show_label=False) with gr.Accordion("Generation details", open=False): with gr.Row(): prompt_txt = gr.Textbox(label="Generated prompt", interactive=False) retry_button = gr.Button("🔄 Retry", scale=0) with gr.Accordion("Advanced Settings", open=False): prompt_suffix = gr.Text( label="Prompt suffix", visible=True, value=QUALITY_TAGS, ) negative_prompt = gr.Text( label="Negative prompt", placeholder="Enter a negative prompt", visible=True, value=NEGATIVE_PROMPT, ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=640, maximum=MAX_IMAGE_SIZE, step=64, value=960, # Replace with defaults that work for your model ) height = gr.Slider( label="Height", minimum=640, maximum=MAX_IMAGE_SIZE, step=64, value=1344, # Replace with defaults that work for your model ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=1.0, maximum=10.0, step=0.5, value=6.5, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=20, maximum=40, step=1, value=28, ) gr.on( triggers=[run_button.click], fn=on_generate, inputs=[ subject_radio, prompt_suffix, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, ], outputs=[result, prompt_txt, seed], ) gr.on( triggers=[retry_button.click], fn=on_retry, inputs=[ prompt_txt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, ], outputs=[result, prompt_txt, seed], ) demo.queue().launch()