import gradio as gr import torch from PIL import Image, ImageFilter, ImageEnhance, ImageDraw from diffusers import LCMScheduler, StableDiffusionInpaintPipeline from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation from tqdm import tqdm import numpy as np from datetime import datetime # ideally: # preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") # but segformer does not work on mps lolololol preferred_device = "cuda" if torch.cuda.is_available() else "cpu" preferred_dtype = torch.float16 if preferred_device == 'cuda' else torch.float32 seg_model_img_size = 768 seg_model_size = 0 seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}") seg_model = SegformerForSemanticSegmentation.from_pretrained( f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}" ).to(preferred_device) #.to(preferred_dtype) inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", variant="fp16", torch_dtype=preferred_dtype, safety_checker=None, ).to(preferred_device) inpainting_pipeline.unet = torch.compile(inpainting_pipeline.unet) inpainting_pipeline.vae = torch.compile(inpainting_pipeline.vae) # inpainting_pipeline.scheduler = LCMScheduler.from_config(inpainting_pipeline.scheduler.config) # inpainting_pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5", torch_dtype=preferred_dtype) # inpainting_pipeline.fuse_lora() seg_working_size = (seg_model_img_size, seg_model_img_size) repaint_working_size = (384, 384) default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing" seg_vocabulary = seg_model.config.label2id print(f"vocab: {seg_vocabulary}") ban_cars_mask = [0] * len(seg_vocabulary) banned_classes = ["car", "road", "sidewalk", "traffic light", "traffic sign"] for c in banned_classes: ban_cars_mask[seg_vocabulary[c]] = 1 ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8) def get_seg_mask(img): inputs = seg_feature_extractor(images=img, return_tensors="pt").to(preferred_device) outputs = seg_model(**inputs) logits = outputs.logits[0] mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255) blurred_widened_mask = ImageEnhance.Contrast(mask.filter(ImageFilter.GaussianBlur(5))).enhance(9000) return blurred_widened_mask def app(img, prompt): start_time = datetime.now().timestamp() old_size = Image.fromarray(img).size img = np.array(Image.fromarray(img).resize(seg_working_size)) mask = get_seg_mask(img) # mask.save("mask.jpg") mask_time = datetime.now().timestamp() #print(prompt.__class__, img.__class__, mask.__class__, img.shape, mask.shape, mask.dtype, img.dtype) overlay_img = inpainting_pipeline( prompt=prompt, image=Image.fromarray(img).resize(repaint_working_size), mask_image=(mask).resize(repaint_working_size), strength=0.95, num_inference_steps=16, height=repaint_working_size[0], width=repaint_working_size[1], ).images[0] #overlay_img.save("overlay_raw.jpg") end_time = datetime.now().timestamp() draw = ImageDraw.Draw(overlay_img) # replace spaces with newlines after many words to line break prompt prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))]) draw.text((10, 50), f"Old size: {old_size}\nTotal duration: {int(1000 * (end_time - start_time))}ms\nSegmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))} \n<{prompt}>", fill=(0, 255, 0)) #overlay_img.save("overlay_with_text.jpg") return overlay_img ### kick the tires before we start for i in tqdm(range(2)): app(np.array(Image.fromarray(np.zeros((1024,1024,3), dtype=np.uint8))), default_inpainting_prompt).save("zeros_inpainting_oneshot.jpg") #ideally: #iface = gr.Interface(app, gr.Image(sources=["webcam"], streaming=True), "image", live=True) iface = gr.Interface(app, [gr.Image(), gr.Textbox(value=default_inpainting_prompt)], "image") iface.launch()