import gradio as gr import torch from PIL import Image, ImageFilter, ImageEnhance, ImageDraw from diffusers import LCMScheduler, StableDiffusionInpaintPipeline from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation from tqdm import tqdm import numpy as np from datetime import datetime # ideally: # preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") # but segformer does not work on mps lolololol seg_preferred_device = "cuda" if torch.cuda.is_available() else "cpu" seg_preferred_dtype = torch.float32 # torch.float16 if seg_preferred_device == 'cuda' else torch.float32 inpaint_preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") inpaint_preferred_dtype = torch.float32 if inpaint_preferred_device == 'cpu' else torch.float16 torch.backends.cuda.matmul.allow_tf32 = True print(f"backends: {torch._dynamo.list_backends()}") preferred_backend = "aot_eager" if inpaint_preferred_device == "mps" else "inductor" seg_model_img_size = 768 seg_model_size = 0 seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}") seg_model = SegformerForSemanticSegmentation.from_pretrained( f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}" ).to(seg_preferred_device).to(seg_preferred_dtype) inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained( "SimianLuo/LCM_Dreamshaper_v7", torch_dtype=inpaint_preferred_dtype, safety_checker=None, ).to(inpaint_preferred_device) inpainting_pipeline.text_encoder = torch.compile(inpainting_pipeline.text_encoder, backend=preferred_backend) inpainting_pipeline.unet = torch.compile(inpainting_pipeline.unet, backend=preferred_backend) inpainting_pipeline.vae = torch.compile(inpainting_pipeline.vae, backend=preferred_backend) seg_model = torch.compile(seg_model, backend=preferred_backend) seg_working_size = (seg_model_img_size, seg_model_img_size) default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing" seg_vocabulary = seg_model.config.label2id print(f"vocab: {seg_vocabulary}") ban_cars_mask = [0] * len(seg_vocabulary) banned_classes = ["car", "road", "sidewalk", "traffic light", "traffic sign"] for c in banned_classes: ban_cars_mask[seg_vocabulary[c]] = 1 ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8) def get_seg_mask(img): inputs = seg_feature_extractor(images=img, return_tensors="pt").to(seg_preferred_device).to(seg_preferred_dtype) outputs = seg_model(**inputs) logits = outputs.logits[0] mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255) blurred_widened_mask = ImageEnhance.Contrast(mask.filter(ImageFilter.GaussianBlur(2))).enhance(9000) return blurred_widened_mask def app(img, prompt, num_inference_steps, seed, inpaint_size): start_time = datetime.now().timestamp() img = np.array(Image.fromarray(img).resize(seg_working_size)) mask = get_seg_mask(img) # mask.save("mask.jpg") mask_time = datetime.now().timestamp() #print(prompt.__class__, img.__class__, mask.__class__, img.shape, mask.shape, mask.dtype, img.dtype) overlay_img = inpainting_pipeline( prompt=prompt, image=Image.fromarray(img).resize((inpaint_size, inpaint_size)), mask_image=(mask).resize((inpaint_size, inpaint_size)), strength=1, num_inference_steps=num_inference_steps, height=inpaint_size, width=inpaint_size, generator=torch.manual_seed(int(seed)), ).images[0] #overlay_img.save("overlay_raw.jpg") end_time = datetime.now().timestamp() draw = ImageDraw.Draw(overlay_img) # replace spaces with newlines after many words to line break prompt prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))]) draw.text((10, 50), "\n".join([ f"Total duration: {int(1000 * (end_time - start_time))}ms", f"Inference steps: {num_inference_steps}", f"Segmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))}", f"<{prompt}>" ]), fill=(0, 255, 0)) #overlay_img.save("overlay_with_text.jpg") return overlay_img # warmup, for compiling and then for timing for size in [384,512]: for i in range(2): for j in tqdm(range(3 ** i)): app(np.array(Image.fromarray(np.zeros((1024,1024,3), dtype=np.uint8))), default_inpainting_prompt, 4, 42, size).save("zeros_inpainting_oneshot.jpg") #ideally: #iface = gr.Interface(app, gr.Image(sources=["webcam"], streaming=True), "image", live=True) iface = gr.Interface(app, [ gr.Image(), gr.Textbox(value=default_inpainting_prompt), gr.Number(minimum=1, maximum=8, value=4), gr.Number(value=42), gr.Number(value=512, maximum=seg_model_img_size,) ], "image") iface.launch(share=True)