fisheye-experimental / dataloader.py
oskarastrom's picture
Annotation: Frame extraction
193f172
raw
history blame
No virus
19.8 kB
import project_path
import os
import cv2
import numpy as np
import json
from threading import Lock
import struct
from contextlib import contextmanager
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
from PIL import Image
# assumes yolov5 on sys.path
from lib.yolov5.utils.general import xyxy2xywh
from lib.yolov5.utils.augmentations import letterbox
from lib.yolov5.utils.dataloaders import create_dataloader as create_yolo_dataloader
from pyDIDSON import pyDIDSON
from aris import ImageData
# use this flag to test the difference between direct ARIS dataloading and
# using the jpeg compressed version. very slow. not much difference observed.
TEST_JPG_COMPRESSION = False
# # # # # #
# Factory(ish) methods for DataLoader creation. Easy entry points to this module.
# # # # # #
def create_dataloader_aris(aris_filepath, beam_width_dir, annotations_file, batch_size=32, stride=64, pad=0.5, img_size=896, rank=-1, world_size=1, workers=0,
disable_output=False, cache_bg_frames=False):
"""
Get a PyTorch Dataset and DataLoader for ARIS files with (optional) associated fisheye-formatted labels.
"""
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
# this is a no-op for a single-gpu machine
with torch_distributed_zero_first(rank):
dataset = YOLOARISBatchedDataset(aris_filepath, beam_width_dir, annotations_file, stride, pad, img_size,
disable_output=disable_output, cache_bg_frames=cache_bg_frames)
batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
if not disable_output:
print("dataset size", len(dataset))
print("dataset shape", dataset.shape)
print("Num workers", nw)
# sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None # if extending to multi-GPU inference, will need this
dataloader = torch.utils.data.dataloader.DataLoader(dataset,
batch_size=None,
sampler=OnePerBatchSampler(data_source=dataset, batch_size=batch_size),
num_workers=nw,
pin_memory=True,
collate_fn=collate_fn)
return dataloader, dataset
def create_dataloader_frames(frames_path, batch_size=32, model_stride_max=32,
pad=0.5, img_size=896, rank=-1, world_size=1, workers=0, disable_output=False):
"""
Create a DataLoader for a directory of frames without labels.
Args:
model_stride_max: use model.stride.max()
"""
gs = max(int(model_stride_max), 32) # grid size (max stride)
return create_yolo_dataloader(frames_path, img_size, batch_size, gs, single_cls=False, augment=False,
hyp=None, cache=None, rect=True, rank=rank,
workers=workers, pad=pad)[0]
def create_dataloader_frames_only(frames_path, batch_size=32, img_size=896, workers=0):
"""
Create a DataLoader for a directory of frames without labels.
Args:
model_stride_max: use model.stride.max()
"""
return YOLOFrameDataset(frames_path, img_size=img_size, batch_size=batch_size)
# # # # # #
# End factory(ish) methods
# # # # # #
import os
import pandas as pd
from torchvision.io import read_image
import re
class YOLOFrameDataset(Dataset):
def __init__(self, img_dir, img_size=896, batch_size=32, stride=64, pad=0.5):
self.img_dir = img_dir
self.img_size = img_size
self.files = os.listdir(img_dir)
self.files = list(filter(lambda f: f[-4:] == ".jpg", self.files))
self.files.sort(key=lambda f: int(re.sub('\D', '', f)))
temp_img = read_image(os.path.join(self.img_dir, self.files[0]))
size = temp_img.shape
self.ydim = size[1]
self.xdim = size[2]
n = len(self.files)
aspect_ratio = self.ydim / self.xdim
if aspect_ratio < 1:
shape = [aspect_ratio, 1]
elif aspect_ratio > 1:
shape = [1, 1 / aspect_ratio]
self.original_shape = (self.ydim, self.xdim)
self.shape = np.ceil(np.array(shape) * img_size / stride + pad).astype(int) * stride
self.batch_indices = []
for i in range(0,n,batch_size):
self.batch_indices.append((i, min(n, i+batch_size)))
@classmethod
def load_image(cls, img, img_size=896):
"""Loads and resizes 1 image from dataset, returns img, original hw, resized hw.
Modified from ScaledYOLOv4.datasets.load_image()
"""
h0, w0 = img.shape[:2]
h1, w1 = h0, w0
r = img_size / max(h0, w0)
if r != 1: # always resize down, only resize up if training with augmentation
interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
h1, w1 = img.shape[:2]
return img, (h0, w0), (h1, w1) # img, hw_original, hw_resized
def __len__(self):
return len(self.batches)
def __iter__(self):
for batch_idx in self.batch_indices:
batch = []
labels = None
shapes = []
for i in range(batch_idx[0], batch_idx[1]):
img_name = self.files[i]
img_path = os.path.join(self.img_dir, img_name)
image = Image.open(img_path)
image = np.asarray(image)
img, (h0, w0), (h, w) = self.load_image(image, img_size=self.img_size)
# Letterbox
img, ratio, pad = letterbox(img, self.shape, auto=False, scaleup=False)
shape = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
img = img.transpose(2, 0, 1) # to -> C x H x W
img = np.ascontiguousarray(img)
img = torch.from_numpy(img)
shapes.append(shape)
batch.append(img)
image = torch.stack(batch)
yield (image, labels, shapes)
class ARISBatchedDataset(Dataset):
def __init__(self, aris_filepath, beam_width_dir, annotations_file, batch_size, num_frames_bg_subtract=1000, disable_output=False,
cache_bg_frames=False):
"""
A PyTorch Dataset class for loading an ARIS file and (optional) associated fisheye-format labels.
This class handles the ARIS frame extraction and 3-channel representation generation.
It is called a "BatchedDataset" because it loads contiguous frames in self.batch_size chunks.
** The PyTorch sampler must be aware of this!! ** Use the OnePerBatchSampler in this module when using this Dataset.
Args:
cache_bg_frames: keep the frames used for bg subtraction stored in memory. careful of memory issues. only recommended
for small values of num_frames_bg_subtract
"""
# open ARIS data stream - TODO: make sure this is one per worker
self.data = open(aris_filepath, 'rb')
self.data_lock = Lock()
self.beam_width_dir = beam_width_dir
self.disable_output = disable_output
self.aris_filepath = aris_filepath
self.cache_bg_frames = cache_bg_frames
# get header info
self.didson = pyDIDSON(self.aris_filepath, beam_width_dir=beam_width_dir)
self.xdim = self.didson.info['xdim']
self.ydim = self.didson.info['ydim']
# disable automatic batching - do it ourselves, reading batch_size frames from
# the ARIS file at a time
self.batch_size = batch_size
# load fisheye annotations
if annotations_file is None:
if not self.disable_output:
print("Loading file with no labels.")
self.start_frame = self.didson.info['startframe']
self.end_frame = self.didson.info['endframe'] or self.didson.info['numframes']
self.labels = None
else:
self._load_labels(annotations_file)
# intiialize the background subtraction
self.num_frames_bg_subtract = num_frames_bg_subtract
self._init_bg_frame()
def _init_bg_frame(self):
"""
Intialize bg frame for bg subtraction.
Uses min(self.num_frames_bg_subtract, total_frames) frames to do mean subtraction.
Caches these frames in self.extracted_frames for reuse.
"""
# ensure the number of frames used is a multiple of self.batch_size so we can cache them and retrieve full batches
# add 1 extra frame to be used for optical flow calculation
num_frames_bg = min(self.end_frame - self.start_frame, self.num_frames_bg_subtract // self.batch_size * self.batch_size + 1)
if not self.disable_output:
print("Initializing mean frame for background subtraction using", num_frames_bg, "frames...")
frames_for_bg_subtract = self.didson.load_frames(start_frame=self.start_frame, end_frame=self.start_frame + num_frames_bg)
### NEW WAY ###
# save memory (and time?) by computing these in a streaming fashion vs. in a big batch
self.mean_blurred_frame = np.zeros([self.ydim, self.xdim], dtype=np.float32)
max_blurred_frame = np.zeros([self.ydim, self.xdim], dtype=np.float32)
for i in range(frames_for_bg_subtract.shape[0]):
blurred = cv2.GaussianBlur(
frames_for_bg_subtract[i],
(5,5),
0)
self.mean_blurred_frame += blurred
max_blurred_frame = np.maximum(max_blurred_frame, np.abs(blurred))
self.mean_blurred_frame /= frames_for_bg_subtract.shape[0]
max_blurred_frame -= self.mean_blurred_frame
self.mean_normalization_value = np.max(max_blurred_frame)
# cache these for later
self.extracted_frames = []
# Because of the optical flow computation, we only go to end_frame - 1
next_blur = None
for i in range(len(frames_for_bg_subtract) - 1):
if next_blur is None:
this_blur = ((cv2.GaussianBlur(frames_for_bg_subtract[i], (5,5), 0) - self.mean_blurred_frame) / self.mean_normalization_value + 1) / 2
else:
this_blur = next_blur
next_blur = ((cv2.GaussianBlur(frames_for_bg_subtract[i+1], (5,5), 0) - self.mean_blurred_frame) / self.mean_normalization_value + 1) / 2
frame_image = np.dstack([frames_for_bg_subtract[i],
this_blur * 255,
np.abs(next_blur - this_blur) * 255]).astype(np.uint8, copy=False)
if TEST_JPG_COMPRESSION:
from PIL import Image
import os
Image.fromarray(frame_image).save(f"tmp{i}.jpg", quality=95)
frame_image = cv2.imread(f"tmp{i}.jpg")[:, :, ::-1] # BGR to RGB
os.remove(f"tmp{i}.jpg")
if self.cache_bg_frames:
self.extracted_frames.append(frame_image)
if not self.disable_output:
print("Done initializing background frame.")
def _load_labels(self, fisheye_json):
"""Load labels from a fisheye-formatted json file into self.labels in normalized
xywh format.
"""
js = json.load(open(fisheye_json, 'r'))
labels = []
for frame in js['frames']:
l = []
for fish in frame['fish']:
x, y, w, h = xyxy2xywh(fish['bbox'])
cx = x + w/2.0
cy = y + h/2.0
# Each row is `class x_center y_center width height` format. (Normalized)
l.append([0, cx, cy, w, h])
l = np.array(l, dtype=np.float32)
if len(l) == 0:
l = np.zeros((0, 5), dtype=np.float32)
labels.append(l)
self.labels = labels
self.start_frame = js['start_frame']
self.end_frame = js['end_frame']
def __len__(self):
# account for optical flow - we can't do the last frame
return self.end_frame - self.start_frame - 1
def _postprocess(self, frame_images, frame_labels):
raise NotImplementedError
def __getitem__(self, idx):
"""
Return a numpy array representing this batch of frames and labels according to pyARIS frame extraction logic.
This class returns a full batch rather than just 1 example, assuming a OnePerBatchSampler is used.
"""
final_idx = min(idx+self.batch_size, len(self))
frame_labels = self.labels[idx:final_idx] if self.labels else None
# see if we have already cached this from bg subtraction
# assumes len(self.extracted_frames) is a multiple of self.batch_size
if idx+1 < len(self.extracted_frames):
return self._postprocess(self.extracted_frames[idx:final_idx], frame_labels)
else:
frames = self.didson.load_frames(start_frame=self.start_frame+idx, end_frame=self.start_frame + final_idx + 1)
blurred_frames = frames.astype(np.float32)
for i in range(frames.shape[0]):
blurred_frames[i] = cv2.GaussianBlur(
blurred_frames[i],
(5,5),
0
)
blurred_frames -= self.mean_blurred_frame
blurred_frames /= self.mean_normalization_value
blurred_frames += 1
blurred_frames /= 2
frame_images = np.stack([ frames[:-1], blurred_frames[:-1] * 255, np.abs(blurred_frames[1:] - blurred_frames[:-1]) * 255 ], axis=-1).astype(np.uint8, copy=False)
if TEST_JPG_COMPRESSION:
from PIL import Image
import os
new_frame_images = []
for image in frame_images:
Image.fromarray(image).save(f"tmp{idx}.jpg", quality=95)
image = cv2.imread(f"tmp{idx}.jpg")[:, :, ::-1] # BGR to RGB
os.remove(f"tmp{idx}.jpg")
new_frame_images.append(image)
frame_images = new_frame_images
return self._postprocess(frame_images, frame_labels)
class YOLOARISBatchedDataset(ARISBatchedDataset):
"""An ARIS Dataset that works with YOLOv5 inference."""
def __init__(self, aris_filepath, beam_width_dir, annotations_file, stride=64, pad=0.5, img_size=896, batch_size=32,
disable_output=False, cache_bg_frames=False):
super().__init__(aris_filepath, beam_width_dir, annotations_file, batch_size, disable_output=disable_output, cache_bg_frames=cache_bg_frames)
# compute shapes for letterbox
aspect_ratio = self.ydim / self.xdim
if aspect_ratio < 1:
shape = [aspect_ratio, 1]
elif aspect_ratio > 1:
shape = [1, 1 / aspect_ratio]
self.original_shape = (self.ydim, self.xdim)
self.shape = np.ceil(np.array(shape) * img_size / stride + pad).astype(int) * stride
@classmethod
def load_image(cls, img, img_size=896):
"""Loads and resizes 1 image from dataset, returns img, original hw, resized hw.
Modified from ScaledYOLOv4.datasets.load_image()
"""
h0, w0 = img.shape[:2] # orig hw
r = img_size / max(h0, w0) # resize image to img_size
if r != 1: # always resize down, only resize up if training with augmentation
interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
def _postprocess(self, frame_images, frame_labels):
"""
Return a batch of data in the format used by ScaledYOLOv4.
That is, a list of tuples, on tuple per image in the batch:
[
(img ->torch.Tensor,
labels ->torch.Tensor,
shapes ->tuple describing image original dimensions and scaled/padded dimensions
),
...
]
"""
outputs = []
frame_labels = frame_labels or [ None for _ in frame_images ]
for image, x in zip(frame_images, frame_labels):
img, (h0, w0), (h, w) = self.load_image(image)
# Letterbox
img, ratio, pad = letterbox(img, self.shape, auto=False, scaleup=False)
shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
img = img.transpose(2, 0, 1) # to -> C x H x W
img = np.ascontiguousarray(img)
# Load labels
# Convert from normalized xywh to pixel xyxy format in order to add padding from letterbox
labels = []
if x is not None and x.size > 0:
labels = x.copy()
labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
# convert back to normalized xywh with padding
nL = len(labels) # number of labels
labels_out = torch.zeros((nL, 6))
if nL:
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
labels[:, [2, 4]] /= img.shape[1] # normalized height 0-1
labels[:, [1, 3]] /= img.shape[2] # normalized width 0-1
labels_out[:, 1:] = torch.from_numpy(labels)
outputs.append( (torch.from_numpy(img), labels_out, shapes) )
return outputs
@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for each local_master to do something.
"""
if local_rank not in [-1, 0]:
torch.distributed.barrier()
yield
if local_rank == 0:
torch.distributed.barrier()
class OnePerBatchSampler(torch.utils.data.Sampler):
"""Yields the first index of each batch, given a batch size.
In other words, returns multiples of self.batch_size up to the size of the Dataset.
This is a workaround for Pytorch's standard batch creation that allows us to manually
select contiguous segments of an ARIS clip for each batch.
"""
def __init__(self, data_source, batch_size):
self.data_source = data_source
self.batch_size = batch_size
def __iter__(self):
idxs = [i*self.batch_size for i in range(len(self))]
return iter(idxs)
def __len__(self):
return len(self.data_source) // self.batch_size
def collate_fn(batch):
"""See YOLOv5.utils.datasets.collate_fn"""
if not len(batch):
print("help!")
print(batch)
img, label, shapes = zip(*batch) # transposed
for i, l in enumerate(label):
l[:, 0] = i # add target image index for build_targets()
return torch.stack(img, 0), torch.cat(label, 0), shapes