Spaces:
Runtime error
Runtime error
import random | |
import tempfile | |
import time | |
import gradio as gr | |
import numpy as np | |
import torch | |
import math | |
import re | |
from gradio import inputs | |
from diffusers import ( | |
AutoencoderKL, | |
DDIMScheduler, | |
UNet2DConditionModel, | |
) | |
from modules.model import ( | |
CrossAttnProcessor, | |
StableDiffusionPipeline, | |
) | |
from torchvision import transforms | |
from transformers import CLIPTokenizer, CLIPTextModel | |
from PIL import Image | |
from pathlib import Path | |
from safetensors.torch import load_file | |
import modules.safe as _ | |
from modules.lora import LoRANetwork | |
models = [ | |
("LSPV1", "Jeffsun/LSP", 2), | |
] | |
keep_vram = ["Jeffsun/LSP"] | |
base_name, base_model, clip_skip = models[0] | |
samplers_k_diffusion = [ | |
("Euler a", "sample_euler_ancestral", {}), | |
("Euler", "sample_euler", {}), | |
("LMS", "sample_lms", {}), | |
("Heun", "sample_heun", {}), | |
("DPM2", "sample_dpm_2", {"discard_next_to_last_sigma": True}), | |
("DPM2 a", "sample_dpm_2_ancestral", {"discard_next_to_last_sigma": True}), | |
("DPM++ 2S a", "sample_dpmpp_2s_ancestral", {}), | |
("DPM++ 2M", "sample_dpmpp_2m", {}), | |
("DPM++ SDE", "sample_dpmpp_sde", {}), | |
("LMS Karras", "sample_lms", {"scheduler": "karras"}), | |
("DPM2 Karras", "sample_dpm_2", {"scheduler": "karras", "discard_next_to_last_sigma": True}), | |
("DPM2 a Karras", "sample_dpm_2_ancestral", {"scheduler": "karras", "discard_next_to_last_sigma": True}), | |
("DPM++ 2S a Karras", "sample_dpmpp_2s_ancestral", {"scheduler": "karras"}), | |
("DPM++ 2M Karras", "sample_dpmpp_2m", {"scheduler": "karras"}), | |
("DPM++ SDE Karras", "sample_dpmpp_sde", {"scheduler": "karras"}), | |
] | |
# samplers_diffusers = [ | |
# ("DDIMScheduler", "diffusers.schedulers.DDIMScheduler", {}) | |
# ("DDPMScheduler", "diffusers.schedulers.DDPMScheduler", {}) | |
# ("DEISMultistepScheduler", "diffusers.schedulers.DEISMultistepScheduler", {}) | |
# ] | |
start_time = time.time() | |
timeout = 90 | |
scheduler = DDIMScheduler.from_pretrained( | |
base_model, | |
subfolder="scheduler", | |
) | |
vae = AutoencoderKL.from_pretrained( | |
"stabilityai/sd-vae-ft-ema", | |
torch_dtype=torch.float16 | |
) | |
text_encoder = CLIPTextModel.from_pretrained( | |
base_model, | |
subfolder="text_encoder", | |
torch_dtype=torch.float16, | |
) | |
tokenizer = CLIPTokenizer.from_pretrained( | |
base_model, | |
subfolder="tokenizer", | |
torch_dtype=torch.float16, | |
) | |
unet = UNet2DConditionModel.from_pretrained( | |
base_model, | |
subfolder="unet", | |
torch_dtype=torch.float16, | |
) | |
pipe = StableDiffusionPipeline( | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
vae=vae, | |
scheduler=scheduler, | |
) | |
unet.set_attn_processor(CrossAttnProcessor) | |
pipe.setup_text_encoder(clip_skip, text_encoder) | |
if torch.cuda.is_available(): | |
pipe = pipe.to("cuda") | |
def get_model_list(): | |
return models | |
te_cache = { | |
base_model: text_encoder | |
} | |
unet_cache = { | |
base_model: unet | |
} | |
lora_cache = { | |
base_model: LoRANetwork(text_encoder, unet) | |
} | |
te_base_weight_length = text_encoder.get_input_embeddings().weight.data.shape[0] | |
original_prepare_for_tokenization = tokenizer.prepare_for_tokenization | |
current_model = base_model | |
def setup_model(name, lora_state=None, lora_scale=1.0): | |
global pipe, current_model | |
keys = [k[0] for k in models] | |
model = models[keys.index(name)][1] | |
if model not in unet_cache: | |
unet = UNet2DConditionModel.from_pretrained(model, subfolder="unet", torch_dtype=torch.float16) | |
text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder", torch_dtype=torch.float16) | |
unet_cache[model] = unet | |
te_cache[model] = text_encoder | |
lora_cache[model] = LoRANetwork(text_encoder, unet) | |
if current_model != model: | |
if current_model not in keep_vram: | |
# offload current model | |
unet_cache[current_model].to("cpu") | |
te_cache[current_model].to("cpu") | |
lora_cache[current_model].to("cpu") | |
current_model = model | |
local_te, local_unet, local_lora, = te_cache[model], unet_cache[model], lora_cache[model] | |
local_unet.set_attn_processor(CrossAttnProcessor()) | |
local_lora.reset() | |
clip_skip = models[keys.index(name)][2] | |
if torch.cuda.is_available(): | |
local_unet.to("cuda") | |
local_te.to("cuda") | |
if lora_state is not None and lora_state != "": | |
local_lora.load(lora_state, lora_scale) | |
local_lora.to(local_unet.device, dtype=local_unet.dtype) | |
pipe.text_encoder, pipe.unet = local_te, local_unet | |
pipe.setup_unet(local_unet) | |
pipe.tokenizer.prepare_for_tokenization = original_prepare_for_tokenization | |
pipe.tokenizer.added_tokens_encoder = {} | |
pipe.tokenizer.added_tokens_decoder = {} | |
pipe.setup_text_encoder(clip_skip, local_te) | |
return pipe | |
def error_str(error, title="Error"): | |
return ( | |
f"""#### {title} | |
{error}""" | |
if error | |
else "" | |
) | |
def make_token_names(embs): | |
all_tokens = [] | |
for name, vec in embs.items(): | |
tokens = [f'emb-{name}-{i}' for i in range(len(vec))] | |
all_tokens.append(tokens) | |
return all_tokens | |
def setup_tokenizer(tokenizer, embs): | |
reg_match = [re.compile(fr"(?:^|(?<=\s|,)){k}(?=,|\s|$)") for k in embs.keys()] | |
clip_keywords = [' '.join(s) for s in make_token_names(embs)] | |
def parse_prompt(prompt: str): | |
for m, v in zip(reg_match, clip_keywords): | |
prompt = m.sub(v, prompt) | |
return prompt | |
def prepare_for_tokenization(self, text: str, is_split_into_words: bool = False, **kwargs): | |
text = parse_prompt(text) | |
r = original_prepare_for_tokenization(text, is_split_into_words, **kwargs) | |
return r | |
tokenizer.prepare_for_tokenization = prepare_for_tokenization.__get__(tokenizer, CLIPTokenizer) | |
return [t for sublist in make_token_names(embs) for t in sublist] | |
def convert_size(size_bytes): | |
if size_bytes == 0: | |
return "0B" | |
size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") | |
i = int(math.floor(math.log(size_bytes, 1024))) | |
p = math.pow(1024, i) | |
s = round(size_bytes / p, 2) | |
return "%s %s" % (s, size_name[i]) | |
def inference( | |
prompt, | |
guidance, | |
steps, | |
width=512, | |
height=512, | |
seed=0, | |
neg_prompt="", | |
state=None, | |
g_strength=0.4, | |
img_input=None, | |
i2i_scale=0.5, | |
hr_enabled=False, | |
hr_method="Latent", | |
hr_scale=1.5, | |
hr_denoise=0.8, | |
sampler="DPM++ 2M Karras", | |
embs=None, | |
model=None, | |
lora_state=None, | |
lora_scale=None, | |
): | |
if seed is None or seed == 0: | |
seed = random.randint(0, 2147483647) | |
pipe = setup_model(model, lora_state, lora_scale) | |
generator = torch.Generator("cuda").manual_seed(int(seed)) | |
start_time = time.time() | |
sampler_name, sampler_opt = None, None | |
for label, funcname, options in samplers_k_diffusion: | |
if label == sampler: | |
sampler_name, sampler_opt = funcname, options | |
tokenizer, text_encoder = pipe.tokenizer, pipe.text_encoder | |
if embs is not None and len(embs) > 0: | |
ti_embs = {} | |
for name, file in embs.items(): | |
if str(file).endswith(".pt"): | |
loaded_learned_embeds = torch.load(file, map_location="cpu") | |
else: | |
loaded_learned_embeds = load_file(file, device="cpu") | |
loaded_learned_embeds = loaded_learned_embeds["string_to_param"]["*"] if "string_to_param" in loaded_learned_embed else loaded_learned_embed | |
ti_embs[name] = loaded_learned_embeds | |
if len(ti_embs) > 0: | |
tokens = setup_tokenizer(tokenizer, ti_embs) | |
added_tokens = tokenizer.add_tokens(tokens) | |
delta_weight = torch.cat([val for val in ti_embs.values()], dim=0) | |
assert added_tokens == delta_weight.shape[0] | |
text_encoder.resize_token_embeddings(len(tokenizer)) | |
token_embeds = text_encoder.get_input_embeddings().weight.data | |
token_embeds[-delta_weight.shape[0]:] = delta_weight | |
config = { | |
"negative_prompt": neg_prompt, | |
"num_inference_steps": int(steps), | |
"guidance_scale": guidance, | |
"generator": generator, | |
"sampler_name": sampler_name, | |
"sampler_opt": sampler_opt, | |
"pww_state": state, | |
"pww_attn_weight": g_strength, | |
"start_time": start_time, | |
"timeout": timeout, | |
} | |
if img_input is not None: | |
ratio = min(height / img_input.height, width / img_input.width) | |
img_input = img_input.resize( | |
(int(img_input.width * ratio), int(img_input.height * ratio)), Image.LANCZOS | |
) | |
result = pipe.img2img(prompt, image=img_input, strength=i2i_scale, **config) | |
elif hr_enabled: | |
result = pipe.txt2img( | |
prompt, | |
width=width, | |
height=height, | |
upscale=True, | |
upscale_x=hr_scale, | |
upscale_denoising_strength=hr_denoise, | |
**config, | |
**latent_upscale_modes[hr_method], | |
) | |
else: | |
result = pipe.txt2img(prompt, width=width, height=height, **config) | |
end_time = time.time() | |
vram_free, vram_total = torch.cuda.mem_get_info() | |
print(f"done: model={model}, res={width}x{height}, step={steps}, time={round(end_time-start_time, 2)}s, vram_alloc={convert_size(vram_total-vram_free)}/{convert_size(vram_total)}") | |
return gr.Image.update(result[0][0], label=f"Initial Seed: {seed}") | |
color_list = [] | |
def get_color(n): | |
for _ in range(n - len(color_list)): | |
color_list.append(tuple(np.random.random(size=3) * 256)) | |
return color_list | |
def create_mixed_img(current, state, w=512, h=512): | |
w, h = int(w), int(h) | |
image_np = np.full([h, w, 4], 255) | |
if state is None: | |
state = {} | |
colors = get_color(len(state)) | |
idx = 0 | |
for key, item in state.items(): | |
if item["map"] is not None: | |
m = item["map"] < 255 | |
alpha = 150 | |
if current == key: | |
alpha = 200 | |
image_np[m] = colors[idx] + (alpha,) | |
idx += 1 | |
return image_np | |
# width.change(apply_new_res, inputs=[width, height, global_stats], outputs=[global_stats, sp, rendered]) | |
def apply_new_res(w, h, state): | |
w, h = int(w), int(h) | |
for key, item in state.items(): | |
if item["map"] is not None: | |
item["map"] = resize(item["map"], w, h) | |
update_img = gr.Image.update(value=create_mixed_img("", state, w, h)) | |
return state, update_img | |
def detect_text(text, state, width, height): | |
if text is None or text == "": | |
return None, None, gr.Radio.update(value=None), None | |
t = text.split(",") | |
new_state = {} | |
for item in t: | |
item = item.strip() | |
if item == "": | |
continue | |
if state is not None and item in state: | |
new_state[item] = { | |
"map": state[item]["map"], | |
"weight": state[item]["weight"], | |
"mask_outsides": state[item]["mask_outsides"], | |
} | |
else: | |
new_state[item] = { | |
"map": None, | |
"weight": 0.5, | |
"mask_outsides": False | |
} | |
update = gr.Radio.update(choices=[key for key in new_state.keys()], value=None) | |
update_img = gr.update(value=create_mixed_img("", new_state, width, height)) | |
update_sketch = gr.update(value=None, interactive=False) | |
return new_state, update_sketch, update, update_img | |
def resize(img, w, h): | |
trs = transforms.Compose( | |
[ | |
transforms.ToPILImage(), | |
transforms.Resize(min(h, w)), | |
transforms.CenterCrop((h, w)), | |
] | |
) | |
result = np.array(trs(img), dtype=np.uint8) | |
return result | |
def switch_canvas(entry, state, width, height): | |
if entry == None: | |
return None, 0.5, False, create_mixed_img("", state, width, height) | |
return ( | |
gr.update(value=None, interactive=True), | |
gr.update(value=state[entry]["weight"] if entry in state else 0.5), | |
gr.update(value=state[entry]["mask_outsides"] if entry in state else False), | |
create_mixed_img(entry, state, width, height), | |
) | |
def apply_canvas(selected, draw, state, w, h): | |
if selected in state: | |
w, h = int(w), int(h) | |
state[selected]["map"] = resize(draw, w, h) | |
return state, gr.Image.update(value=create_mixed_img(selected, state, w, h)) | |
def apply_weight(selected, weight, state): | |
if selected in state: | |
state[selected]["weight"] = weight | |
return state | |
def apply_option(selected, mask, state): | |
if selected in state: | |
state[selected]["mask_outsides"] = mask | |
return state | |
# sp2, radio, width, height, global_stats | |
def apply_image(image, selected, w, h, strgength, mask, state): | |
if selected in state: | |
state[selected] = { | |
"map": resize(image, w, h), | |
"weight": strgength, | |
"mask_outsides": mask | |
} | |
return state, gr.Image.update(value=create_mixed_img(selected, state, w, h)) | |
# [ti_state, lora_state, ti_vals, lora_vals, uploads] | |
def add_net(files, ti_state, lora_state): | |
if files is None: | |
return ti_state, "", lora_state, None | |
for file in files: | |
item = Path(file.name) | |
stripedname = str(item.stem).strip() | |
if item.suffix == ".pt": | |
state_dict = torch.load(file.name, map_location="cpu") | |
else: | |
state_dict = load_file(file.name, device="cpu") | |
if any("lora" in k for k in state_dict.keys()): | |
lora_state = file.name | |
else: | |
ti_state[stripedname] = file.name | |
return ( | |
ti_state, | |
lora_state, | |
gr.Text.update(f"{[key for key in ti_state.keys()]}"), | |
gr.Text.update(f"{lora_state}"), | |
gr.Files.update(value=None), | |
) | |
# [ti_state, lora_state, ti_vals, lora_vals, uploads] | |
def clean_states(ti_state, lora_state): | |
return ( | |
dict(), | |
None, | |
gr.Text.update(f""), | |
gr.Text.update(f""), | |
gr.File.update(value=None), | |
) | |
latent_upscale_modes = { | |
"Latent": {"upscale_method": "bilinear", "upscale_antialias": False}, | |
"Latent (antialiased)": {"upscale_method": "bilinear", "upscale_antialias": True}, | |
"Latent (bicubic)": {"upscale_method": "bicubic", "upscale_antialias": False}, | |
"Latent (bicubic antialiased)": { | |
"upscale_method": "bicubic", | |
"upscale_antialias": True, | |
}, | |
"Latent (nearest)": {"upscale_method": "nearest", "upscale_antialias": False}, | |
"Latent (nearest-exact)": { | |
"upscale_method": "nearest-exact", | |
"upscale_antialias": False, | |
}, | |
} | |
css = """ | |
.finetuned-diffusion-div div{ | |
display:inline-flex; | |
align-items:center; | |
gap:.8rem; | |
font-size:1.75rem; | |
padding-top:2rem; | |
} | |
.finetuned-diffusion-div div h1{ | |
font-weight:900; | |
margin-bottom:7px | |
} | |
.finetuned-diffusion-div p{ | |
margin-bottom:10px; | |
font-size:94% | |
} | |
.box { | |
float: left; | |
height: 20px; | |
width: 20px; | |
margin-bottom: 15px; | |
border: 1px solid black; | |
clear: both; | |
} | |
a{ | |
text-decoration:underline | |
} | |
.tabs{ | |
margin-top:0; | |
margin-bottom:0 | |
} | |
#gallery{ | |
min-height:20rem | |
} | |
.no-border { | |
border: none !important; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML( | |
f""" | |
<div class="finetuned-diffusion-div"> | |
<div> | |
<h1>Demo for diffusion models</h1> | |
</div> | |
<p>Hso @ nyanko.sketch2img.gradio</p> | |
</div> | |
""" | |
) | |
global_stats = gr.State(value={}) | |
with gr.Row(): | |
with gr.Column(scale=55): | |
model = gr.Dropdown( | |
choices=[k[0] for k in get_model_list()], | |
label="Model", | |
value=base_name, | |
) | |
image_out = gr.Image(height=512) | |
# gallery = gr.Gallery( | |
# label="Generated images", show_label=False, elem_id="gallery" | |
# ).style(grid=[1], height="auto") | |
with gr.Column(scale=45): | |
with gr.Group(): | |
with gr.Row(): | |
with gr.Column(scale=70): | |
prompt = gr.Textbox( | |
label="Prompt", | |
value="best quality, masterpiece, highres, an extremely delicate and beautiful, original, extremely detailed wallpaper, highres , 1girl", | |
show_label=True, | |
max_lines=4, | |
placeholder="Enter prompt.", | |
) | |
neg_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="simple background,monochrome ,lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits,twisting jawline, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, lowres, bad anatomy, bad hands, text, error, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, ugly,pregnant,vore,duplicate,morbid,mut ilated,tran nsexual, hermaphrodite,long neck,mutated hands,poorly drawn hands,poorly drawn face,mutation,deformed,blurry,bad anatomy,bad proportions,malformed limbs,extra limbs,cloned face,disfigured,gross proportions, missing arms, missing legs, extra arms,extra legs,pubic hair, plump,bad legs,error legs,username,blurry,bad feet", | |
show_label=True, | |
max_lines=4, | |
placeholder="Enter negative prompt.", | |
) | |
generate = gr.Button(value="Generate").style( | |
rounded=(False, True, True, False) | |
) | |
with gr.Tab("Options"): | |
with gr.Group(): | |
# n_images = gr.Slider(label="Images", value=1, minimum=1, maximum=4, step=1) | |
with gr.Row(): | |
guidance = gr.Slider( | |
label="Guidance scale", value=7.5, maximum=15 | |
) | |
steps = gr.Slider( | |
label="Steps", value=25, minimum=2, maximum=50, step=1 | |
) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", value=512, minimum=64, maximum=1024, step=64 | |
) | |
height = gr.Slider( | |
label="Height", value=512, minimum=64, maximum=1024, step=64 | |
) | |
sampler = gr.Dropdown( | |
value="DPM++ 2M Karras", | |
label="Sampler", | |
choices=[s[0] for s in samplers_k_diffusion], | |
) | |
seed = gr.Number(label="Seed (0 = random)", value=0) | |
with gr.Tab("Image to image"): | |
with gr.Group(): | |
inf_image = gr.Image( | |
label="Image", height=256, tool="editor", type="pil" | |
) | |
inf_strength = gr.Slider( | |
label="Transformation strength", | |
minimum=0, | |
maximum=1, | |
step=0.01, | |
value=0.5, | |
) | |
def res_cap(g, w, h, x): | |
if g: | |
return f"Enable upscaler: {w}x{h} to {int(w*x)}x{int(h*x)}" | |
else: | |
return "Enable upscaler" | |
with gr.Tab("Hires fix"): | |
with gr.Group(): | |
hr_enabled = gr.Checkbox(label="Enable upscaler", value=False) | |
hr_method = gr.Dropdown( | |
[key for key in latent_upscale_modes.keys()], | |
value="Latent", | |
label="Upscale method", | |
) | |
hr_scale = gr.Slider( | |
label="Upscale factor", | |
minimum=1.0, | |
maximum=2.0, | |
step=0.1, | |
value=1.5, | |
) | |
hr_denoise = gr.Slider( | |
label="Denoising strength", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=0.8, | |
) | |
hr_scale.change( | |
lambda g, x, w, h: gr.Checkbox.update( | |
label=res_cap(g, w, h, x) | |
), | |
inputs=[hr_enabled, hr_scale, width, height], | |
outputs=hr_enabled, | |
queue=False, | |
) | |
hr_enabled.change( | |
lambda g, x, w, h: gr.Checkbox.update( | |
label=res_cap(g, w, h, x) | |
), | |
inputs=[hr_enabled, hr_scale, width, height], | |
outputs=hr_enabled, | |
queue=False, | |
) | |
with gr.Tab("Embeddings/Loras"): | |
ti_state = gr.State(dict()) | |
lora_state = gr.State() | |
with gr.Group(): | |
with gr.Row(): | |
with gr.Column(scale=90): | |
ti_vals = gr.Text(label="Loaded embeddings") | |
with gr.Row(): | |
with gr.Column(scale=90): | |
lora_vals = gr.Text(label="Loaded loras") | |
with gr.Row(): | |
uploads = gr.Files(label="Upload new embeddings/lora") | |
with gr.Column(): | |
lora_scale = gr.Slider( | |
label="Lora scale", | |
minimum=0, | |
maximum=2, | |
step=0.01, | |
value=1.0, | |
) | |
btn = gr.Button(value="Upload") | |
btn_del = gr.Button(value="Reset") | |
btn.click( | |
add_net, | |
inputs=[uploads, ti_state, lora_state], | |
outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads], | |
queue=False, | |
) | |
btn_del.click( | |
clean_states, | |
inputs=[ti_state, lora_state], | |
outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads], | |
queue=False, | |
) | |
# error_output = gr.Markdown() | |
gr.HTML( | |
f""" | |
<div class="finetuned-diffusion-div"> | |
<div> | |
<h1>Paint with words</h1> | |
</div> | |
<p> | |
Will use the following formula: w = scale * token_weight_martix * log(1 + sigma) * max(qk). | |
</p> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=55): | |
rendered = gr.Image( | |
invert_colors=True, | |
source="canvas", | |
interactive=False, | |
image_mode="RGBA", | |
) | |
with gr.Column(scale=45): | |
with gr.Group(): | |
with gr.Row(): | |
with gr.Column(scale=70): | |
g_strength = gr.Slider( | |
label="Weight scaling", | |
minimum=0, | |
maximum=0.8, | |
step=0.01, | |
value=0.4, | |
) | |
text = gr.Textbox( | |
lines=2, | |
interactive=True, | |
label="Token to Draw: (Separate by comma)", | |
) | |
radio = gr.Radio([], label="Tokens") | |
sk_update = gr.Button(value="Update").style( | |
rounded=(False, True, True, False) | |
) | |
# g_strength.change(lambda b: gr.update(f"Scaled additional attn: $w = {b} \log (1 + \sigma) \std (Q^T K)$."), inputs=g_strength, outputs=[g_output]) | |
with gr.Tab("SketchPad"): | |
sp = gr.Image( | |
image_mode="L", | |
tool="sketch", | |
source="canvas", | |
interactive=False, | |
) | |
mask_outsides = gr.Checkbox( | |
label="Mask other areas", | |
value=False | |
) | |
strength = gr.Slider( | |
label="Token strength", | |
minimum=0, | |
maximum=0.8, | |
step=0.01, | |
value=0.5, | |
) | |
sk_update.click( | |
detect_text, | |
inputs=[text, global_stats, width, height], | |
outputs=[global_stats, sp, radio, rendered], | |
queue=False, | |
) | |
radio.change( | |
switch_canvas, | |
inputs=[radio, global_stats, width, height], | |
outputs=[sp, strength, mask_outsides, rendered], | |
queue=False, | |
) | |
sp.edit( | |
apply_canvas, | |
inputs=[radio, sp, global_stats, width, height], | |
outputs=[global_stats, rendered], | |
queue=False, | |
) | |
strength.change( | |
apply_weight, | |
inputs=[radio, strength, global_stats], | |
outputs=[global_stats], | |
queue=False, | |
) | |
mask_outsides.change( | |
apply_option, | |
inputs=[radio, mask_outsides, global_stats], | |
outputs=[global_stats], | |
queue=False, | |
) | |
with gr.Tab("UploadFile"): | |
sp2 = gr.Image( | |
image_mode="L", | |
source="upload", | |
shape=(512, 512), | |
) | |
mask_outsides2 = gr.Checkbox( | |
label="Mask other areas", | |
value=False, | |
) | |
strength2 = gr.Slider( | |
label="Token strength", | |
minimum=0, | |
maximum=0.8, | |
step=0.01, | |
value=0.5, | |
) | |
apply_style = gr.Button(value="Apply") | |
apply_style.click( | |
apply_image, | |
inputs=[sp2, radio, width, height, strength2, mask_outsides2, global_stats], | |
outputs=[global_stats, rendered], | |
queue=False, | |
) | |
width.change( | |
apply_new_res, | |
inputs=[width, height, global_stats], | |
outputs=[global_stats, rendered], | |
queue=False, | |
) | |
height.change( | |
apply_new_res, | |
inputs=[width, height, global_stats], | |
outputs=[global_stats, rendered], | |
queue=False, | |
) | |
# color_stats = gr.State(value={}) | |
# text.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered]) | |
# sp.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered]) | |
inputs = [ | |
prompt, | |
guidance, | |
steps, | |
width, | |
height, | |
seed, | |
neg_prompt, | |
global_stats, | |
g_strength, | |
inf_image, | |
inf_strength, | |
hr_enabled, | |
hr_method, | |
hr_scale, | |
hr_denoise, | |
sampler, | |
ti_state, | |
model, | |
lora_state, | |
lora_scale, | |
] | |
outputs = [image_out] | |
prompt.submit(inference, inputs=inputs, outputs=outputs) | |
generate.click(inference, inputs=inputs, outputs=outputs) | |
print(f"Space built in {time.time() - start_time:.2f} seconds") | |
# demo.launch(share=True) | |
demo.launch(enable_queue=True, server_name="0.0.0.0", server_port=7860) |