Spaces:
Running
Running
from collections import defaultdict | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as mpatches | |
from matplotlib import cm | |
import torch | |
def draw_panoptic_segmentation(model,segmentation, segments_info): | |
# get the used color map | |
viridis = cm.get_cmap('viridis', torch.max(segmentation)) | |
fig, ax = plt.subplots() | |
ax.imshow(segmentation.cpu().numpy()) | |
instances_counter = defaultdict(int) | |
handles = [] | |
# for each segment, draw its legend | |
for segment in segments_info: | |
segment_id = segment['id'] | |
segment_label_id = segment['label_id'] | |
segment_label = model.config.id2label[segment_label_id] | |
label = f"{segment_label}-{instances_counter[segment_label_id]}" | |
instances_counter[segment_label_id] += 1 | |
color = viridis(segment_id) | |
handles.append(mpatches.Patch(color=color, label=label)) | |
# ax.legend(handles=handles) | |
fig.savefig('final_mask.png') | |
return 'final_mask.png' |