Spaces:
Runtime error
Runtime error
import os | |
import random | |
from huggingface_hub import login | |
import gradio as gr | |
import torch | |
from diffusers import ( | |
DiffusionPipeline, | |
DDIMScheduler, | |
DPMSolverMultistepScheduler, | |
EulerAncestralDiscreteScheduler, | |
EulerDiscreteScheduler, | |
LMSDiscreteScheduler, | |
PNDMScheduler, | |
UniPCMultistepScheduler, | |
) | |
from diffusers.utils import make_image_grid | |
ACCESS_TOKEN = os.environ["ACCESS_TOKEN"] | |
login(token=ACCESS_TOKEN) | |
pipeline = DiffusionPipeline.from_pretrained( | |
"stabilityai/japanese-stable-diffusion-xl", | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
use_auth_token=ACCESS_TOKEN | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipeline.to(device) | |
SCHEDULER_MAPPING = { | |
"ddim": DDIMScheduler, | |
"plms": PNDMScheduler, | |
"lms": LMSDiscreteScheduler, | |
"euler": EulerDiscreteScheduler, | |
"euler_ancestral": EulerAncestralDiscreteScheduler, | |
"dpm_solver++": DPMSolverMultistepScheduler, | |
"unipc": UniPCMultistepScheduler, | |
} | |
noise_scheduler_name = "euler" | |
SD_XL_BASE_RATIOS = { | |
"0.5": (704, 1408), | |
"0.52": (704, 1344), | |
"0.57": (768, 1344), | |
"0.6": (768, 1280), | |
"0.68": (832, 1216), | |
"0.72": (832, 1152), | |
"0.78": (896, 1152), | |
"0.82": (896, 1088), | |
"0.88": (960, 1088), | |
"0.94": (960, 1024), | |
"1.0": (1024, 1024), | |
"1.07": (1024, 960), | |
"1.13": (1088, 960), | |
"1.21": (1088, 896), | |
"1.29": (1152, 896), | |
"1.38": (1152, 832), | |
"1.46": (1216, 832), | |
"1.67": (1280, 768), | |
"1.75": (1344, 768), | |
"1.91": (1344, 704), | |
"2.0": (1408, 704), | |
"2.09": (1472, 704), | |
"2.4": (1536, 640), | |
"2.5": (1600, 640), | |
"2.89": (1664, 576), | |
"3.0": (1728, 576), | |
# "small": (512, 512), # for testing | |
} | |
def set_noise_scheduler(name) -> None: | |
pipeline.scheduler = SCHEDULER_MAPPING[name].from_config(pipeline.scheduler.config) | |
def infer( | |
prompt, | |
scale=7.5, | |
steps=40, | |
ratio="1.0", | |
n_samples=1, | |
seed="random", | |
negative_prompt="", | |
scheduler_name="euler", | |
): | |
global noise_scheduler_name | |
if noise_scheduler_name != scheduler_name: | |
set_noise_scheduler(scheduler_name) | |
noise_scheduler_name = scheduler_name | |
scale = float(scale) | |
steps = int(steps) | |
W, H = SD_XL_BASE_RATIOS[ratio] | |
n_samples = int(n_samples) | |
if seed == "random": | |
seed = random.randint(0, 2**32) | |
seed = int(seed) | |
images = pipeline( | |
prompt=prompt, | |
negative_prompt=negative_prompt if len(negative_prompt) > 0 else None, | |
guidance_scale=scale, | |
generator=torch.Generator(device=device).manual_seed(seed), | |
num_images_per_prompt=n_samples, | |
num_inference_steps=steps, | |
height=H, | |
width=W, | |
).images | |
# grid = make_image_grid(images, 1, len(images)) | |
return ( | |
images, | |
{ | |
"seed": seed, | |
}, | |
) | |
examples = [ | |
["柴犬、カラフルアート"], | |
["満面の笑みのお爺さん、スケッチ"], | |
["星空の中の1匹の鹿、アート"], | |
["ジャングルに立っている日本男性のポートレート"], | |
["茶色の猫のイラスト、アニメ"], | |
["舞妓さんのポートレート、デジタルアート"], | |
] | |
with gr.Blocks() as demo: | |
gr.Markdown("# Japanese Stable Diffusion XL Demo") | |
gr.Markdown( | |
"""[Japanese Stable Diffusion XL](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl) is a Japanese-version SDXL by [Stability AI](https://ja.stability.ai/). | |
- Blog: https://ja.stability.ai/blog/japanese-stable-diffusion-xl | |
- Twitter: https://twitter.com/StabilityAI_JP | |
- Discord: https://discord.com/invite/StableJP""" | |
) | |
gr.Markdown( | |
"### You can also try JSDXL on Google Colab [here](https://colab.research.google.com/github/Stability-AI/model-demo-notebooks/blob/main/japanese_stable_diffusion_xl.ipynb). " | |
) | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="prompt", | |
max_lines=1, | |
show_label=False, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
btn = gr.Button("Run", scale=0) | |
gallery = gr.Gallery(label="Generated images", show_label=False) | |
with gr.Accordion(label="sampling info", open=False): | |
info = gr.JSON(label="sampling_info") | |
with gr.Accordion(open=False, label="Advanced options"): | |
scale = gr.Number(value=7.5, label="cfg_scale") | |
steps = gr.Number(value=25, label="steps", visible=False) | |
size_ratio = gr.Dropdown( | |
choices=list(SD_XL_BASE_RATIOS.keys()), | |
value="1.0", | |
label="size ratio", | |
multiselect=False, | |
) | |
n_samples = gr.Slider( | |
minimum=1, | |
maximum=2, | |
value=2, | |
label="n_samples", | |
) | |
seed = gr.Text( | |
value="random", | |
label="seed (integer or 'random')", | |
) | |
negative_prompt = gr.Textbox( | |
label="negative prompt", | |
value="", | |
) | |
noise_scheduler = gr.Dropdown( | |
list(SCHEDULER_MAPPING.keys()), value="euler", visible=False | |
) | |
inputs = [ | |
prompt, | |
scale, | |
steps, | |
size_ratio, | |
n_samples, | |
seed, | |
negative_prompt, | |
noise_scheduler, | |
] | |
outputs = [gallery, info] | |
prompt.submit(infer, inputs=inputs, outputs=outputs) | |
btn.click(infer, inputs=inputs, outputs=outputs) | |
gr.Examples(examples=examples, inputs=inputs, outputs=outputs, fn=infer) | |
demo.queue().launch() |