File size: 5,264 Bytes
76c920d cd4ffa3 829ed2e cd4ffa3 f2943ab 76c920d 829ed2e 804539d 2f284ac c8e54f6 804539d c8e54f6 f84aa8c 804539d cd4ffa3 829ed2e f2943ab 829ed2e 804539d 10af47d 7c3d7a0 c8e54f6 804539d 829ed2e c8e54f6 e674f0f 804539d 829ed2e 5079795 cd4ffa3 12d035e 4bcc00f 829ed2e cd4ffa3 804539d 829ed2e c8e54f6 829ed2e cd4ffa3 a8bc6ca 34895c9 f2943ab 5079795 829ed2e 5079795 e674f0f 829ed2e e674f0f c734fae c8e54f6 34895c9 c8e54f6 e674f0f 829ed2e f2943ab e674f0f be54e00 829ed2e f2943ab a09c55b c8e54f6 829ed2e be54e00 829ed2e 5533283 34895c9 be54e00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
import torch
from PIL import Image, ImageFilter, ImageEnhance, ImageDraw
from diffusers import LCMScheduler, StableDiffusionInpaintPipeline
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from tqdm import tqdm
import numpy as np
from datetime import datetime
# ideally:
# preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
# but segformer does not work on mps lolololol
seg_preferred_device = "cuda" if torch.cuda.is_available() else "cpu"
seg_preferred_dtype = torch.float32 # torch.float16 if seg_preferred_device == 'cuda' else torch.float32
inpaint_preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
inpaint_preferred_dtype = torch.float32 if inpaint_preferred_device == 'cpu' else torch.float16
torch.backends.cuda.matmul.allow_tf32 = True
print(f"backends: {torch._dynamo.list_backends()}")
preferred_backend = "aot_eager" if inpaint_preferred_device == "mps" else "inductor"
seg_model_img_size = 768
seg_model_size = 0
seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}")
seg_model = SegformerForSemanticSegmentation.from_pretrained(
f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}"
).to(seg_preferred_device).to(seg_preferred_dtype)
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"SimianLuo/LCM_Dreamshaper_v7",
torch_dtype=inpaint_preferred_dtype,
safety_checker=None,
).to(inpaint_preferred_device)
inpainting_pipeline.text_encoder = torch.compile(inpainting_pipeline.text_encoder, backend=preferred_backend)
inpainting_pipeline.unet = torch.compile(inpainting_pipeline.unet, backend=preferred_backend)
inpainting_pipeline.vae = torch.compile(inpainting_pipeline.vae, backend=preferred_backend)
seg_model = torch.compile(seg_model, backend=preferred_backend)
seg_working_size = (seg_model_img_size, seg_model_img_size)
default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing"
seg_vocabulary = seg_model.config.label2id
print(f"vocab: {seg_vocabulary}")
ban_cars_mask = [0] * len(seg_vocabulary)
banned_classes = ["car", "road", "sidewalk", "traffic light", "traffic sign"]
for c in banned_classes:
ban_cars_mask[seg_vocabulary[c]] = 1
ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8)
def get_seg_mask(img):
inputs = seg_feature_extractor(images=img, return_tensors="pt").to(seg_preferred_device).to(seg_preferred_dtype)
outputs = seg_model(**inputs)
logits = outputs.logits[0]
mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255)
blurred_widened_mask = ImageEnhance.Contrast(mask.filter(ImageFilter.GaussianBlur(2))).enhance(9000)
return blurred_widened_mask
def app(img, prompt, num_inference_steps, seed, inpaint_size):
start_time = datetime.now().timestamp()
img = np.array(Image.fromarray(img).resize(seg_working_size))
mask = get_seg_mask(img)
# mask.save("mask.jpg")
mask_time = datetime.now().timestamp()
#print(prompt.__class__, img.__class__, mask.__class__, img.shape, mask.shape, mask.dtype, img.dtype)
overlay_img = inpainting_pipeline(
prompt=prompt,
image=Image.fromarray(img).resize((inpaint_size, inpaint_size)),
mask_image=(mask).resize((inpaint_size, inpaint_size)),
strength=1,
num_inference_steps=num_inference_steps,
height=inpaint_size,
width=inpaint_size,
generator=torch.manual_seed(int(seed)),
).images[0]
#overlay_img.save("overlay_raw.jpg")
end_time = datetime.now().timestamp()
draw = ImageDraw.Draw(overlay_img)
# replace spaces with newlines after many words to line break prompt
prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))])
draw.text((10, 50), "\n".join([
f"Total duration: {int(1000 * (end_time - start_time))}ms",
f"Inference steps: {num_inference_steps}",
f"Segmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))}",
f"<{prompt}>"
]), fill=(0, 255, 0))
#overlay_img.save("overlay_with_text.jpg")
return overlay_img
# warmup, for compiling and then for timing
for size in [384,512]:
for i in range(2):
for j in tqdm(range(3 ** i)):
app(np.array(Image.fromarray(np.zeros((1024,1024,3), dtype=np.uint8))), default_inpainting_prompt, 4, 42, size).save("zeros_inpainting_oneshot.jpg")
#ideally:
#iface = gr.Interface(app, gr.Image(sources=["webcam"], streaming=True), "image", live=True)
iface = gr.Interface(app, [
gr.Image(),
gr.Textbox(value=default_inpainting_prompt),
gr.Number(minimum=1, maximum=8, value=4),
gr.Number(value=42),
gr.Number(value=512, maximum=seg_model_img_size,)
],
"image")
iface.launch(share=True)
|