Animagine-XL / app.py
justmalhar's picture
Duplicate from Linaqruf/Animagine-XL
a4e562b
#!/usr/bin/env python
from __future__ import annotations
import os
import random
import gradio as gr
import numpy as np
import PIL.Image
import torch
from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
DESCRIPTION = '# Animagine XL'
if not torch.cuda.is_available():
DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv(
'CACHE_EXAMPLES') == '1'
MAX_IMAGE_SIZE = int(os.getenv('MAX_IMAGE_SIZE', '2048'))
USE_TORCH_COMPILE = os.getenv('USE_TORCH_COMPILE') == '1'
ENABLE_CPU_OFFLOAD = os.getenv('ENABLE_CPU_OFFLOAD') == '1'
MODEL = "Linaqruf/animagine-xl"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
pipe = StableDiffusionXLPipeline.from_pretrained(
MODEL,
torch_dtype=torch.float16,
use_safetensors=True,
variant='fp16')
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
if ENABLE_CPU_OFFLOAD:
pipe.enable_model_cpu_offload()
else:
pipe.to(device)
if USE_TORCH_COMPILE:
pipe.unet = torch.compile(pipe.unet,
mode='reduce-overhead',
fullgraph=True)
else:
pipe = None
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def generate(prompt: str,
negative_prompt: str = '',
prompt_2: str = '',
negative_prompt_2: str = '',
use_prompt_2: bool = False,
seed: int = 0,
width: int = 1024,
height: int = 1024,
target_width: int = 1024,
target_height: int = 1024,
original_width: int = 4096,
original_height: int = 4096,
guidance_scale_base: float = 12.0,
num_inference_steps_base: int = 50) -> PIL.Image.Image:
generator = torch.Generator().manual_seed(seed)
if negative_prompt == '':
negative_prompt = None # type: ignore
if not use_prompt_2:
prompt_2 = None # type: ignore
negative_prompt_2 = None # type: ignore
if negative_prompt_2 == '':
negative_prompt_2 = None
return pipe(prompt=prompt,
negative_prompt=negative_prompt,
prompt_2=prompt_2,
negative_prompt_2=negative_prompt_2,
width=width,
height=height,
target_size=(target_width, target_height),
original_size=(original_width, original_height),
guidance_scale=guidance_scale_base,
num_inference_steps=num_inference_steps_base,
generator=generator,
output_type='pil').images[0]
examples = [
'face focus, cute, masterpiece, best quality, 1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, night, turtleneck',
'face focus, bishounen, masterpiece, best quality, 1boy, green hair, sweater, looking at viewer, upper body, beanie, outdoors, night, turtleneck',
]
# choices = [
# "Vertical (9:16)",
# "Portrait (4:5)",
# "Square (1:1)",
# "Photo (4:3)",
# "Landscape (3:2)",
# "Widescreen (16:9)",
# "Cinematic (21:9)",
# ]
# choice_to_size = {
# "Vertical (9:16)": (768, 1344),
# "Portrait (4:5)": (912, 1144),
# "Square (1:1)": (1024, 1024),
# "Photo (4:3)": (1184, 888),
# "Landscape (3:2)": (1256, 832),
# "Widescreen (16:9)": (1368, 768),
# "Cinematic (21:9)": (1568, 672),
# }
with gr.Blocks(css='style.css') as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value='Duplicate Space for private use',
elem_id='duplicate-button',
visible=os.getenv('SHOW_DUPLICATE_BUTTON') == '1')
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Text(
label='Prompt',
max_lines=1,
placeholder='Enter your prompt',
)
negative_prompt = gr.Text(
label='Negative Prompt',
max_lines=1,
placeholder='Enter a negative prompt',
value='lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry',
)
use_prompt_2 = gr.Checkbox(
label='Use prompt 2',
value=False
)
prompt_2 = gr.Text(
label='Prompt 2',
max_lines=1,
placeholder='Enter your prompt',
visible=False,
)
negative_prompt_2 = gr.Text(
label='Negative prompt 2',
max_lines=1,
placeholder='Enter a negative prompt',
visible=False,
)
# with gr.Row():
# aspect_ratio = gr.Dropdown(choices=choices, label="Aspect Ratio Preset", value=choices[2])
with gr.Row():
width = gr.Slider(
label='Width',
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label='Height',
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Accordion(label='Advanced Config', open=False):
with gr.Accordion(label='Conditioning Resolution', open=False):
with gr.Row():
original_width = gr.Slider(
label='Original Width',
minimum=1024,
maximum=4096,
step=32,
value=4096,
)
original_height = gr.Slider(
label='Original Height',
minimum=1024,
maximum=4096,
step=32,
value=4096,
)
with gr.Row():
target_width = gr.Slider(
label='Target Width',
minimum=1024,
maximum=4096,
step=32,
value=1024,
)
target_height = gr.Slider(
label='Target Height',
minimum=1024,
maximum=4096,
step=32,
value=1024,
)
seed = gr.Slider(label='Seed',
minimum=0,
maximum=MAX_SEED,
step=1,
value=0)
randomize_seed = gr.Checkbox(label='Randomize seed', value=True)
with gr.Row():
guidance_scale_base = gr.Slider(
label='Guidance scale',
minimum=1,
maximum=20,
step=0.1,
value=12.0)
num_inference_steps_base = gr.Slider(
label='Number of inference steps',
minimum=10,
maximum=100,
step=1,
value=50)
with gr.Column(scale=2):
with gr.Blocks():
run_button = gr.Button('Generate')
result = gr.Image(label='Result', show_label=False)
gr.Examples(examples=examples,
inputs=prompt,
outputs=result,
fn=generate,
cache_examples=CACHE_EXAMPLES)
use_prompt_2.change(
fn=lambda x: gr.update(visible=x),
inputs=use_prompt_2,
outputs=prompt_2,
queue=False,
api_name=False,
)
use_prompt_2.change(
fn=lambda x: gr.update(visible=x),
inputs=use_prompt_2,
outputs=negative_prompt_2,
queue=False,
api_name=False,
)
inputs = [
prompt,
negative_prompt,
prompt_2,
negative_prompt_2,
use_prompt_2,
seed,
width,
height,
target_width,
target_height,
original_width,
original_height,
guidance_scale_base,
num_inference_steps_base,
]
prompt.submit(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=generate,
inputs=inputs,
outputs=result,
api_name='run',
)
negative_prompt.submit(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=generate,
inputs=inputs,
outputs=result,
api_name=False,
)
prompt_2.submit(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=generate,
inputs=inputs,
outputs=result,
api_name=False,
)
negative_prompt_2.submit(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=generate,
inputs=inputs,
outputs=result,
api_name=False,
)
run_button.click(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=generate,
inputs=inputs,
outputs=result,
api_name=False,
)
demo.queue(max_size=20).launch()