File size: 4,643 Bytes
76c920d
 
cd4ffa3
829ed2e
 
 
 
cd4ffa3
 
f2943ab
76c920d
829ed2e
 
 
 
79b99e0
cd4ffa3
829ed2e
 
f2943ab
829ed2e
 
 
79b99e0
10af47d
7c3d7a0
b80c840
4cc3150
e674f0f
829ed2e
e674f0f
 
926718d
 
 
 
 
 
 
 
dffcbe0
46253a6
 
 
829ed2e
5079795
926718d
cd4ffa3
12d035e
4bcc00f
829ed2e
 
 
 
 
 
 
 
 
cd4ffa3
 
829ed2e
 
 
 
 
 
cd4ffa3
a8bc6ca
4bcc00f
f2943ab
e34fa5d
5079795
829ed2e
5079795
e674f0f
829ed2e
e674f0f
 
5079795
 
e674f0f
46253a6
a5540bb
 
e674f0f
829ed2e
f2943ab
 
e674f0f
 
 
829ed2e
 
f2943ab
a09c55b
829ed2e
 
 
 
 
5533283
 
4bcc00f
76c920d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr

import torch
from PIL import Image, ImageFilter, ImageEnhance, ImageDraw
from diffusers import LCMScheduler, StableDiffusionInpaintPipeline
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation

from tqdm import tqdm
import numpy as np
from datetime import datetime

# ideally:
# preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
# but segformer does not work on mps lolololol
preferred_device = "cuda" if torch.cuda.is_available() else "cpu"
preferred_dtype = torch.float16 if preferred_device == 'cuda' else torch.float32

seg_model_img_size = 768
seg_model_size = 0

seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}")
seg_model = SegformerForSemanticSegmentation.from_pretrained(
    f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}"
).to(preferred_device) #.to(preferred_dtype)

inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
    "runwayml/stable-diffusion-inpainting",
    variant="fp16",
    torch_dtype=preferred_dtype,
    safety_checker=None,
).to(preferred_device)

from DeepCache import DeepCacheSDHelper
helper = DeepCacheSDHelper(pipe=inpainting_pipeline)
helper.set_params(cache_interval=3, cache_branch_id=0)
helper.enable()

if preferred_device == "cuda":
    inpainting_pipeline.unet = torch.compile(inpainting_pipeline.unet)
    inpainting_pipeline.vae = torch.compile(inpainting_pipeline.vae)

# inpainting_pipeline.scheduler = LCMScheduler.from_config(inpainting_pipeline.scheduler.config)
# inpainting_pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5", torch_dtype=preferred_dtype)
# inpainting_pipeline.fuse_lora()

seg_working_size = (seg_model_img_size, seg_model_img_size)
repaint_working_size = (512, 512)

default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing"

seg_vocabulary = seg_model.config.label2id
print(f"vocab: {seg_vocabulary}")

ban_cars_mask = [0] * len(seg_vocabulary)
banned_classes = ["car", "road", "sidewalk", "traffic light", "traffic sign"]
for c in banned_classes:
    ban_cars_mask[seg_vocabulary[c]] = 1
ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8)


def get_seg_mask(img):
    inputs = seg_feature_extractor(images=img, return_tensors="pt").to(preferred_device)
    outputs = seg_model(**inputs)
    logits = outputs.logits[0]
    mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255)
    blurred_widened_mask = ImageEnhance.Contrast(mask.filter(ImageFilter.GaussianBlur(5))).enhance(9000)
    return blurred_widened_mask


def app(img, prompt):
    start_time = datetime.now().timestamp()
    old_size = Image.fromarray(img).size
    img = np.array(Image.fromarray(img).resize(seg_working_size))
    mask = get_seg_mask(img)
    # mask.save("mask.jpg")
    mask_time = datetime.now().timestamp()
    #print(prompt.__class__, img.__class__, mask.__class__, img.shape, mask.shape, mask.dtype, img.dtype)
    overlay_img = inpainting_pipeline(
        prompt=prompt,
        image=Image.fromarray(img).resize(repaint_working_size),
        mask_image=(mask).resize(repaint_working_size),
        strength=0.95,
        num_inference_steps=16,
        height=repaint_working_size[0],
        width=repaint_working_size[1],
    ).images[0]
    #overlay_img.save("overlay_raw.jpg")
    end_time = datetime.now().timestamp()
    draw = ImageDraw.Draw(overlay_img)
    # replace spaces with newlines after many words to line break prompt
    prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))])

    draw.text((10, 50), f"Old size: {old_size}\nTotal duration: {int(1000 * (end_time - start_time))}ms\nSegmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))} \n<{prompt}>", fill=(0, 255, 0))
    #overlay_img.save("overlay_with_text.jpg")
    return overlay_img

### kick the tires before we start

for i in tqdm(range(2)):
    app(np.array(Image.fromarray(np.zeros((1024,1024,3), dtype=np.uint8))), default_inpainting_prompt).save("zeros_inpainting_oneshot.jpg")

#ideally:
#iface = gr.Interface(app, gr.Image(sources=["webcam"], streaming=True), "image", live=True)
iface = gr.Interface(app, [gr.Image(), gr.Textbox(value=default_inpainting_prompt)], "image")
iface.launch()