File size: 5,264 Bytes
76c920d
 
cd4ffa3
829ed2e
 
 
 
cd4ffa3
 
f2943ab
76c920d
829ed2e
 
 
804539d
2f284ac
c8e54f6
804539d
c8e54f6
f84aa8c
 
 
804539d
cd4ffa3
829ed2e
 
f2943ab
829ed2e
 
 
804539d
10af47d
7c3d7a0
c8e54f6
804539d
829ed2e
c8e54f6
e674f0f
804539d
 
 
 
829ed2e
5079795
cd4ffa3
12d035e
4bcc00f
829ed2e
 
 
 
 
 
 
 
 
cd4ffa3
 
804539d
829ed2e
 
 
c8e54f6
829ed2e
cd4ffa3
a8bc6ca
34895c9
f2943ab
5079795
829ed2e
5079795
e674f0f
829ed2e
e674f0f
 
c734fae
 
c8e54f6
 
34895c9
 
c8e54f6
e674f0f
829ed2e
f2943ab
 
e674f0f
 
 
be54e00
 
 
 
 
 
829ed2e
f2943ab
a09c55b
c8e54f6
829ed2e
be54e00
 
 
 
829ed2e
5533283
 
34895c9
 
 
 
 
 
 
 
be54e00
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr

import torch
from PIL import Image, ImageFilter, ImageEnhance, ImageDraw
from diffusers import LCMScheduler, StableDiffusionInpaintPipeline
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation

from tqdm import tqdm
import numpy as np
from datetime import datetime

# ideally:
# preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
# but segformer does not work on mps lolololol
seg_preferred_device = "cuda" if torch.cuda.is_available() else "cpu"
seg_preferred_dtype = torch.float32 # torch.float16 if seg_preferred_device == 'cuda' else torch.float32
inpaint_preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
inpaint_preferred_dtype = torch.float32 if inpaint_preferred_device == 'cpu' else torch.float16
torch.backends.cuda.matmul.allow_tf32 = True

print(f"backends: {torch._dynamo.list_backends()}")

preferred_backend = "aot_eager" if inpaint_preferred_device == "mps" else "inductor"

seg_model_img_size = 768
seg_model_size = 0

seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}")
seg_model = SegformerForSemanticSegmentation.from_pretrained(
    f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}"
).to(seg_preferred_device).to(seg_preferred_dtype)

inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
    "SimianLuo/LCM_Dreamshaper_v7",
    torch_dtype=inpaint_preferred_dtype,
    safety_checker=None,
).to(inpaint_preferred_device)

inpainting_pipeline.text_encoder = torch.compile(inpainting_pipeline.text_encoder, backend=preferred_backend)
inpainting_pipeline.unet = torch.compile(inpainting_pipeline.unet, backend=preferred_backend)
inpainting_pipeline.vae = torch.compile(inpainting_pipeline.vae, backend=preferred_backend)
seg_model = torch.compile(seg_model, backend=preferred_backend)

seg_working_size = (seg_model_img_size, seg_model_img_size)

default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing"

seg_vocabulary = seg_model.config.label2id
print(f"vocab: {seg_vocabulary}")

ban_cars_mask = [0] * len(seg_vocabulary)
banned_classes = ["car", "road", "sidewalk", "traffic light", "traffic sign"]
for c in banned_classes:
    ban_cars_mask[seg_vocabulary[c]] = 1
ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8)


def get_seg_mask(img):
    inputs = seg_feature_extractor(images=img, return_tensors="pt").to(seg_preferred_device).to(seg_preferred_dtype)
    outputs = seg_model(**inputs)
    logits = outputs.logits[0]
    mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255)
    blurred_widened_mask = ImageEnhance.Contrast(mask.filter(ImageFilter.GaussianBlur(2))).enhance(9000)
    return blurred_widened_mask


def app(img, prompt, num_inference_steps, seed, inpaint_size):
    start_time = datetime.now().timestamp()
    img = np.array(Image.fromarray(img).resize(seg_working_size))
    mask = get_seg_mask(img)
    # mask.save("mask.jpg")
    mask_time = datetime.now().timestamp()
    #print(prompt.__class__, img.__class__, mask.__class__, img.shape, mask.shape, mask.dtype, img.dtype)
    overlay_img = inpainting_pipeline(
        prompt=prompt,
        image=Image.fromarray(img).resize((inpaint_size, inpaint_size)),
        mask_image=(mask).resize((inpaint_size, inpaint_size)),
        strength=1,
        num_inference_steps=num_inference_steps,
        height=inpaint_size,
        width=inpaint_size,
        generator=torch.manual_seed(int(seed)),
    ).images[0]
    #overlay_img.save("overlay_raw.jpg")
    end_time = datetime.now().timestamp()
    draw = ImageDraw.Draw(overlay_img)
    # replace spaces with newlines after many words to line break prompt
    prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))])

    draw.text((10, 50), "\n".join([
            f"Total duration: {int(1000 * (end_time - start_time))}ms",
            f"Inference steps: {num_inference_steps}",
            f"Segmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))}",
            f"<{prompt}>"
        ]), fill=(0, 255, 0))
    #overlay_img.save("overlay_with_text.jpg")
    return overlay_img

# warmup, for compiling and then for timing

for size in [384,512]:
    for i in range(2):
        for j in tqdm(range(3 ** i)):
            app(np.array(Image.fromarray(np.zeros((1024,1024,3), dtype=np.uint8))), default_inpainting_prompt, 4, 42, size).save("zeros_inpainting_oneshot.jpg")

#ideally:
#iface = gr.Interface(app, gr.Image(sources=["webcam"], streaming=True), "image", live=True)
iface = gr.Interface(app, [
        gr.Image(),
        gr.Textbox(value=default_inpainting_prompt),
        gr.Number(minimum=1, maximum=8, value=4),
        gr.Number(value=42),
        gr.Number(value=512, maximum=seg_model_img_size,)
    ],
    "image")
iface.launch(share=True)