File size: 9,543 Bytes
bcc611b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import os
import csv
import glob
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from avsbench_utils import get_v2_pallete
# categories
categories = [{'supercategory': None, 'id': 1, 'name': 'accordion', 'frequency': 'n'},
{'supercategory': None, 'id': 2, 'name': 'airplane', 'frequency': 'f'},
{'supercategory': None, 'id': 3, 'name': 'axe', 'frequency': 'r'},
{'supercategory': None, 'id': 4, 'name': 'baby', 'frequency': 'f'},
{'supercategory': None, 'id': 5, 'name': 'bassoon', 'frequency': 'n'},
{'supercategory': None, 'id': 6, 'name': 'bell', 'frequency': 'f'},
{'supercategory': None, 'id': 7, 'name': 'bird', 'frequency': 'f'},
{'supercategory': None, 'id': 8, 'name': 'boat', 'frequency': 'f'},
{'supercategory': None, 'id': 9, 'name': 'boy', 'frequency': 'f'},
{'supercategory': None, 'id': 10, 'name': 'bus', 'frequency': 'f'},
{'supercategory': None, 'id': 11, 'name': 'car', 'frequency': 'f'},
{'supercategory': None, 'id': 12, 'name': 'cat', 'frequency': 'f'},
{'supercategory': None, 'id': 13, 'name': 'cello', 'frequency': 'n'},
{'supercategory': None, 'id': 14, 'name': 'clarinet', 'frequency': 'r'},
{'supercategory': None, 'id': 15, 'name': 'clipper', 'frequency': 'r'},
{'supercategory': None, 'id': 16, 'name': 'clock', 'frequency': 'f'},
{'supercategory': None, 'id': 17, 'name': 'dog', 'frequency': 'f'},
{'supercategory': None, 'id': 18, 'name': 'donkey', 'frequency': 'c'},
{'supercategory': None, 'id': 19, 'name': 'drum', 'frequency': 'c'},
{'supercategory': None, 'id': 20, 'name': 'duck', 'frequency': 'f'},
{'supercategory': None, 'id': 21, 'name': 'elephant', 'frequency': 'f'},
{'supercategory': None, 'id': 22, 'name': 'emergency-car', 'frequency': 'n'},
{'supercategory': None, 'id': 23, 'name': 'erhu', 'frequency': 'n'},
{'supercategory': None, 'id': 24, 'name': 'flute', 'frequency': 'n'},
{'supercategory': None, 'id': 25, 'name': 'frying-food', 'frequency': 'n'},
{'supercategory': None, 'id': 26, 'name': 'girl', 'frequency': 'f'},
{'supercategory': None, 'id': 27, 'name': 'goose', 'frequency': 'c'},
{'supercategory': None, 'id': 28, 'name': 'guitar', 'frequency': 'f'},
{'supercategory': None, 'id': 29, 'name': 'gun', 'frequency': 'c'},
{'supercategory': None, 'id': 30, 'name': 'guzheng', 'frequency': 'n'},
{'supercategory': None, 'id': 31, 'name': 'hair-dryer', 'frequency': 'f'},
{'supercategory': None, 'id': 32, 'name': 'handpan', 'frequency': 'n'},
{'supercategory': None, 'id': 33, 'name': 'harmonica', 'frequency': 'n'},
{'supercategory': None, 'id': 34, 'name': 'harp', 'frequency': 'n'},
{'supercategory': None, 'id': 35, 'name': 'helicopter', 'frequency': 'c'},
{'supercategory': None, 'id': 36, 'name': 'hen', 'frequency': 'c'},
{'supercategory': None, 'id': 37, 'name': 'horse', 'frequency': 'f'},
{'supercategory': None, 'id': 38, 'name': 'keyboard', 'frequency': 'f'},
{'supercategory': None, 'id': 39, 'name': 'leopard', 'frequency': 'n'},
{'supercategory': None, 'id': 40, 'name': 'lion', 'frequency': 'c'},
{'supercategory': None, 'id': 41, 'name': 'man', 'frequency': 'f'},
{'supercategory': None, 'id': 42, 'name': 'marimba', 'frequency': 'n'},
{'supercategory': None, 'id': 43, 'name': 'missile-rocket', 'frequency': 'c'},
{'supercategory': None, 'id': 44, 'name': 'motorcycle', 'frequency': 'f'},
{'supercategory': None, 'id': 45, 'name': 'mower', 'frequency': 'r'},
{'supercategory': None, 'id': 46, 'name': 'parrot', 'frequency': 'c'},
{'supercategory': None, 'id': 47, 'name': 'piano', 'frequency': 'f'},
{'supercategory': None, 'id': 48, 'name': 'pig', 'frequency': 'c'},
{'supercategory': None, 'id': 49, 'name': 'pipa', 'frequency': 'n'},
{'supercategory': None, 'id': 50, 'name': 'saw', 'frequency': 'r'},
{'supercategory': None, 'id': 51, 'name': 'saxophone', 'frequency': 'r'},
{'supercategory': None, 'id': 52, 'name': 'sheep', 'frequency': 'f'},
{'supercategory': None, 'id': 53, 'name': 'sitar', 'frequency': 'n'},
{'supercategory': None, 'id': 54, 'name': 'sorna', 'frequency': 'n'},
{'supercategory': None, 'id': 55, 'name': 'squirrel', 'frequency': 'c'},
{'supercategory': None, 'id': 56, 'name': 'tabla', 'frequency': 'n'},
{'supercategory': None, 'id': 57, 'name': 'tank', 'frequency': 'n'},
{'supercategory': None, 'id': 58, 'name': 'tiger', 'frequency': 'c'},
{'supercategory': None, 'id': 59, 'name': 'tractor', 'frequency': 'c'},
{'supercategory': None, 'id': 60, 'name': 'train', 'frequency': 'f'},
{'supercategory': None, 'id': 61, 'name': 'trombone', 'frequency': 'n'},
{'supercategory': None, 'id': 62, 'name': 'truck', 'frequency': 'f'},
{'supercategory': None, 'id': 63, 'name': 'trumpet', 'frequency': 'c'},
{'supercategory': None, 'id': 64, 'name': 'tuba', 'frequency': 'r'},
{'supercategory': None, 'id': 65, 'name': 'ukulele', 'frequency': 'n'},
{'supercategory': None, 'id': 66, 'name': 'utv', 'frequency': 'n'},
{'supercategory': None, 'id': 67, 'name': 'vacuum-cleaner', 'frequency': 'c'},
{'supercategory': None, 'id': 68, 'name': 'violin', 'frequency': 'r'},
{'supercategory': None, 'id': 69, 'name': 'wolf', 'frequency': 'r'},
{'supercategory': None, 'id': 70, 'name': 'woman', 'frequency': 'f'}]
def color_mask_to_label(mask, v_pallete):
mask_array = np.array(mask).astype('int32')
semantic_map = []
for colour in v_pallete:
equality = np.equal(mask_array, colour)
class_map = np.all(equality, axis=-1)
semantic_map.append(class_map)
semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
label = np.argmax(semantic_map, axis=-1)
return label
def load_color_mask_in_PIL_to_Tensor(path, v_pallete, size, mode='RGB'):
color_mask_PIL = Image.open(path).convert(mode)
color_mask_PIL = color_mask_PIL.resize(size, Image.NEAREST)
# obtain semantic label
color_label = color_mask_to_label(color_mask_PIL, v_pallete)
color_label = torch.from_numpy(color_label) # [H, W]
color_label = color_label.unsqueeze(0)
return color_label # both [1, H, W]
class AVSBenchTest(Dataset):
def __init__(self, test_img_dir="./datasets/AVSBench-openvoc/test/JPEGImages/",
test_mask_dir="./datasets/AVSBench-semantic/",
pallete_dir="./datasets/AVSBench-semantic/label2idx.json",
avsbench_csv="./datasets/AVSBench-semantic/metadata.csv"):
super(AVSBenchTest, self).__init__()
self.test_img_dir = test_img_dir
self.test_mask_dir = test_mask_dir
self.pallete_dir = pallete_dir
self.avsbench_csv = avsbench_csv
self.frame_num = 5
self.videos = [v for v in os.listdir(test_img_dir)]
print("{} videos are used for test.".format(len(self.videos)))
self.v2_pallete = get_v2_pallete(pallete_dir, num_cls=71)
def __getitem__(self, index):
video = self.videos[index]
with open(self.avsbench_csv, "r") as csvfile:
csvreader = csv.reader(csvfile)
for info in csvreader:
if info[1] == video:
video_class = info[4]
video_label = info[6]
video_class_id = [0]
for d in video_class.split("_"):
for cate in categories:
if d == cate["name"]:
video_class_id.append(cate["id"])
# img:
imgs_pth = []
imgs_names = []
img_pth = os.path.join(self.test_img_dir, video)
for img_name in sorted(os.listdir(img_pth)):
imgs_names.append(img_name)
imgs_pth.append(os.path.join(img_pth, img_name))
# gt:
annos = []
anno_pth = os.path.join(self.test_mask_dir, video_label, video, "labels_rgb")
for mask_image in sorted(glob.glob(anno_pth + "/*.png")):
color_label = np.array(load_color_mask_in_PIL_to_Tensor(mask_image, self.v2_pallete, (224, 224)))
# filter error annotations
color_label_mask = np.isin(color_label, video_class_id)
color_label[~color_label_mask] = 0
annos.append(color_label)
# audios:
audios = os.path.join(self.test_mask_dir, video_label, video, "audio.wav")
return imgs_pth, annos, video, imgs_names, audios
def __len__(self):
return len(self.videos)
if __name__=="__main__":
test_dataset = AVSBenchTest()
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
for n_iter, batch_data in enumerate(test_dataloader):
imgs_pth, annos, video_name, imgs_names, audios = batch_data
print(imgs_pth)
|