Update segment.py
Browse files- segment.py +8 -0
segment.py
CHANGED
@@ -11,6 +11,7 @@ import numpy as np
|
|
11 |
import argparse
|
12 |
import matplotlib
|
13 |
import gradio as gr
|
|
|
14 |
|
15 |
def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
|
16 |
if type(image_path) is str:
|
@@ -52,6 +53,8 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
|
|
52 |
if torch.min(segmentation) == 0:
|
53 |
mask = segmentation==0
|
54 |
mask = mask.cpu().detach().numpy() # [512,512] bool
|
|
|
|
|
55 |
segment_label = "rest"
|
56 |
color = viridis(0)
|
57 |
label = f"{segment_label}-{0}"
|
@@ -65,6 +68,8 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
|
|
65 |
if torch.min(segmentation) != 0:
|
66 |
segment_id -= 1
|
67 |
mask = mask.cpu().detach().numpy() # [512,512] bool
|
|
|
|
|
68 |
mask_np_list.append(mask)
|
69 |
segment_label = model.config.id2label[segment['label_id']]
|
70 |
instances_counter[segment['label_id']] += 1
|
@@ -76,6 +81,9 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
|
|
76 |
label_list.append(label)
|
77 |
else:
|
78 |
mask = np.full(segmentation.shape, True)
|
|
|
|
|
|
|
79 |
segment_label = "all"
|
80 |
mask_np_list.append(mask)
|
81 |
color = viridis(0)
|
|
|
11 |
import argparse
|
12 |
import matplotlib
|
13 |
import gradio as gr
|
14 |
+
import cv2
|
15 |
|
16 |
def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
|
17 |
if type(image_path) is str:
|
|
|
53 |
if torch.min(segmentation) == 0:
|
54 |
mask = segmentation==0
|
55 |
mask = mask.cpu().detach().numpy() # [512,512] bool
|
56 |
+
print(mask.shape)
|
57 |
+
mask = cv2.resize(mask,(512,512))
|
58 |
segment_label = "rest"
|
59 |
color = viridis(0)
|
60 |
label = f"{segment_label}-{0}"
|
|
|
68 |
if torch.min(segmentation) != 0:
|
69 |
segment_id -= 1
|
70 |
mask = mask.cpu().detach().numpy() # [512,512] bool
|
71 |
+
print(mask.shape)
|
72 |
+
mask = cv2.resize(mask,(512,512))
|
73 |
mask_np_list.append(mask)
|
74 |
segment_label = model.config.id2label[segment['label_id']]
|
75 |
instances_counter[segment['label_id']] += 1
|
|
|
81 |
label_list.append(label)
|
82 |
else:
|
83 |
mask = np.full(segmentation.shape, True)
|
84 |
+
print(mask.shape)
|
85 |
+
mask = cv2.resize(mask,(512,512))
|
86 |
+
|
87 |
segment_label = "all"
|
88 |
mask_np_list.append(mask)
|
89 |
color = viridis(0)
|