|
import os |
|
import random |
|
import gradio as gr |
|
import torch |
|
from diffusers import ( |
|
DiffusionPipeline, |
|
DDIMScheduler, |
|
DPMSolverMultistepScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
LMSDiscreteScheduler, |
|
PNDMScheduler, |
|
UniPCMultistepScheduler, |
|
) |
|
from diffusers.utils import make_image_grid |
|
|
|
ACCESS_TOKEN = os.environ["ACCESS_TOKEN"] |
|
pipeline = DiffusionPipeline.from_pretrained( |
|
"stabilityai/japanese-stable-diffusion-xl", |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16, |
|
use_auth_token=ACCESS_TOKEN |
|
) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
pipeline.to(device) |
|
SCHEDULER_MAPPING = { |
|
"ddim": DDIMScheduler, |
|
"plms": PNDMScheduler, |
|
"lms": LMSDiscreteScheduler, |
|
"euler": EulerDiscreteScheduler, |
|
"euler_ancestral": EulerAncestralDiscreteScheduler, |
|
"dpm_solver++": DPMSolverMultistepScheduler, |
|
"unipc": UniPCMultistepScheduler, |
|
} |
|
noise_scheduler_name = "euler" |
|
SD_XL_BASE_RATIOS = { |
|
"0.5": (704, 1408), |
|
"0.52": (704, 1344), |
|
"0.57": (768, 1344), |
|
"0.6": (768, 1280), |
|
"0.68": (832, 1216), |
|
"0.72": (832, 1152), |
|
"0.78": (896, 1152), |
|
"0.82": (896, 1088), |
|
"0.88": (960, 1088), |
|
"0.94": (960, 1024), |
|
"1.0": (1024, 1024), |
|
"1.07": (1024, 960), |
|
"1.13": (1088, 960), |
|
"1.21": (1088, 896), |
|
"1.29": (1152, 896), |
|
"1.38": (1152, 832), |
|
"1.46": (1216, 832), |
|
"1.67": (1280, 768), |
|
"1.75": (1344, 768), |
|
"1.91": (1344, 704), |
|
"2.0": (1408, 704), |
|
"2.09": (1472, 704), |
|
"2.4": (1536, 640), |
|
"2.5": (1600, 640), |
|
"2.89": (1664, 576), |
|
"3.0": (1728, 576), |
|
|
|
} |
|
|
|
|
|
def set_noise_scheduler(name) -> None: |
|
pipeline.scheduler = SCHEDULER_MAPPING[name].from_config(pipeline.scheduler.config) |
|
|
|
|
|
def infer( |
|
prompt, |
|
scale=7.5, |
|
steps=40, |
|
ratio="1.0", |
|
n_samples=1, |
|
seed="random", |
|
negative_prompt="", |
|
scheduler_name="euler", |
|
): |
|
global noise_scheduler_name |
|
if noise_scheduler_name != scheduler_name: |
|
set_noise_scheduler(scheduler_name) |
|
noise_scheduler_name = scheduler_name |
|
scale = float(scale) |
|
steps = int(steps) |
|
W, H = SD_XL_BASE_RATIOS[ratio] |
|
n_samples = int(n_samples) |
|
if seed == "random": |
|
seed = random.randint(0, 2**32) |
|
seed = int(seed) |
|
|
|
images = pipeline( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt if len(negative_prompt) > 0 else None, |
|
guidance_scale=scale, |
|
generator=torch.Generator(device=device).manual_seed(seed), |
|
num_images_per_prompt=n_samples, |
|
num_inference_steps=steps, |
|
height=H, |
|
width=W, |
|
).images |
|
|
|
return ( |
|
images, |
|
{ |
|
"seed": seed, |
|
}, |
|
) |
|
|
|
|
|
examples = [ |
|
["柴犬、カラフルアート"], |
|
["満面の笑みのお爺さん、スケッチ"], |
|
["星空の中の1匹の鹿、アート"], |
|
["ジャングルに立っている日本男性のポートレート"], |
|
["茶色の猫のイラスト、アニメ"], |
|
["舞妓さんのポートレート、デジタルアート"], |
|
] |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Japanese Stable Diffusion XL Demo") |
|
gr.Markdown( |
|
"""[Japanese Stable Diffusion XL](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl) is a Japanese-version SDXL by [Stability AI](https://ja.stability.ai/). |
|
- Blog: https://ja.stability.ai/blog/japanese-stable-diffusion-xl |
|
- Twitter: https://twitter.com/StabilityAI_JP |
|
- Discord: https://discord.com/invite/StableJP""" |
|
) |
|
gr.Markdown( |
|
"### You can also try JSDXL on Google Colab [here](https://colab.research.google.com/github/Stability-AI/model-demo-notebooks/blob/main/japanese_stable_diffusion_xl.ipynb). " |
|
) |
|
with gr.Group(): |
|
with gr.Row(): |
|
prompt = gr.Textbox( |
|
label="prompt", |
|
max_lines=1, |
|
show_label=False, |
|
placeholder="Enter your prompt", |
|
container=False, |
|
) |
|
btn = gr.Button("Run", scale=0) |
|
gallery = gr.Gallery(label="Generated images", show_label=False) |
|
with gr.Accordion(label="sampling info", open=False): |
|
info = gr.JSON(label="sampling_info") |
|
with gr.Accordion(open=False, label="Advanced options"): |
|
scale = gr.Number(value=7.5, label="cfg_scale") |
|
steps = gr.Number(value=25, label="steps", visible=False) |
|
size_ratio = gr.Dropdown( |
|
choices=list(SD_XL_BASE_RATIOS.keys()), |
|
value="1.0", |
|
label="size ratio", |
|
multiselect=False, |
|
) |
|
n_samples = gr.Slider( |
|
minimum=1, |
|
maximum=3, |
|
value=2, |
|
label="n_samples", |
|
) |
|
seed = gr.Text( |
|
value="random", |
|
label="seed (integer or 'random')", |
|
) |
|
negative_prompt = gr.Textbox( |
|
label="negative prompt", |
|
value="", |
|
) |
|
noise_scheduler = gr.Dropdown( |
|
list(SCHEDULER_MAPPING.keys()), value="euler", visible=False |
|
) |
|
|
|
inputs = [ |
|
prompt, |
|
scale, |
|
steps, |
|
size_ratio, |
|
n_samples, |
|
seed, |
|
negative_prompt, |
|
noise_scheduler, |
|
] |
|
outputs = [gallery, info] |
|
prompt.submit(infer, inputs=inputs, outputs=outputs) |
|
btn.click(infer, inputs=inputs, outputs=outputs) |
|
gr.Examples(examples=examples, inputs=inputs, outputs=outputs, fn=infer) |
|
|
|
demo.queue().launch(debug=True, share=True, show_error=True) |