import spaces
import os
import random
import math
import torch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = True
import numpy as np
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
StableDiffusionXLPipeline,
)
from diffusers.schedulers.scheduling_euler_ancestral_discrete import (
EulerAncestralDiscreteScheduler,
)
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
try:
from dotenv import load_dotenv
load_dotenv()
except:
print("failed to import dotenv (this is not a problem on the production)")
HF_TOKEN = os.environ.get("HF_TOKEN")
assert HF_TOKEN is not None
IMAGE_MODEL_REPO_ID = os.environ.get(
"IMAGE_MODEL_REPO_ID", "OnomaAIResearch/Illustrious-xl-early-release-v0"
)
DART_V3_REPO_ID = os.environ.get("DART_V3_REPO_ID", None)
assert DART_V3_REPO_ID is not None
CPU_OFFLOAD = os.environ.get("CPU_OFFLOAD", "False").lower() == "true"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
TEMPLATE = (
"<|bos|>"
#
"<|rating:general|>"
"{aspect_ratio}"
"<|length:medium|>"
#
""
#
""
#
"{subject}"
)
QUALITY_TAGS = "original style"
NEGATIVE_PROMPT = "lowres, blurry, watermark, signature, copyright, logo, artistic error, bad anatomy, bad hands, retro, 2000s, 2010s, 2011s, 2012s, 2013s"
BAN_TAGS = [
"photoshop (medium)",
"clip studio paint (medium)",
"2005", # year tags
"2006",
"2007",
"2008",
"2009",
"2010",
"2011",
"2012",
"2013",
"2014",
"2015",
"2016",
"2017",
"2018",
"2019",
"2020",
]
device = "cuda" if torch.cuda.is_available() else "cpu"
dart = AutoModelForCausalLM.from_pretrained(
DART_V3_REPO_ID,
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
use_cache=True,
device_map="cpu",
)
dart = dart.eval()
dart = dart.requires_grad_(False)
dart = torch.compile(dart)
tokenizer = AutoTokenizer.from_pretrained(DART_V3_REPO_ID)
BAN_TOKENS = [tokenizer.convert_tokens_to_ids([tag]) for tag in BAN_TAGS]
def load_pipeline():
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
pipe = StableDiffusionXLPipeline.from_pretrained(
IMAGE_MODEL_REPO_ID,
vae=vae,
torch_dtype=torch.float16,
use_safetensors=True,
add_watermarker=False,
custom_pipeline="lpw_stable_diffusion_xl",
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
if CPU_OFFLOAD: # local
pipe.enable_sequential_cpu_offload(gpu_id=0, device=device)
else:
pipe.to(device) # for spaces
return pipe
if torch.cuda.is_available():
pipe = load_pipeline()
print("Loaded pipeline")
else:
pipe = None
def get_aspect_ratio(width: int, height: int) -> str:
ar = width / height
if ar <= 1 / math.sqrt(3):
return "<|aspect_ratio:ultra_tall|>"
elif ar <= 8 / 9:
return "<|aspect_ratio:tall|>"
elif ar < 9 / 8:
return "<|aspect_ratio:square|>"
elif ar < math.sqrt(3):
return "<|aspect_ratio:wide|>"
else:
return "<|aspect_ratio:ultra_wide|>"
@torch.inference_mode
def generate_prompt(subject: str, aspect_ratio: str):
input_ids = tokenizer.encode_plus(
TEMPLATE.format(aspect_ratio=aspect_ratio, subject=subject),
return_tensors="pt",
).input_ids
print("input_ids:", input_ids)
output_ids = dart.generate(
input_ids,
max_new_tokens=256,
do_sample=True,
temperature=1.0,
top_p=1.0,
top_k=100,
num_beams=1,
bad_words_ids=BAN_TOKENS,
)[0]
generated = output_ids[len(input_ids) :]
decoded = ", ".join(
[
token
for token in tokenizer.batch_decode(generated, skip_special_tokens=True)
if token.strip() != ""
]
)
print("decoded:", decoded)
return decoded
def format_prompt(prompt: str, prompt_suffix: str):
return f"{prompt}, {prompt_suffix}"
@spaces.GPU(duration=20)
@torch.inference_mode
def generate_image(
prompt: str,
negative_prompt: str,
generator,
width: int,
height: int,
guidance_scale: float,
num_inference_steps: int,
):
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image
def on_generate(
subject: str,
suffix: str,
negative_prompt: str,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
ar = get_aspect_ratio(width, height)
print("ar:", ar)
prompt = generate_prompt(subject, ar)
prompt = format_prompt(prompt, suffix)
print(prompt)
image = generate_image(
prompt,
negative_prompt,
generator,
width,
height,
guidance_scale,
num_inference_steps,
)
return image, prompt, seed
def on_retry(
prompt: str,
negative_prompt: str,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
print(prompt)
image = generate_image(
prompt,
negative_prompt,
generator,
width,
height,
guidance_scale,
num_inference_steps,
)
return image, prompt, seed
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("""
# IllustriousXL Random Gacha
Image model: [IllustriousXL v0.1](https://huggingface.co/OnomaAIResearch/Illustrious-xl-early-release-v0)
""")
with gr.Row():
subject_radio = gr.Dropdown(
label="Subject",
choices=["1girl", "2girls", "1boy", "no humans"],
value="1girl",
)
run_button = gr.Button("Pull gacha", variant="primary", scale=0)
result = gr.Image(label="Gacha result", show_label=False)
with gr.Accordion("Generation details", open=False):
with gr.Row():
prompt_txt = gr.Textbox(label="Generated prompt", interactive=False)
retry_button = gr.Button("🔄 Retry", scale=0)
with gr.Accordion("Advanced Settings", open=False):
prompt_suffix = gr.Text(
label="Prompt suffix",
visible=True,
value=QUALITY_TAGS,
)
negative_prompt = gr.Text(
label="Negative prompt",
placeholder="Enter a negative prompt",
visible=True,
value=NEGATIVE_PROMPT,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=640,
maximum=MAX_IMAGE_SIZE,
step=64,
value=960, # Replace with defaults that work for your model
)
height = gr.Slider(
label="Height",
minimum=640,
maximum=MAX_IMAGE_SIZE,
step=64,
value=1344, # Replace with defaults that work for your model
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1.0,
maximum=10.0,
step=0.5,
value=6.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=20,
maximum=40,
step=1,
value=28,
)
gr.on(
triggers=[run_button.click],
fn=on_generate,
inputs=[
subject_radio,
prompt_suffix,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, prompt_txt, seed],
)
gr.on(
triggers=[retry_button.click],
fn=on_retry,
inputs=[
prompt_txt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, prompt_txt, seed],
)
demo.queue().launch()