mkshing's picture
Update app.py
5df33b1
import os
import random
from huggingface_hub import login
import gradio as gr
import torch
from diffusers import (
DiffusionPipeline,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UniPCMultistepScheduler,
)
from diffusers.utils import make_image_grid
ACCESS_TOKEN = os.environ["ACCESS_TOKEN"]
login(token=ACCESS_TOKEN)
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/japanese-stable-diffusion-xl",
trust_remote_code=True,
torch_dtype=torch.float16,
use_auth_token=ACCESS_TOKEN
)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline.to(device)
SCHEDULER_MAPPING = {
"ddim": DDIMScheduler,
"plms": PNDMScheduler,
"lms": LMSDiscreteScheduler,
"euler": EulerDiscreteScheduler,
"euler_ancestral": EulerAncestralDiscreteScheduler,
"dpm_solver++": DPMSolverMultistepScheduler,
"unipc": UniPCMultistepScheduler,
}
noise_scheduler_name = "euler"
SD_XL_BASE_RATIOS = {
"0.5": (704, 1408),
"0.52": (704, 1344),
"0.57": (768, 1344),
"0.6": (768, 1280),
"0.68": (832, 1216),
"0.72": (832, 1152),
"0.78": (896, 1152),
"0.82": (896, 1088),
"0.88": (960, 1088),
"0.94": (960, 1024),
"1.0": (1024, 1024),
"1.07": (1024, 960),
"1.13": (1088, 960),
"1.21": (1088, 896),
"1.29": (1152, 896),
"1.38": (1152, 832),
"1.46": (1216, 832),
"1.67": (1280, 768),
"1.75": (1344, 768),
"1.91": (1344, 704),
"2.0": (1408, 704),
"2.09": (1472, 704),
"2.4": (1536, 640),
"2.5": (1600, 640),
"2.89": (1664, 576),
"3.0": (1728, 576),
# "small": (512, 512), # for testing
}
def set_noise_scheduler(name) -> None:
pipeline.scheduler = SCHEDULER_MAPPING[name].from_config(pipeline.scheduler.config)
def infer(
prompt,
scale=7.5,
steps=40,
ratio="1.0",
n_samples=1,
seed="random",
negative_prompt="",
scheduler_name="euler",
):
global noise_scheduler_name
if noise_scheduler_name != scheduler_name:
set_noise_scheduler(scheduler_name)
noise_scheduler_name = scheduler_name
scale = float(scale)
steps = int(steps)
W, H = SD_XL_BASE_RATIOS[ratio]
n_samples = int(n_samples)
if seed == "random":
seed = random.randint(0, 2**32)
seed = int(seed)
images = pipeline(
prompt=prompt,
negative_prompt=negative_prompt if len(negative_prompt) > 0 else None,
guidance_scale=scale,
generator=torch.Generator(device=device).manual_seed(seed),
num_images_per_prompt=n_samples,
num_inference_steps=steps,
height=H,
width=W,
).images
# grid = make_image_grid(images, 1, len(images))
return (
images,
{
"seed": seed,
},
)
examples = [
["柴犬、カラフルアート"],
["満面の笑みのお爺さん、スケッチ"],
["星空の中の1匹の鹿、アート"],
["ジャングルに立っている日本男性のポートレート"],
["茶色の猫のイラスト、アニメ"],
["舞妓さんのポートレート、デジタルアート"],
]
with gr.Blocks() as demo:
gr.Markdown("# Japanese Stable Diffusion XL Demo")
gr.Markdown(
"""[Japanese Stable Diffusion XL](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl) is a Japanese-version SDXL by [Stability AI](https://ja.stability.ai/).
- Blog: https://ja.stability.ai/blog/japanese-stable-diffusion-xl
- Twitter: https://twitter.com/StabilityAI_JP
- Discord: https://discord.com/invite/StableJP"""
)
gr.Markdown(
"### You can also try JSDXL on Google Colab [here](https://colab.research.google.com/github/Stability-AI/model-demo-notebooks/blob/main/japanese_stable_diffusion_xl.ipynb). "
)
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="prompt",
max_lines=1,
show_label=False,
placeholder="Enter your prompt",
container=False,
)
btn = gr.Button("Run", scale=0)
gallery = gr.Gallery(label="Generated images", show_label=False)
with gr.Accordion(label="sampling info", open=False):
info = gr.JSON(label="sampling_info")
with gr.Accordion(open=False, label="Advanced options"):
scale = gr.Number(value=7.5, label="cfg_scale")
steps = gr.Number(value=25, label="steps", visible=False)
size_ratio = gr.Dropdown(
choices=list(SD_XL_BASE_RATIOS.keys()),
value="1.0",
label="size ratio",
multiselect=False,
)
n_samples = gr.Slider(
minimum=1,
maximum=2,
value=2,
label="n_samples",
)
seed = gr.Text(
value="random",
label="seed (integer or 'random')",
)
negative_prompt = gr.Textbox(
label="negative prompt",
value="",
)
noise_scheduler = gr.Dropdown(
list(SCHEDULER_MAPPING.keys()), value="euler", visible=False
)
inputs = [
prompt,
scale,
steps,
size_ratio,
n_samples,
seed,
negative_prompt,
noise_scheduler,
]
outputs = [gallery, info]
prompt.submit(infer, inputs=inputs, outputs=outputs)
btn.click(infer, inputs=inputs, outputs=outputs)
gr.Examples(examples=examples, inputs=inputs, outputs=outputs, fn=infer)
demo.queue().launch()