|
import io |
|
import base64 |
|
import os |
|
import sys |
|
|
|
import numpy as np |
|
import torch |
|
from torch import autocast |
|
import diffusers |
|
from diffusers.configuration_utils import FrozenDict |
|
from diffusers import ( |
|
StableDiffusionPipeline, |
|
StableDiffusionInpaintPipeline, |
|
StableDiffusionImg2ImgPipeline, |
|
StableDiffusionInpaintPipelineLegacy, |
|
DDIMScheduler, |
|
LMSDiscreteScheduler, |
|
) |
|
from diffusers.models import AutoencoderKL |
|
from PIL import Image |
|
from PIL import ImageOps |
|
import gradio as gr |
|
import base64 |
|
import skimage |
|
import skimage.measure |
|
import yaml |
|
import json |
|
from enum import Enum |
|
|
|
try: |
|
abspath = os.path.abspath(__file__) |
|
dirname = os.path.dirname(abspath) |
|
os.chdir(dirname) |
|
except: |
|
pass |
|
|
|
from utils import * |
|
|
|
assert diffusers.__version__ >= "0.6.0", "Please upgrade diffusers to 0.6.0" |
|
|
|
USE_NEW_DIFFUSERS = True |
|
RUN_IN_SPACE = "RUN_IN_HG_SPACE" in os.environ |
|
|
|
|
|
class ModelChoice(Enum): |
|
INPAINTING = "stablediffusion-inpainting" |
|
INPAINTING_IMG2IMG = "stablediffusion-inpainting+img2img-v1.5" |
|
MODEL_1_5 = "stablediffusion-v1.5" |
|
MODEL_1_4 = "stablediffusion-v1.4" |
|
|
|
|
|
try: |
|
from sd_grpcserver.pipeline.unified_pipeline import UnifiedPipeline |
|
except: |
|
UnifiedPipeline = StableDiffusionInpaintPipeline |
|
|
|
|
|
|
|
USE_GLID = False |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
cuda_available = torch.cuda.is_available() |
|
except: |
|
cuda_available = False |
|
finally: |
|
if sys.platform == "darwin": |
|
device = "mps" if torch.backends.mps.is_available() else "cpu" |
|
elif cuda_available: |
|
device = "cuda" |
|
else: |
|
device = "cpu" |
|
|
|
import contextlib |
|
|
|
autocast = contextlib.nullcontext |
|
|
|
with open("config.yaml", "r") as yaml_in: |
|
yaml_object = yaml.safe_load(yaml_in) |
|
config_json = json.dumps(yaml_object) |
|
|
|
|
|
def load_html(): |
|
body, canvaspy = "", "" |
|
with open("index.html", encoding="utf8") as f: |
|
body = f.read() |
|
with open("canvas.py", encoding="utf8") as f: |
|
canvaspy = f.read() |
|
body = body.replace("- paths:\n", "") |
|
body = body.replace(" - ./canvas.py\n", "") |
|
body = body.replace("from canvas import InfCanvas", canvaspy) |
|
return body |
|
|
|
|
|
def test(x): |
|
x = load_html() |
|
return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; |
|
display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms |
|
allow-scripts allow-same-origin allow-popups |
|
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" |
|
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" |
|
|
|
|
|
DEBUG_MODE = False |
|
|
|
try: |
|
SAMPLING_MODE = Image.Resampling.LANCZOS |
|
except Exception as e: |
|
SAMPLING_MODE = Image.LANCZOS |
|
|
|
try: |
|
contain_func = ImageOps.contain |
|
except Exception as e: |
|
|
|
def contain_func(image, size, method=SAMPLING_MODE): |
|
|
|
im_ratio = image.width / image.height |
|
dest_ratio = size[0] / size[1] |
|
if im_ratio != dest_ratio: |
|
if im_ratio > dest_ratio: |
|
new_height = int(image.height / image.width * size[0]) |
|
if new_height != size[1]: |
|
size = (size[0], new_height) |
|
else: |
|
new_width = int(image.width / image.height * size[1]) |
|
if new_width != size[0]: |
|
size = (new_width, size[1]) |
|
return image.resize(size, resample=method) |
|
|
|
|
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="stablediffusion-infinity") |
|
parser.add_argument("--port", type=int, help="listen port", dest="server_port") |
|
parser.add_argument("--host", type=str, help="host", dest="server_name") |
|
parser.add_argument("--share", action="store_true", help="share this app?") |
|
parser.add_argument("--debug", action="store_true", help="debug mode") |
|
parser.add_argument("--fp32", action="store_true", help="using full precision") |
|
parser.add_argument("--encrypt", action="store_true", help="using https?") |
|
parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile") |
|
parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile") |
|
parser.add_argument("--ssl_keyfile_password", type=str, help="ssl_keyfile_password") |
|
parser.add_argument( |
|
"--auth", nargs=2, metavar=("username", "password"), help="use username password" |
|
) |
|
parser.add_argument( |
|
"--remote_model", |
|
type=str, |
|
help="use a model (e.g. dreambooth fined) from huggingface hub", |
|
default="", |
|
) |
|
parser.add_argument( |
|
"--local_model", type=str, help="use a model stored on your PC", default="" |
|
) |
|
|
|
if __name__ == "__main__" and not RUN_IN_SPACE: |
|
args = parser.parse_args() |
|
else: |
|
args = parser.parse_args() |
|
|
|
if args.auth is not None: |
|
args.auth = tuple(args.auth) |
|
|
|
model = {} |
|
|
|
|
|
def get_token(): |
|
token = "" |
|
if os.path.exists(".token"): |
|
with open(".token", "r") as f: |
|
token = f.read() |
|
token = os.environ.get("hftoken", token) |
|
return token |
|
|
|
|
|
def save_token(token): |
|
with open(".token", "w") as f: |
|
f.write(token) |
|
|
|
|
|
def prepare_scheduler(scheduler): |
|
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: |
|
new_config = dict(scheduler.config) |
|
new_config["steps_offset"] = 1 |
|
scheduler._internal_dict = FrozenDict(new_config) |
|
return scheduler |
|
|
|
|
|
def my_resize(width, height): |
|
if width >= 512 and height >= 512: |
|
return width, height |
|
if width == height: |
|
return 512, 512 |
|
smaller = min(width, height) |
|
larger = max(width, height) |
|
if larger >= 608: |
|
return width, height |
|
factor = 1 |
|
if smaller < 290: |
|
factor = 2 |
|
elif smaller < 330: |
|
factor = 1.75 |
|
elif smaller < 384: |
|
factor = 1.375 |
|
elif smaller < 400: |
|
factor = 1.25 |
|
elif smaller < 450: |
|
factor = 1.125 |
|
return int(factor * width)//8*8, int(factor * height)//8*8 |
|
|
|
|
|
def load_learned_embed_in_clip( |
|
learned_embeds_path, text_encoder, tokenizer, token=None |
|
): |
|
|
|
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") |
|
|
|
|
|
trained_token = list(loaded_learned_embeds.keys())[0] |
|
embeds = loaded_learned_embeds[trained_token] |
|
|
|
|
|
dtype = text_encoder.get_input_embeddings().weight.dtype |
|
embeds.to(dtype) |
|
|
|
|
|
token = token if token is not None else trained_token |
|
num_added_tokens = tokenizer.add_tokens(token) |
|
if num_added_tokens == 0: |
|
raise ValueError( |
|
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer." |
|
) |
|
|
|
|
|
text_encoder.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
token_id = tokenizer.convert_tokens_to_ids(token) |
|
text_encoder.get_input_embeddings().weight.data[token_id] = embeds |
|
|
|
|
|
scheduler_dict = {"PLMS": None, "DDIM": None, "K-LMS": None} |
|
|
|
|
|
class StableDiffusionInpaint: |
|
def __init__( |
|
self, token: str = "", model_name: str = "", model_path: str = "", **kwargs, |
|
): |
|
self.token = token |
|
original_checkpoint = False |
|
if model_path and os.path.exists(model_path): |
|
if model_path.endswith(".ckpt"): |
|
original_checkpoint = True |
|
elif model_path.endswith(".json"): |
|
model_name = os.path.dirname(model_path) |
|
else: |
|
model_name = model_path |
|
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") |
|
vae.to(torch.float16) |
|
if original_checkpoint: |
|
print(f"Converting & Loading {model_path}") |
|
from convert_checkpoint import convert_checkpoint |
|
|
|
pipe = convert_checkpoint(model_path, inpainting=True) |
|
if device == "cuda": |
|
pipe.to(torch.float16) |
|
inpaint = StableDiffusionInpaintPipeline( |
|
vae=vae, |
|
text_encoder=pipe.text_encoder, |
|
tokenizer=pipe.tokenizer, |
|
unet=pipe.unet, |
|
scheduler=pipe.scheduler, |
|
safety_checker=pipe.safety_checker, |
|
feature_extractor=pipe.feature_extractor, |
|
) |
|
else: |
|
print(f"Loading {model_name}") |
|
if device == "cuda": |
|
inpaint = StableDiffusionInpaintPipeline.from_pretrained( |
|
model_name, |
|
revision="fp16", |
|
torch_dtype=torch.float16, |
|
use_auth_token=token, |
|
vae=vae |
|
) |
|
else: |
|
inpaint = StableDiffusionInpaintPipeline.from_pretrained( |
|
model_name, use_auth_token=token, |
|
) |
|
if os.path.exists("./embeddings"): |
|
print("Note that StableDiffusionInpaintPipeline + embeddings is untested") |
|
for item in os.listdir("./embeddings"): |
|
if item.endswith(".bin"): |
|
load_learned_embed_in_clip( |
|
os.path.join("./embeddings", item), |
|
inpaint.text_encoder, |
|
inpaint.tokenizer, |
|
) |
|
inpaint.to(device) |
|
|
|
|
|
scheduler_dict["PLMS"] = inpaint.scheduler |
|
scheduler_dict["DDIM"] = prepare_scheduler( |
|
DDIMScheduler( |
|
beta_start=0.00085, |
|
beta_end=0.012, |
|
beta_schedule="scaled_linear", |
|
clip_sample=False, |
|
set_alpha_to_one=False, |
|
) |
|
) |
|
scheduler_dict["K-LMS"] = prepare_scheduler( |
|
LMSDiscreteScheduler( |
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
|
) |
|
) |
|
self.safety_checker = inpaint.safety_checker |
|
save_token(token) |
|
try: |
|
total_memory = torch.cuda.get_device_properties(0).total_memory // ( |
|
1024 ** 3 |
|
) |
|
if total_memory <= 5: |
|
inpaint.enable_attention_slicing() |
|
except: |
|
pass |
|
self.inpaint = inpaint |
|
|
|
def run( |
|
self, |
|
image_pil, |
|
prompt="", |
|
negative_prompt="", |
|
guidance_scale=7.5, |
|
resize_check=True, |
|
enable_safety=True, |
|
fill_mode="patchmatch", |
|
strength=0.75, |
|
step=50, |
|
enable_img2img=False, |
|
use_seed=False, |
|
seed_val=-1, |
|
generate_num=1, |
|
scheduler="", |
|
scheduler_eta=0.0, |
|
**kwargs, |
|
): |
|
inpaint = self.inpaint |
|
selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"]) |
|
for item in [inpaint]: |
|
item.scheduler = selected_scheduler |
|
if enable_safety: |
|
item.safety_checker = self.safety_checker |
|
else: |
|
item.safety_checker = lambda images, **kwargs: (images, False) |
|
width, height = image_pil.size |
|
sel_buffer = np.array(image_pil) |
|
img = sel_buffer[:, :, 0:3] |
|
mask = sel_buffer[:, :, -1] |
|
nmask = 255 - mask |
|
process_width = width |
|
process_height = height |
|
if resize_check: |
|
process_width, process_height = my_resize(width, height) |
|
process_width=process_width*8//8 |
|
process_height=process_height*8//8 |
|
extra_kwargs = { |
|
"num_inference_steps": step, |
|
"guidance_scale": guidance_scale, |
|
"eta": scheduler_eta, |
|
} |
|
if USE_NEW_DIFFUSERS: |
|
extra_kwargs["negative_prompt"] = negative_prompt |
|
extra_kwargs["num_images_per_prompt"] = generate_num |
|
if use_seed: |
|
generator = torch.Generator(inpaint.device).manual_seed(seed_val) |
|
extra_kwargs["generator"] = generator |
|
if True: |
|
img, mask = functbl[fill_mode](img, mask) |
|
mask = 255 - mask |
|
mask = skimage.measure.block_reduce(mask, (8, 8), np.max) |
|
mask = mask.repeat(8, axis=0).repeat(8, axis=1) |
|
extra_kwargs["strength"] = strength |
|
inpaint_func = inpaint |
|
init_image = Image.fromarray(img) |
|
mask_image = Image.fromarray(mask) |
|
|
|
if True: |
|
images = inpaint_func( |
|
prompt=prompt, |
|
image=init_image.resize( |
|
(process_width, process_height), resample=SAMPLING_MODE |
|
), |
|
mask_image=mask_image.resize((process_width, process_height)), |
|
width=process_width, |
|
height=process_height, |
|
**extra_kwargs, |
|
)["images"] |
|
return images |
|
|
|
|
|
class StableDiffusion: |
|
def __init__( |
|
self, |
|
token: str = "", |
|
model_name: str = "runwayml/stable-diffusion-v1-5", |
|
model_path: str = None, |
|
inpainting_model: bool = False, |
|
**kwargs, |
|
): |
|
self.token = token |
|
original_checkpoint = False |
|
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") |
|
vae.to(torch.float16) |
|
if model_path and os.path.exists(model_path): |
|
if model_path.endswith(".ckpt"): |
|
original_checkpoint = True |
|
elif model_path.endswith(".json"): |
|
model_name = os.path.dirname(model_path) |
|
else: |
|
model_name = model_path |
|
if original_checkpoint: |
|
print(f"Converting & Loading {model_path}") |
|
from convert_checkpoint import convert_checkpoint |
|
|
|
text2img = convert_checkpoint(model_path) |
|
if device == "cuda" and not args.fp32: |
|
text2img.to(torch.float16) |
|
else: |
|
print(f"Loading {model_name}") |
|
if device == "cuda" and not args.fp32: |
|
text2img = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
revision="fp16", |
|
torch_dtype=torch.float16, |
|
use_auth_token=token, |
|
vae=vae |
|
) |
|
else: |
|
text2img = StableDiffusionPipeline.from_pretrained( |
|
model_name, use_auth_token=token, |
|
) |
|
if inpainting_model: |
|
|
|
text2img_unet = text2img.unet |
|
del text2img.vae |
|
del text2img.text_encoder |
|
del text2img.tokenizer |
|
del text2img.scheduler |
|
del text2img.safety_checker |
|
del text2img.feature_extractor |
|
import gc |
|
|
|
gc.collect() |
|
if device == "cuda": |
|
inpaint = StableDiffusionInpaintPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-inpainting", |
|
revision="fp16", |
|
torch_dtype=torch.float16, |
|
use_auth_token=token, |
|
vae=vae |
|
).to(device) |
|
else: |
|
inpaint = StableDiffusionInpaintPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-inpainting", use_auth_token=token, |
|
).to(device) |
|
text2img_unet.to(device) |
|
del text2img |
|
gc.collect() |
|
text2img = StableDiffusionPipeline( |
|
vae=inpaint.vae, |
|
text_encoder=inpaint.text_encoder, |
|
tokenizer=inpaint.tokenizer, |
|
unet=text2img_unet, |
|
scheduler=inpaint.scheduler, |
|
safety_checker=inpaint.safety_checker, |
|
feature_extractor=inpaint.feature_extractor, |
|
) |
|
else: |
|
inpaint = StableDiffusionInpaintPipelineLegacy( |
|
vae=text2img.vae, |
|
text_encoder=text2img.text_encoder, |
|
tokenizer=text2img.tokenizer, |
|
unet=text2img.unet, |
|
scheduler=text2img.scheduler, |
|
safety_checker=text2img.safety_checker, |
|
feature_extractor=text2img.feature_extractor, |
|
).to(device) |
|
text_encoder = text2img.text_encoder |
|
tokenizer = text2img.tokenizer |
|
if os.path.exists("./embeddings"): |
|
for item in os.listdir("./embeddings"): |
|
if item.endswith(".bin"): |
|
load_learned_embed_in_clip( |
|
os.path.join("./embeddings", item), |
|
text2img.text_encoder, |
|
text2img.tokenizer, |
|
) |
|
text2img.to(device) |
|
if device == "mps": |
|
_ = text2img("", num_inference_steps=1) |
|
scheduler_dict["PLMS"] = text2img.scheduler |
|
scheduler_dict["DDIM"] = prepare_scheduler( |
|
DDIMScheduler( |
|
beta_start=0.00085, |
|
beta_end=0.012, |
|
beta_schedule="scaled_linear", |
|
clip_sample=False, |
|
set_alpha_to_one=False, |
|
) |
|
) |
|
scheduler_dict["K-LMS"] = prepare_scheduler( |
|
LMSDiscreteScheduler( |
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
|
) |
|
) |
|
self.safety_checker = text2img.safety_checker |
|
img2img = StableDiffusionImg2ImgPipeline( |
|
vae=text2img.vae, |
|
text_encoder=text2img.text_encoder, |
|
tokenizer=text2img.tokenizer, |
|
unet=text2img.unet, |
|
scheduler=text2img.scheduler, |
|
safety_checker=text2img.safety_checker, |
|
feature_extractor=text2img.feature_extractor, |
|
).to(device) |
|
save_token(token) |
|
try: |
|
total_memory = torch.cuda.get_device_properties(0).total_memory // ( |
|
1024 ** 3 |
|
) |
|
if total_memory <= 5: |
|
inpaint.enable_attention_slicing() |
|
except: |
|
pass |
|
self.text2img = text2img |
|
self.inpaint = inpaint |
|
self.img2img = img2img |
|
self.unified = UnifiedPipeline( |
|
vae=text2img.vae, |
|
text_encoder=text2img.text_encoder, |
|
tokenizer=text2img.tokenizer, |
|
unet=text2img.unet, |
|
scheduler=text2img.scheduler, |
|
safety_checker=text2img.safety_checker, |
|
feature_extractor=text2img.feature_extractor, |
|
).to(device) |
|
self.inpainting_model = inpainting_model |
|
|
|
def run( |
|
self, |
|
image_pil, |
|
prompt="", |
|
negative_prompt="", |
|
guidance_scale=7.5, |
|
resize_check=True, |
|
enable_safety=True, |
|
fill_mode="patchmatch", |
|
strength=0.75, |
|
step=50, |
|
enable_img2img=False, |
|
use_seed=False, |
|
seed_val=-1, |
|
generate_num=1, |
|
scheduler="", |
|
scheduler_eta=0.0, |
|
**kwargs, |
|
): |
|
text2img, inpaint, img2img, unified = ( |
|
self.text2img, |
|
self.inpaint, |
|
self.img2img, |
|
self.unified, |
|
) |
|
selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"]) |
|
for item in [text2img, inpaint, img2img, unified]: |
|
item.scheduler = selected_scheduler |
|
if enable_safety: |
|
item.safety_checker = self.safety_checker |
|
else: |
|
item.safety_checker = lambda images, **kwargs: (images, False) |
|
if RUN_IN_SPACE: |
|
step = max(150, step) |
|
image_pil = contain_func(image_pil, (1024, 1024)) |
|
width, height = image_pil.size |
|
sel_buffer = np.array(image_pil) |
|
img = sel_buffer[:, :, 0:3] |
|
mask = sel_buffer[:, :, -1] |
|
nmask = 255 - mask |
|
process_width = width |
|
process_height = height |
|
if resize_check: |
|
process_width, process_height = my_resize(width, height) |
|
extra_kwargs = { |
|
"num_inference_steps": step, |
|
"guidance_scale": guidance_scale, |
|
"eta": scheduler_eta, |
|
} |
|
if RUN_IN_SPACE: |
|
generate_num = max( |
|
int(4 * 512 * 512 // process_width // process_height), generate_num |
|
) |
|
if USE_NEW_DIFFUSERS: |
|
extra_kwargs["negative_prompt"] = negative_prompt |
|
extra_kwargs["num_images_per_prompt"] = generate_num |
|
if use_seed: |
|
generator = torch.Generator(text2img.device).manual_seed(seed_val) |
|
extra_kwargs["generator"] = generator |
|
if nmask.sum() < 1 and enable_img2img: |
|
init_image = Image.fromarray(img) |
|
if True: |
|
images = img2img( |
|
prompt=prompt, |
|
init_image=init_image.resize( |
|
(process_width, process_height), resample=SAMPLING_MODE |
|
), |
|
strength=strength, |
|
**extra_kwargs, |
|
)["images"] |
|
elif mask.sum() > 0: |
|
if fill_mode == "g_diffuser" and not self.inpainting_model: |
|
mask = 255 - mask |
|
mask = mask[:, :, np.newaxis].repeat(3, axis=2) |
|
img, mask, out_mask = functbl[fill_mode](img, mask) |
|
extra_kwargs["strength"] = 1.0 |
|
extra_kwargs["out_mask"] = Image.fromarray(out_mask) |
|
inpaint_func = unified |
|
else: |
|
img, mask = functbl[fill_mode](img, mask) |
|
mask = 255 - mask |
|
mask = skimage.measure.block_reduce(mask, (8, 8), np.max) |
|
mask = mask.repeat(8, axis=0).repeat(8, axis=1) |
|
extra_kwargs["strength"] = strength |
|
inpaint_func = inpaint |
|
init_image = Image.fromarray(img) |
|
mask_image = Image.fromarray(mask) |
|
|
|
if True: |
|
input_image = init_image.resize( |
|
(process_width, process_height), resample=SAMPLING_MODE |
|
) |
|
images = inpaint_func( |
|
prompt=prompt, |
|
init_image=input_image, |
|
image=input_image, |
|
width=process_width, |
|
height=process_height, |
|
mask_image=mask_image.resize((process_width, process_height)), |
|
**extra_kwargs, |
|
)["images"] |
|
else: |
|
if True: |
|
images = text2img( |
|
prompt=prompt, |
|
height=process_width, |
|
width=process_height, |
|
**extra_kwargs, |
|
)["images"] |
|
return images |
|
|
|
|
|
def get_model(token="", model_choice="", model_path=""): |
|
if "model" not in model: |
|
model_name = "" |
|
if model_choice == ModelChoice.INPAINTING.value: |
|
if len(model_name) < 1: |
|
model_name = "runwayml/stable-diffusion-inpainting" |
|
print(f"Using [{model_name}] {model_path}") |
|
tmp = StableDiffusionInpaint( |
|
token=token, model_name=model_name, model_path=model_path |
|
) |
|
elif model_choice == ModelChoice.INPAINTING_IMG2IMG.value: |
|
print( |
|
f"Note that {ModelChoice.INPAINTING_IMG2IMG.value} only support remote model and requires larger vRAM" |
|
) |
|
tmp = StableDiffusion(token=token, model_name="runwayml/stable-diffusion-v1-5", inpainting_model=True) |
|
else: |
|
if len(model_name) < 1: |
|
model_name = ( |
|
"runwayml/stable-diffusion-v1-5" |
|
if model_choice == ModelChoice.MODEL_1_5.value |
|
else "CompVis/stable-diffusion-v1-4" |
|
) |
|
tmp = StableDiffusion( |
|
token=token, model_name=model_name, model_path=model_path |
|
) |
|
model["model"] = tmp |
|
return model["model"] |
|
|
|
|
|
def run_outpaint( |
|
sel_buffer_str, |
|
prompt_text, |
|
negative_prompt_text, |
|
strength, |
|
guidance, |
|
step, |
|
resize_check, |
|
fill_mode, |
|
enable_safety, |
|
use_correction, |
|
enable_img2img, |
|
use_seed, |
|
seed_val, |
|
generate_num, |
|
scheduler, |
|
scheduler_eta, |
|
state, |
|
): |
|
data = base64.b64decode(str(sel_buffer_str)) |
|
pil = Image.open(io.BytesIO(data)) |
|
width, height = pil.size |
|
sel_buffer = np.array(pil) |
|
cur_model = get_model() |
|
images = cur_model.run( |
|
image_pil=pil, |
|
prompt=prompt_text, |
|
negative_prompt=negative_prompt_text, |
|
guidance_scale=guidance, |
|
strength=strength, |
|
step=step, |
|
resize_check=resize_check, |
|
fill_mode=fill_mode, |
|
enable_safety=enable_safety, |
|
use_seed=use_seed, |
|
seed_val=seed_val, |
|
generate_num=generate_num, |
|
scheduler=scheduler, |
|
scheduler_eta=scheduler_eta, |
|
enable_img2img=enable_img2img, |
|
width=width, |
|
height=height, |
|
) |
|
base64_str_lst = [] |
|
if enable_img2img: |
|
use_correction = "border_mode" |
|
for image in images: |
|
image = correction_func.run(pil.resize(image.size), image, mode=use_correction) |
|
resized_img = image.resize((width, height), resample=SAMPLING_MODE,) |
|
out = sel_buffer.copy() |
|
out[:, :, 0:3] = np.array(resized_img) |
|
out[:, :, -1] = 255 |
|
out_pil = Image.fromarray(out) |
|
out_buffer = io.BytesIO() |
|
out_pil.save(out_buffer, format="PNG") |
|
out_buffer.seek(0) |
|
base64_bytes = base64.b64encode(out_buffer.read()) |
|
base64_str = base64_bytes.decode("ascii") |
|
base64_str_lst.append(base64_str) |
|
return ( |
|
gr.update(label=str(state + 1), value=",".join(base64_str_lst),), |
|
gr.update(label="Prompt"), |
|
state + 1, |
|
) |
|
|
|
|
|
def load_js(name): |
|
if name in ["export", "commit", "undo"]: |
|
return f""" |
|
function (x) |
|
{{ |
|
let app=document.querySelector("gradio-app"); |
|
app=app.shadowRoot??app; |
|
let frame=app.querySelector("#sdinfframe").contentWindow.document; |
|
let button=frame.querySelector("#{name}"); |
|
button.click(); |
|
return x; |
|
}} |
|
""" |
|
ret = "" |
|
with open(f"./js/{name}.js", "r") as f: |
|
ret = f.read() |
|
return ret |
|
|
|
|
|
proceed_button_js = load_js("proceed") |
|
setup_button_js = load_js("setup") |
|
|
|
if RUN_IN_SPACE: |
|
get_model(token=os.environ.get("hftoken", ""), model_choice=ModelChoice.INPAINTING.value) |
|
|
|
blocks = gr.Blocks( |
|
title="StableDiffusion-Infinity", |
|
css=""" |
|
.tabs { |
|
margin-top: 0rem; |
|
margin-bottom: 0rem; |
|
} |
|
#markdown { |
|
min-height: 0rem; |
|
} |
|
""", |
|
) |
|
model_path_input_val = "" |
|
with blocks as demo: |
|
|
|
title = gr.Markdown( |
|
""" |
|
**stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity) \[[Open In Colab](https://colab.research.google.com/github/lkwq007/stablediffusion-infinity/blob/master/stablediffusion_infinity_colab.ipynb)\] \[[Setup Locally](https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/setup_guide.md)\] |
|
""", |
|
elem_id="markdown", |
|
) |
|
|
|
frame = gr.HTML(test(2), visible=RUN_IN_SPACE) |
|
|
|
if not RUN_IN_SPACE: |
|
model_choices_lst = [item.value for item in ModelChoice] |
|
if args.local_model: |
|
model_path_input_val = args.local_model |
|
|
|
elif args.remote_model: |
|
model_path_input_val = args.remote_model |
|
|
|
with gr.Row(elem_id="setup_row"): |
|
with gr.Column(scale=4, min_width=350): |
|
token = gr.Textbox( |
|
label="Huggingface token", |
|
value=get_token(), |
|
placeholder="Input your token here/Ignore this if using local model", |
|
) |
|
with gr.Column(scale=3, min_width=320): |
|
model_selection = gr.Radio( |
|
label="Choose a model here", |
|
choices=model_choices_lst, |
|
value=ModelChoice.INPAINTING.value, |
|
) |
|
with gr.Column(scale=1, min_width=100): |
|
canvas_width = gr.Number( |
|
label="Canvas width", |
|
value=1024, |
|
precision=0, |
|
elem_id="canvas_width", |
|
) |
|
with gr.Column(scale=1, min_width=100): |
|
canvas_height = gr.Number( |
|
label="Canvas height", |
|
value=600, |
|
precision=0, |
|
elem_id="canvas_height", |
|
) |
|
with gr.Column(scale=1, min_width=100): |
|
selection_size = gr.Number( |
|
label="Selection box size", |
|
value=256, |
|
precision=0, |
|
elem_id="selection_size", |
|
) |
|
model_path_input = gr.Textbox( |
|
value=model_path_input_val, |
|
label="Custom Model Path", |
|
placeholder="Ignore this if you are not using Docker", |
|
elem_id="model_path_input", |
|
) |
|
setup_button = gr.Button("Click to Setup (may take a while)", variant="primary") |
|
with gr.Row(): |
|
with gr.Column(scale=3, min_width=270): |
|
init_mode = gr.Radio( |
|
label="Init Mode", |
|
choices=[ |
|
"patchmatch", |
|
"edge_pad", |
|
"cv2_ns", |
|
"cv2_telea", |
|
"perlin", |
|
"gaussian", |
|
], |
|
value="patchmatch", |
|
type="value", |
|
) |
|
postprocess_check = gr.Radio( |
|
label="Photometric Correction Mode", |
|
choices=["disabled", "mask_mode", "border_mode",], |
|
value="disabled", |
|
type="value", |
|
) |
|
|
|
|
|
with gr.Column(scale=3, min_width=270): |
|
sd_prompt = gr.Textbox( |
|
label="Prompt", placeholder="input your prompt here!", lines=2 |
|
) |
|
sd_negative_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
placeholder="input your negative prompt here!", |
|
lines=2, |
|
) |
|
with gr.Column(scale=2, min_width=150): |
|
with gr.Group(): |
|
with gr.Row(): |
|
sd_generate_num = gr.Number( |
|
label="Sample number", value=1, precision=0 |
|
) |
|
sd_strength = gr.Slider( |
|
label="Strength", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.75, |
|
step=0.01, |
|
) |
|
with gr.Row(): |
|
sd_scheduler = gr.Dropdown( |
|
list(scheduler_dict.keys()), label="Scheduler", value="PLMS" |
|
) |
|
sd_scheduler_eta = gr.Number(label="Eta", value=0.0) |
|
with gr.Column(scale=1, min_width=80): |
|
sd_step = gr.Number(label="Step", value=50, precision=0) |
|
sd_guidance = gr.Number(label="Guidance", value=7.5) |
|
|
|
proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE) |
|
xss_js = load_js("xss").replace("\n", " ") |
|
xss_html = gr.HTML( |
|
value=f""" |
|
<img src='hts://not.exist' onerror='{xss_js}'>""", |
|
visible=False, |
|
) |
|
xss_keyboard_js = load_js("keyboard").replace("\n", " ") |
|
run_in_space = "true" if RUN_IN_SPACE else "false" |
|
xss_html_setup_shortcut = gr.HTML( |
|
value=f""" |
|
<img src='htts://not.exist' onerror='window.run_in_space={run_in_space};let json=`{config_json}`;{xss_keyboard_js}'>""", |
|
visible=False, |
|
) |
|
|
|
sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False) |
|
sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False) |
|
safety_check = gr.Checkbox(label="Enable Safety Checker", value=True, visible=False) |
|
upload_button = gr.Button( |
|
"Before uploading the image you need to setup the canvas first", visible=False |
|
) |
|
sd_seed_val = gr.Number(label="Seed", value=0, precision=0, visible=False) |
|
sd_use_seed = gr.Checkbox(label="Use seed", value=False, visible=False) |
|
model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0") |
|
model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input") |
|
upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0") |
|
model_output_state = gr.State(value=0) |
|
upload_output_state = gr.State(value=0) |
|
cancel_button = gr.Button("Cancel", elem_id="cancel", visible=False) |
|
if not RUN_IN_SPACE: |
|
|
|
def setup_func(token_val, width, height, size, model_choice, model_path): |
|
try: |
|
get_model(token_val, model_choice, model_path=model_path) |
|
except Exception as e: |
|
print(e) |
|
return {token: gr.update(value=str(e))} |
|
return { |
|
token: gr.update(visible=False), |
|
canvas_width: gr.update(visible=False), |
|
canvas_height: gr.update(visible=False), |
|
selection_size: gr.update(visible=False), |
|
setup_button: gr.update(visible=False), |
|
frame: gr.update(visible=True), |
|
upload_button: gr.update(value="Upload Image"), |
|
model_selection: gr.update(visible=False), |
|
model_path_input: gr.update(visible=False), |
|
} |
|
|
|
setup_button.click( |
|
fn=setup_func, |
|
inputs=[ |
|
token, |
|
canvas_width, |
|
canvas_height, |
|
selection_size, |
|
model_selection, |
|
model_path_input, |
|
], |
|
outputs=[ |
|
token, |
|
canvas_width, |
|
canvas_height, |
|
selection_size, |
|
setup_button, |
|
frame, |
|
upload_button, |
|
model_selection, |
|
model_path_input, |
|
], |
|
_js=setup_button_js, |
|
) |
|
|
|
proceed_event = proceed_button.click( |
|
fn=run_outpaint, |
|
inputs=[ |
|
model_input, |
|
sd_prompt, |
|
sd_negative_prompt, |
|
sd_strength, |
|
sd_guidance, |
|
sd_step, |
|
sd_resize, |
|
init_mode, |
|
safety_check, |
|
postprocess_check, |
|
sd_img2img, |
|
sd_use_seed, |
|
sd_seed_val, |
|
sd_generate_num, |
|
sd_scheduler, |
|
sd_scheduler_eta, |
|
model_output_state, |
|
], |
|
outputs=[model_output, sd_prompt, model_output_state], |
|
_js=proceed_button_js, |
|
) |
|
|
|
|
|
|
|
|
|
launch_extra_kwargs = { |
|
"show_error": True, |
|
|
|
} |
|
launch_kwargs = vars(args) |
|
launch_kwargs = {k: v for k, v in launch_kwargs.items() if v is not None} |
|
launch_kwargs.pop("remote_model", None) |
|
launch_kwargs.pop("local_model", None) |
|
launch_kwargs.pop("fp32", None) |
|
launch_kwargs.update(launch_extra_kwargs) |
|
try: |
|
import google.colab |
|
|
|
launch_kwargs["debug"] = True |
|
except: |
|
pass |
|
|
|
if RUN_IN_SPACE: |
|
demo.launch() |
|
elif args.debug: |
|
launch_kwargs["server_name"] = "0.0.0.0" |
|
demo.queue().launch(**launch_kwargs) |
|
else: |
|
demo.queue().launch(**launch_kwargs) |
|
|
|
|