MiaLiu222 commited on
Commit
a7addcb
1 Parent(s): 0157bcd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1025
app.py CHANGED
@@ -1,1028 +1,7 @@
1
- import io
2
- import base64
3
- import os
4
- import sys
5
-
6
- import numpy as np
7
- import torch
8
- from torch import autocast
9
- import diffusers
10
- from diffusers.configuration_utils import FrozenDict
11
- from diffusers import (
12
- StableDiffusionPipeline,
13
- StableDiffusionInpaintPipeline,
14
- StableDiffusionImg2ImgPipeline,
15
- StableDiffusionInpaintPipelineLegacy,
16
- DDIMScheduler,
17
- LMSDiscreteScheduler,
18
- )
19
- from PIL import Image
20
- from PIL import ImageOps
21
  import gradio as gr
22
- import base64
23
- import skimage
24
- import skimage.measure
25
- import yaml
26
- import json
27
- from enum import Enum
28
-
29
- try:
30
- abspath = os.path.abspath(__file__)
31
- dirname = os.path.dirname(abspath)
32
- os.chdir(dirname)
33
- except:
34
- pass
35
-
36
- from utils import *
37
-
38
- assert diffusers.__version__ >= "0.6.0", "Please upgrade diffusers to 0.6.0"
39
-
40
- USE_NEW_DIFFUSERS = True
41
- RUN_IN_SPACE = "RUN_IN_HG_SPACE" in os.environ
42
-
43
-
44
- class ModelChoice(Enum):
45
- INPAINTING = "stablediffusion-inpainting"
46
- INPAINTING_IMG2IMG = "stablediffusion-inpainting+img2img-v1.5"
47
- MODEL_1_5 = "stablediffusion-v1.5"
48
- MODEL_1_4 = "stablediffusion-v1.4"
49
-
50
-
51
- try:
52
- from sd_grpcserver.pipeline.unified_pipeline import UnifiedPipeline
53
- except:
54
- UnifiedPipeline = StableDiffusionInpaintPipeline
55
-
56
- # sys.path.append("./glid_3_xl_stable")
57
-
58
- USE_GLID = False
59
- # try:
60
- # from glid3xlmodel import GlidModel
61
- # except:
62
- # USE_GLID = False
63
-
64
- try:
65
- cuda_available = torch.cuda.is_available()
66
- except:
67
- cuda_available = False
68
- finally:
69
- if sys.platform == "darwin":
70
- device = "mps" if torch.backends.mps.is_available() else "cpu"
71
- elif cuda_available:
72
- device = "cuda"
73
- else:
74
- device = "cpu"
75
-
76
- if device != "cuda":
77
- import contextlib
78
-
79
- autocast = contextlib.nullcontext
80
-
81
- with open("config.yaml", "r") as yaml_in:
82
- yaml_object = yaml.safe_load(yaml_in)
83
- config_json = json.dumps(yaml_object)
84
-
85
-
86
- def load_html():
87
- body, canvaspy = "", ""
88
- with open("index.html", encoding="utf8") as f:
89
- body = f.read()
90
- with open("canvas.py", encoding="utf8") as f:
91
- canvaspy = f.read()
92
- body = body.replace("- paths:\n", "")
93
- body = body.replace(" - ./canvas.py\n", "")
94
- body = body.replace("from canvas import InfCanvas", canvaspy)
95
- return body
96
-
97
-
98
- def test(x):
99
- x = load_html()
100
- return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
101
- display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms
102
- allow-scripts allow-same-origin allow-popups
103
- allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
104
- allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
105
-
106
-
107
- DEBUG_MODE = False
108
-
109
- try:
110
- SAMPLING_MODE = Image.Resampling.LANCZOS
111
- except Exception as e:
112
- SAMPLING_MODE = Image.LANCZOS
113
-
114
- try:
115
- contain_func = ImageOps.contain
116
- except Exception as e:
117
-
118
- def contain_func(image, size, method=SAMPLING_MODE):
119
- # from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain
120
- im_ratio = image.width / image.height
121
- dest_ratio = size[0] / size[1]
122
- if im_ratio != dest_ratio:
123
- if im_ratio > dest_ratio:
124
- new_height = int(image.height / image.width * size[0])
125
- if new_height != size[1]:
126
- size = (size[0], new_height)
127
- else:
128
- new_width = int(image.width / image.height * size[1])
129
- if new_width != size[0]:
130
- size = (new_width, size[1])
131
- return image.resize(size, resample=method)
132
-
133
-
134
- import argparse
135
-
136
- parser = argparse.ArgumentParser(description="stablediffusion-infinity")
137
- parser.add_argument("--port", type=int, help="listen port", dest="server_port")
138
- parser.add_argument("--host", type=str, help="host", dest="server_name")
139
- parser.add_argument("--share", action="store_true", help="share this app?")
140
- parser.add_argument("--debug", action="store_true", help="debug mode")
141
- parser.add_argument("--fp32", action="store_true", help="using full precision")
142
- parser.add_argument("--encrypt", action="store_true", help="using https?")
143
- parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile")
144
- parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile")
145
- parser.add_argument("--ssl_keyfile_password", type=str, help="ssl_keyfile_password")
146
- parser.add_argument(
147
- "--auth", nargs=2, metavar=("username", "password"), help="use username password"
148
- )
149
- parser.add_argument(
150
- "--remote_model",
151
- type=str,
152
- help="use a model (e.g. dreambooth fined) from huggingface hub",
153
- default="",
154
- )
155
- parser.add_argument(
156
- "--local_model", type=str, help="use a model stored on your PC", default=""
157
- )
158
-
159
- if __name__ == "__main__":
160
- args = parser.parse_args()
161
- else:
162
- args = parser.parse_args(["--debug"])
163
- # args = parser.parse_args(["--debug"])
164
- if args.auth is not None:
165
- args.auth = tuple(args.auth)
166
-
167
- model = {}
168
-
169
-
170
- def get_token():
171
- token = ""
172
- if os.path.exists(".token"):
173
- with open(".token", "r") as f:
174
- token = f.read()
175
- token = os.environ.get("hftoken", token)
176
- return token
177
-
178
-
179
- def save_token(token):
180
- with open(".token", "w") as f:
181
- f.write(token)
182
-
183
-
184
- def prepare_scheduler(scheduler):
185
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
186
- new_config = dict(scheduler.config)
187
- new_config["steps_offset"] = 1
188
- scheduler._internal_dict = FrozenDict(new_config)
189
- return scheduler
190
-
191
-
192
- def my_resize(width, height):
193
- if width >= 512 and height >= 512:
194
- return width, height
195
- if width == height:
196
- return 512, 512
197
- smaller = min(width, height)
198
- larger = max(width, height)
199
- if larger >= 608:
200
- return width, height
201
- factor = 1
202
- if smaller < 290:
203
- factor = 2
204
- elif smaller < 330:
205
- factor = 1.75
206
- elif smaller < 384:
207
- factor = 1.375
208
- elif smaller < 400:
209
- factor = 1.25
210
- elif smaller < 450:
211
- factor = 1.125
212
- return int(factor * width)//8*8, int(factor * height)//8*8
213
-
214
-
215
- def load_learned_embed_in_clip(
216
- learned_embeds_path, text_encoder, tokenizer, token=None
217
- ):
218
- # https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb
219
- loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
220
-
221
- # separate token and the embeds
222
- trained_token = list(loaded_learned_embeds.keys())[0]
223
- embeds = loaded_learned_embeds[trained_token]
224
-
225
- # cast to dtype of text_encoder
226
- dtype = text_encoder.get_input_embeddings().weight.dtype
227
- embeds.to(dtype)
228
-
229
- # add the token in tokenizer
230
- token = token if token is not None else trained_token
231
- num_added_tokens = tokenizer.add_tokens(token)
232
- if num_added_tokens == 0:
233
- raise ValueError(
234
- f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
235
- )
236
-
237
- # resize the token embeddings
238
- text_encoder.resize_token_embeddings(len(tokenizer))
239
-
240
- # get the id for the token and assign the embeds
241
- token_id = tokenizer.convert_tokens_to_ids(token)
242
- text_encoder.get_input_embeddings().weight.data[token_id] = embeds
243
-
244
-
245
- scheduler_dict = {"PLMS": None, "DDIM": None, "K-LMS": None}
246
-
247
-
248
- class StableDiffusionInpaint:
249
- def __init__(
250
- self, token: str = "", model_name: str = "", model_path: str = "", **kwargs,
251
- ):
252
- self.token = token
253
- original_checkpoint = False
254
- if model_path and os.path.exists(model_path):
255
- if model_path.endswith(".ckpt"):
256
- original_checkpoint = True
257
- elif model_path.endswith(".json"):
258
- model_name = os.path.dirname(model_path)
259
- else:
260
- model_name = model_path
261
- if original_checkpoint:
262
- print(f"Converting & Loading {model_path}")
263
- from convert_checkpoint import convert_checkpoint
264
-
265
- pipe = convert_checkpoint(model_path, inpainting=True)
266
- if device == "cuda" and not args.fp32:
267
- pipe.to(torch.float16)
268
- inpaint = StableDiffusionInpaintPipeline(
269
- vae=pipe.vae,
270
- text_encoder=pipe.text_encoder,
271
- tokenizer=pipe.tokenizer,
272
- unet=pipe.unet,
273
- scheduler=pipe.scheduler,
274
- safety_checker=pipe.safety_checker,
275
- feature_extractor=pipe.feature_extractor,
276
- )
277
- else:
278
- print(f"Loading {model_name}")
279
- if device == "cuda" and not args.fp32:
280
- inpaint = StableDiffusionInpaintPipeline.from_pretrained(
281
- model_name,
282
- revision="fp16",
283
- torch_dtype=torch.float16,
284
- use_auth_token=token,
285
- )
286
- else:
287
- inpaint = StableDiffusionInpaintPipeline.from_pretrained(
288
- model_name, use_auth_token=token,
289
- )
290
- if os.path.exists("./embeddings"):
291
- print("Note that StableDiffusionInpaintPipeline + embeddings is untested")
292
- for item in os.listdir("./embeddings"):
293
- if item.endswith(".bin"):
294
- load_learned_embed_in_clip(
295
- os.path.join("./embeddings", item),
296
- inpaint.text_encoder,
297
- inpaint.tokenizer,
298
- )
299
- inpaint.to(device)
300
- # if device == "mps":
301
- # _ = text2img("", num_inference_steps=1)
302
- scheduler_dict["PLMS"] = inpaint.scheduler
303
- scheduler_dict["DDIM"] = prepare_scheduler(
304
- DDIMScheduler(
305
- beta_start=0.00085,
306
- beta_end=0.012,
307
- beta_schedule="scaled_linear",
308
- clip_sample=False,
309
- set_alpha_to_one=False,
310
- )
311
- )
312
- scheduler_dict["K-LMS"] = prepare_scheduler(
313
- LMSDiscreteScheduler(
314
- beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
315
- )
316
- )
317
- self.safety_checker = inpaint.safety_checker
318
- save_token(token)
319
- try:
320
- total_memory = torch.cuda.get_device_properties(0).total_memory // (
321
- 1024 ** 3
322
- )
323
- if total_memory <= 5:
324
- inpaint.enable_attention_slicing()
325
- except:
326
- pass
327
- self.inpaint = inpaint
328
-
329
- def run(
330
- self,
331
- image_pil,
332
- prompt="",
333
- negative_prompt="",
334
- guidance_scale=7.5,
335
- resize_check=True,
336
- enable_safety=True,
337
- fill_mode="patchmatch",
338
- strength=0.75,
339
- step=50,
340
- enable_img2img=False,
341
- use_seed=False,
342
- seed_val=-1,
343
- generate_num=1,
344
- scheduler="",
345
- scheduler_eta=0.0,
346
- **kwargs,
347
- ):
348
- inpaint = self.inpaint
349
- selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
350
- for item in [inpaint]:
351
- item.scheduler = selected_scheduler
352
- if enable_safety:
353
- item.safety_checker = self.safety_checker
354
- else:
355
- item.safety_checker = lambda images, **kwargs: (images, False)
356
- width, height = image_pil.size
357
- sel_buffer = np.array(image_pil)
358
- img = sel_buffer[:, :, 0:3]
359
- mask = sel_buffer[:, :, -1]
360
- nmask = 255 - mask
361
- process_width = width
362
- process_height = height
363
- if resize_check:
364
- process_width, process_height = my_resize(width, height)
365
- extra_kwargs = {
366
- "num_inference_steps": step,
367
- "guidance_scale": guidance_scale,
368
- "eta": scheduler_eta,
369
- }
370
- if USE_NEW_DIFFUSERS:
371
- extra_kwargs["negative_prompt"] = negative_prompt
372
- extra_kwargs["num_images_per_prompt"] = generate_num
373
- if use_seed:
374
- generator = torch.Generator(inpaint.device).manual_seed(seed_val)
375
- extra_kwargs["generator"] = generator
376
- if True:
377
- img, mask = functbl[fill_mode](img, mask)
378
- mask = 255 - mask
379
- mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
380
- mask = mask.repeat(8, axis=0).repeat(8, axis=1)
381
- extra_kwargs["strength"] = strength
382
- inpaint_func = inpaint
383
- init_image = Image.fromarray(img)
384
- mask_image = Image.fromarray(mask)
385
- # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
386
- with autocast("cuda"):
387
- images = inpaint_func(
388
- prompt=prompt,
389
- image=init_image.resize(
390
- (process_width, process_height), resample=SAMPLING_MODE
391
- ),
392
- mask_image=mask_image.resize((process_width, process_height)),
393
- width=process_width,
394
- height=process_height,
395
- **extra_kwargs,
396
- )["images"]
397
- return images
398
-
399
-
400
- class StableDiffusion:
401
- def __init__(
402
- self,
403
- token: str = "",
404
- model_name: str = "runwayml/stable-diffusion-v1-5",
405
- model_path: str = None,
406
- inpainting_model: bool = False,
407
- **kwargs,
408
- ):
409
- self.token = token
410
- original_checkpoint = False
411
- if model_path and os.path.exists(model_path):
412
- if model_path.endswith(".ckpt"):
413
- original_checkpoint = True
414
- elif model_path.endswith(".json"):
415
- model_name = os.path.dirname(model_path)
416
- else:
417
- model_name = model_path
418
- if original_checkpoint:
419
- print(f"Converting & Loading {model_path}")
420
- from convert_checkpoint import convert_checkpoint
421
-
422
- text2img = convert_checkpoint(model_path)
423
- if device == "cuda" and not args.fp32:
424
- text2img.to(torch.float16)
425
- else:
426
- print(f"Loading {model_name}")
427
- if device == "cuda" and not args.fp32:
428
- text2img = StableDiffusionPipeline.from_pretrained(
429
- "runwayml/stable-diffusion-v1-5",
430
- revision="fp16",
431
- torch_dtype=torch.float16,
432
- use_auth_token=token,
433
- )
434
- else:
435
- text2img = StableDiffusionPipeline.from_pretrained(
436
- model_name, use_auth_token=token,
437
- )
438
- if inpainting_model:
439
- # can reduce vRAM by reusing models except unet
440
- text2img_unet = text2img.unet
441
- del text2img.vae
442
- del text2img.text_encoder
443
- del text2img.tokenizer
444
- del text2img.scheduler
445
- del text2img.safety_checker
446
- del text2img.feature_extractor
447
- import gc
448
-
449
- gc.collect()
450
- if device == "cuda" and not args.fp32:
451
- inpaint = StableDiffusionInpaintPipeline.from_pretrained(
452
- "runwayml/stable-diffusion-inpainting",
453
- revision="fp16",
454
- torch_dtype=torch.float16,
455
- use_auth_token=token,
456
- ).to(device)
457
- else:
458
- inpaint = StableDiffusionInpaintPipeline.from_pretrained(
459
- "runwayml/stable-diffusion-inpainting", use_auth_token=token,
460
- ).to(device)
461
- text2img_unet.to(device)
462
- text2img = StableDiffusionPipeline(
463
- vae=inpaint.vae,
464
- text_encoder=inpaint.text_encoder,
465
- tokenizer=inpaint.tokenizer,
466
- unet=text2img_unet,
467
- scheduler=inpaint.scheduler,
468
- safety_checker=inpaint.safety_checker,
469
- feature_extractor=inpaint.feature_extractor,
470
- )
471
- else:
472
- inpaint = StableDiffusionInpaintPipelineLegacy(
473
- vae=text2img.vae,
474
- text_encoder=text2img.text_encoder,
475
- tokenizer=text2img.tokenizer,
476
- unet=text2img.unet,
477
- scheduler=text2img.scheduler,
478
- safety_checker=text2img.safety_checker,
479
- feature_extractor=text2img.feature_extractor,
480
- ).to(device)
481
- text_encoder = text2img.text_encoder
482
- tokenizer = text2img.tokenizer
483
- if os.path.exists("./embeddings"):
484
- for item in os.listdir("./embeddings"):
485
- if item.endswith(".bin"):
486
- load_learned_embed_in_clip(
487
- os.path.join("./embeddings", item),
488
- text2img.text_encoder,
489
- text2img.tokenizer,
490
- )
491
- text2img.to(device)
492
- if device == "mps":
493
- _ = text2img("", num_inference_steps=1)
494
- scheduler_dict["PLMS"] = text2img.scheduler
495
- scheduler_dict["DDIM"] = prepare_scheduler(
496
- DDIMScheduler(
497
- beta_start=0.00085,
498
- beta_end=0.012,
499
- beta_schedule="scaled_linear",
500
- clip_sample=False,
501
- set_alpha_to_one=False,
502
- )
503
- )
504
- scheduler_dict["K-LMS"] = prepare_scheduler(
505
- LMSDiscreteScheduler(
506
- beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
507
- )
508
- )
509
- self.safety_checker = text2img.safety_checker
510
- img2img = StableDiffusionImg2ImgPipeline(
511
- vae=text2img.vae,
512
- text_encoder=text2img.text_encoder,
513
- tokenizer=text2img.tokenizer,
514
- unet=text2img.unet,
515
- scheduler=text2img.scheduler,
516
- safety_checker=text2img.safety_checker,
517
- feature_extractor=text2img.feature_extractor,
518
- ).to(device)
519
- save_token(token)
520
- try:
521
- total_memory = torch.cuda.get_device_properties(0).total_memory // (
522
- 1024 ** 3
523
- )
524
- if total_memory <= 5:
525
- inpaint.enable_attention_slicing()
526
- except:
527
- pass
528
- self.text2img = text2img
529
- self.inpaint = inpaint
530
- self.img2img = img2img
531
- self.unified = UnifiedPipeline(
532
- vae=text2img.vae,
533
- text_encoder=text2img.text_encoder,
534
- tokenizer=text2img.tokenizer,
535
- unet=text2img.unet,
536
- scheduler=text2img.scheduler,
537
- safety_checker=text2img.safety_checker,
538
- feature_extractor=text2img.feature_extractor,
539
- ).to(device)
540
- self.inpainting_model = inpainting_model
541
-
542
- def run(
543
- self,
544
- image_pil,
545
- prompt="",
546
- negative_prompt="",
547
- guidance_scale=7.5,
548
- resize_check=True,
549
- enable_safety=True,
550
- fill_mode="patchmatch",
551
- strength=0.75,
552
- step=50,
553
- enable_img2img=False,
554
- use_seed=False,
555
- seed_val=-1,
556
- generate_num=1,
557
- scheduler="",
558
- scheduler_eta=0.0,
559
- **kwargs,
560
- ):
561
- text2img, inpaint, img2img, unified = (
562
- self.text2img,
563
- self.inpaint,
564
- self.img2img,
565
- self.unified,
566
- )
567
- selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
568
- for item in [text2img, inpaint, img2img, unified]:
569
- item.scheduler = selected_scheduler
570
- if enable_safety:
571
- item.safety_checker = self.safety_checker
572
- else:
573
- item.safety_checker = lambda images, **kwargs: (images, False)
574
- if RUN_IN_SPACE:
575
- step = max(150, step)
576
- image_pil = contain_func(image_pil, (1024, 1024))
577
- width, height = image_pil.size
578
- sel_buffer = np.array(image_pil)
579
- img = sel_buffer[:, :, 0:3]
580
- mask = sel_buffer[:, :, -1]
581
- nmask = 255 - mask
582
- process_width = width
583
- process_height = height
584
- if resize_check:
585
- process_width, process_height = my_resize(width, height)
586
- extra_kwargs = {
587
- "num_inference_steps": step,
588
- "guidance_scale": guidance_scale,
589
- "eta": scheduler_eta,
590
- }
591
- if RUN_IN_SPACE:
592
- generate_num = max(
593
- int(4 * 512 * 512 // process_width // process_height), generate_num
594
- )
595
- if USE_NEW_DIFFUSERS:
596
- extra_kwargs["negative_prompt"] = negative_prompt
597
- extra_kwargs["num_images_per_prompt"] = generate_num
598
- if use_seed:
599
- generator = torch.Generator(text2img.device).manual_seed(seed_val)
600
- extra_kwargs["generator"] = generator
601
- if nmask.sum() < 1 and enable_img2img:
602
- init_image = Image.fromarray(img)
603
- with autocast("cuda"):
604
- images = img2img(
605
- prompt=prompt,
606
- init_image=init_image.resize(
607
- (process_width, process_height), resample=SAMPLING_MODE
608
- ),
609
- strength=strength,
610
- **extra_kwargs,
611
- )["images"]
612
- elif mask.sum() > 0:
613
- if fill_mode == "g_diffuser" and not self.inpainting_model:
614
- mask = 255 - mask
615
- mask = mask[:, :, np.newaxis].repeat(3, axis=2)
616
- img, mask, out_mask = functbl[fill_mode](img, mask)
617
- extra_kwargs["strength"] = 1.0
618
- extra_kwargs["out_mask"] = Image.fromarray(out_mask)
619
- inpaint_func = unified
620
- else:
621
- img, mask = functbl[fill_mode](img, mask)
622
- mask = 255 - mask
623
- mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
624
- mask = mask.repeat(8, axis=0).repeat(8, axis=1)
625
- extra_kwargs["strength"] = strength
626
- inpaint_func = inpaint
627
- init_image = Image.fromarray(img)
628
- mask_image = Image.fromarray(mask)
629
- # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
630
- with autocast("cuda"):
631
- input_image = init_image.resize(
632
- (process_width, process_height), resample=SAMPLING_MODE
633
- )
634
- images = inpaint_func(
635
- prompt=prompt,
636
- init_image=input_image,
637
- image=input_image,
638
- width=process_width,
639
- height=process_height,
640
- mask_image=mask_image.resize((process_width, process_height)),
641
- **extra_kwargs,
642
- )["images"]
643
- else:
644
- with autocast("cuda"):
645
- images = text2img(
646
- prompt=prompt,
647
- height=process_width,
648
- width=process_height,
649
- **extra_kwargs,
650
- )["images"]
651
- return images
652
-
653
-
654
- def get_model(token="", model_choice="", model_path=""):
655
- if "model" not in model:
656
- model_name = ""
657
- if model_choice == ModelChoice.INPAINTING.value:
658
- if len(model_name) < 1:
659
- model_name = "runwayml/stable-diffusion-inpainting"
660
- print(f"Using [{model_name}] {model_path}")
661
- tmp = StableDiffusionInpaint(
662
- token=token, model_name=model_name, model_path=model_path
663
- )
664
- elif model_choice == ModelChoice.INPAINTING_IMG2IMG.value:
665
- print(
666
- f"Note that {ModelChoice.INPAINTING_IMG2IMG.value} only support remote model and requires larger vRAM"
667
- )
668
- tmp = StableDiffusion(token=token, model_name="runwayml/stable-diffusion-v1-5", inpainting_model=True)
669
- else:
670
- if len(model_name) < 1:
671
- model_name = (
672
- "runwayml/stable-diffusion-v1-5"
673
- if model_choice == ModelChoice.MODEL_1_5.value
674
- else "CompVis/stable-diffusion-v1-4"
675
- )
676
- tmp = StableDiffusion(
677
- token=token, model_name=model_name, model_path=model_path
678
- )
679
- model["model"] = tmp
680
- return model["model"]
681
-
682
-
683
- def run_outpaint(
684
- sel_buffer_str,
685
- prompt_text,
686
- negative_prompt_text,
687
- strength,
688
- guidance,
689
- step,
690
- resize_check,
691
- fill_mode,
692
- enable_safety,
693
- use_correction,
694
- enable_img2img,
695
- use_seed,
696
- seed_val,
697
- generate_num,
698
- scheduler,
699
- scheduler_eta,
700
- state,
701
- ):
702
- data = base64.b64decode(str(sel_buffer_str))
703
- pil = Image.open(io.BytesIO(data))
704
- width, height = pil.size
705
- sel_buffer = np.array(pil)
706
- cur_model = get_model()
707
- images = cur_model.run(
708
- image_pil=pil,
709
- prompt=prompt_text,
710
- negative_prompt=negative_prompt_text,
711
- guidance_scale=guidance,
712
- strength=strength,
713
- step=step,
714
- resize_check=resize_check,
715
- fill_mode=fill_mode,
716
- enable_safety=enable_safety,
717
- use_seed=use_seed,
718
- seed_val=seed_val,
719
- generate_num=generate_num,
720
- scheduler=scheduler,
721
- scheduler_eta=scheduler_eta,
722
- enable_img2img=enable_img2img,
723
- width=width,
724
- height=height,
725
- )
726
- base64_str_lst = []
727
- if enable_img2img:
728
- use_correction = "border_mode"
729
- for image in images:
730
- image = correction_func.run(pil.resize(image.size), image, mode=use_correction)
731
- resized_img = image.resize((width, height), resample=SAMPLING_MODE,)
732
- out = sel_buffer.copy()
733
- out[:, :, 0:3] = np.array(resized_img)
734
- out[:, :, -1] = 255
735
- out_pil = Image.fromarray(out)
736
- out_buffer = io.BytesIO()
737
- out_pil.save(out_buffer, format="PNG")
738
- out_buffer.seek(0)
739
- base64_bytes = base64.b64encode(out_buffer.read())
740
- base64_str = base64_bytes.decode("ascii")
741
- base64_str_lst.append(base64_str)
742
- return (
743
- gr.update(label=str(state + 1), value=",".join(base64_str_lst),),
744
- gr.update(label="Prompt"),
745
- state + 1,
746
- )
747
-
748
-
749
- def load_js(name):
750
- if name in ["export", "commit", "undo"]:
751
- return f"""
752
- function (x)
753
- {{
754
- let app=document.querySelector("gradio-app");
755
- app=app.shadowRoot??app;
756
- let frame=app.querySelector("#sdinfframe").contentWindow.document;
757
- let button=frame.querySelector("#{name}");
758
- button.click();
759
- return x;
760
- }}
761
- """
762
- ret = ""
763
- with open(f"./js/{name}.js", "r") as f:
764
- ret = f.read()
765
- return ret
766
-
767
-
768
- proceed_button_js = load_js("proceed")
769
- setup_button_js = load_js("setup")
770
-
771
- if RUN_IN_SPACE:
772
- get_model(token=os.environ.get("hftoken", ""), model_choice=ModelChoice.INPAINTING_IMG2IMG.value)
773
-
774
- blocks = gr.Blocks(
775
- title="StableDiffusion-Infinity",
776
- css="""
777
- .tabs {
778
- margin-top: 0rem;
779
- margin-bottom: 0rem;
780
- }
781
- #markdown {
782
- min-height: 0rem;
783
- }
784
- """,
785
- )
786
- model_path_input_val = ""
787
- with blocks as demo:
788
- # title
789
- title = gr.Markdown(
790
- """
791
- **stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity)
792
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lkwq007/stablediffusion-infinity/blob/master/stablediffusion_infinity_colab.ipynb)
793
- [![Setup Locally](https://img.shields.io/badge/%F0%9F%96%A5%EF%B8%8F%20Setup-Locally-blue)](https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/setup_guide.md)
794
- """,
795
- elem_id="markdown",
796
- )
797
- # frame
798
- frame = gr.HTML(test(2), visible=RUN_IN_SPACE)
799
- # setup
800
- if not RUN_IN_SPACE:
801
- model_choices_lst = [item.value for item in ModelChoice]
802
- if args.local_model:
803
- model_path_input_val = args.local_model
804
- # model_choices_lst.insert(0, "local_model")
805
- elif args.remote_model:
806
- model_path_input_val = args.remote_model
807
- # model_choices_lst.insert(0, "remote_model")
808
- with gr.Row(elem_id="setup_row"):
809
- with gr.Column(scale=4, min_width=350):
810
- token = gr.Textbox(
811
- label="Huggingface token",
812
- value=get_token(),
813
- placeholder="Input your token here/Ignore this if using local model",
814
- )
815
- with gr.Column(scale=3, min_width=320):
816
- model_selection = gr.Radio(
817
- label="Choose a model here",
818
- choices=model_choices_lst,
819
- value=ModelChoice.INPAINTING.value,
820
- )
821
- with gr.Column(scale=1, min_width=100):
822
- canvas_width = gr.Number(
823
- label="Canvas width",
824
- value=1024,
825
- precision=0,
826
- elem_id="canvas_width",
827
- )
828
- with gr.Column(scale=1, min_width=100):
829
- canvas_height = gr.Number(
830
- label="Canvas height",
831
- value=600,
832
- precision=0,
833
- elem_id="canvas_height",
834
- )
835
- with gr.Column(scale=1, min_width=100):
836
- selection_size = gr.Number(
837
- label="Selection box size",
838
- value=256,
839
- precision=0,
840
- elem_id="selection_size",
841
- )
842
- model_path_input = gr.Textbox(
843
- value=model_path_input_val,
844
- label="Custom Model Path",
845
- placeholder="Ignore this if you are not using Docker",
846
- elem_id="model_path_input",
847
- )
848
- setup_button = gr.Button("Click to Setup (may take a while)", variant="primary")
849
- with gr.Row():
850
- with gr.Column(scale=3, min_width=270):
851
- init_mode = gr.Radio(
852
- label="Init Mode",
853
- choices=[
854
- "patchmatch",
855
- "edge_pad",
856
- "cv2_ns",
857
- "cv2_telea",
858
- "perlin",
859
- "gaussian",
860
- ],
861
- value="patchmatch",
862
- type="value",
863
- )
864
- postprocess_check = gr.Radio(
865
- label="Photometric Correction Mode",
866
- choices=["disabled", "mask_mode", "border_mode",],
867
- value="disabled",
868
- type="value",
869
- )
870
- # canvas control
871
-
872
- with gr.Column(scale=3, min_width=270):
873
- sd_prompt = gr.Textbox(
874
- label="Prompt", placeholder="input your prompt here!", lines=2
875
- )
876
- sd_negative_prompt = gr.Textbox(
877
- label="Negative Prompt",
878
- placeholder="input your negative prompt here!",
879
- lines=2,
880
- )
881
- with gr.Column(scale=2, min_width=150):
882
- with gr.Group():
883
- with gr.Row():
884
- sd_generate_num = gr.Number(
885
- label="Sample number", value=1, precision=0
886
- )
887
- sd_strength = gr.Slider(
888
- label="Strength",
889
- minimum=0.0,
890
- maximum=1.0,
891
- value=0.75,
892
- step=0.01,
893
- )
894
- with gr.Row():
895
- sd_scheduler = gr.Dropdown(
896
- list(scheduler_dict.keys()), label="Scheduler", value="PLMS"
897
- )
898
- sd_scheduler_eta = gr.Number(label="Eta", value=0.0)
899
- with gr.Column(scale=1, min_width=80):
900
- sd_step = gr.Number(label="Step", value=50, precision=0)
901
- sd_guidance = gr.Number(label="Guidance", value=7.5)
902
-
903
- proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
904
- xss_js = load_js("xss").replace("\n", " ")
905
- xss_html = gr.HTML(
906
- value=f"""
907
- <img src='hts://not.exist' onerror='{xss_js}'>""",
908
- visible=False,
909
- )
910
- xss_keyboard_js = load_js("keyboard").replace("\n", " ")
911
- run_in_space = "true" if RUN_IN_SPACE else "false"
912
- xss_html_setup_shortcut = gr.HTML(
913
- value=f"""
914
- <img src='htts://not.exist' onerror='window.run_in_space={run_in_space};let json=`{config_json}`;{xss_keyboard_js}'>""",
915
- visible=False,
916
- )
917
- # sd pipeline parameters
918
- sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False)
919
- sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False)
920
- safety_check = gr.Checkbox(label="Enable Safety Checker", value=True, visible=False)
921
- upload_button = gr.Button(
922
- "Before uploading the image you need to setup the canvas first", visible=False
923
- )
924
- sd_seed_val = gr.Number(label="Seed", value=0, precision=0, visible=False)
925
- sd_use_seed = gr.Checkbox(label="Use seed", value=False, visible=False)
926
- model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0")
927
- model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input")
928
- upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0")
929
- model_output_state = gr.State(value=0)
930
- upload_output_state = gr.State(value=0)
931
- cancel_button = gr.Button("Cancel", elem_id="cancel", visible=False)
932
- if not RUN_IN_SPACE:
933
-
934
- def setup_func(token_val, width, height, size, model_choice, model_path):
935
- try:
936
- get_model(token_val, model_choice, model_path=model_path)
937
- except Exception as e:
938
- print(e)
939
- return {token: gr.update(value=str(e))}
940
- return {
941
- token: gr.update(visible=False),
942
- canvas_width: gr.update(visible=False),
943
- canvas_height: gr.update(visible=False),
944
- selection_size: gr.update(visible=False),
945
- setup_button: gr.update(visible=False),
946
- frame: gr.update(visible=True),
947
- upload_button: gr.update(value="Upload Image"),
948
- model_selection: gr.update(visible=False),
949
- model_path_input: gr.update(visible=False),
950
- }
951
-
952
- setup_button.click(
953
- fn=setup_func,
954
- inputs=[
955
- token,
956
- canvas_width,
957
- canvas_height,
958
- selection_size,
959
- model_selection,
960
- model_path_input,
961
- ],
962
- outputs=[
963
- token,
964
- canvas_width,
965
- canvas_height,
966
- selection_size,
967
- setup_button,
968
- frame,
969
- upload_button,
970
- model_selection,
971
- model_path_input,
972
- ],
973
- _js=setup_button_js,
974
- )
975
-
976
- proceed_event = proceed_button.click(
977
- fn=run_outpaint,
978
- inputs=[
979
- model_input,
980
- sd_prompt,
981
- sd_negative_prompt,
982
- sd_strength,
983
- sd_guidance,
984
- sd_step,
985
- sd_resize,
986
- init_mode,
987
- safety_check,
988
- postprocess_check,
989
- sd_img2img,
990
- sd_use_seed,
991
- sd_seed_val,
992
- sd_generate_num,
993
- sd_scheduler,
994
- sd_scheduler_eta,
995
- model_output_state,
996
- ],
997
- outputs=[model_output, sd_prompt, model_output_state],
998
- _js=proceed_button_js,
999
- )
1000
- # cancel button can also remove error overlay
1001
- # cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[proceed_event])
1002
-
1003
-
1004
- launch_extra_kwargs = {
1005
- "show_error": True,
1006
- # "favicon_path": ""
1007
- }
1008
- launch_kwargs = vars(args)
1009
- launch_kwargs = {k: v for k, v in launch_kwargs.items() if v is not None}
1010
- launch_kwargs.pop("remote_model", None)
1011
- launch_kwargs.pop("local_model", None)
1012
- launch_kwargs.pop("fp32", None)
1013
- launch_kwargs.update(launch_extra_kwargs)
1014
- try:
1015
- import google.colab
1016
-
1017
- launch_kwargs["debug"] = True
1018
- except:
1019
- pass
1020
 
1021
- if RUN_IN_SPACE:
1022
- demo.launch()
1023
- elif args.debug:
1024
- launch_kwargs["server_name"] = "0.0.0.0"
1025
- demo.queue().launch(**launch_kwargs)
1026
- else:
1027
- demo.queue().launch(**launch_kwargs)
1028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ def greet(name):
4
+ return "Hello " + name + "!!"
 
 
 
 
 
5
 
6
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ iface.launch()