File size: 4,643 Bytes
76c920d cd4ffa3 829ed2e cd4ffa3 f2943ab 76c920d 829ed2e 79b99e0 cd4ffa3 829ed2e f2943ab 829ed2e 79b99e0 10af47d 7c3d7a0 b80c840 4cc3150 e674f0f 829ed2e e674f0f 926718d dffcbe0 46253a6 829ed2e 5079795 926718d cd4ffa3 12d035e 4bcc00f 829ed2e cd4ffa3 829ed2e cd4ffa3 a8bc6ca 4bcc00f f2943ab e34fa5d 5079795 829ed2e 5079795 e674f0f 829ed2e e674f0f 5079795 e674f0f 46253a6 a5540bb e674f0f 829ed2e f2943ab e674f0f 829ed2e f2943ab a09c55b 829ed2e 5533283 4bcc00f 76c920d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import gradio as gr
import torch
from PIL import Image, ImageFilter, ImageEnhance, ImageDraw
from diffusers import LCMScheduler, StableDiffusionInpaintPipeline
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from tqdm import tqdm
import numpy as np
from datetime import datetime
# ideally:
# preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
# but segformer does not work on mps lolololol
preferred_device = "cuda" if torch.cuda.is_available() else "cpu"
preferred_dtype = torch.float16 if preferred_device == 'cuda' else torch.float32
seg_model_img_size = 768
seg_model_size = 0
seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}")
seg_model = SegformerForSemanticSegmentation.from_pretrained(
f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}"
).to(preferred_device) #.to(preferred_dtype)
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
variant="fp16",
torch_dtype=preferred_dtype,
safety_checker=None,
).to(preferred_device)
from DeepCache import DeepCacheSDHelper
helper = DeepCacheSDHelper(pipe=inpainting_pipeline)
helper.set_params(cache_interval=3, cache_branch_id=0)
helper.enable()
if preferred_device == "cuda":
inpainting_pipeline.unet = torch.compile(inpainting_pipeline.unet)
inpainting_pipeline.vae = torch.compile(inpainting_pipeline.vae)
# inpainting_pipeline.scheduler = LCMScheduler.from_config(inpainting_pipeline.scheduler.config)
# inpainting_pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5", torch_dtype=preferred_dtype)
# inpainting_pipeline.fuse_lora()
seg_working_size = (seg_model_img_size, seg_model_img_size)
repaint_working_size = (512, 512)
default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing"
seg_vocabulary = seg_model.config.label2id
print(f"vocab: {seg_vocabulary}")
ban_cars_mask = [0] * len(seg_vocabulary)
banned_classes = ["car", "road", "sidewalk", "traffic light", "traffic sign"]
for c in banned_classes:
ban_cars_mask[seg_vocabulary[c]] = 1
ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8)
def get_seg_mask(img):
inputs = seg_feature_extractor(images=img, return_tensors="pt").to(preferred_device)
outputs = seg_model(**inputs)
logits = outputs.logits[0]
mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255)
blurred_widened_mask = ImageEnhance.Contrast(mask.filter(ImageFilter.GaussianBlur(5))).enhance(9000)
return blurred_widened_mask
def app(img, prompt):
start_time = datetime.now().timestamp()
old_size = Image.fromarray(img).size
img = np.array(Image.fromarray(img).resize(seg_working_size))
mask = get_seg_mask(img)
# mask.save("mask.jpg")
mask_time = datetime.now().timestamp()
#print(prompt.__class__, img.__class__, mask.__class__, img.shape, mask.shape, mask.dtype, img.dtype)
overlay_img = inpainting_pipeline(
prompt=prompt,
image=Image.fromarray(img).resize(repaint_working_size),
mask_image=(mask).resize(repaint_working_size),
strength=0.95,
num_inference_steps=16,
height=repaint_working_size[0],
width=repaint_working_size[1],
).images[0]
#overlay_img.save("overlay_raw.jpg")
end_time = datetime.now().timestamp()
draw = ImageDraw.Draw(overlay_img)
# replace spaces with newlines after many words to line break prompt
prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))])
draw.text((10, 50), f"Old size: {old_size}\nTotal duration: {int(1000 * (end_time - start_time))}ms\nSegmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))} \n<{prompt}>", fill=(0, 255, 0))
#overlay_img.save("overlay_with_text.jpg")
return overlay_img
### kick the tires before we start
for i in tqdm(range(2)):
app(np.array(Image.fromarray(np.zeros((1024,1024,3), dtype=np.uint8))), default_inpainting_prompt).save("zeros_inpainting_oneshot.jpg")
#ideally:
#iface = gr.Interface(app, gr.Image(sources=["webcam"], streaming=True), "image", live=True)
iface = gr.Interface(app, [gr.Image(), gr.Textbox(value=default_inpainting_prompt)], "image")
iface.launch()
|