CV-Agent / utils.py
Samarth991's picture
adding CV agent file
0e78cbf
raw
history blame
979 Bytes
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import cm
import torch
def draw_panoptic_segmentation(model,segmentation, segments_info):
# get the used color map
viridis = cm.get_cmap('viridis', torch.max(segmentation))
fig, ax = plt.subplots()
ax.imshow(segmentation.cpu().numpy())
instances_counter = defaultdict(int)
handles = []
# for each segment, draw its legend
for segment in segments_info:
segment_id = segment['id']
segment_label_id = segment['label_id']
segment_label = model.config.id2label[segment_label_id]
label = f"{segment_label}-{instances_counter[segment_label_id]}"
instances_counter[segment_label_id] += 1
color = viridis(segment_id)
handles.append(mpatches.Patch(color=color, label=label))
# ax.legend(handles=handles)
fig.savefig('final_mask.png')
return 'final_mask.png'