RamziBm's picture
init
bdb955e
import pickle
import torch
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as T
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import h5py
from tqdm.auto import tqdm
from collections import defaultdict
from pathlib import Path
import json
import os
from argparse import ArgumentParser
lines_classes = [
"Big rect. left bottom",
"Big rect. left main",
"Big rect. left top",
"Big rect. right bottom",
"Big rect. right main",
"Big rect. right top",
"Circle central",
"Circle left",
"Circle right",
# "Goal left crossbar",
# "Goal left post left ",
# "Goal left post right",
# "Goal right crossbar",
# "Goal right post left",
# "Goal right post right",
"Goal unknown",
"Line unknown",
"Middle line",
"Side line bottom",
"Side line left",
"Side line right",
"Side line top",
"Small rect. left bottom",
"Small rect. left main",
"Small rect. left top",
"Small rect. right bottom",
"Small rect. right main",
"Small rect. right top",
]
# RGB values
palette = {
"Big rect. left bottom": (127, 0, 0),
"Big rect. left main": (102, 102, 102),
"Big rect. left top": (0, 0, 127),
"Big rect. right bottom": (86, 32, 39),
"Big rect. right main": (48, 77, 0),
"Big rect. right top": (14, 97, 100),
"Circle central": (0, 0, 255),
"Circle left": (255, 127, 0),
"Circle right": (0, 255, 255),
# "Goal left crossbar": (255, 255, 200),
# "Goal left post left ": (165, 255, 0),
# "Goal left post right": (155, 119, 45),
# "Goal right crossbar": (86, 32, 139),
# "Goal right post left": (196, 120, 153),
# "Goal right post right": (166, 36, 52),
"Goal unknown": (0, 0, 0),
"Line unknown": (0, 0, 0),
"Middle line": (255, 255, 0),
"Side line bottom": (255, 0, 255),
"Side line left": (0, 255, 150),
"Side line right": (0, 230, 0),
"Side line top": (230, 0, 0),
"Small rect. left bottom": (0, 150, 255),
"Small rect. left main": (254, 173, 225),
"Small rect. left top": (87, 72, 39),
"Small rect. right bottom": (122, 0, 255),
"Small rect. right main": (255, 255, 255),
"Small rect. right top": (153, 23, 153),
}
def create_target_from_annotation(width, height, annotation, classes, linewidth=4):
"""Draw one-hot encoded segments according to the annotation.
Creates target that matches image size ([C+1]xHxW).
"""
annotation_abs = defaultdict(list)
# unnormalize every point in every class k
for k in annotation.keys():
if k not in lines_classes:
continue
start = annotation[k][0].copy()
end = annotation[k][-1].copy()
for annotation_point in annotation[k]:
tup = annotation_point.copy()
tup["x"] *= width
tup["x"] = int(tup["x"])
tup["y"] *= height
tup["y"] = int(tup["y"])
annotation_abs[k].append(tup)
# draw lines between annotated points for each segment
# offset class +1 such that no classes detected will end in argmax 0
# otherwise argmax 0 will be another class
classes_segments = np.zeros(shape=(len(classes) + 1, height, width))
for cls, points in annotation_abs.items():
class_segments = np.zeros(shape=(height, width, 3))
for start, end in zip(points, points[1:]):
startxy = (start["x"], start["y"])
endxy = [end["x"], end["y"]]
class_segments = cv2.line(
class_segments, startxy, endxy, (1, 1, 1), linewidth
)
classes_segments[classes.index(cls) + 1] = class_segments[:, :, 1]
classes_segments = torch.Tensor(classes_segments)
return classes_segments
class ExtremitiesDataset(Dataset):
def __init__(
self, root, split, annotations, filter_cam=None, extremities_prefix="", classes=lines_classes, palette=palette
):
self.data_root = Path(root)
self.split = split
self.annotations_path = annotations
if filter_cam is None:
files = os.listdir(self.data_root / self.split)
self.annotations = sorted([fn for fn in files if fn.endswith("json")])
self.images = sorted([fn for fn in files if fn.endswith("jpg")])
else:
df = pd.read_json(self.data_root / self.split / "match_info_cam_gt.json").T
df = df.loc[df.camera == filter_cam]
assert len(df.index) > 0
df["image_file"] = df.index
df = df.sort_values(by=["image_file"])
df["annotation_file"] = df["image_file"].apply(
lambda s: extremities_prefix + s.split(".jpg")[0] + ".json"
)
self.annotations = df["annotation_file"].tolist()
self.images = df["image_file"].tolist()
self.classes = classes
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
# see https://learnopencv.com/pytorch-for-beginners-semantic-segmentation-using-torchvision/
impath = self.data_root / self.split / self.images[idx]
annotation_path = self.annotations_path / self.annotations[idx]
with open(annotation_path, "r") as f:
annotation = json.load(f)
img = Image.open(impath) # .resize((1280, 720))
trf = T.Compose(
[
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# prepare batches
img = trf(img)
# see https://git.tib.eu/vid2pos/sccvsd/-/blob/master/utils/synthetic_util.py
# draw lines (linewidth=4 for 720p) -> hence we rescale first
target = create_target_from_annotation(1280, 720, annotation, self.classes)
target = target.long().argmax(dim=0).unsqueeze(0)
# to binary mask
target = target.bool().float()
# rescale target equivalent to cv2.resize() with default args (interpolation bilinear) -> same as in torchvision
# bilinear -> [0, 0.25, 0.5, 0.7, 1.0] are entries
target = torchvision.transforms.Resize((180, 320))(target)
# print(torch.unique(target))
# to uint8
target = (target * 255.0).to(torch.uint8)
return img, target, impath.name
if __name__ == "__main__":
args = ArgumentParser()
args.add_argument("--data_dir", type=Path)
args.add_argument("--annotations", type=Path)
args.add_argument("--output_dir", type=Path)
args.add_argument("--extremities_prefix", type=str, default="")
args = args.parse_args()
data_dir = args.data_dir.parent
split = args.data_dir.name
output_dir = args.output_dir
if not output_dir.exists():
raise FileNotFoundError
dataset = ExtremitiesDataset(data_dir, split, args.annotations, filter_cam="Main camera center", extremities_prefix=args.extremities_prefix)
# img, edge_map, img_id = dataset[0]
# edge_map = edge_map.squeeze(0).numpy()
# Image.fromarray(edge_map).show()
image_src = []
edge_maps = np.zeros((len(dataset), 1, 180, 320), dtype=np.uint8)
for i, (_, edge_map, img_id) in enumerate(tqdm(dataset)):
edge_map = edge_map.numpy()
# Image.fromarray(edge_map).show()
edge_maps[i] = edge_map
image_src.append(img_id)
with h5py.File(output_dir / "seg_edge_maps.h5", "w") as f:
f.create_dataset("edge_map", data=edge_maps)
with open(output_dir / "seg_image_paths.pkl", "wb") as f:
pickle.dump(image_src, f)