import gradio as gr import torch from PIL import Image, ImageFilter, ImageEnhance, ImageDraw from diffusers import LCMScheduler, StableDiffusionInpaintPipeline from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation from tqdm import tqdm import numpy as np from datetime import datetime # ideally: # preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") # but segformer does not work on mps lolololol preferred_device = "cuda" if torch.cuda.is_available() else "cpu" preferred_dtype = torch.float16 if preferred_device == 'cuda' else torch.float32 inpaint_preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") torch.backends.cuda.matmul.allow_tf32 = True preferred_backend = "aot_eager" if inpaint_preferred_device == "mps" else ("tensorrt" if inpaint_preferred_device == "cuda" else "inductor") seg_model_img_size = 768 seg_model_size = 0 seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}") seg_model = SegformerForSemanticSegmentation.from_pretrained( f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}" ).to(preferred_device).to(preferred_dtype) inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained( "SimianLuo/LCM_Dreamshaper_v7", torch_dtype=preferred_dtype, safety_checker=None, ).to(inpaint_preferred_device) inpainting_pipeline.unet = torch.compile(inpainting_pipeline.unet, backend=preferred_backend) inpainting_pipeline.vae = torch.compile(inpainting_pipeline.vae, backend=preferred_backend) seg_model = torch.compile(seg_model, backend=preferred_backend) seg_working_size = (seg_model_img_size, seg_model_img_size) repaint_working_size = (768, 768) default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing" seg_vocabulary = seg_model.config.label2id print(f"vocab: {seg_vocabulary}") ban_cars_mask = [0] * len(seg_vocabulary) banned_classes = ["car", "road", "sidewalk", "traffic light", "traffic sign"] for c in banned_classes: ban_cars_mask[seg_vocabulary[c]] = 1 ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8) def get_seg_mask(img): inputs = seg_feature_extractor(images=img, return_tensors="pt").to(preferred_device) outputs = seg_model(**inputs) logits = outputs.logits[0] mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255) blurred_widened_mask = ImageEnhance.Contrast(mask.filter(ImageFilter.GaussianBlur(2))).enhance(9000) return blurred_widened_mask def app(img, prompt, num_inference_steps, seed): start_time = datetime.now().timestamp() old_size = Image.fromarray(img).size img = np.array(Image.fromarray(img).resize(seg_working_size)) mask = get_seg_mask(img) # mask.save("mask.jpg") mask_time = datetime.now().timestamp() #print(prompt.__class__, img.__class__, mask.__class__, img.shape, mask.shape, mask.dtype, img.dtype) overlay_img = inpainting_pipeline( prompt=prompt, image=Image.fromarray(img).resize(repaint_working_size), mask_image=(mask).resize(repaint_working_size), strength=1, num_inference_steps=num_inference_steps, height=repaint_working_size[0], width=repaint_working_size[1], generator=torch.manual_seed(int(seed)), ).images[0] #overlay_img.save("overlay_raw.jpg") end_time = datetime.now().timestamp() draw = ImageDraw.Draw(overlay_img) # replace spaces with newlines after many words to line break prompt prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))]) draw.text((10, 50), f"Old size: {old_size}\nTotal duration: {int(1000 * (end_time - start_time))}ms\nSegmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))} \n<{prompt}>", fill=(0, 255, 0)) #overlay_img.save("overlay_with_text.jpg") return overlay_img # warmup, for compiling and then for timing for i in range(2): for j in tqdm(range(3 ** i)): app(np.array(Image.fromarray(np.zeros((1024,1024,3), dtype=np.uint8))), default_inpainting_prompt, 4, 42).save("zeros_inpainting_oneshot.jpg") #ideally: #iface = gr.Interface(app, gr.Image(sources=["webcam"], streaming=True), "image", live=True) iface = gr.Interface(app, [gr.Image(), gr.Textbox(value=default_inpainting_prompt), gr.Number(minimum=1, maximum=8, value=4), gr.Number(value=42)], "image") iface.launch()