File size: 4,103 Bytes
76c920d
 
cd4ffa3
829ed2e
 
 
 
cd4ffa3
 
f2943ab
76c920d
829ed2e
 
 
 
182dec8
cd4ffa3
829ed2e
 
f2943ab
829ed2e
 
 
 
10af47d
7c3d7a0
b80c840
4cc3150
e674f0f
829ed2e
e674f0f
 
829ed2e
 
 
 
 
cd4ffa3
12d035e
4bcc00f
829ed2e
 
 
 
 
 
 
 
 
cd4ffa3
 
829ed2e
 
 
 
 
 
cd4ffa3
a8bc6ca
4bcc00f
f2943ab
e34fa5d
2d106b2
829ed2e
 
e674f0f
829ed2e
e674f0f
 
829ed2e
 
e674f0f
829ed2e
e674f0f
829ed2e
f2943ab
 
e674f0f
 
 
829ed2e
 
f2943ab
a09c55b
829ed2e
 
 
 
 
5533283
 
4bcc00f
76c920d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr

import torch
from PIL import Image, ImageFilter, ImageEnhance, ImageDraw
from diffusers import LCMScheduler, StableDiffusionInpaintPipeline
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation

from tqdm import tqdm
import numpy as np
from datetime import datetime

# ideally:
# preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
# but segformer does not work on mps lolololol
preferred_device = "cuda" if torch.cuda.is_available() else "cpu"
preferred_dtype = torch.float32 # torch.float16 if preferred_device == 'cuda' else torch.float32

seg_model_img_size = 768
seg_model_size = 0

seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}")
seg_model = SegformerForSemanticSegmentation.from_pretrained(
    f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}"
).to(preferred_device).to(preferred_dtype)

inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
    "runwayml/stable-diffusion-inpainting",
    variant="fp16",
    torch_dtype=preferred_dtype,
    safety_checker=None,
).to(preferred_device)

inpainting_pipeline.scheduler = LCMScheduler.from_config(inpainting_pipeline.scheduler.config)
inpainting_pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
inpainting_pipeline.fuse_lora()

working_size = (seg_model_img_size, seg_model_img_size)

default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing"

seg_vocabulary = seg_model.config.label2id
print(f"vocab: {seg_vocabulary}")

ban_cars_mask = [0] * len(seg_vocabulary)
banned_classes = ["car", "road", "sidewalk", "traffic light", "traffic sign"]
for c in banned_classes:
    ban_cars_mask[seg_vocabulary[c]] = 1
ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8)


def get_seg_mask(img):
    inputs = seg_feature_extractor(images=img, return_tensors="pt").to(preferred_device)
    outputs = seg_model(**inputs)
    logits = outputs.logits[0]
    mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255)
    blurred_widened_mask = ImageEnhance.Contrast(mask.filter(ImageFilter.GaussianBlur(5))).enhance(9000)
    return blurred_widened_mask


def app(img, prompt):
    start_time = datetime.now().timestamp()
    old_size = Image.fromarray(img).size
    img = np.array(Image.fromarray(img).resize(working_size))
    mask = get_seg_mask(img)
    mask.save("mask.jpg")
    mask_time = datetime.now().timestamp()
    #print(prompt.__class__, img.__class__, mask.__class__, img.shape, mask.shape, mask.dtype, img.dtype)
    overlay_img = inpainting_pipeline(
        prompt=prompt,
        image=Image.fromarray(img),
        mask_image=(mask),
        strength=0.95,
        num_inference_steps=4,
    ).images[0]
    #overlay_img.save("overlay_raw.jpg")
    end_time = datetime.now().timestamp()
    draw = ImageDraw.Draw(overlay_img)
    # replace spaces with newlines after many words to line break prompt
    prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))])

    draw.text((10, 50), f"Old size: {old_size}\nTotal duration: {int(1000 * (end_time - start_time))}ms\nSegmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))} \n<{prompt}>", fill=(0, 255, 0))
    #overlay_img.save("overlay_with_text.jpg")
    return overlay_img

### kick the tires before we start

for i in tqdm(range(2)):
    app(np.array(Image.fromarray(np.zeros((1024,1024,3), dtype=np.uint8))), default_inpainting_prompt).save("zeros_inpainting_oneshot.jpg")

#ideally:
#iface = gr.Interface(app, gr.Image(sources=["webcam"], streaming=True), "image", live=True)
iface = gr.Interface(app, [gr.Image(), gr.Textbox(value=default_inpainting_prompt)], "image")
iface.launch()