File size: 18,295 Bytes
7a4b92f
 
 
 
 
 
 
 
 
 
 
5ab0373
7a4b92f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ab0373
 
 
 
 
 
 
 
 
 
7a4b92f
 
 
 
 
5ab0373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a4b92f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
import project_path

import os
import cv2
import numpy as np
import json
from threading import Lock
import struct
from contextlib import contextmanager
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T

# assumes yolov5 on sys.path
from lib.yolov5.utils.general import xyxy2xywh
from lib.yolov5.utils.augmentations import letterbox
from lib.yolov5.utils.dataloaders import create_dataloader as create_yolo_dataloader

from pyDIDSON import pyDIDSON
from aris import ImageData

# use this flag to test the difference between direct ARIS dataloading and
# using the jpeg compressed version. very slow. not much difference observed.
TEST_JPG_COMPRESSION = False


# # # # # #
# Factory(ish) methods for DataLoader creation. Easy entry points to this module.
# # # # # #

def create_dataloader_aris(aris_filepath, beam_width_dir, annotations_file, batch_size=32, stride=64, pad=0.5, img_size=896, rank=-1, world_size=1, workers=0, 
                      disable_output=False, cache_bg_frames=False):
    """
    Get a PyTorch Dataset and DataLoader for ARIS files with (optional) associated fisheye-formatted labels.
    """
    # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
    # this is a no-op for a single-gpu machine
    with torch_distributed_zero_first(rank):
        dataset = YOLOARISBatchedDataset(aris_filepath, beam_width_dir, annotations_file, stride, pad, img_size, 
                                         disable_output=disable_output, cache_bg_frames=cache_bg_frames)

    
    batch_size = min(batch_size, len(dataset))
    nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers])  # number of workers
    
    if not disable_output:
        print("dataset size", len(dataset))
        print("dataset shape", dataset.shape)
        print("Num workers", nw)
#     sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None # if extending to multi-GPU inference, will need this
    dataloader = torch.utils.data.dataloader.DataLoader(dataset, 
                                                        batch_size=None,
                                                        sampler=OnePerBatchSampler(data_source=dataset, batch_size=batch_size),
                                                        num_workers=nw,
                                                        pin_memory=True,
                                                        collate_fn=collate_fn)
    return dataloader, dataset

def create_dataloader_frames(frames_path, batch_size=32, model_stride_max=32, 
                             pad=0.5, img_size=896, rank=-1, world_size=1, workers=0, disable_output=False):
    """
    Create a DataLoader for a directory of frames without labels.
    
    Args:
        model_stride_max: use model.stride.max()
    """
    
    gs = max(int(model_stride_max), 32)  # grid size (max stride)
    return create_yolo_dataloader(frames_path, img_size, batch_size, gs, single_cls=False, augment=False,
                                       hyp=None, cache=None, rect=True, rank=rank,
                                       workers=workers, pad=pad)[0]

def create_dataloader_frames_only(frames_path, batch_size=32, img_size=896, workers=0):
    """
    Create a DataLoader for a directory of frames without labels.
    
    Args:
        model_stride_max: use model.stride.max()
    """
    
    return YOLOFrameDataset(frames_path, img_size=img_size, batch_size=batch_size)

# # # # # #
# End factory(ish) methods
# # # # # #


import os
import pandas as pd
from torchvision.io import read_image
import re

class YOLOFrameDataset(Dataset):
    def __init__(self, img_dir, img_size=896, batch_size=32):
        self.img_dir = img_dir
        self.img_size = img_size

        self.files = os.listdir(img_dir)
        self.files = list(filter(lambda f: f[-4:] == ".jpg", self.files))
        self.files.sort(key=lambda f: int(re.sub('\D', '', f)))

        n = len(self.files)

        self.batches = []
        for i in range(0,n,batch_size): 
            self.batches.append((i, min(n, i+batch_size)))

    def __len__(self):
        return len(self.batches)

    def __iter__(self):
        for batch_idx in self.batches:

            batch = []
            shapes = []
            for i in range(batch_idx[0], batch_idx[1]):
                img_name = self.files[i]
                img_path = os.path.join(self.img_dir, img_name)
                img = read_image(img_path)

                shapes.append([[img.shape[1], img.shape[2]], None])
                ratio = self.img_size / max(img.shape[1], img.shape[2])
                transform = T.Resize([int(ratio*img.shape[1]), int(ratio*img.shape[2])])
                img = transform(img)
                batch.append(img)
            
            image = torch.stack(batch)

            yield (image, None, shapes)

