File size: 2,806 Bytes
f7c8faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17defce
e7309a2
f7c8faa
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import glob
import torch
import pickle
from PIL import Image, ImageDraw
import numpy as np
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation

import numpy as np
from scipy.ndimage import center_of_mass


def combine_ims(im1, im2, val=128):
  p = Image.new("L", im1.size, val)
  im = Image.composite(im1, im2, p)
  return im

def get_class_centers(segmentation_mask, class_dict):
    segmentation_mask = segmentation_mask.numpy() + 1
    class_centers = {}
    for class_index, _ in class_dict.items():
        class_mask = (segmentation_mask == class_index).astype(int)
        center_of_mass_list = center_of_mass(class_mask)
        
        class_centers[class_index] = center_of_mass_list
    
    class_centers = {k:list(map(int, v)) for k,v in class_centers.items() if not np.isnan(sum(v))}
    return class_centers

def visualize_mask(predicted_semantic_map, class_ids, class_colors):
  h, w = predicted_semantic_map.shape
  color_indexes = np.zeros((h, w), dtype=np.uint8)
  color_indexes[:] = predicted_semantic_map.numpy()
  color_indexes = color_indexes.flatten()

  colors = class_colors[class_ids[color_indexes]]
  output = colors.reshape(h, w, 3).astype(np.uint8)
  image_mask = Image.fromarray(output)
  return image_mask


def get_out_image(image, predicted_semantic_map):
  class_centers = get_class_centers(predicted_semantic_map, class_dict) 
  mask = visualize_mask(predicted_semantic_map, class_ids, class_colors)
  image_mask = combine_ims(image, mask, val=128)
  draw = ImageDraw.Draw(image_mask)
  for id, (y, x) in class_centers.items():
    draw.text((x, y), str(class_names[id-1]), fill='black')

  return image_mask

def gradio_process(image):
  inputs = processor(images=image, return_tensors="pt")

  with torch.no_grad():
      outputs = model(**inputs)

  predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]

  out_image = get_out_image(image, predicted_semantic_map)
  return out_image

with open('ade20k_classes.pickle', 'rb') as f:
  class_names, class_ids, class_colors = pickle.load(f)
class_names, class_ids, class_colors = np.array(class_names), np.array(class_ids), np.array(class_colors)
class_dict = dict(zip(class_ids, class_names))

device = torch.device("cpu")
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic").to(device)
model.eval()

demo = gr.Interface(
    gradio_process, 
    inputs=gr.inputs.Image(type="pil"), 
    outputs=gr.outputs.Image(type="pil"),
    title="Semantic Interior Segmentation",
    examples=glob.glob('./examples/*.jpg'),
    allow_flagging="never",

)

demo.launch()