|
import gradio as gr |
|
|
|
import torch |
|
from PIL import Image, ImageFilter, ImageEnhance, ImageDraw |
|
from diffusers import LCMScheduler, StableDiffusionInpaintPipeline |
|
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation |
|
|
|
from tqdm import tqdm |
|
import numpy as np |
|
from datetime import datetime |
|
|
|
|
|
|
|
|
|
preferred_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
preferred_dtype = torch.float16 if preferred_device == 'cuda' else torch.float32 |
|
inpaint_preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
print(f"backends: {torch._dynamo.list_backends()}") |
|
|
|
preferred_backend = "aot_eager" if inpaint_preferred_device == "mps" else ("onnxrt" if inpaint_preferred_device == "cuda" else "inductor") |
|
|
|
seg_model_img_size = 768 |
|
seg_model_size = 0 |
|
|
|
seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}") |
|
seg_model = SegformerForSemanticSegmentation.from_pretrained( |
|
f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}" |
|
).to(preferred_device) |
|
|
|
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained( |
|
"SimianLuo/LCM_Dreamshaper_v7", |
|
torch_dtype=preferred_dtype, |
|
safety_checker=None, |
|
).to(inpaint_preferred_device) |
|
|
|
|
|
|
|
|
|
|
|
seg_working_size = (seg_model_img_size, seg_model_img_size) |
|
|
|
default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing" |
|
|
|
seg_vocabulary = seg_model.config.label2id |
|
print(f"vocab: {seg_vocabulary}") |
|
|
|
ban_cars_mask = [0] * len(seg_vocabulary) |
|
banned_classes = ["car", "road", "sidewalk", "traffic light", "traffic sign"] |
|
for c in banned_classes: |
|
ban_cars_mask[seg_vocabulary[c]] = 1 |
|
ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8) |
|
|
|
|
|
def get_seg_mask(img): |
|
inputs = seg_feature_extractor(images=img, return_tensors="pt").to(preferred_device) |
|
outputs = seg_model(**inputs) |
|
logits = outputs.logits[0] |
|
mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255) |
|
blurred_widened_mask = ImageEnhance.Contrast(mask.filter(ImageFilter.GaussianBlur(2))).enhance(9000) |
|
return blurred_widened_mask |
|
|
|
|
|
def app(img, prompt, num_inference_steps, seed, inpaint_size): |
|
start_time = datetime.now().timestamp() |
|
old_size = Image.fromarray(img).size |
|
img = np.array(Image.fromarray(img).resize(seg_working_size)) |
|
mask = get_seg_mask(img) |
|
|
|
mask_time = datetime.now().timestamp() |
|
|
|
overlay_img = inpainting_pipeline( |
|
prompt=prompt, |
|
image=Image.fromarray(img).resize(inpaint_size), |
|
mask_image=(mask).resize(inpaint_size), |
|
strength=1, |
|
num_inference_steps=num_inference_steps, |
|
height=inpaint_size, |
|
width=inpaint_size, |
|
generator=torch.manual_seed(int(seed)), |
|
).images[0] |
|
|
|
end_time = datetime.now().timestamp() |
|
draw = ImageDraw.Draw(overlay_img) |
|
|
|
prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))]) |
|
|
|
draw.text((10, 50), f"Old size: {old_size}\nTotal duration: {int(1000 * (end_time - start_time))}ms\nSegmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))} \n<{prompt}>", fill=(0, 255, 0)) |
|
|
|
return overlay_img |
|
|
|
|
|
|
|
for i in range(2): |
|
for j in tqdm(range(3 ** i)): |
|
app(np.array(Image.fromarray(np.zeros((1024,1024,3), dtype=np.uint8))), default_inpainting_prompt, 4, 42, 512).save("zeros_inpainting_oneshot.jpg") |
|
|
|
|
|
|
|
iface = gr.Interface(app, [ |
|
gr.Image(), |
|
gr.Textbox(value=default_inpainting_prompt), |
|
gr.Number(minimum=1, maximum=8, value=4), |
|
gr.Number(value=42), |
|
gr.Number(value=512, maximum=seg_model_img_size,) |
|
], |
|
"image") |
|
iface.launch() |
|
|