class ARISBatchedDataset(Dataset):
    def __init__(self, aris_filepath, beam_width_dir, annotations_file, batch_size, num_frames_bg_subtract=1000, disable_output=False,
                    cache_bg_frames=False):
        """
        A PyTorch Dataset class for loading an ARIS file and (optional) associated fisheye-format labels. 
        This class handles the ARIS frame extraction and 3-channel representation generation.
        
        It is called a "BatchedDataset" because it loads contiguous frames in self.batch_size chunks. 
        ** The PyTorch sampler must be aware of this!! ** Use the OnePerBatchSampler in this module when using this Dataset.
        
        Args:
            cache_bg_frames: keep the frames used for bg subtraction stored in memory. careful of memory issues. only recommended
                            for small values of num_frames_bg_subtract
        """
        # open ARIS data stream - TODO: make sure this is one per worker
        self.data = open(aris_filepath, 'rb')
        self.data_lock = Lock()
        self.beam_width_dir = beam_width_dir
        self.disable_output = disable_output
        self.aris_filepath = aris_filepath
        self.cache_bg_frames = cache_bg_frames
        
        # get header info
        self.didson = pyDIDSON(self.aris_filepath, beam_width_dir=beam_width_dir)
        self.xdim = self.didson.info['xdim']
        self.ydim = self.didson.info['ydim']

        # disable automatic batching - do it ourselves, reading batch_size frames from
        # the ARIS file at a time
        self.batch_size = batch_size
        
        # load fisheye annotations
        if annotations_file is None:
            if not self.disable_output:
                print("Loading file with no labels.")
            self.start_frame = self.didson.info['startframe']
            self.end_frame = self.didson.info['endframe'] or self.didson.info['numframes']
            self.labels = None
        else:
            self._load_labels(annotations_file)
            
        # intiialize the background subtraction
        self.num_frames_bg_subtract = num_frames_bg_subtract
        self._init_bg_frame()
        
    def _init_bg_frame(self):
        """
        Intialize bg frame for bg subtraction.
        Uses min(self.num_frames_bg_subtract, total_frames) frames to do mean subtraction.
        Caches these frames in self.extracted_frames for reuse.
        """
        # ensure the number of frames used is a multiple of self.batch_size so we can cache them and retrieve full batches
        # add 1 extra frame to be used for optical flow calculation
        num_frames_bg = min(self.end_frame - self.start_frame, self.num_frames_bg_subtract // self.batch_size * self.batch_size + 1)
        
        if not self.disable_output:
            print("Initializing mean frame for background subtraction using", num_frames_bg, "frames...")
        frames_for_bg_subtract = self.didson.load_frames(start_frame=self.start_frame, end_frame=self.start_frame + num_frames_bg)

        ### NEW WAY ###
        # save memory (and time?) by computing these in a streaming fashion vs. in a big batch
        self.mean_blurred_frame = np.zeros([self.ydim, self.xdim], dtype=np.float32)
        max_blurred_frame = np.zeros([self.ydim, self.xdim], dtype=np.float32)
        for i in range(frames_for_bg_subtract.shape[0]):
            blurred = cv2.GaussianBlur(
                frames_for_bg_subtract[i],
                (5,5),
                0)
            self.mean_blurred_frame += blurred
            max_blurred_frame = np.maximum(max_blurred_frame, np.abs(blurred))
        self.mean_blurred_frame /= frames_for_bg_subtract.shape[0]
        max_blurred_frame -= self.mean_blurred_frame
        self.mean_normalization_value = np.max(max_blurred_frame)
        
        # cache these for later
        self.extracted_frames = []
        
        # Because of the optical flow computation, we only go to end_frame - 1
        next_blur = None
        for i in range(len(frames_for_bg_subtract) - 1):
            if next_blur is None:
                this_blur = ((cv2.GaussianBlur(frames_for_bg_subtract[i], (5,5), 0) - self.mean_blurred_frame) / self.mean_normalization_value + 1) / 2
            else:
                this_blur = next_blur
            next_blur = ((cv2.GaussianBlur(frames_for_bg_subtract[i+1], (5,5), 0) - self.mean_blurred_frame) / self.mean_normalization_value + 1) / 2
            frame_image = np.dstack([frames_for_bg_subtract[i], 
                                     this_blur * 255, 
                                     np.abs(next_blur - this_blur) * 255]).astype(np.uint8, copy=False)

            if TEST_JPG_COMPRESSION:
                from PIL import Image
                import os
                Image.fromarray(frame_image).save(f"tmp{i}.jpg", quality=95)
                frame_image = cv2.imread(f"tmp{i}.jpg")[:, :, ::-1] # BGR to RGB
                os.remove(f"tmp{i}.jpg")
                
            if self.cache_bg_frames:
                self.extracted_frames.append(frame_image)
                
        if not self.disable_output:
            print("Done initializing background frame.")
        
    def _load_labels(self, fisheye_json):
        """Load labels from a fisheye-formatted json file into self.labels in normalized
        xywh format.
        """
        js = json.load(open(fisheye_json, 'r'))
        labels = []

        for frame in js['frames']:

            l = []
            for fish in frame['fish']:
                x, y, w, h = xyxy2xywh(fish['bbox'])
                cx = x + w/2.0
                cy = y + h/2.0
                # Each row is `class x_center y_center width height` format. (Normalized)
                l.append([0, cx, cy, w, h])

            l = np.array(l, dtype=np.float32)
            if len(l) == 0:
                l = np.zeros((0, 5), dtype=np.float32)

            labels.append(l)

        self.labels = labels
        self.start_frame = js['start_frame']
        self.end_frame = js['end_frame']

    def __len__(self):
        # account for optical flow - we can't do the last frame
        return self.end_frame - self.start_frame - 1

    def _postprocess(self, frame_images, frame_labels):
        raise NotImplementedError
    
    def __getitem__(self, idx):
        """
        Return a numpy array representing this batch of frames and labels according to pyARIS frame extraction logic.
        This class returns a full batch rather than just 1 example, assuming a OnePerBatchSampler is used.
        """
        final_idx = min(idx+self.batch_size, len(self))
        frame_labels = self.labels[idx:final_idx] if self.labels else None
        
        # see if we have already cached this from bg subtraction
        # assumes len(self.extracted_frames) is a multiple of self.batch_size
        if idx+1 < len(self.extracted_frames):
            return self._postprocess(self.extracted_frames[idx:final_idx], frame_labels)
        else:
            frames = self.didson.load_frames(start_frame=self.start_frame+idx, end_frame=self.start_frame + final_idx + 1)
            blurred_frames = frames.astype(np.float32)
            for i in range(frames.shape[0]):
                blurred_frames[i] = cv2.GaussianBlur(
                    blurred_frames[i],
                    (5,5),
                    0
                )
            blurred_frames -= self.mean_blurred_frame
            blurred_frames /= self.mean_normalization_value
            blurred_frames += 1
            blurred_frames /= 2
            
            frame_images = np.stack([ frames[:-1], blurred_frames[:-1] * 255, np.abs(blurred_frames[1:] - blurred_frames[:-1]) * 255 ], axis=-1).astype(np.uint8, copy=False)
            
            if TEST_JPG_COMPRESSION:
                from PIL import Image
                import os
                new_frame_images = []
                for image in frame_images:
                    Image.fromarray(image).save(f"tmp{idx}.jpg", quality=95)
                    image = cv2.imread(f"tmp{idx}.jpg")[:, :, ::-1] # BGR to RGB
                    os.remove(f"tmp{idx}.jpg")
                    new_frame_images.append(image)
                frame_images = new_frame_images
            
            return self._postprocess(frame_images, frame_labels)
        
class YOLOARISBatchedDataset(ARISBatchedDataset):
    """An ARIS Dataset that works with YOLOv5 inference."""
    
    def __init__(self, aris_filepath, beam_width_dir, annotations_file, stride=64, pad=0.5, img_size=896, batch_size=32, 
                 disable_output=False, cache_bg_frames=False):
        super().__init__(aris_filepath, beam_width_dir, annotations_file, batch_size, disable_output=disable_output, cache_bg_frames=cache_bg_frames)
        
        # compute shapes for letterbox
        aspect_ratio = self.ydim / self.xdim
        if aspect_ratio < 1:
            shape = [aspect_ratio, 1]
        elif aspect_ratio > 1:
            shape = [1, 1 / aspect_ratio]
        self.original_shape = (self.ydim, self.xdim)
        self.shape = np.ceil(np.array(shape) * img_size / stride + pad).astype(int) * stride

    @classmethod
    def load_image(cls, img, img_size=896):
        """Loads and resizes 1 image from dataset, returns img, original hw, resized hw.
        Modified from ScaledYOLOv4.datasets.load_image()
        """
        h0, w0 = img.shape[:2]  # orig hw
        r = img_size / max(h0, w0)  # resize image to img_size
        if r != 1:  # always resize down, only resize up if training with augmentation
            interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
            img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
        return img, (h0, w0), img.shape[:2]  # img, hw_original, hw_resized
        
    def _postprocess(self, frame_images, frame_labels):
        """
        Return a batch of data in the format used by ScaledYOLOv4.
        That is, a list of tuples, on tuple per image in the batch:
            [
                (img ->torch.Tensor,
                labels ->torch.Tensor,
                shapes ->tuple describing image original dimensions and scaled/padded dimensions
                ),
                ...
            ]
        """
        outputs = []
        frame_labels = frame_labels or [ None for _ in frame_images ]
        for image, x in zip(frame_images, frame_labels):
            img, (h0, w0), (h, w) = self.load_image(image)

            # Letterbox
            img, ratio, pad = letterbox(img, self.shape, auto=False, scaleup=False)
            shapes = (h0, w0), ((h / h0, w / w0), pad)  # for COCO mAP rescaling

            img = img.transpose(2, 0, 1) # to -> C x H x W
            img = np.ascontiguousarray(img)

            # Load labels
            # Convert from normalized xywh to pixel xyxy format in order to add padding from letterbox
            labels = []
            if x is not None and x.size > 0:
                labels = x.copy()
                labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0]  # pad width
                labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1]  # pad height
                labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
                labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]

            # convert back to normalized xywh with padding
            nL = len(labels)  # number of labels
            labels_out = torch.zeros((nL, 6))
            if nL:
                labels[:, 1:5] = xyxy2xywh(labels[:, 1:5])  # convert xyxy to xywh
                labels[:, [2, 4]] /= img.shape[1]  # normalized height 0-1
                labels[:, [1, 3]] /= img.shape[2]  # normalized width 0-1
                labels_out[:, 1:] = torch.from_numpy(labels)
            
            outputs.append( (torch.from_numpy(img), labels_out, shapes) )
            
        return outputs
    
@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """
    Decorator to make all processes in distributed training wait for each local_master to do something.
    """
    if local_rank not in [-1, 0]:
        torch.distributed.barrier()
    yield
    if local_rank == 0:
        torch.distributed.barrier()
        
class OnePerBatchSampler(torch.utils.data.Sampler):
    """Yields the first index of each batch, given a batch size.
    In other words, returns multiples of self.batch_size up to the size of the Dataset.
    This is a workaround for Pytorch's standard batch creation that allows us to manually
    select contiguous segments of an ARIS clip for each batch.
    """

    def __init__(self, data_source, batch_size):
        self.data_source = data_source
        self.batch_size = batch_size

    def __iter__(self):
        idxs = [i*self.batch_size for i in range(len(self))]
        return iter(idxs)

    def __len__(self):
        return len(self.data_source) // self.batch_size
    
def collate_fn(batch):
    """See YOLOv5.utils.datasets.collate_fn"""
    if not len(batch):
        print("help!")
        print(batch)
        
    img, label, shapes = zip(*batch) # transposed
    for i, l in enumerate(label):
            l[:, 0] = i  # add target image index for build_targets()
    return torch.stack(img, 0), torch.cat(label, 0), shapes