File size: 3,829 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import os
import os.path
import math
from PIL import Image

import random
import numpy as np

import torch
import torchvision
import torch.utils.data as data

import omnilabeltools as olt
from maskrcnn_benchmark.structures.bounding_box import BoxList
# from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
# from maskrcnn_benchmark.structures.keypoint import PersonKeypoints
# from maskrcnn_benchmark.config import cfg
import pdb


def pil_loader(path, retry=5):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    ri = 0
    while ri < retry:
        try:
            with open(path, "rb") as f:
                img = Image.open(f)
                return img.convert("RGB")
        except:
            ri += 1

def load_omnilabel_json(path_json: str, path_imgs: str):
    assert isinstance(path_json, str)

    ol = olt.OmniLabel(path_json)
    dataset_dicts = []
    for img_id in ol.image_ids:
        img_sample = ol.get_image_sample(img_id)
        dataset_dicts.append({
            "image_id": img_sample["id"],
            "file_name": os.path.join(path_imgs, img_sample["file_name"]),
            "inference_obj_descriptions": [od["text"] for od in img_sample["labelspace"]],
            "inference_obj_description_ids": [od["id"] for od in img_sample["labelspace"]],
            "tokens_positive":[od['anno_info'].get("tokens_positive", None) for od in img_sample["labelspace"]],
        })
    return dataset_dicts

class OmniLabelDataset(data.Dataset):
    """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.



    Args:

        img_folder (string): Root directory where images are downloaded to.

        ann_file (string): Path to json annotation file.

        transform (callable, optional): A function/transform that  takes in an PIL image

            and returns a transformed version. E.g, ``transforms.ToTensor``

        target_transform (callable, optional): A function/transform that takes in the

            target and transforms it.

    """

    def __init__(self, img_folder, ann_file, transforms=None, **kwargs):
        self.img_folder = img_folder
        self.transforms = transforms
        self.dataset_dicts = load_omnilabel_json(ann_file, img_folder)

    def __getitem__(self, index):
        """

        Args:

            index (int): Index



        Returns:

            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.

        """
        data_dict = self.dataset_dicts[index]
        img_id = data_dict["image_id"]
        
        path = data_dict["file_name"]
        img = pil_loader(path)

        # only support test. No box here
        target = BoxList(torch.Tensor(0,4), img.size, mode="xywh").convert("xyxy")
        target.add_field("inference_obj_descriptions", data_dict["inference_obj_descriptions"])
        target.add_field("inference_obj_description_ids", data_dict["inference_obj_description_ids"])
        target.add_field("tokens_positive", data_dict["tokens_positive"])

        if self.transforms is not None:
            img = self.transforms(img)

        return img, target, img_id

    def __len__(self):
        return len(self.dataset_dicts)

    def __repr__(self):
        fmt_str = "Dataset " + self.__class__.__name__ + "\n"
        fmt_str += "    Number of datapoints: {}\n".format(self.__len__())
        fmt_str += "    Root Location: {}\n".format(self.img_folder)
        return fmt_str

    # def get_img_info(self, index):
    #     img_id = self.id_to_img_map[index]
    #     img_data = self.coco.imgs[img_id]
    #     return img_data