|
import gradio as gr |
|
|
|
import torch |
|
from PIL import Image, ImageFilter, ImageEnhance, ImageDraw |
|
from diffusers import LCMScheduler, StableDiffusionInpaintPipeline |
|
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation |
|
|
|
from tqdm import tqdm |
|
import numpy as np |
|
from datetime import datetime |
|
|
|
|
|
|
|
|
|
preferred_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
preferred_dtype = torch.float16 if preferred_device == 'cuda' else torch.float32 |
|
|
|
seg_model_img_size = 768 |
|
seg_model_size = 0 |
|
|
|
seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}") |
|
seg_model = SegformerForSemanticSegmentation.from_pretrained( |
|
f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}" |
|
).to(preferred_device).to(preferred_dtype) |
|
|
|
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-inpainting", |
|
variant="fp16", |
|
torch_dtype=preferred_dtype, |
|
safety_checker=None, |
|
).to(preferred_device) |
|
|
|
inpainting_pipeline.scheduler = LCMScheduler.from_config(inpainting_pipeline.scheduler.config) |
|
inpainting_pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5") |
|
inpainting_pipeline.fuse_lora() |
|
|
|
working_size = (seg_model_img_size, seg_model_img_size) |
|
|
|
default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing" |
|
|
|
seg_vocabulary = seg_model.config.label2id |
|
print(f"vocab: {seg_vocabulary}") |
|
|
|
ban_cars_mask = [0] * len(seg_vocabulary) |
|
banned_classes = ["car", "road", "sidewalk", "traffic light", "traffic sign"] |
|
for c in banned_classes: |
|
ban_cars_mask[seg_vocabulary[c]] = 1 |
|
ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8) |
|
|
|
|
|
def get_seg_mask(img): |
|
inputs = seg_feature_extractor(images=img, return_tensors="pt").to(preferred_device) |
|
outputs = seg_model(**inputs) |
|
logits = outputs.logits[0] |
|
mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255) |
|
blurred_widened_mask = ImageEnhance.Contrast(mask.filter(ImageFilter.GaussianBlur(5))).enhance(9000) |
|
return blurred_widened_mask |
|
|
|
|
|
def app(img, prompt): |
|
start_time = datetime.now().timestamp() |
|
old_size = Image.fromarray(img).size |
|
img = np.array(Image.fromarray(img).resize(working_size)) |
|
mask = get_seg_mask(img) |
|
mask.save("mask.jpg") |
|
mask_time = datetime.now().timestamp() |
|
|
|
overlay_img = inpainting_pipeline( |
|
prompt=prompt, |
|
image=Image.fromarray(img), |
|
mask_image=(mask), |
|
strength=0.95, |
|
num_inference_steps=4, |
|
).images[0] |
|
|
|
end_time = datetime.now().timestamp() |
|
draw = ImageDraw.Draw(overlay_img) |
|
|
|
prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))]) |
|
|
|
draw.text((10, 50), f"Old size: {old_size}\nTotal duration: {int(1000 * (end_time - start_time))}ms\nSegmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))} \n<{prompt}>", fill=(0, 255, 0)) |
|
|
|
return overlay_img |
|
|
|
|
|
|
|
for i in tqdm(range(2)): |
|
app(np.array(Image.fromarray(np.zeros((1024,1024,3), dtype=np.uint8))), default_inpainting_prompt).save("zeros_inpainting_oneshot.jpg") |
|
|
|
|
|
|
|
iface = gr.Interface(app, [gr.Image(), gr.Textbox(value=default_inpainting_prompt)], "image") |
|
iface.launch() |
|
|