|
import gradio as gr |
|
|
|
import torch |
|
from fastai.vision.all import * |
|
from PIL import ImageFilter, ImageEnhance, ImageDraw |
|
from diffusers.utils import make_image_grid |
|
from tqdm import tqdm |
|
from diffusers import AutoPipelineForInpainting, LCMScheduler, DDIMScheduler |
|
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel |
|
import numpy as np |
|
from PIL import Image |
|
from datetime import datetime |
|
|
|
preferred_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
preferred_dtype = torch.float32 if preferred_device == 'cpu' else torch.float16 |
|
|
|
def label_func(fn): return path/"labels"/f"{fn.stem}_P{fn.suffix}" |
|
|
|
segmodel = load_learner("camvid-512.pkl") |
|
|
|
inpainting_pipeline = AutoPipelineForInpainting.from_pretrained( |
|
model="runwayml/stable-diffusion-inpainting", |
|
revision="fp16", |
|
torch_dtype=preferred_dtype, |
|
).to(preferred_device) |
|
|
|
working_size = (512, 512) |
|
|
|
default_inpainting_prompt = "watercolor of a leafy pedestrian mall at golden hour with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing" |
|
|
|
seg_vocabulary = ['Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building', 'Car', |
|
'CartLuggagePram', 'Child', 'Column_Pole', 'Fence', 'LaneMkgsDriv', |
|
'LaneMkgsNonDriv', 'Misc_Text', 'MotorcycleScooter', 'OtherMoving', |
|
'ParkingBlock', 'Pedestrian', 'Road', 'RoadShoulder', 'Sidewalk', |
|
'SignSymbol', 'Sky', 'SUVPickupTruck', 'TrafficCone', |
|
'TrafficLight', 'Train', 'Tree', 'Truck_Bus', 'Tunnel', |
|
'VegetationMisc', 'Void', 'Wall'] |
|
|
|
ban_cars_mask = np.array([0, 0, 0, 0, 0, 1, |
|
0, 0, 1, 0, 1, |
|
1, 1, 0, 0, |
|
1, 0, 1, 1, 1, |
|
1, 0, 1, 1, |
|
1, 0, 0, 0, 1, |
|
0, 1, 0], dtype=np.uint8) |
|
|
|
def get_seg_mask(img): |
|
mask = segmodel.predict(img)[0] |
|
return mask |
|
|
|
|
|
def app(img, prompt): |
|
start_time = datetime.now().timestamp() |
|
old_size = Image.fromarray(img).size |
|
img = np.array(Image.fromarray(img).resize(working_size)) |
|
mask = ban_cars_mask[get_seg_mask(img)] |
|
mask = mask * 255 |
|
mask_time = datetime.now().timestamp() |
|
overlay_img = inpainting_pipeline( |
|
prompt=prompt, |
|
image=Image.fromarray(img), |
|
mask=Image.fromarray(mask), |
|
strength=0.95, |
|
num_inference_steps=13, |
|
).images[0] |
|
end_time = datetime.now().timestamp() |
|
draw = ImageDraw.Draw(overlay_img) |
|
|
|
prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))]) |
|
|
|
draw.text((50, 10), f"Old size: {old_size}\nTotal duration: {int(1000 * (end_time - start_time))}ms\nSegmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))} \n<{prompt}>", fill=(255, 255, 255)) |
|
return overlay_img |
|
|
|
|
|
|
|
iface = gr.Interface(app, [gr.Image(), gr.Textbox(value=default_inpainting_prompt)], "image") |
|
iface.launch() |
|
|