File size: 4,943 Bytes
76c920d
 
cd4ffa3
829ed2e
 
 
 
cd4ffa3
 
f2943ab
76c920d
829ed2e
 
 
 
79b99e0
c8e54f6
 
f84aa8c
 
 
aef61e4
cd4ffa3
829ed2e
 
f2943ab
829ed2e
 
 
4380df2
10af47d
7c3d7a0
c8e54f6
e674f0f
829ed2e
c8e54f6
e674f0f
0025e75
 
 
829ed2e
5079795
cd4ffa3
12d035e
4bcc00f
829ed2e
 
 
 
 
 
 
 
 
cd4ffa3
 
829ed2e
 
 
 
c8e54f6
829ed2e
cd4ffa3
a8bc6ca
34895c9
f2943ab
e34fa5d
5079795
829ed2e
5079795
e674f0f
829ed2e
e674f0f
 
c734fae
 
c8e54f6
 
34895c9
 
c8e54f6
e674f0f
829ed2e
f2943ab
 
e674f0f
 
 
829ed2e
 
f2943ab
a09c55b
c8e54f6
829ed2e
c8e54f6
 
34895c9
829ed2e
5533283
 
34895c9
 
 
 
 
 
 
 
76c920d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr

import torch
from PIL import Image, ImageFilter, ImageEnhance, ImageDraw
from diffusers import LCMScheduler, StableDiffusionInpaintPipeline
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation

from tqdm import tqdm
import numpy as np
from datetime import datetime

# ideally:
# preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
# but segformer does not work on mps lolololol
preferred_device = "cuda" if torch.cuda.is_available() else "cpu"
preferred_dtype = torch.float16 if preferred_device == 'cuda' else torch.float32
inpaint_preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True

print(f"backends: {torch._dynamo.list_backends()}")

preferred_backend = "aot_eager" if inpaint_preferred_device == "mps" else ("onnxrt" if inpaint_preferred_device == "cuda" else "inductor")

seg_model_img_size = 768
seg_model_size = 0

seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}")
seg_model = SegformerForSemanticSegmentation.from_pretrained(
    f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}"
).to(preferred_device) #.to(preferred_dtype)

inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
    "SimianLuo/LCM_Dreamshaper_v7",
    torch_dtype=preferred_dtype,
    safety_checker=None,
).to(inpaint_preferred_device)

#inpainting_pipeline.unet = torch.compile(inpainting_pipeline.unet, backend=preferred_backend)
#inpainting_pipeline.vae = torch.compile(inpainting_pipeline.vae, backend=preferred_backend)
#seg_model = torch.compile(seg_model, backend=preferred_backend)

seg_working_size = (seg_model_img_size, seg_model_img_size)

default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing"

seg_vocabulary = seg_model.config.label2id
print(f"vocab: {seg_vocabulary}")

ban_cars_mask = [0] * len(seg_vocabulary)
banned_classes = ["car", "road", "sidewalk", "traffic light", "traffic sign"]
for c in banned_classes:
    ban_cars_mask[seg_vocabulary[c]] = 1
ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8)


def get_seg_mask(img):
    inputs = seg_feature_extractor(images=img, return_tensors="pt").to(preferred_device)
    outputs = seg_model(**inputs)
    logits = outputs.logits[0]
    mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255)
    blurred_widened_mask = ImageEnhance.Contrast(mask.filter(ImageFilter.GaussianBlur(2))).enhance(9000)
    return blurred_widened_mask


def app(img, prompt, num_inference_steps, seed, inpaint_size):
    start_time = datetime.now().timestamp()
    old_size = Image.fromarray(img).size
    img = np.array(Image.fromarray(img).resize(seg_working_size))
    mask = get_seg_mask(img)
    # mask.save("mask.jpg")
    mask_time = datetime.now().timestamp()
    #print(prompt.__class__, img.__class__, mask.__class__, img.shape, mask.shape, mask.dtype, img.dtype)
    overlay_img = inpainting_pipeline(
        prompt=prompt,
        image=Image.fromarray(img).resize((inpaint_size, inpaint_size)),
        mask_image=(mask).resize((inpaint_size, inpaint_size)),
        strength=1,
        num_inference_steps=num_inference_steps,
        height=inpaint_size,
        width=inpaint_size,
        generator=torch.manual_seed(int(seed)),
    ).images[0]
    #overlay_img.save("overlay_raw.jpg")
    end_time = datetime.now().timestamp()
    draw = ImageDraw.Draw(overlay_img)
    # replace spaces with newlines after many words to line break prompt
    prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))])

    draw.text((10, 50), f"Old size: {old_size}\nTotal duration: {int(1000 * (end_time - start_time))}ms\nSegmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))} \n<{prompt}>", fill=(0, 255, 0))
    #overlay_img.save("overlay_with_text.jpg")
    return overlay_img

# warmup, for compiling and then for timing

for i in range(2):
    for j in tqdm(range(3 ** i)):
        app(np.array(Image.fromarray(np.zeros((1024,1024,3), dtype=np.uint8))), default_inpainting_prompt, 4, 42, 512).save("zeros_inpainting_oneshot.jpg")

#ideally:
#iface = gr.Interface(app, gr.Image(sources=["webcam"], streaming=True), "image", live=True)
iface = gr.Interface(app, [
        gr.Image(),
        gr.Textbox(value=default_inpainting_prompt),
        gr.Number(minimum=1, maximum=8, value=4),
        gr.Number(value=42),
        gr.Number(value=512, maximum=seg_model_img_size,)
    ],
    "image")
iface.launch()