File size: 4,806 Bytes
76c920d
 
cd4ffa3
829ed2e
 
 
 
cd4ffa3
 
f2943ab
76c920d
829ed2e
 
 
 
79b99e0
c8e54f6
 
 
cd4ffa3
829ed2e
 
f2943ab
829ed2e
 
 
c8e54f6
10af47d
7c3d7a0
c8e54f6
e674f0f
829ed2e
c8e54f6
e674f0f
c8e54f6
 
 
829ed2e
5079795
c8e54f6
cd4ffa3
12d035e
4bcc00f
829ed2e
 
 
 
 
 
 
 
 
cd4ffa3
 
829ed2e
 
 
 
c8e54f6
829ed2e
cd4ffa3
a8bc6ca
c8e54f6
f2943ab
e34fa5d
5079795
829ed2e
5079795
e674f0f
829ed2e
e674f0f
 
5079795
 
c8e54f6
 
a5540bb
 
c8e54f6
e674f0f
829ed2e
f2943ab
 
e674f0f
 
 
829ed2e
 
f2943ab
a09c55b
c8e54f6
829ed2e
c8e54f6
 
 
829ed2e
5533283
 
c8e54f6
76c920d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr

import torch
from PIL import Image, ImageFilter, ImageEnhance, ImageDraw
from diffusers import LCMScheduler, StableDiffusionInpaintPipeline
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation

from tqdm import tqdm
import numpy as np
from datetime import datetime

# ideally:
# preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
# but segformer does not work on mps lolololol
preferred_device = "cuda" if torch.cuda.is_available() else "cpu"
preferred_dtype = torch.float16 if preferred_device == 'cuda' else torch.float32
inpaint_preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
preferred_backend = "aot_eager" if inpaint_preferred_device == "mps" else ("tensorrt" if inpaint_preferred_device == "cuda" else "inductor")

seg_model_img_size = 768
seg_model_size = 0

seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}")
seg_model = SegformerForSemanticSegmentation.from_pretrained(
    f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}"
).to(preferred_device).to(preferred_dtype)

inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
    "SimianLuo/LCM_Dreamshaper_v7",
    torch_dtype=preferred_dtype,
    safety_checker=None,
).to(inpaint_preferred_device)

inpainting_pipeline.unet = torch.compile(inpainting_pipeline.unet, backend=preferred_backend)
inpainting_pipeline.vae = torch.compile(inpainting_pipeline.vae, backend=preferred_backend)
seg_model = torch.compile(seg_model, backend=preferred_backend)

seg_working_size = (seg_model_img_size, seg_model_img_size)
repaint_working_size = (768, 768)

default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing"

seg_vocabulary = seg_model.config.label2id
print(f"vocab: {seg_vocabulary}")

ban_cars_mask = [0] * len(seg_vocabulary)
banned_classes = ["car", "road", "sidewalk", "traffic light", "traffic sign"]
for c in banned_classes:
    ban_cars_mask[seg_vocabulary[c]] = 1
ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8)


def get_seg_mask(img):
    inputs = seg_feature_extractor(images=img, return_tensors="pt").to(preferred_device)
    outputs = seg_model(**inputs)
    logits = outputs.logits[0]
    mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255)
    blurred_widened_mask = ImageEnhance.Contrast(mask.filter(ImageFilter.GaussianBlur(2))).enhance(9000)
    return blurred_widened_mask


def app(img, prompt, num_inference_steps, seed):
    start_time = datetime.now().timestamp()
    old_size = Image.fromarray(img).size
    img = np.array(Image.fromarray(img).resize(seg_working_size))
    mask = get_seg_mask(img)
    # mask.save("mask.jpg")
    mask_time = datetime.now().timestamp()
    #print(prompt.__class__, img.__class__, mask.__class__, img.shape, mask.shape, mask.dtype, img.dtype)
    overlay_img = inpainting_pipeline(
        prompt=prompt,
        image=Image.fromarray(img).resize(repaint_working_size),
        mask_image=(mask).resize(repaint_working_size),
        strength=1,
        num_inference_steps=num_inference_steps,
        height=repaint_working_size[0],
        width=repaint_working_size[1],
        generator=torch.manual_seed(int(seed)),
    ).images[0]
    #overlay_img.save("overlay_raw.jpg")
    end_time = datetime.now().timestamp()
    draw = ImageDraw.Draw(overlay_img)
    # replace spaces with newlines after many words to line break prompt
    prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))])

    draw.text((10, 50), f"Old size: {old_size}\nTotal duration: {int(1000 * (end_time - start_time))}ms\nSegmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))} \n<{prompt}>", fill=(0, 255, 0))
    #overlay_img.save("overlay_with_text.jpg")
    return overlay_img

# warmup, for compiling and then for timing

for i in range(2):
    for j in tqdm(range(3 ** i)):
        app(np.array(Image.fromarray(np.zeros((1024,1024,3), dtype=np.uint8))), default_inpainting_prompt, 4, 42).save("zeros_inpainting_oneshot.jpg")

#ideally:
#iface = gr.Interface(app, gr.Image(sources=["webcam"], streaming=True), "image", live=True)
iface = gr.Interface(app, [gr.Image(), gr.Textbox(value=default_inpainting_prompt), gr.Number(minimum=1, maximum=8, value=4), gr.Number(value=42)], "image")
iface.launch()