sdui / docker-entrypoint.py
atikur-rabbi's picture
up
1659e0c
raw
history blame
No virus
8.26 kB
#!/usr/bin/env python
import argparse, datetime, inspect, os, warnings
import numpy as np
import torch
from PIL import Image
from diffusers import (
OnnxStableDiffusionPipeline,
OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionImg2ImgPipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionUpscalePipeline,
schedulers,
)
def iso_date_time():
return datetime.datetime.now().isoformat()
def load_image(path):
image = Image.open(os.path.join("input", path)).convert("RGB")
print(f"loaded image from {path}:", iso_date_time(), flush=True)
return image
def remove_unused_args(p):
params = inspect.signature(p.pipeline).parameters.keys()
args = {
"prompt": p.prompt,
"negative_prompt": p.negative_prompt,
"image": p.image,
"mask_image": p.mask,
"height": p.height,
"width": p.width,
"num_images_per_prompt": p.samples,
"num_inference_steps": p.steps,
"guidance_scale": p.scale,
"image_guidance_scale": p.image_scale,
"strength": p.strength,
"generator": p.generator,
}
return {p: args[p] for p in params if p in args}
def stable_diffusion_pipeline(p):
p.dtype = torch.float16 if p.half else torch.float32
if p.onnx:
p.diffuser = OnnxStableDiffusionPipeline
p.revision = "onnx"
else:
p.diffuser = StableDiffusionPipeline
p.revision = "fp16" if p.half else "main"
models = argparse.Namespace(
**{
"depth2img": ["stabilityai/stable-diffusion-2-depth"],
"pix2pix": ["timbrooks/instruct-pix2pix"],
"upscalers": ["stabilityai/stable-diffusion-x4-upscaler"],
}
)
if p.image is not None:
if p.revision == "onnx":
p.diffuser = OnnxStableDiffusionImg2ImgPipeline
elif p.model in models.depth2img:
p.diffuser = StableDiffusionDepth2ImgPipeline
elif p.model in models.pix2pix:
p.diffuser = StableDiffusionInstructPix2PixPipeline
elif p.model in models.upscalers:
p.diffuser = StableDiffusionUpscalePipeline
else:
p.diffuser = StableDiffusionImg2ImgPipeline
p.image = load_image(p.image)
if p.mask is not None:
if p.revision == "onnx":
p.diffuser = OnnxStableDiffusionInpaintPipeline
else:
p.diffuser = StableDiffusionInpaintPipeline
p.mask = load_image(p.mask)
if p.token is None:
with open("token.txt") as f:
p.token = f.read().replace("\n", "")
if p.seed == 0:
p.seed = torch.random.seed()
if p.revision == "onnx":
p.seed = p.seed >> 32 if p.seed > 2**32 - 1 else p.seed
p.generator = np.random.RandomState(p.seed)
else:
p.generator = torch.Generator(device=p.device).manual_seed(p.seed)
print("load pipeline start:", iso_date_time(), flush=True)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
pipeline = p.diffuser.from_pretrained(
p.model,
torch_dtype=p.dtype,
revision=p.revision,
use_auth_token=p.token,
).to(p.device)
if p.scheduler is not None:
scheduler = getattr(schedulers, p.scheduler)
pipeline.scheduler = scheduler.from_config(pipeline.scheduler.config)
if p.skip:
pipeline.safety_checker = None
if p.attention_slicing:
pipeline.enable_attention_slicing()
if p.xformers_memory_efficient_attention:
pipeline.enable_xformers_memory_efficient_attention()
if p.vae_tiling:
pipeline.vae.enable_tiling()
p.pipeline = pipeline
print("loaded models after:", iso_date_time(), flush=True)
return p
def stable_diffusion_inference(p):
prefix = p.prompt.replace(" ", "_")[:170]
for j in range(p.iters):
result = p.pipeline(**remove_unused_args(p))
for i, img in enumerate(result.images):
idx = j * p.samples + i + 1
out = f"{prefix}__steps_{p.steps}__scale_{p.scale:.2f}__seed_{p.seed}__n_{idx}.png"
img.save(os.path.join("output", out))
print("completed pipeline:", iso_date_time(), flush=True)
def main():
parser = argparse.ArgumentParser(description="Create images from a text prompt.")
parser.add_argument(
"--attention-slicing",
action="store_true",
help="Use less memory at the expense of inference speed",
)
parser.add_argument(
"--device",
type=str,
nargs="?",
default="cuda",
help="The cpu or cuda device to use to render images",
)
parser.add_argument(
"--half",
action="store_true",
help="Use float16 (half-sized) tensors instead of float32",
)
parser.add_argument(
"--height", type=int, nargs="?", default=512, help="Image height in pixels"
)
parser.add_argument(
"--image",
type=str,
nargs="?",
help="The input image to use for image-to-image diffusion",
)
parser.add_argument(
"--image-scale",
type=float,
nargs="?",
help="How closely the image should follow the original image",
)
parser.add_argument(
"--iters",
type=int,
nargs="?",
default=1,
help="Number of times to run pipeline",
)
parser.add_argument(
"--mask",
type=str,
nargs="?",
help="The input mask to use for diffusion inpainting",
)
parser.add_argument(
"--model",
type=str,
nargs="?",
default="CompVis/stable-diffusion-v1-4",
help="The model used to render images",
)
parser.add_argument(
"--negative-prompt",
type=str,
nargs="?",
help="The prompt to not render into an image",
)
parser.add_argument(
"--onnx",
action="store_true",
help="Use the onnx runtime for inference",
)
parser.add_argument(
"--prompt", type=str, nargs="?", help="The prompt to render into an image"
)
parser.add_argument(
"--samples",
type=int,
nargs="?",
default=1,
help="Number of images to create per run",
)
parser.add_argument(
"--scale",
type=float,
nargs="?",
default=7.5,
help="How closely the image should follow the prompt",
)
parser.add_argument(
"--scheduler",
type=str,
nargs="?",
help="Override the scheduler used to denoise the image",
)
parser.add_argument(
"--seed", type=int, nargs="?", default=0, help="RNG seed for repeatability"
)
parser.add_argument(
"--skip",
action="store_true",
help="Skip the safety checker",
)
parser.add_argument(
"--steps", type=int, nargs="?", default=50, help="Number of sampling steps"
)
parser.add_argument(
"--strength",
type=float,
default=0.75,
help="Diffusion strength to apply to the input image",
)
parser.add_argument(
"--token", type=str, nargs="?", help="Huggingface user access token"
)
parser.add_argument(
"--vae-tiling",
action="store_true",
help="Use less memory when generating ultra-high resolution images",
)
parser.add_argument(
"--width", type=int, nargs="?", default=512, help="Image width in pixels"
)
parser.add_argument(
"--xformers-memory-efficient-attention",
action="store_true",
help="Use less memory but require the xformers library",
)
parser.add_argument(
"prompt0",
metavar="PROMPT",
type=str,
nargs="?",
help="The prompt to render into an image",
)
args = parser.parse_args()
if args.prompt0 is not None:
args.prompt = args.prompt0
pipeline = stable_diffusion_pipeline(args)
stable_diffusion_inference(pipeline)
if __name__ == "__main__":
main()