EdgeFM commited on
Commit
882e048
1 Parent(s): cd2ed5b

Upload 1804 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. .gitignore +5 -0
  3. data/README.md +94 -0
  4. data/__init__.py +12 -0
  5. data/__pycache__/__init__.cpython-38.pyc +0 -0
  6. data/__pycache__/dataloader.cpython-38.pyc +0 -0
  7. data/__pycache__/dataset.cpython-38.pyc +0 -0
  8. data/build/__init__.py +0 -0
  9. data/build/__pycache__/__init__.cpython-38.pyc +0 -0
  10. data/build/__pycache__/build.cpython-38.pyc +0 -0
  11. data/build/__pycache__/merge_alias.cpython-38.pyc +0 -0
  12. data/build/__pycache__/scenario.cpython-38.pyc +0 -0
  13. data/build/build.py +495 -0
  14. data/build/merge_alias.py +106 -0
  15. data/build/scenario.py +466 -0
  16. data/build_cl/__pycache__/build.cpython-38.pyc +0 -0
  17. data/build_cl/__pycache__/scenario.cpython-38.pyc +0 -0
  18. data/build_cl/build.py +161 -0
  19. data/build_cl/scenario.py +146 -0
  20. data/convert_all_load_to_single_load.py +56 -0
  21. data/convert_det_dataset_to_cls.py +55 -0
  22. data/convert_seg_dataset_to_cls.py +324 -0
  23. data/convert_seg_dataset_to_det.py +399 -0
  24. data/dataloader.py +131 -0
  25. data/dataset.py +43 -0
  26. data/datasets/__init__.py +11 -0
  27. data/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
  28. data/datasets/__pycache__/ab_dataset.cpython-38.pyc +0 -0
  29. data/datasets/__pycache__/data_aug.cpython-38.pyc +0 -0
  30. data/datasets/__pycache__/dataset_cache.cpython-38.pyc +0 -0
  31. data/datasets/__pycache__/dataset_split.cpython-38.pyc +0 -0
  32. data/datasets/__pycache__/registery.cpython-38.pyc +0 -0
  33. data/datasets/ab_dataset.py +48 -0
  34. data/datasets/action_recognition/__init__.py +4 -0
  35. data/datasets/action_recognition/__pycache__/__init__.cpython-38.pyc +0 -0
  36. data/datasets/action_recognition/__pycache__/common_dataset.cpython-38.pyc +0 -0
  37. data/datasets/action_recognition/__pycache__/hmdb51.cpython-38.pyc +0 -0
  38. data/datasets/action_recognition/__pycache__/ixmas.cpython-38.pyc +0 -0
  39. data/datasets/action_recognition/__pycache__/ucf101.cpython-38.pyc +0 -0
  40. data/datasets/action_recognition/common_dataset.py +152 -0
  41. data/datasets/action_recognition/hmdb51.py +45 -0
  42. data/datasets/action_recognition/ixmas.py +45 -0
  43. data/datasets/action_recognition/kinetics400.py +51 -0
  44. data/datasets/action_recognition/ucf101.py +45 -0
  45. data/datasets/data_aug.py +93 -0
  46. data/datasets/dataset_cache.py +40 -0
  47. data/datasets/dataset_split.py +81 -0
  48. data/datasets/image_classification/__init__.py +24 -0
  49. data/datasets/image_classification/__pycache__/__init__.cpython-38.pyc +0 -0
  50. data/datasets/image_classification/__pycache__/baidu_person_cls.cpython-38.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/datasets/visual_question_answering/generate_c_image/imagenet_c/frost/frost1.png filter=lfs diff=lfs merge=lfs -text
37
+ data/datasets/visual_question_answering/generate_c_image/robustness-master/assets/spatter.gif filter=lfs diff=lfs merge=lfs -text
38
+ data/datasets/visual_question_answering/generate_c_image/robustness-master/assets/tilt.gif filter=lfs diff=lfs merge=lfs -text
39
+ data/datasets/visual_question_answering/generate_c_image/robustness-master/assets/translate.gif filter=lfs diff=lfs merge=lfs -text
40
+ data/datasets/visual_question_answering/generate_c_image/robustness-master/ImageNet-C/create_c/frost1.png filter=lfs diff=lfs merge=lfs -text
41
+ data/datasets/visual_question_answering/generate_c_image/robustness-master/ImageNet-C/imagenet_c/imagenet_c/frost/frost1.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ results
2
+ logs
3
+ entry_model
4
+ __pycache__
5
+ backup_codes
data/README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## How to implement a dataset?
2
+
3
+ For example, we want to implement a image classification dataset.
4
+
5
+ 1. create a file in corresponding directory, i.e. `benchmark/data/datasets/image_classification`
6
+
7
+ 2. create a class (inherited from `benchmark.data.datasets.ab_dataset.ABDataset`), e.g. `class YourDataset(ABDataset)`
8
+
9
+ 3. register your dataset with `benchmark.data.datasets.registry.dataset_register(name, classes, classes_aliases)`, which represents the name of your dataset, the classes of your dataset, and the possible aliases of the classes. Examples refer to `benchmark/data/datasets/image_classification/cifar10.py` or other files.
10
+
11
+ Note that the order of `classes` must match the indexes. For example, `classes` of MNIST must be `['0', '1', '2', ..., '9']`, which means 0-th class is '0', 1-st class is '1', 2-nd class is '2', ...; `['1', '2', '0', ...]` is not correct because 0-th class is not '1' and 1-st class is not '2'.
12
+
13
+ How to get `classes` of a dataset? For PyTorch built-in dataset (CIFAR10, MNIST, ...) and general dataset build by `ImageFolder`, you can initialize it (e.g. `dataset = CIFAR10(...)`) and get its classes by `dataset.classes`.
14
+
15
+ ```python
16
+ # How to get classes in CIFAR10?
17
+ from torchvision.datasets import CIFAR10
18
+ dataset = CIFAR10(...)
19
+ print(dataset.classes)
20
+ # copy this output to @dataset_register(classes=<what you copied>)
21
+
22
+ # it's not recommended to dynamically get classes, e.g.:
23
+ # this works but runs slowly!
24
+ from torchvision.datasets import CIFAR10 as RawCIFAR10
25
+ dataset = RawCIFAR10(...)
26
+
27
+ @dataset_register(
28
+ name='CIFAR10',
29
+ classes=dataset.classes
30
+ )
31
+ class CIFAR10(ABDataset):
32
+ # ...
33
+ ```
34
+
35
+ For object detection dataset, you can read the annotation JSON file and find `categories` information in it.
36
+
37
+ 4. implement abstract function `create_dataset(self, root_dir: str, split: str, transform: Optional[Compose], classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]])`.
38
+
39
+ Arguments:
40
+
41
+ - `root_dir`: the location of data
42
+ - `split`: `train / val / test`
43
+ - `transform`: preprocess function in `torchvision.transforms`
44
+ - `classes`: the same value with `dataset_register.classes`
45
+ - `ignore_classes`: **classes should be discarded. You should remove images which belong to these ignore classes.**
46
+ - `idx_map`: **map the original class index to new class index. For example, `{0: 2}` means the index of 0-th class will be 2 instead of 0. You should implement this by modifying the stored labels in the original dataset. **
47
+
48
+ You should do five things in this function:
49
+
50
+ 1. if no user-defined transform is passed, you should implemented the default transform
51
+ 2. create the original dataset
52
+ 3. remove ignored classes in the original dataset if there are ignored classes
53
+ 4. map the original class index to new class index if there is index map
54
+ 5. split the original dataset to train / val / test dataset. If there's no val dataset in original dataset (e.g. DomainNetReal), you should split the original dataset to train / val / test dataset. If there's already val dataset in original dataset (e.g. CIFAR10 and ImageNet), regard the original val dataset as test dataset, and split the original train dataset into train / val dataset. Details just refer to existed files.
55
+
56
+ Example (`benchmark/data/datasets/image_classification/cifar10.py`):
57
+
58
+ ```python
59
+ @dataset_register(
60
+ name='CIFAR10',
61
+ # means in the original CIFAR10, 0-th class is airplane, 1-st class is automobile, ...
62
+ classes=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
63
+ # means 'automobile' and 'car' are the same thing actually
64
+ class_aliases=[['automobile', 'car']]
65
+ )
66
+ class CIFAR10(ABDataset):
67
+ def create_dataset(self, root_dir: str, split: str, transform: Optional[Compose],
68
+ classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
69
+ # 1. if no user-defined transform is passed, you should implemented the default transform
70
+ if transform is None:
71
+ transform = cifar_like_image_train_aug() if split == 'train' else cifar_like_image_test_aug()
72
+ # 2. create the original dataset
73
+ dataset = RawCIFAR10(root_dir, split != 'test', transform=transform, download=True)
74
+
75
+ # 3. remove ignored classes in the original dataset if there are ignored classes
76
+ dataset.targets = np.asarray(dataset.targets)
77
+ if len(ignore_classes) > 0:
78
+ for ignore_class in ignore_classes:
79
+ dataset.data = dataset.data[dataset.targets != classes.index(ignore_class)]
80
+ dataset.targets = dataset.targets[dataset.targets != classes.index(ignore_class)]
81
+
82
+ # 4. map the original class index to new class index if there is index map
83
+ if idx_map is not None:
84
+ for ti, t in enumerate(dataset.targets):
85
+ dataset.targets[ti] = idx_map[t]
86
+
87
+ # 5. split the original dataset to train / val / test dataset.
88
+ # there is not val dataset in CIFAR10 dataset, so we split the val dataset from the train dataset.
89
+ if split != 'test':
90
+ dataset = train_val_split(dataset, split)
91
+ return dataset
92
+ ```
93
+
94
+ After implementing a new dataset, you can create a test file in `example` and load the dataset by `benchmark.data.dataset.get_dataset()`. Try using this dataset to ensure it works. (Example: `example/1.py`)
data/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .dataset import get_dataset
2
+ from .build.build import build_scenario_manually_v2 as build_scenario
3
+ from .dataloader import build_dataloader
4
+ from .build.scenario import IndexReturnedDataset, MergedDataset
5
+ from .datasets.ab_dataset import ABDataset
6
+ from .build.scenario import Scenario
7
+
8
+ from .build_cl.build import build_cl_scenario
9
+ from .build_cl.scenario import Scenario as CLScenario
10
+
11
+
12
+ from .datasets.dataset_split import split_dataset
data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (682 Bytes). View file
 
data/__pycache__/dataloader.cpython-38.pyc ADDED
Binary file (3.53 kB). View file
 
data/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (1.29 kB). View file
 
data/build/__init__.py ADDED
File without changes
data/build/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (188 Bytes). View file
 
data/build/__pycache__/build.cpython-38.pyc ADDED
Binary file (9.12 kB). View file
 
data/build/__pycache__/merge_alias.cpython-38.pyc ADDED
Binary file (2.55 kB). View file
 
data/build/__pycache__/scenario.cpython-38.pyc ADDED
Binary file (10.7 kB). View file
 
data/build/build.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Type, Union
2
+ from ..datasets.ab_dataset import ABDataset
3
+ # from benchmark.data.visualize import visualize_classes_in_object_detection
4
+ # from benchmark.scenario.val_domain_shift import get_val_domain_shift_transform
5
+ from ..dataset import get_dataset
6
+ import copy
7
+ from torchvision.transforms import Compose
8
+
9
+ from .merge_alias import merge_the_same_meaning_classes
10
+ from ..datasets.registery import static_dataset_registery
11
+
12
+
13
+ # some legacy aliases of variables:
14
+ # ignore_classes == discarded classes
15
+ # private_classes == unknown classes in partial / open-set / universal DA
16
+
17
+
18
+ def _merge_the_same_meaning_classes(classes_info_of_all_datasets):
19
+ final_classes_of_all_datasets, rename_map = merge_the_same_meaning_classes(classes_info_of_all_datasets)
20
+ return final_classes_of_all_datasets, rename_map
21
+
22
+
23
+ def _find_ignore_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode):
24
+ thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode]
25
+
26
+ from functools import reduce
27
+ a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set())
28
+
29
+ if set(a_classes) == set(b_classes):
30
+ # a is equal to b, normal
31
+ # 1. no ignore classes; 2. match class idx
32
+ a_ignore_classes, b_ignore_classes = [], []
33
+
34
+ elif set(a_classes) > set(b_classes):
35
+ # a contains b, partial
36
+ a_ignore_classes, b_ignore_classes = [], []
37
+ if thres == 3 or thres == 1: # ignore extra classes in a
38
+ a_ignore_classes = set(a_classes) - set(b_classes)
39
+
40
+ elif set(a_classes) < set(b_classes):
41
+ # a is contained by b, open set
42
+ a_ignore_classes, b_ignore_classes = [], []
43
+ if thres == 3 or thres == 2: # ignore extra classes in b
44
+ b_ignore_classes = set(b_classes) - set(a_classes)
45
+
46
+ elif len(set(a_classes) & set(b_classes)) > 0:
47
+ a_ignore_classes, b_ignore_classes = [], []
48
+ if thres == 3:
49
+ a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes))
50
+ b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes))
51
+ elif thres == 2:
52
+ b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes))
53
+ elif thres == 1:
54
+ a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes))
55
+
56
+ else:
57
+ return None # a has no intersection with b, none
58
+
59
+ as_ignore_classes = [list(set(a_classes) & set(a_ignore_classes)) for a_classes in as_classes]
60
+
61
+ return as_ignore_classes, list(b_ignore_classes)
62
+
63
+
64
+ def _find_private_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode):
65
+ thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode]
66
+
67
+ from functools import reduce
68
+ a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set())
69
+
70
+ if set(a_classes) == set(b_classes):
71
+ # a is equal to b, normal
72
+ # 1. no ignore classes; 2. match class idx
73
+ a_private_classes, b_private_classes = [], []
74
+
75
+ elif set(a_classes) > set(b_classes):
76
+ # a contains b, partial
77
+ a_private_classes, b_private_classes = [], []
78
+ # if thres == 2 or thres == 0: # ignore extra classes in a
79
+ # a_private_classes = set(a_classes) - set(b_classes)
80
+ # if thres == 0: # ignore extra classes in a
81
+ # a_private_classes = set(a_classes) - set(b_classes)
82
+
83
+ elif set(a_classes) < set(b_classes):
84
+ # a is contained by b, open set
85
+ a_private_classes, b_private_classes = [], []
86
+ if thres == 1 or thres == 0: # ignore extra classes in b
87
+ b_private_classes = set(b_classes) - set(a_classes)
88
+
89
+ elif len(set(a_classes) & set(b_classes)) > 0:
90
+ a_private_classes, b_private_classes = [], []
91
+ if thres == 0:
92
+ # a_private_classes = set(a_classes) - (set(a_classes) & set(b_classes))
93
+
94
+ b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes))
95
+ elif thres == 1:
96
+ b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes))
97
+ elif thres == 2:
98
+ # a_private_classes = set(a_classes) - (set(a_classes) & set(b_classes))
99
+ pass
100
+
101
+ else:
102
+ return None # a has no intersection with b, none
103
+
104
+ return list(b_private_classes)
105
+
106
+
107
+ class _ABDatasetMetaInfo:
108
+ def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type):
109
+ self.name = name
110
+ self.classes = classes
111
+ self.class_aliases = class_aliases
112
+ self.shift_type = shift_type
113
+ self.task_type = task_type
114
+ self.object_type = object_type
115
+
116
+
117
+ def _get_dist_shift_type_when_source_a_to_target_b(a: _ABDatasetMetaInfo, b: _ABDatasetMetaInfo):
118
+ if b.shift_type is None:
119
+ return 'Dataset Shifts'
120
+
121
+ if a.name in b.shift_type.keys():
122
+ return b.shift_type[a.name]
123
+
124
+ mid_dataset_name = list(b.shift_type.keys())[0]
125
+ mid_dataset_meta_info = _ABDatasetMetaInfo(mid_dataset_name, *static_dataset_registery[mid_dataset_name][1:])
126
+
127
+ return _get_dist_shift_type_when_source_a_to_target_b(a, mid_dataset_meta_info) + ' + ' + list(b.shift_type.values())[0]
128
+
129
+
130
+ def _handle_all_datasets_v2(source_datasets: List[_ABDatasetMetaInfo], target_datasets: List[_ABDatasetMetaInfo], da_mode):
131
+
132
+ # 1. merge the same meaning classes
133
+ classes_info_of_all_datasets = {
134
+ d.name: (d.classes, d.class_aliases)
135
+ for d in source_datasets + target_datasets
136
+ }
137
+ final_classes_of_all_datasets, rename_map = _merge_the_same_meaning_classes(classes_info_of_all_datasets)
138
+ all_datasets_classes = copy.deepcopy(final_classes_of_all_datasets)
139
+
140
+ # print(all_datasets_known_classes)
141
+
142
+ # 2. find ignored classes according to DA mode
143
+ # source_datasets_ignore_classes, target_datasets_ignore_classes = {d.name: [] for d in source_datasets}, \
144
+ # {d.name: [] for d in target_datasets}
145
+ # source_datasets_private_classes, target_datasets_private_classes = {d.name: [] for d in source_datasets}, \
146
+ # {d.name: [] for d in target_datasets}
147
+ target_source_relationship_map = {td.name: {} for td in target_datasets}
148
+ # source_target_relationship_map = {sd.name: [] for sd in source_datasets}
149
+
150
+ # 1. construct target_source_relationship_map
151
+ for sd in source_datasets:#sd和td使列表中每一个元素(类)的实例
152
+ for td in target_datasets:
153
+ sc = all_datasets_classes[sd.name]
154
+ tc = all_datasets_classes[td.name]
155
+
156
+ if len(set(sc) & set(tc)) == 0:#只保留有相似类别的源域和目标域
157
+ continue
158
+
159
+ target_source_relationship_map[td.name][sd.name] = _get_dist_shift_type_when_source_a_to_target_b(sd, td)
160
+
161
+ # print(target_source_relationship_map)
162
+ # exit()
163
+
164
+ source_datasets_ignore_classes = {}
165
+ for td_name, v1 in target_source_relationship_map.items():
166
+ for sd_name, v2 in v1.items():
167
+ source_datasets_ignore_classes[sd_name + '|' + td_name] = []
168
+ target_datasets_ignore_classes = {d.name: [] for d in target_datasets}
169
+ target_datasets_private_classes = {d.name: [] for d in target_datasets}
170
+ # 保证对于每个目标域上的DA都符合给定的label shift
171
+ # 所以不同目标域就算对应同一个源域,该源域也可能不相同
172
+
173
+ for td_name, v1 in target_source_relationship_map.items():
174
+ sd_names = list(v1.keys())
175
+
176
+ sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names]
177
+ td_classes = all_datasets_classes[td_name]
178
+ ss_ignore_classes, t_ignore_classes = _find_ignore_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode)#根据DA方式不同产生ignore_classes
179
+ t_private_classes = _find_private_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode)
180
+
181
+ for sd_name, s_ignore_classes in zip(sd_names, ss_ignore_classes):
182
+ source_datasets_ignore_classes[sd_name + '|' + td_name] = s_ignore_classes
183
+ target_datasets_ignore_classes[td_name] = t_ignore_classes
184
+ target_datasets_private_classes[td_name] = t_private_classes
185
+
186
+ source_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in source_datasets_ignore_classes.items()}
187
+ target_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_ignore_classes.items()}
188
+ target_datasets_private_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_private_classes.items()}
189
+
190
+ # for k, v in source_datasets_ignore_classes.items():
191
+ # print(k, len(v))
192
+ # print()
193
+ # for k, v in target_datasets_ignore_classes.items():
194
+ # print(k, len(v))
195
+ # print()
196
+ # for k, v in target_datasets_private_classes.items():
197
+ # print(k, len(v))
198
+ # print()
199
+
200
+ # print(source_datasets_private_classes, target_datasets_private_classes)
201
+ # 3. reparse classes idx
202
+ # 3.1. agg all used classes
203
+ # all_used_classes = []
204
+ # all_datasets_private_class_idx_map = {}
205
+
206
+ # source_datasets_classes_idx_map = {}
207
+ # for td_name, v1 in target_source_relationship_map.items():
208
+ # for sd_name, v2 in v1.items():
209
+ # source_datasets_classes_idx_map[sd_name + '|' + td_name] = []
210
+ # target_datasets_classes_idx_map = {}
211
+
212
+ global_idx = 0
213
+ all_used_classes_idx_map = {}
214
+ # all_datasets_known_classes = {d: [] for d in final_classes_of_all_datasets.keys()}
215
+ for dataset_name, classes in all_datasets_classes.items():
216
+ if dataset_name not in target_datasets_ignore_classes.keys():
217
+ ignore_classes = [0] * 100000
218
+ for sn, sic in source_datasets_ignore_classes.items():
219
+ if sn.startswith(dataset_name):
220
+ if len(sic) < len(ignore_classes):
221
+ ignore_classes = sic
222
+ else:
223
+ ignore_classes = target_datasets_ignore_classes[dataset_name]
224
+ private_classes = [] \
225
+ if dataset_name not in target_datasets_ignore_classes.keys() else target_datasets_private_classes[dataset_name]
226
+
227
+ for c in classes:
228
+ if c not in ignore_classes and c not in all_used_classes_idx_map.keys() and c not in private_classes:
229
+ all_used_classes_idx_map[c] = global_idx
230
+ global_idx += 1
231
+
232
+ # print(all_used_classes_idx_map)
233
+
234
+ # dataset_private_class_idx_offset = 0
235
+ target_private_class_idx = global_idx
236
+ target_datasets_private_class_idx = {d: None for d in target_datasets_private_classes.keys()}
237
+
238
+ for dataset_name, classes in final_classes_of_all_datasets.items():
239
+ if dataset_name not in target_datasets_private_classes.keys():
240
+ continue
241
+
242
+ # ignore_classes = target_datasets_ignore_classes[dataset_name]
243
+ private_classes = target_datasets_private_classes[dataset_name]
244
+ # private_classes = [] \
245
+ # if dataset_name in source_datasets_private_classes.keys() else target_datasets_private_classes[dataset_name]
246
+ # for c in classes:
247
+ # if c not in ignore_classes and c not in all_used_classes_idx_map.keys() and c in private_classes:
248
+ # all_used_classes_idx_map[c] = global_idx + dataset_private_class_idx_offset
249
+
250
+ if len(private_classes) > 0:
251
+ # all_datasets_private_class_idx[dataset_name] = global_idx + dataset_private_class_idx_offset
252
+ # dataset_private_class_idx_offset += 1
253
+ # if dataset_name in source_datasets_private_classes.keys():
254
+ # if source_private_class_idx is None:
255
+ # source_private_class_idx = global_idx if target_private_class_idx is None else target_private_class_idx + 1
256
+ # all_datasets_private_class_idx[dataset_name] = source_private_class_idx
257
+ # else:
258
+ # if target_private_class_idx is None:
259
+ # target_private_class_idx = global_idx if source_private_class_idx is None else source_private_class_idx + 1
260
+ # all_datasets_private_class_idx[dataset_name] = target_private_class_idx
261
+ target_datasets_private_class_idx[dataset_name] = target_private_class_idx
262
+ target_private_class_idx += 1
263
+
264
+
265
+ # all_used_classes = sorted(set(all_used_classes), key=all_used_classes.index)
266
+ # all_used_classes_idx_map = {c: i for i, c in enumerate(all_used_classes)}
267
+
268
+ # print('rename_map', rename_map)
269
+
270
+ # 3.2 raw_class -> rename_map[raw_classes] -> all_used_classes_idx_map
271
+ all_datasets_e2e_idx_map = {}
272
+ all_datasets_e2e_class_to_idx_map = {}
273
+
274
+ for td_name, v1 in target_source_relationship_map.items():
275
+ sd_names = list(v1.keys())
276
+ sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names]
277
+ td_classes = all_datasets_classes[td_name]
278
+
279
+ for sd_name, sd_classes in zip(sd_names, sds_classes):
280
+ cur_e2e_idx_map = {}
281
+ cur_e2e_class_to_idx_map = {}
282
+
283
+ for raw_ci, raw_c in enumerate(sd_classes):
284
+ renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c]
285
+
286
+ ignore_classes = source_datasets_ignore_classes[sd_name + '|' + td_name]
287
+ if renamed_c in ignore_classes:
288
+ continue
289
+
290
+ idx = all_used_classes_idx_map[renamed_c]
291
+
292
+ cur_e2e_idx_map[raw_ci] = idx
293
+ cur_e2e_class_to_idx_map[raw_c] = idx
294
+
295
+ all_datasets_e2e_idx_map[sd_name + '|' + td_name] = cur_e2e_idx_map
296
+ all_datasets_e2e_class_to_idx_map[sd_name + '|' + td_name] = cur_e2e_class_to_idx_map
297
+ cur_e2e_idx_map = {}
298
+ cur_e2e_class_to_idx_map = {}
299
+ for raw_ci, raw_c in enumerate(td_classes):
300
+ renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c]
301
+
302
+ ignore_classes = target_datasets_ignore_classes[td_name]
303
+ if renamed_c in ignore_classes:
304
+ continue
305
+
306
+ if renamed_c in target_datasets_private_classes[td_name]:
307
+ idx = target_datasets_private_class_idx[td_name]
308
+ else:
309
+ idx = all_used_classes_idx_map[renamed_c]
310
+
311
+ cur_e2e_idx_map[raw_ci] = idx
312
+ cur_e2e_class_to_idx_map[raw_c] = idx
313
+
314
+ all_datasets_e2e_idx_map[td_name] = cur_e2e_idx_map
315
+ all_datasets_e2e_class_to_idx_map[td_name] = cur_e2e_class_to_idx_map
316
+
317
+ all_datasets_ignore_classes = {**source_datasets_ignore_classes, **target_datasets_ignore_classes}
318
+ # all_datasets_private_classes = {**source_datasets_private_classes, **target_datasets_private_classes}
319
+
320
+ classes_idx_set = []
321
+ for d, m in all_datasets_e2e_class_to_idx_map.items():
322
+ classes_idx_set += list(m.values())
323
+ classes_idx_set = set(classes_idx_set)
324
+ num_classes = len(classes_idx_set)
325
+
326
+ return all_datasets_ignore_classes, target_datasets_private_classes, \
327
+ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
328
+ target_source_relationship_map, rename_map, num_classes
329
+
330
+
331
+ def _build_scenario_info_v2(
332
+ source_datasets_name: List[str],
333
+ target_datasets_order: List[str],
334
+ da_mode: str
335
+ ):
336
+ assert da_mode in ['close_set', 'partial', 'open_set', 'universal']
337
+ da_mode = {'close_set': 'da', 'partial': 'partial_da', 'open_set': 'open_set_da', 'universal': 'universal_da'}[da_mode]
338
+
339
+ source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name]#获知对应的名字和对应属性,要添加数据集时,直接register就行
340
+ target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))]
341
+
342
+ all_datasets_ignore_classes, target_datasets_private_classes, \
343
+ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
344
+ target_source_relationship_map, rename_map, num_classes \
345
+ = _handle_all_datasets_v2(source_datasets_meta_info, target_datasets_meta_info, da_mode)
346
+
347
+ return all_datasets_ignore_classes, target_datasets_private_classes, \
348
+ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
349
+ target_source_relationship_map, rename_map, num_classes
350
+
351
+
352
+ def build_scenario_manually_v2(
353
+ source_datasets_name: List[str],
354
+ target_datasets_order: List[str],
355
+ da_mode: str,
356
+ data_dirs: Dict[str, str],
357
+ # transforms: Optional[Dict[str, Compose]] = None
358
+ ):
359
+ configs = copy.deepcopy(locals())#返回当前局部变量
360
+
361
+ source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name]
362
+ target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))]
363
+
364
+ all_datasets_ignore_classes, target_datasets_private_classes, \
365
+ all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
366
+ target_source_relationship_map, rename_map, num_classes \
367
+ = _build_scenario_info_v2(source_datasets_name, target_datasets_order, da_mode)
368
+ # from rich.console import Console
369
+ # console = Console(width=10000)
370
+
371
+ # def print_obj(_o):
372
+ # # import pprint
373
+ # # s = pprint.pformat(_o, width=140, compact=True)
374
+ # console.print(_o)
375
+
376
+ # console.print('configs:', style='bold red')
377
+ # print_obj(configs)
378
+ # console.print('renamed classes:', style='bold red')
379
+ # print_obj(rename_map)
380
+ # console.print('discarded classes:', style='bold red')
381
+ # print_obj(all_datasets_ignore_classes)
382
+ # console.print('unknown classes:', style='bold red')
383
+ # print_obj(target_datasets_private_classes)
384
+ # console.print('class to index map:', style='bold red')
385
+ # print_obj(all_datasets_e2e_class_to_idx_map)
386
+ # console.print('index map:', style='bold red')
387
+ # print_obj(all_datasets_e2e_idx_map)
388
+ # console = Console()
389
+ # # console.print('class distribution:', style='bold red')
390
+ # # class_dist = {
391
+ # # k: {
392
+ # # '#known classes': len(all_datasets_known_classes[k]),
393
+ # # '#unknown classes': len(all_datasets_private_classes[k]),
394
+ # # '#discarded classes': len(all_datasets_ignore_classes[k])
395
+ # # } for k in all_datasets_ignore_classes.keys()
396
+ # # }
397
+ # # print_obj(class_dist)
398
+ # console.print('corresponding sources of each target:', style='bold red')
399
+ # print_obj(target_source_relationship_map)
400
+
401
+ # return
402
+
403
+ # res_source_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split, getattr(transforms, d, None),
404
+ # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
405
+ # for split in ['train', 'val', 'test']}
406
+ # for d in source_datasets_name}
407
+ # res_target_datasets_map = {d: {'train': get_num_limited_dataset(get_dataset(d, data_dirs[d], 'test', getattr(transforms, d, None),
408
+ # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d]),
409
+ # num_samples_in_each_target_domain),
410
+ # 'test': get_dataset(d, data_dirs[d], 'test', getattr(transforms, d, None),
411
+ # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
412
+ # }
413
+ # for d in list(set(target_datasets_order))}
414
+
415
+ # res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split,
416
+ # getattr(transforms, d.split('|')[0], None),
417
+ # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
418
+ # for split in ['train', 'val', 'test']}
419
+ # for d in all_datasets_ignore_classes.keys() if d.split('|')[0] in source_datasets_name}
420
+
421
+ # from functools import reduce
422
+ # res_offline_train_source_datasets_map = {}
423
+ # res_offline_train_source_datasets_map_names = {}
424
+
425
+ # for d in source_datasets_name:
426
+ # source_dataset_with_max_num_classes = None
427
+
428
+ # for ed_name, ed in res_source_datasets_map.items():
429
+ # if not ed_name.startswith(d):
430
+ # continue
431
+
432
+ # if source_dataset_with_max_num_classes is None:
433
+ # source_dataset_with_max_num_classes = ed
434
+ # res_offline_train_source_datasets_map_names[d] = ed_name
435
+
436
+ # if len(ed['train'].ignore_classes) < len(source_dataset_with_max_num_classes['train'].ignore_classes):
437
+ # source_dataset_with_max_num_classes = ed
438
+ # res_offline_train_source_datasets_map_names[d] = ed_name
439
+
440
+ # res_offline_train_source_datasets_map[d] = source_dataset_with_max_num_classes
441
+
442
+ # res_target_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split, getattr(transforms, d, None),
443
+ # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
444
+ # for split in ['train', 'val', 'test']}
445
+ # for d in list(set(target_datasets_order))}
446
+
447
+ from .scenario import Scenario, DatasetMetaInfo
448
+
449
+ # test_scenario = Scenario(
450
+ # config=configs,
451
+ # offline_source_datasets_meta_info={
452
+ # d: DatasetMetaInfo(d,
453
+ # {k: v for k, v in all_datasets_e2e_class_to_idx_map[res_offline_train_source_datasets_map_names[d]].items()},
454
+ # None)
455
+ # for d in source_datasets_name
456
+ # },
457
+ # offline_source_datasets={d: res_offline_train_source_datasets_map[d] for d in source_datasets_name},
458
+
459
+ # online_datasets_meta_info=[
460
+ # (
461
+ # {sd + '|' + d: DatasetMetaInfo(d,
462
+ # {k: v for k, v in all_datasets_e2e_class_to_idx_map[sd + '|' + d].items()},
463
+ # None)
464
+ # for sd in target_source_relationship_map[d].keys()},
465
+ # DatasetMetaInfo(d,
466
+ # {k: v for k, v in all_datasets_e2e_class_to_idx_map[d].items() if k not in target_datasets_private_classes[d]},
467
+ # target_datasets_private_class_idx[d])
468
+ # )
469
+ # for d in target_datasets_order
470
+ # ],
471
+ # online_datasets={**res_source_datasets_map, **res_target_datasets_map},
472
+ # target_domains_order=target_datasets_order,
473
+ # target_source_map=target_source_relationship_map,
474
+ # num_classes=num_classes
475
+ # )
476
+ import os
477
+ os.environ['_ZQL_NUMC'] = str(num_classes)
478
+
479
+ test_scenario = Scenario(config=configs, all_datasets_ignore_classes_map=all_datasets_ignore_classes,
480
+ all_datasets_idx_map=all_datasets_e2e_idx_map,
481
+ target_domains_order=target_datasets_order,
482
+ target_source_map=target_source_relationship_map,
483
+ all_datasets_e2e_class_to_idx_map=all_datasets_e2e_class_to_idx_map,
484
+ num_classes=num_classes)
485
+
486
+
487
+ return test_scenario
488
+
489
+
490
+ if __name__ == '__main__':
491
+ test_scenario = build_scenario_manually_v2(['CIFAR10', 'SVHN'],
492
+ ['STL10', 'MNIST', 'STL10', 'USPS', 'MNIST', 'STL10'],
493
+ 'close_set')
494
+ print(test_scenario.num_classes)
495
+
data/build/merge_alias.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from re import L
2
+ from typing import Dict, List
3
+ from collections import Counter
4
+
5
+
6
+ def grouping(bondlist):
7
+ # reference: https://blog.csdn.net/YnagShanwen/article/details/111344386
8
+ groups = []
9
+ break1 = False
10
+ while bondlist:
11
+ pair1 = bondlist.pop(0)
12
+ a = 11111
13
+ b = 10000
14
+ while b != a:
15
+ a = b
16
+ for atomid in pair1:
17
+ for i,pair2 in enumerate(bondlist):
18
+ if atomid in pair2:
19
+ pair1 = pair1 + pair2
20
+ bondlist.pop(i)
21
+ if not bondlist:
22
+ break1 = True
23
+ break
24
+ if break1:
25
+ break
26
+ b = len(pair1)
27
+ groups.append(pair1)
28
+ return groups
29
+
30
+
31
+ def build_semantic_class_info(classes: List[str], aliases: List[List[str]]):
32
+ res = []
33
+ for c in classes:
34
+ # print(res)
35
+ if len(aliases) == 0:
36
+ res += [[c]]
37
+ else:
38
+ find_alias = False
39
+ for alias in aliases:
40
+ if c in alias:
41
+ res += [alias]
42
+ find_alias = True
43
+ break
44
+ if not find_alias:
45
+ res += [[c]]
46
+ # print(classes, res)
47
+ return res
48
+
49
+
50
+ def merge_the_same_meaning_classes(classes_info_of_all_datasets):
51
+ # print(classes_info_of_all_datasets)
52
+
53
+ semantic_classes_of_all_datasets = []
54
+ all_aliases = []
55
+ for classes, aliases in classes_info_of_all_datasets.values():
56
+ all_aliases += aliases
57
+ for classes, aliases in classes_info_of_all_datasets.values():
58
+ semantic_classes_of_all_datasets += build_semantic_class_info(classes, all_aliases)
59
+
60
+ # print(semantic_classes_of_all_datasets)
61
+
62
+ grouped_classes_of_all_datasets = grouping(semantic_classes_of_all_datasets)#匹配过后的数据
63
+
64
+ # print(grouped_classes_of_all_datasets)
65
+
66
+ # final_grouped_classes_of_all_datasets = [Counter(c).most_common()[0][0] for c in grouped_classes_of_all_datasets]
67
+ # use most common class name; if the same common, use shortest class name!
68
+ final_grouped_classes_of_all_datasets = []
69
+ for c in grouped_classes_of_all_datasets:
70
+ counter = Counter(c).most_common()
71
+ max_times = counter[0][1]
72
+ candidate_class_names = []
73
+ for item, times in counter:
74
+ if times < max_times:
75
+ break
76
+ candidate_class_names += [item]
77
+ candidate_class_names.sort(key=lambda x: len(x))
78
+
79
+ final_grouped_classes_of_all_datasets += [candidate_class_names[0]]
80
+ res = {}
81
+ res_map = {d: {} for d in classes_info_of_all_datasets.keys()}
82
+
83
+ for dataset_name, (classes, _) in classes_info_of_all_datasets.items():
84
+ final_classes = []
85
+ for c in classes:
86
+ for grouped_names, final_name in zip(grouped_classes_of_all_datasets, final_grouped_classes_of_all_datasets):
87
+ if c in grouped_names:
88
+ final_classes += [final_name]
89
+ if final_name != c:
90
+ res_map[dataset_name][c] = final_name
91
+ break
92
+ res[dataset_name] = sorted(set(final_classes), key=final_classes.index)
93
+ return res, res_map
94
+
95
+
96
+ if __name__ == '__main__':
97
+ cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
98
+ cifar10_aliases = [['automobile', 'car']]
99
+ stl10_classes = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck']
100
+
101
+ final_classes_of_all_datasets, rename_map = merge_the_same_meaning_classes({
102
+ 'CIFAR10': (cifar10_classes, cifar10_aliases),
103
+ 'STL10': (stl10_classes, [])
104
+ })
105
+
106
+ print(final_classes_of_all_datasets, rename_map)
data/build/scenario.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from functools import reduce
3
+ from typing import Dict, List, Tuple
4
+ import numpy as np
5
+ import copy
6
+ from utils.common.log import logger
7
+ from ..datasets.ab_dataset import ABDataset
8
+ from ..dataloader import FastDataLoader, InfiniteDataLoader, build_dataloader
9
+ from data import get_dataset
10
+
11
+
12
+ class DatasetMetaInfo:
13
+ def __init__(self, name,
14
+ known_classes_name_idx_map, unknown_class_idx):
15
+
16
+ assert unknown_class_idx not in known_classes_name_idx_map.keys()
17
+
18
+ self.name = name
19
+ self.unknown_class_idx = unknown_class_idx
20
+ self.known_classes_name_idx_map = known_classes_name_idx_map
21
+
22
+ @property
23
+ def num_classes(self):
24
+ return len(self.known_classes_idx) + 1
25
+
26
+
27
+ class MergedDataset:
28
+ def __init__(self, datasets: List[ABDataset]):
29
+ self.datasets = datasets
30
+ self.datasets_len = [len(i) for i in self.datasets]
31
+ logger.info(f'create MergedDataset: len of datasets {self.datasets_len}')
32
+ self.datasets_cum_len = np.cumsum(self.datasets_len)
33
+
34
+ def __getitem__(self, idx):
35
+ for i, cum_len in enumerate(self.datasets_cum_len):
36
+ if idx < cum_len:
37
+ return self.datasets[i][idx - sum(self.datasets_len[0: i])]
38
+
39
+ def __len__(self):
40
+ return sum(self.datasets_len)
41
+
42
+
43
+ class IndexReturnedDataset:
44
+ def __init__(self, dataset: ABDataset):
45
+ self.dataset = dataset
46
+
47
+ def __getitem__(self, idx):
48
+ res = self.dataset[idx]
49
+
50
+ if isinstance(res, (tuple, list)):
51
+ return (*res, idx)
52
+ else:
53
+ return res, idx
54
+
55
+ def __len__(self):
56
+ return len(self.dataset)
57
+
58
+
59
+ # class Scenario:
60
+ # def __init__(self, config,
61
+ # source_datasets_meta_info: Dict[str, DatasetMetaInfo], target_datasets_meta_info: Dict[str, DatasetMetaInfo],
62
+ # target_source_map: Dict[str, Dict[str, str]],
63
+ # target_domains_order: List[str],
64
+ # source_datasets: Dict[str, Dict[str, ABDataset]], target_datasets: Dict[str, Dict[str, ABDataset]]):
65
+
66
+ # self.__config = config
67
+ # self.__source_datasets_meta_info = source_datasets_meta_info
68
+ # self.__target_datasets_meta_info = target_datasets_meta_info
69
+ # self.__target_source_map = target_source_map
70
+ # self.__target_domains_order = target_domains_order
71
+ # self.__source_datasets = source_datasets
72
+ # self.__target_datasets = target_datasets
73
+
74
+ # # 1. basic
75
+ # def get_config(self):
76
+ # return copy.deepcopy(self.__config)
77
+
78
+ # def get_task_type(self):
79
+ # return list(self.__source_datasets.values())[0]['train'].task_type
80
+
81
+ # def get_num_classes(self):
82
+ # known_classes_idx = []
83
+ # unknown_classes_idx = []
84
+ # for v in self.__source_datasets_meta_info.values():
85
+ # known_classes_idx += list(v.known_classes_name_idx_map.values())
86
+ # unknown_classes_idx += [v.unknown_class_idx]
87
+ # for v in self.__target_datasets_meta_info.values():
88
+ # known_classes_idx += list(v.known_classes_name_idx_map.values())
89
+ # unknown_classes_idx += [v.unknown_class_idx]
90
+ # unknown_classes_idx = [i for i in unknown_classes_idx if i is not None]
91
+ # # print(known_classes_idx, unknown_classes_idx)
92
+ # res = len(set(known_classes_idx)), len(set(unknown_classes_idx)), len(set(known_classes_idx + unknown_classes_idx))
93
+ # # print(res)
94
+ # assert res[0] + res[1] == res[2]
95
+ # return res
96
+
97
+ # def build_dataloader(self, dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool):
98
+ # if infinite:
99
+ # dataloader = InfiniteDataLoader(
100
+ # dataset, None, batch_size, num_workers=num_workers)
101
+ # else:
102
+ # dataloader = FastDataLoader(
103
+ # dataset, batch_size, num_workers, shuffle=shuffle_when_finite)
104
+
105
+ # return dataloader
106
+
107
+ # def build_sub_dataset(self, dataset: ABDataset, indexes: List[int]):
108
+ # from ..data.datasets.dataset_split import _SplitDataset
109
+ # dataset.dataset = _SplitDataset(dataset.dataset, indexes)
110
+ # return dataset
111
+
112
+ # def build_index_returned_dataset(self, dataset: ABDataset):
113
+ # return IndexReturnedDataset(dataset)
114
+
115
+ # # 2. source
116
+ # def get_source_datasets_meta_info(self):
117
+ # return self.__source_datasets_meta_info
118
+
119
+ # def get_source_datasets_name(self):
120
+ # return list(self.__source_datasets.keys())
121
+
122
+ # def get_merged_source_dataset(self, split):
123
+ # source_train_datasets = {n: d[split] for n, d in self.__source_datasets.items()}
124
+ # return MergedDataset(list(source_train_datasets.values()))
125
+
126
+ # def get_source_datasets(self, split):
127
+ # source_train_datasets = {n: d[split] for n, d in self.__source_datasets.items()}
128
+ # return source_train_datasets
129
+
130
+ # # 3. target **domain**
131
+ # # (do we need such API `get_ith_target_domain()`?)
132
+ # def get_target_domains_meta_info(self):
133
+ # return self.__source_datasets_meta_info
134
+
135
+ # def get_target_domains_order(self):
136
+ # return self.__target_domains_order
137
+
138
+ # def get_corr_source_datasets_name_of_target_domain(self, target_domain_name):
139
+ # return self.__target_source_map[target_domain_name]
140
+
141
+ # def get_limited_target_train_dataset(self):
142
+ # if len(self.__target_domains_order) > 1:
143
+ # raise RuntimeError('this API is only for pass-in scenario in user-defined online DA algorithm')
144
+ # return list(self.__target_datasets.values())[0]['train']
145
+
146
+ # def get_target_domains_iterator(self, split):
147
+ # for target_domain_index, target_domain_name in enumerate(self.__target_domains_order):
148
+ # target_dataset = self.__target_datasets[target_domain_name]
149
+ # target_domain_meta_info = self.__target_datasets_meta_info[target_domain_name]
150
+
151
+ # yield target_domain_index, target_domain_name, target_dataset[split], target_domain_meta_info
152
+
153
+ # # 4. permission management
154
+ # def get_sub_scenario(self, source_datasets_name, source_splits, target_domains_order, target_splits):
155
+ # def get_split(dataset, splits):
156
+ # res = {}
157
+ # for s, d in dataset.items():
158
+ # if s in splits:
159
+ # res[s] = d
160
+ # return res
161
+
162
+ # return Scenario(
163
+ # config=self.__config,
164
+ # source_datasets_meta_info={k: v for k, v in self.__source_datasets_meta_info.items() if k in source_datasets_name},
165
+ # target_datasets_meta_info={k: v for k, v in self.__target_datasets_meta_info.items() if k in target_domains_order},
166
+ # target_source_map={k: v for k, v in self.__target_source_map.items() if k in target_domains_order},
167
+ # target_domains_order=target_domains_order,
168
+ # source_datasets={k: get_split(v, source_splits) for k, v in self.__source_datasets.items() if k in source_datasets_name},
169
+ # target_datasets={k: get_split(v, target_splits) for k, v in self.__target_datasets.items() if k in target_domains_order}
170
+ # )
171
+
172
+ # def get_only_source_sub_scenario_for_exp_tracker(self):
173
+ # return self.get_sub_scenario(self.get_source_datasets_name(), ['train', 'val', 'test'], [], [])
174
+
175
+ # def get_only_source_sub_scenario_for_alg(self):
176
+ # return self.get_sub_scenario(self.get_source_datasets_name(), ['train'], [], [])
177
+
178
+ # def get_one_da_sub_scenario_for_alg(self, target_domain_name):
179
+ # return self.get_sub_scenario(self.get_corr_source_datasets_name_of_target_domain(target_domain_name),
180
+ # ['train', 'val'], [target_domain_name], ['train'])
181
+
182
+
183
+ # class Scenario:
184
+ # def __init__(self, config,
185
+
186
+ # offline_source_datasets_meta_info: Dict[str, DatasetMetaInfo],
187
+ # offline_source_datasets: Dict[str, ABDataset],
188
+
189
+ # online_datasets_meta_info: List[Tuple[Dict[str, DatasetMetaInfo], DatasetMetaInfo]],
190
+ # online_datasets: Dict[str, ABDataset],
191
+ # target_domains_order: List[str],
192
+ # target_source_map: Dict[str, Dict[str, str]],
193
+
194
+ # num_classes: int):
195
+
196
+ # self.config = config
197
+
198
+ # self.offline_source_datasets_meta_info = offline_source_datasets_meta_info
199
+ # self.offline_source_datasets = offline_source_datasets
200
+
201
+ # self.online_datasets_meta_info = online_datasets_meta_info
202
+ # self.online_datasets = online_datasets
203
+
204
+ # self.target_domains_order = target_domains_order
205
+ # self.target_source_map = target_source_map
206
+
207
+ # self.num_classes = num_classes
208
+
209
+ # def get_offline_source_datasets(self, split):
210
+ # return {n: d[split] for n, d in self.offline_source_datasets.items()}
211
+
212
+ # def get_offline_source_merged_dataset(self, split):
213
+ # return MergedDataset([d[split] for d in self.offline_source_datasets.values()])
214
+
215
+ # def get_online_current_corresponding_source_datasets(self, domain_index, split):
216
+ # cur_target_domain_name = self.target_domains_order[domain_index]
217
+ # cur_source_datasets_name = list(self.target_source_map[cur_target_domain_name].keys())
218
+ # cur_source_datasets = {n: self.online_datasets[n + '|' + cur_target_domain_name][split] for n in cur_source_datasets_name}
219
+ # return cur_source_datasets
220
+
221
+ # def get_online_current_corresponding_merged_source_dataset(self, domain_index, split):
222
+ # cur_target_domain_name = self.target_domains_order[domain_index]
223
+ # cur_source_datasets_name = list(self.target_source_map[cur_target_domain_name].keys())
224
+ # cur_source_datasets = {n: self.online_datasets[n + '|' + cur_target_domain_name][split] for n in cur_source_datasets_name}
225
+ # return MergedDataset([d for d in cur_source_datasets.values()])
226
+
227
+ # def get_online_current_target_dataset(self, domain_index, split):
228
+ # cur_target_domain_name = self.target_domains_order[domain_index]
229
+ # return self.online_datasets[cur_target_domain_name][split]
230
+
231
+ # def build_dataloader(self, dataset: ABDataset, batch_size: int, num_workers: int,
232
+ # infinite: bool, shuffle_when_finite: bool, to_iterator: bool):
233
+ # if infinite:
234
+ # dataloader = InfiniteDataLoader(
235
+ # dataset, None, batch_size, num_workers=num_workers)
236
+ # else:
237
+ # dataloader = FastDataLoader(
238
+ # dataset, batch_size, num_workers, shuffle=shuffle_when_finite)
239
+
240
+ # if to_iterator:
241
+ # dataloader = iter(dataloader)
242
+
243
+ # return dataloader
244
+
245
+ # def build_sub_dataset(self, dataset: ABDataset, indexes: List[int]):
246
+ # from data.datasets.dataset_split import _SplitDataset
247
+ # dataset.dataset = _SplitDataset(dataset.dataset, indexes)
248
+ # return dataset
249
+
250
+ # def build_index_returned_dataset(self, dataset: ABDataset):
251
+ # return IndexReturnedDataset(dataset)
252
+
253
+ # def get_config(self):
254
+ # return copy.deepcopy(self.config)
255
+
256
+ # def get_task_type(self):
257
+ # return list(self.online_datasets.values())[0]['train'].task_type
258
+
259
+ # def get_num_classes(self):
260
+ # return self.num_classes
261
+
262
+
263
+ class Scenario:
264
+ def __init__(self, config, all_datasets_ignore_classes_map, all_datasets_idx_map, target_domains_order, target_source_map,
265
+ all_datasets_e2e_class_to_idx_map,
266
+ num_classes):
267
+ self.config = config
268
+ self.all_datasets_ignore_classes_map = all_datasets_ignore_classes_map
269
+ self.all_datasets_idx_map = all_datasets_idx_map
270
+ self.target_domains_order = target_domains_order
271
+ self.target_source_map = target_source_map
272
+ self.all_datasets_e2e_class_to_idx_map = all_datasets_e2e_class_to_idx_map
273
+ self.num_classes = num_classes
274
+ self.cur_domain_index = 0
275
+
276
+ logger.info(f'[scenario build] # classes: {num_classes}')
277
+ logger.debug(f'[scenario build] idx map: {all_datasets_idx_map}')
278
+
279
+ def to_json(self):
280
+ return dict(
281
+ config=self.config, all_datasets_ignore_classes_map=self.all_datasets_ignore_classes_map,
282
+ all_datasets_idx_map=self.all_datasets_idx_map, target_domains_order=self.target_domains_order,
283
+ target_source_map=self.target_source_map,
284
+ all_datasets_e2e_class_to_idx_map=self.all_datasets_e2e_class_to_idx_map,
285
+ num_classes=self.num_classes
286
+ )
287
+
288
+ def __str__(self):
289
+ return f'Scenario({self.to_json()})'
290
+
291
+ def get_offline_datasets(self, transform=None):
292
+ # make source datasets which contains all unioned classes
293
+ res_offline_train_source_datasets_map = {}
294
+
295
+ from .. import get_dataset
296
+ data_dirs = self.config['data_dirs']
297
+
298
+ source_datasets_name = self.config['source_datasets_name']
299
+ res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split,
300
+ transform,
301
+ self.all_datasets_ignore_classes_map[d], self.all_datasets_idx_map[d])
302
+ for split in ['train', 'val', 'test']}
303
+ for d in self.all_datasets_ignore_classes_map.keys() if d.split('|')[0] in source_datasets_name}
304
+
305
+ for source_dataset_name in self.config['source_datasets_name']:
306
+ source_datasets = [v for k, v in res_source_datasets_map.items() if source_dataset_name in k]
307
+
308
+ # how to merge idx map?
309
+ # 35 79 97
310
+ idx_maps = [d['train'].idx_map for d in source_datasets]
311
+ ignore_classes_list = [d['train'].ignore_classes for d in source_datasets]
312
+
313
+ union_idx_map = {}
314
+ for idx_map in idx_maps:
315
+ for k, v in idx_map.items():
316
+ if k not in union_idx_map:
317
+ union_idx_map[k] = v
318
+ else:
319
+ assert union_idx_map[k] == v
320
+
321
+ union_ignore_classes = reduce(lambda res, cur: res & set(cur), ignore_classes_list, set(ignore_classes_list[0]))
322
+ assert len(union_ignore_classes) + len(union_idx_map) == len(source_datasets[0]['train'].raw_classes)
323
+
324
+ logger.info(f'[scenario build] {source_dataset_name} has {len(union_idx_map)} classes in offline training')
325
+
326
+ d = source_dataset_name
327
+ res_offline_train_source_datasets_map[d] = {split: get_dataset(d, data_dirs[d], split,
328
+ transform,
329
+ union_ignore_classes, union_idx_map)
330
+ for split in ['train', 'val', 'test']}
331
+
332
+ return res_offline_train_source_datasets_map
333
+
334
+ def get_offline_datasets_args(self):
335
+ # make source datasets which contains all unioned classes
336
+ res_offline_train_source_datasets_map = {}
337
+
338
+ from .. import get_dataset
339
+ data_dirs = self.config['data_dirs']
340
+
341
+ source_datasets_name = self.config['source_datasets_name']
342
+ res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split,
343
+ None,
344
+ self.all_datasets_ignore_classes_map[d], self.all_datasets_idx_map[d])
345
+ for split in ['train', 'val', 'test']}
346
+ for d in self.all_datasets_ignore_classes_map.keys() if d.split('|')[0] in source_datasets_name}
347
+
348
+ for source_dataset_name in self.config['source_datasets_name']:
349
+ source_datasets = [v for k, v in res_source_datasets_map.items() if source_dataset_name in k]
350
+
351
+ # how to merge idx map?
352
+ # 35 79 97
353
+ idx_maps = [d['train'].idx_map for d in source_datasets]
354
+ ignore_classes_list = [d['train'].ignore_classes for d in source_datasets]
355
+
356
+ union_idx_map = {}
357
+ for idx_map in idx_maps:
358
+ for k, v in idx_map.items():
359
+ if k not in union_idx_map:
360
+ union_idx_map[k] = v
361
+ else:
362
+ assert union_idx_map[k] == v
363
+
364
+ union_ignore_classes = reduce(lambda res, cur: res & set(cur), ignore_classes_list, set(ignore_classes_list[0]))
365
+ assert len(union_ignore_classes) + len(union_idx_map) == len(source_datasets[0]['train'].raw_classes)
366
+
367
+ logger.info(f'[scenario build] {source_dataset_name} has {len(union_idx_map)} classes in offline training')
368
+
369
+ d = source_dataset_name
370
+ res_offline_train_source_datasets_map[d] = {split: dict(d, data_dirs[d], split,
371
+ None,
372
+ union_ignore_classes, union_idx_map)
373
+ for split in ['train', 'val', 'test']}
374
+
375
+ return res_offline_train_source_datasets_map
376
+
377
+ # for d in source_datasets_name:
378
+ # source_dataset_with_max_num_classes = None
379
+
380
+ # for ed_name, ed in res_source_datasets_map.items():
381
+ # if not ed_name.startswith(d):
382
+ # continue
383
+
384
+ # if source_dataset_with_max_num_classes is None:
385
+ # source_dataset_with_max_num_classes = ed
386
+ # res_offline_train_source_datasets_map_names[d] = ed_name
387
+
388
+ # if len(ed['train'].ignore_classes) < len(source_dataset_with_max_num_classes['train'].ignore_classes):
389
+ # source_dataset_with_max_num_classes = ed
390
+ # res_offline_train_source_datasets_map_names[d] = ed_name
391
+
392
+ # res_offline_train_source_datasets_map[d] = source_dataset_with_max_num_classes
393
+
394
+ # return res_offline_train_source_datasets_map
395
+
396
+ def get_online_ith_domain_datasets_args_for_inference(self, domain_index):
397
+ target_dataset_name = self.target_domains_order[domain_index]
398
+ # dataset_name: Any, root_dir: Any, split: Any, transform: Any | None = None, ignore_classes: Any = [], idx_map: Any | None = None
399
+
400
+ if 'MM-CityscapesDet' in self.target_domains_order or 'CityscapesDet' in self.target_domains_order or 'BaiduPersonDet' in self.target_domains_order:
401
+ logger.info(f'use val split for inference test (only Det workload)')
402
+ split = 'test'
403
+ else:
404
+ split = 'train'
405
+
406
+ return dict(dataset_name=target_dataset_name,
407
+ root_dir=self.config['data_dirs'][target_dataset_name],
408
+ split=split,
409
+ transform=None,
410
+ ignore_classes=self.all_datasets_ignore_classes_map[target_dataset_name],
411
+ idx_map=self.all_datasets_idx_map[target_dataset_name])
412
+
413
+ def get_online_ith_domain_datasets_args_for_training(self, domain_index):
414
+ target_dataset_name = self.target_domains_order[domain_index]
415
+ source_datasets_name = list(self.target_source_map[target_dataset_name].keys())
416
+
417
+ res = {}
418
+ # dataset_name: Any, root_dir: Any, split: Any, transform: Any | None = None, ignore_classes: Any = [], idx_map: Any | None = None
419
+ res[target_dataset_name] = {split: dict(dataset_name=target_dataset_name,
420
+ root_dir=self.config['data_dirs'][target_dataset_name],
421
+ split=split,
422
+ transform=None,
423
+ ignore_classes=self.all_datasets_ignore_classes_map[target_dataset_name],
424
+ idx_map=self.all_datasets_idx_map[target_dataset_name]) for split in ['train', 'val']}
425
+ for d in source_datasets_name:
426
+ res[d] = {split: dict(dataset_name=d,
427
+ root_dir=self.config['data_dirs'][d],
428
+ split=split,
429
+ transform=None,
430
+ ignore_classes=self.all_datasets_ignore_classes_map[d + '|' + target_dataset_name],
431
+ idx_map=self.all_datasets_idx_map[d + '|' + target_dataset_name]) for split in ['train', 'val']}
432
+
433
+ return res
434
+
435
+ def get_online_cur_domain_datasets_args_for_inference(self):
436
+ return self.get_online_ith_domain_datasets_args_for_inference(self.cur_domain_index)
437
+
438
+ def get_online_cur_domain_datasets_args_for_training(self):
439
+ return self.get_online_ith_domain_datasets_args_for_training(self.cur_domain_index)
440
+
441
+ def get_online_cur_domain_datasets_for_training(self, transform=None):
442
+ res = {}
443
+ datasets_args = self.get_online_ith_domain_datasets_args_for_training(self.cur_domain_index)
444
+ for dataset_name, dataset_args in datasets_args.items():
445
+ res[dataset_name] = {}
446
+ for split, args in dataset_args.items():
447
+ if transform is not None:
448
+ args['transform'] = transform
449
+ dataset = get_dataset(**args)
450
+ res[dataset_name][split] = dataset
451
+ return res
452
+
453
+ def get_online_cur_domain_datasets_for_inference(self, transform=None):
454
+ datasets_args = self.get_online_ith_domain_datasets_args_for_inference(self.cur_domain_index)
455
+ if transform is not None:
456
+ datasets_args['transform'] = transform
457
+ return get_dataset(**datasets_args)
458
+
459
+ def get_online_cur_domain_samples_for_training(self, num_samples, transform=None, collate_fn=None):
460
+ dataset = self.get_online_cur_domain_datasets_for_training(transform=transform)
461
+ dataset = dataset[self.target_domains_order[self.cur_domain_index]]['train']
462
+ return next(iter(build_dataloader(dataset, num_samples, 0, True, None, collate_fn=collate_fn)))[0]
463
+
464
+ def next_domain(self):
465
+ self.cur_domain_index += 1
466
+
data/build_cl/__pycache__/build.cpython-38.pyc ADDED
Binary file (4.36 kB). View file
 
data/build_cl/__pycache__/scenario.cpython-38.pyc ADDED
Binary file (5.48 kB). View file
 
data/build_cl/build.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Type, Union
2
+ from ..datasets.ab_dataset import ABDataset
3
+ # from benchmark.data.visualize import visualize_classes_in_object_detection
4
+ # from benchmark.scenario.val_domain_shift import get_val_domain_shift_transform
5
+ from ..dataset import get_dataset
6
+ import copy
7
+ from torchvision.transforms import Compose
8
+ from ..datasets.registery import static_dataset_registery
9
+ from ..build.scenario import Scenario as DAScenario
10
+ from copy import deepcopy
11
+ from utils.common.log import logger
12
+ import random
13
+ from .scenario import _ABDatasetMetaInfo, Scenario
14
+
15
+
16
+ def _check(source_datasets_meta_info: List[_ABDatasetMetaInfo], target_datasets_meta_info: List[_ABDatasetMetaInfo]):
17
+ # requirements for simplity
18
+ # 1. no same class in source datasets
19
+
20
+ source_datasets_class = [i.classes for i in source_datasets_meta_info]
21
+ for ci1, c1 in enumerate(source_datasets_class):
22
+ for ci2, c2 in enumerate(source_datasets_class):
23
+ if ci1 == ci2:
24
+ continue
25
+
26
+ c1_name = source_datasets_meta_info[ci1].name
27
+ c2_name = source_datasets_meta_info[ci2].name
28
+ intersection = set(c1).intersection(set(c2))
29
+ assert len(intersection) == 0, f'{c1_name} has intersection with {c2_name}: {intersection}'
30
+
31
+
32
+ def build_cl_scenario(
33
+ da_scenario: DAScenario,
34
+ target_datasets_name: List[str],
35
+ num_classes_per_task: int,
36
+ max_num_tasks: int,
37
+ data_dirs,
38
+ sanity_check=False
39
+ ):
40
+ config = deepcopy(locals())
41
+
42
+ source_datasets_idx_map = {}
43
+ source_class_idx_max = 0
44
+
45
+ for sd in da_scenario.config['source_datasets_name']:
46
+ da_scenario_idx_map = None
47
+ for k, v in da_scenario.all_datasets_idx_map.items():
48
+ if k.startswith(sd):
49
+ da_scenario_idx_map = v
50
+ break
51
+
52
+ source_datasets_idx_map[sd] = da_scenario_idx_map
53
+ source_class_idx_max = max(source_class_idx_max, max(list(da_scenario_idx_map.values())))
54
+
55
+
56
+ target_class_idx_start = source_class_idx_max + 1
57
+
58
+ target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:], None, None) for d in target_datasets_name]
59
+
60
+ task_datasets_seq = []
61
+
62
+ num_tasks_per_dataset = {}
63
+
64
+ for td_info_i, td_info in enumerate(target_datasets_meta_info):
65
+
66
+ if td_info_i >= 1:
67
+ for _td_info_i, _td_info in enumerate(target_datasets_meta_info[0: td_info_i]):
68
+ if _td_info.name == td_info.name:
69
+ # print(111)
70
+ # class_idx_offset = sum([len(t.classes) for t in target_datasets_meta_info[0: td_info_i]])
71
+ print(len(task_datasets_seq))
72
+
73
+ task_index_offset = sum([v if __i < _td_info_i else 0 for __i, v in enumerate(num_tasks_per_dataset.values())])
74
+
75
+ task_datasets_seq += task_datasets_seq[task_index_offset: task_index_offset + num_tasks_per_dataset[_td_info_i]]
76
+ print(len(task_datasets_seq))
77
+ break
78
+ continue
79
+
80
+ td_classes = td_info.classes
81
+ num_tasks_per_dataset[td_info_i] = 0
82
+
83
+ for ci in range(0, len(td_classes), num_classes_per_task):
84
+ task_i = ci // num_classes_per_task
85
+ task_datasets_seq += [_ABDatasetMetaInfo(
86
+ f'{td_info.name}|task-{task_i}|ci-{ci}-{ci + num_classes_per_task - 1}',
87
+ td_classes[ci: ci + num_classes_per_task],
88
+ td_info.task_type,
89
+ td_info.object_type,
90
+ td_info.class_aliases,
91
+ td_info.shift_type,
92
+
93
+ td_classes[:ci] + td_classes[ci + num_classes_per_task: ],
94
+ {cii: cii + target_class_idx_start for cii in range(ci, ci + num_classes_per_task)}
95
+ )]
96
+ num_tasks_per_dataset[td_info_i] += 1
97
+
98
+ if ci + num_classes_per_task < len(td_classes) - 1:
99
+ task_datasets_seq += [_ABDatasetMetaInfo(
100
+ f'{td_info.name}-task-{task_i + 1}|ci-{ci}-{ci + num_classes_per_task - 1}',
101
+ td_classes[ci: len(td_classes)],
102
+ td_info.task_type,
103
+ td_info.object_type,
104
+ td_info.class_aliases,
105
+ td_info.shift_type,
106
+
107
+ td_classes[:ci],
108
+ {cii: cii + target_class_idx_start for cii in range(ci, len(td_classes))}
109
+ )]
110
+ num_tasks_per_dataset[td_info_i] += 1
111
+
112
+ target_class_idx_start += len(td_classes)
113
+
114
+ if len(task_datasets_seq) < max_num_tasks:
115
+ print(len(task_datasets_seq), max_num_tasks)
116
+ raise RuntimeError()
117
+
118
+ task_datasets_seq = task_datasets_seq[0: max_num_tasks]
119
+ target_class_idx_start = max([max(list(td.idx_map.values())) + 1 for td in task_datasets_seq])
120
+
121
+ scenario = Scenario(config, task_datasets_seq, target_class_idx_start, source_class_idx_max + 1, data_dirs)
122
+
123
+ if sanity_check:
124
+ selected_tasks_index = []
125
+ for task_index, _ in enumerate(scenario.target_tasks_order):
126
+ cur_datasets = scenario.get_cur_task_train_datasets()
127
+
128
+ if len(cur_datasets) < 300:
129
+ # empty_tasks_index += [task_index]
130
+ # while True:
131
+ # replaced_task_index = random.randint(0, task_index - 1) # ensure no random
132
+ replaced_task_index = task_index // 2
133
+ assert replaced_task_index != task_index
134
+ while replaced_task_index in selected_tasks_index:
135
+ replaced_task_index += 1
136
+
137
+ task_datasets_seq[task_index] = deepcopy(task_datasets_seq[replaced_task_index])
138
+ selected_tasks_index += [replaced_task_index]
139
+
140
+ logger.warning(f'replace {task_index}-th task with {replaced_task_index}-th task')
141
+
142
+ # print(task_index, [t.name for t in task_datasets_seq])
143
+
144
+ scenario.next_task()
145
+
146
+ # print([t.name for t in task_datasets_seq])
147
+
148
+ if len(selected_tasks_index) > 0:
149
+ target_class_idx_start = max([max(list(td.idx_map.values())) + 1 for td in task_datasets_seq])
150
+ scenario = Scenario(config, task_datasets_seq, target_class_idx_start, source_class_idx_max + 1, data_dirs)
151
+
152
+ for task_index, _ in enumerate(scenario.target_tasks_order):
153
+ cur_datasets = scenario.get_cur_task_train_datasets()
154
+ logger.info(f'task {task_index}, len {len(cur_datasets)}')
155
+ assert len(cur_datasets) > 0
156
+
157
+ scenario.next_task()
158
+
159
+ scenario = Scenario(config, task_datasets_seq, target_class_idx_start, source_class_idx_max + 1, data_dirs)
160
+
161
+ return scenario
data/build_cl/scenario.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from functools import reduce
3
+ from typing import Dict, List, Tuple
4
+ import numpy as np
5
+ import copy
6
+ from utils.common.log import logger
7
+ from ..datasets.ab_dataset import ABDataset
8
+ from ..dataloader import FastDataLoader, InfiniteDataLoader, build_dataloader
9
+ from data import get_dataset, MergedDataset, Scenario as DAScenario
10
+
11
+
12
+ class _ABDatasetMetaInfo:
13
+ def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type, ignore_classes, idx_map):
14
+ self.name = name
15
+ self.classes = classes
16
+ self.class_aliases = class_aliases
17
+ self.shift_type = shift_type
18
+ self.task_type = task_type
19
+ self.object_type = object_type
20
+
21
+ self.ignore_classes = ignore_classes
22
+ self.idx_map = idx_map
23
+
24
+ def __repr__(self) -> str:
25
+ return f'({self.name}, {self.classes}, {self.idx_map})'
26
+
27
+
28
+ class Scenario:
29
+ def __init__(self, config, target_datasets_info: List[_ABDatasetMetaInfo], num_classes: int, num_source_classes: int, data_dirs):
30
+ self.config = config
31
+ self.target_datasets_info = target_datasets_info
32
+ self.num_classes = num_classes
33
+ self.cur_task_index = 0
34
+ self.num_source_classes = num_source_classes
35
+ self.cur_class_offset = num_source_classes
36
+ self.data_dirs = data_dirs
37
+
38
+ self.target_tasks_order = [i.name for i in self.target_datasets_info]
39
+ self.num_tasks_to_be_learn = sum([len(i.classes) for i in target_datasets_info])
40
+
41
+ logger.info(f'[scenario build] # classes: {num_classes}, # tasks to be learnt: {len(target_datasets_info)}, '
42
+ f'# classes per task: {config["num_classes_per_task"]}')
43
+
44
+ def to_json(self):
45
+ config = copy.deepcopy(self.config)
46
+ config['da_scenario'] = config['da_scenario'].to_json()
47
+ target_datasets_info = [str(i) for i in self.target_datasets_info]
48
+ return dict(
49
+ config=config, target_datasets_info=target_datasets_info,
50
+ num_classes=self.num_classes
51
+ )
52
+
53
+ def __str__(self):
54
+ return f'Scenario({self.to_json()})'
55
+
56
+ def get_cur_class_offset(self):
57
+ return self.cur_class_offset
58
+
59
+ def get_cur_num_class(self):
60
+ return len(self.target_datasets_info[self.cur_task_index].classes)
61
+
62
+ def get_nc_per_task(self):
63
+ return len(self.target_datasets_info[0].classes)
64
+
65
+ def next_task(self):
66
+ self.cur_class_offset += len(self.target_datasets_info[self.cur_task_index].classes)
67
+ self.cur_task_index += 1
68
+
69
+ print(f'now, cur task: {self.cur_task_index}, cur_class_offset: {self.cur_class_offset}')
70
+
71
+ def get_cur_task_datasets(self):
72
+ dataset_info = self.target_datasets_info[self.cur_task_index]
73
+ dataset_name = dataset_info.name.split('|')[0]
74
+ # print()
75
+
76
+ # source_datasets_info = []
77
+
78
+ res ={ **{split: get_dataset(dataset_name=dataset_name,
79
+ root_dir=self.data_dirs[dataset_name],
80
+ split=split,
81
+ transform=None,
82
+ ignore_classes=dataset_info.ignore_classes,
83
+ idx_map=dataset_info.idx_map) for split in ['train']},
84
+
85
+ **{split: MergedDataset([get_dataset(dataset_name=dataset_name,
86
+ root_dir=self.data_dirs[dataset_name],
87
+ split=split,
88
+ transform=None,
89
+ ignore_classes=di.ignore_classes,
90
+ idx_map=di.idx_map) for di in self.target_datasets_info[0: self.cur_task_index + 1]])
91
+ for split in ['val', 'test']}
92
+ }
93
+
94
+ # if len(res['train']) < 200 or len(res['val']) < 200 or len(res['test']) < 200:
95
+ # return None
96
+
97
+
98
+ if len(res['train']) < 1000:
99
+ res['train'] = MergedDataset([res['train']] * 5)
100
+ logger.info('aug train dataset')
101
+ if len(res['val']) < 1000:
102
+ res['val'] = MergedDataset(res['val'].datasets * 5)
103
+ logger.info('aug val dataset')
104
+ if len(res['test']) < 1000:
105
+ res['test'] = MergedDataset(res['test'].datasets * 5)
106
+ logger.info('aug test dataset')
107
+ # da_scenario: DAScenario = self.config['da_scenario']
108
+ # offline_datasets = da_scenario.get_offline_datasets()
109
+
110
+ for k, v in res.items():
111
+ logger.info(f'{k} dataset: {len(v)}')
112
+
113
+ # new_val_datasets = [
114
+ # *[d['val'] for d in offline_datasets.values()],
115
+ # res['val']
116
+ # ]
117
+ # res['val'] = MergedDataset(new_val_datasets)
118
+
119
+ # new_test_datasets = [
120
+ # *[d['test'] for d in offline_datasets.values()],
121
+ # res['test']
122
+ # ]
123
+ # res['test'] = MergedDataset(new_test_datasets)
124
+
125
+ return res
126
+
127
+ def get_cur_task_train_datasets(self):
128
+ dataset_info = self.target_datasets_info[self.cur_task_index]
129
+ dataset_name = dataset_info.name.split('|')[0]
130
+ # print()
131
+
132
+ # source_datasets_info = []
133
+
134
+ res = get_dataset(dataset_name=dataset_name,
135
+ root_dir=self.data_dirs[dataset_name],
136
+ split='train',
137
+ transform=None,
138
+ ignore_classes=dataset_info.ignore_classes,
139
+ idx_map=dataset_info.idx_map)
140
+
141
+ return res
142
+
143
+ def get_online_cur_task_samples_for_training(self, num_samples):
144
+ dataset = self.get_cur_task_datasets()
145
+ dataset = dataset['train']
146
+ return next(iter(build_dataloader(dataset, num_samples, 0, True, None)))[0]
data/convert_all_load_to_single_load.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ convert load-all-images-into-memory-before-training dataset
3
+ to load-when-training-dataset
4
+
5
+
6
+ """
7
+
8
+
9
+ from torchvision.datasets import CIFAR10, STL10, MNIST, USPS, SVHN
10
+ import os
11
+ import tqdm
12
+
13
+
14
+ def convert(datasets_of_split, new_dir):
15
+ img_idx = {}
16
+
17
+ for d in datasets_of_split:
18
+ for x, y in tqdm.tqdm(d, total=len(d), dynamic_ncols=True):
19
+ # print(type(x), type(y))
20
+ # break
21
+ # y = str(y)
22
+ if y not in img_idx:
23
+ img_idx[y] = -1
24
+ img_idx[y] += 1
25
+
26
+ p = os.path.join(new_dir, f'{y:06d}', f'{img_idx[y]:06d}' + '.png')
27
+ os.makedirs(os.path.dirname(p), exist_ok=True)
28
+
29
+ x.save(p)
30
+
31
+
32
+ if __name__ == '__main__':
33
+ # convert(
34
+ # [CIFAR10('/data/zql/datasets/CIFAR10', True, download=True), CIFAR10('/data/zql/datasets/CIFAR10', False, download=True)],
35
+ # '/data/zql/datasets/CIFAR10-single'
36
+ # )
37
+
38
+ # convert(
39
+ # [STL10('/data/zql/datasets/STL10', 'train', download=False), STL10('/data/zql/datasets/STL10', 'test', download=False)],
40
+ # '/data/zql/datasets/STL10-single'
41
+ # )
42
+
43
+ # convert(
44
+ # [MNIST('/data/zql/datasets/MNIST', True, download=True), MNIST('/data/zql/datasets/MNIST', False, download=True)],
45
+ # '/data/zql/datasets/MNIST-single'
46
+ # )
47
+
48
+ convert(
49
+ [SVHN('/data/zql/datasets/SVHN', 'train', download=True), SVHN('/data/zql/datasets/SVHN', 'test', download=True)],
50
+ '/data/zql/datasets/SVHN-single'
51
+ )
52
+
53
+ # convert(
54
+ # [USPS('/data/zql/datasets/USPS', True, download=False), USPS('/data/zql/datasets/USPS', False, download=False)],
55
+ # '/data/zql/datasets/USPS-single'
56
+ # )
data/convert_det_dataset_to_cls.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data import ABDataset
2
+ from utils.common.data_record import read_json, write_json
3
+ from PIL import Image
4
+ import os
5
+ from utils.common.file import ensure_dir
6
+ import numpy as np
7
+ from itertools import groupby
8
+ from skimage import morphology, measure
9
+ from PIL import Image
10
+ from scipy import misc
11
+ import tqdm
12
+ from PIL import ImageFile
13
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
+ import shutil
15
+
16
+
17
+ def convert_det_dataset_to_det(coco_ann_json_path, data_dir, target_data_dir, min_img_size=224):
18
+
19
+ coco_ann = read_json(coco_ann_json_path)
20
+
21
+ img_id_to_path = {}
22
+ for img in coco_ann['images']:
23
+ img_id_to_path[img['id']] = os.path.join(data_dir, img['file_name'])
24
+
25
+ classes_imgs_id_map = {}
26
+ for ann in tqdm.tqdm(coco_ann['annotations'], total=len(coco_ann['annotations']), dynamic_ncols=True):
27
+ img_id = ann['image_id']
28
+ img_path = img_id_to_path[img_id]
29
+ img = Image.open(img_path)
30
+
31
+ bbox = ann['bbox']
32
+ if bbox[2] < min_img_size or bbox[3] < min_img_size:
33
+ continue
34
+
35
+ bbox = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]
36
+
37
+ class_idx = str(ann['category_id'])
38
+ if class_idx not in classes_imgs_id_map.keys():
39
+ classes_imgs_id_map[class_idx] = 0
40
+ target_cropped_img_path = os.path.join(target_data_dir, class_idx,
41
+ f'{classes_imgs_id_map[class_idx]}.{img_path.split(".")[-1]}')
42
+ classes_imgs_id_map[class_idx] += 1
43
+
44
+ ensure_dir(target_cropped_img_path)
45
+ img.crop(bbox).save(target_cropped_img_path)
46
+
47
+
48
+
49
+ if __name__ == '__main__':
50
+ convert_det_dataset_to_det(
51
+ coco_ann_json_path='/data/zql/datasets/coco2017/train2017/coco_ann.json',
52
+ data_dir='/data/zql/datasets/coco2017/train2017',
53
+ target_data_dir='/data/zql/datasets/coco2017_for_cls_task',
54
+ min_img_size=224
55
+ )
data/convert_seg_dataset_to_cls.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data import ABDataset
2
+ from utils.common.data_record import read_json
3
+ from PIL import Image
4
+ import os
5
+ from utils.common.file import ensure_dir
6
+ import numpy as np
7
+ from itertools import groupby
8
+ from skimage import morphology, measure
9
+ from PIL import Image
10
+ from scipy import misc
11
+ import tqdm
12
+ from PIL import ImageFile
13
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
+ import shutil
15
+
16
+
17
+ def convert_seg_dataset_to_cls(seg_imgs_path, seg_labels_path, target_cls_data_dir, ignore_classes_idx, thread_i, min_img_size=224, label_after_hook=lambda x: x):
18
+ """
19
+ Reference: https://blog.csdn.net/lizaijinsheng/article/details/119889946
20
+
21
+ NOTE:
22
+ Background class should not be considered.
23
+ However, if a seg dataset has only one valid class, so that the generated cls dataset also has only one class and
24
+ the cls accuracy will be 100% forever. But we do not use the generated cls dataset alone, so it is ok.
25
+ """
26
+ assert len(seg_imgs_path) == len(seg_labels_path)
27
+
28
+ classes_imgs_id_map = {}
29
+
30
+ for seg_img_path, seg_label_path in tqdm.tqdm(zip(seg_imgs_path, seg_labels_path), total=len(seg_imgs_path),
31
+ dynamic_ncols=True, leave=False, desc=f'thread {thread_i}'):
32
+
33
+ try:
34
+ seg_img = Image.open(seg_img_path)
35
+ seg_label = Image.open(seg_label_path).convert('L')
36
+ seg_label = np.array(seg_label)
37
+ seg_label = label_after_hook(seg_label)
38
+ except Exception as e:
39
+ print(e)
40
+ print(f'file {seg_img_path} error, skip')
41
+ exit()
42
+ # seg_img = Image.open(seg_img_path)
43
+ # seg_label = Image.open(seg_label_path).convert('L')
44
+ # seg_label = np.array(seg_label)
45
+
46
+ this_img_classes = set(seg_label.reshape(-1).tolist())
47
+ # print(this_img_classes)
48
+
49
+ for class_idx in this_img_classes:
50
+ if class_idx in ignore_classes_idx:
51
+ continue
52
+
53
+ if class_idx not in classes_imgs_id_map.keys():
54
+ classes_imgs_id_map[class_idx] = 0
55
+
56
+ mask = np.zeros((seg_label.shape[0], seg_label.shape[1]), dtype=np.uint8)
57
+ mask[seg_label == class_idx] = 1
58
+ mask_without_small = morphology.remove_small_objects(mask, min_size=10, connectivity=2)
59
+ label_image = measure.label(mask_without_small)
60
+
61
+ for region in measure.regionprops(label_image):
62
+ bbox = region.bbox # (top, left, bottom, right)
63
+ bbox = [bbox[1], bbox[0], bbox[3], bbox[2]] # (left, top, right, bottom)
64
+
65
+ width, height = bbox[2] - bbox[0], bbox[3] - bbox[1]
66
+ if width < min_img_size or height < min_img_size:
67
+ continue
68
+
69
+ target_cropped_img_path = os.path.join(target_cls_data_dir, str(class_idx),
70
+ f'{classes_imgs_id_map[class_idx]}.{seg_img_path.split(".")[-1]}')
71
+ ensure_dir(target_cropped_img_path)
72
+ seg_img.crop(bbox).save(target_cropped_img_path)
73
+ # print(target_cropped_img_path)
74
+ # exit()
75
+
76
+ classes_imgs_id_map[class_idx] += 1
77
+
78
+ num_cls_imgs = 0
79
+ for k, v in classes_imgs_id_map.items():
80
+ # print(f'# class {k}: {v + 1}')
81
+ num_cls_imgs += v
82
+ # print(f'total: {num_cls_imgs}')
83
+
84
+ return classes_imgs_id_map
85
+
86
+
87
+ from concurrent.futures import ThreadPoolExecutor
88
+
89
+
90
+
91
+ # def convert_seg_dataset_to_cls_multi_thread(seg_imgs_path, seg_labels_path, target_cls_data_dir, ignore_classes_idx, num_threads):
92
+ # if os.path.exists(target_cls_data_dir):
93
+ # shutil.rmtree(target_cls_data_dir)
94
+
95
+ # assert len(seg_imgs_path) == len(seg_labels_path)
96
+ # n = len(seg_imgs_path) // num_threads
97
+
98
+ # pool = ThreadPoolExecutor(max_workers=num_threads)
99
+ # # threads = []
100
+ # futures = []
101
+ # for thread_i in range(num_threads):
102
+ # # thread = threading.Thread(target=convert_seg_dataset_to_cls,
103
+ # # args=(seg_imgs_path[thread_i * n: (thread_i + 1) * n],
104
+ # # seg_labels_path[thread_i * n: (thread_i + 1) * n],
105
+ # # target_cls_data_dir, ignore_classes_idx))
106
+ # # threads += [thread]
107
+ # future = pool.submit(convert_seg_dataset_to_cls, *(seg_imgs_path[thread_i * n: (thread_i + 1) * n],
108
+ # seg_labels_path[thread_i * n: (thread_i + 1) * n],
109
+ # target_cls_data_dir, ignore_classes_idx, thread_i))
110
+ # futures += [future]
111
+
112
+ # futures += [
113
+ # pool.submit(convert_seg_dataset_to_cls, *(seg_imgs_path[(thread_i + 1) * n: ],
114
+ # seg_labels_path[(thread_i + 1) * n: ],
115
+ # target_cls_data_dir, ignore_classes_idx, thread_i))
116
+ # ]
117
+
118
+ # for f in futures:
119
+ # f.done()
120
+
121
+ # res = []
122
+ # for f in futures:
123
+ # res += [f.result()]
124
+ # print(res[-1])
125
+
126
+ # res_dist = {}
127
+ # for r in res:
128
+ # for k, v in r.items():
129
+ # if k in res_dist.keys():
130
+ # res_dist[k] += v
131
+ # else:
132
+ # res_dist[k] = v
133
+
134
+ # print('results:')
135
+ # print(res_dist)
136
+
137
+ # pool.shutdown()
138
+
139
+
140
+
141
+ import random
142
+ def random_crop_aug(target_dir):
143
+ for class_dir in os.listdir(target_dir):
144
+ class_dir = os.path.join(target_dir, class_dir)
145
+
146
+ for img_path in os.listdir(class_dir):
147
+ img_path = os.path.join(class_dir, img_path)
148
+
149
+ img = Image.open(img_path)
150
+
151
+ w, h = img.width, img.height
152
+
153
+ for ri in range(5):
154
+ img.crop(
155
+ [
156
+ random.randint(0, w // 5),
157
+ random.randint(0, h // 5),
158
+ random.randint(w // 5 * 4, w),
159
+ random.randint(h // 5 * 4, h)
160
+ ]
161
+ ).save(
162
+ os.path.join(os.path.dirname(img_path), f'randaug_{ri}_' + os.path.basename(img_path))
163
+ )
164
+ # print(img_path)
165
+ # exit()
166
+
167
+
168
+ if __name__ == '__main__':
169
+ # SuperviselyPerson
170
+ # root_dir = '/data/zql/datasets/supervisely_person/Supervisely Person Dataset'
171
+
172
+ # images_path, labels_path = [], []
173
+ # for p in os.listdir(root_dir):
174
+ # if p.startswith('ds'):
175
+ # p1 = os.path.join(root_dir, p, 'img')
176
+ # images_path += [(p, os.path.join(p1, n)) for n in os.listdir(p1)]
177
+ # for dsi, img_p in images_path:
178
+ # target_p = os.path.join(root_dir, p, dsi, img_p.split('/')[-1])
179
+ # labels_path += [target_p]
180
+ # images_path = [i[1] for i in images_path]
181
+
182
+ # target_dir = '/data/zql/datasets/supervisely_person_for_cls_task'
183
+ # if os.path.exists(target_dir):
184
+ # shutil.rmtree(target_dir)
185
+ # convert_seg_dataset_to_cls(
186
+ # seg_imgs_path=images_path,
187
+ # seg_labels_path=labels_path,
188
+ # target_cls_data_dir=target_dir,
189
+ # ignore_classes_idx=[0, 2],
190
+ # # num_threads=8
191
+ # thread_i=0
192
+ # )
193
+
194
+ # random_crop_aug('/data/zql/datasets/supervisely_person_for_cls_task')
195
+
196
+
197
+ # GTA5
198
+ # root_dir = '/data/zql/datasets/GTA-ls-copy/GTA5'
199
+ # images_path, labels_path = [], []
200
+ # for p in os.listdir(os.path.join(root_dir, 'images')):
201
+ # p = os.path.join(root_dir, 'images', p)
202
+ # if not p.endswith('png'):
203
+ # continue
204
+ # images_path += [p]
205
+ # labels_path += [p.replace('images', 'labels_gt')]
206
+
207
+ # target_dir = '/data/zql/datasets/gta5_for_cls_task'
208
+ # if os.path.exists(target_dir):
209
+ # shutil.rmtree(target_dir)
210
+
211
+ # convert_seg_dataset_to_cls(
212
+ # seg_imgs_path=images_path,
213
+ # seg_labels_path=labels_path,
214
+ # target_cls_data_dir=target_dir,
215
+ # ignore_classes_idx=[],
216
+ # thread_i=0
217
+ # )
218
+
219
+ # cityscapes
220
+ # root_dir = '/data/zql/datasets/cityscape/'
221
+
222
+ # def _get_target_suffix(mode: str, target_type: str) -> str:
223
+ # if target_type == 'instance':
224
+ # return '{}_instanceIds.png'.format(mode)
225
+ # elif target_type == 'semantic':
226
+ # return '{}_labelIds.png'.format(mode)
227
+ # elif target_type == 'color':
228
+ # return '{}_color.png'.format(mode)
229
+ # else:
230
+ # return '{}_polygons.json'.format(mode)
231
+
232
+
233
+ # images_path, labels_path = [], []
234
+ # split = 'train'
235
+ # images_dir = os.path.join(root_dir, 'leftImg8bit', split)
236
+ # targets_dir = os.path.join(root_dir, 'gtFine', split)
237
+ # for city in os.listdir(images_dir):
238
+ # img_dir = os.path.join(images_dir, city)
239
+ # target_dir = os.path.join(targets_dir, city)
240
+ # for file_name in os.listdir(img_dir):
241
+ # target_types = []
242
+ # for t in ['semantic']:
243
+ # target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
244
+ # _get_target_suffix('gtFine', t))
245
+ # target_types.append(os.path.join(target_dir, target_name))
246
+
247
+ # images_path.append(os.path.join(img_dir, file_name))
248
+ # labels_path.append(target_types[0])
249
+
250
+ # print(images_path[0: 5], '\n', labels_path[0: 5])
251
+
252
+ # target_dir = '/data/zql/datasets/cityscapes_for_cls_task'
253
+ # if os.path.exists(target_dir):
254
+ # shutil.rmtree(target_dir)
255
+ # convert_seg_dataset_to_cls(
256
+ # seg_imgs_path=images_path,
257
+ # seg_labels_path=labels_path,
258
+ # target_cls_data_dir=target_dir,
259
+ # ignore_classes_idx=[],
260
+ # # num_threads=8
261
+ # thread_i=0
262
+ # )
263
+
264
+ # import shutil
265
+
266
+ # ignore_target_dir = '/data/zql/datasets/cityscapes_for_cls_task_ignored'
267
+
268
+ # ignore_label = 255
269
+ # raw_idx_map_in_y_transform = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,
270
+ # 3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label,
271
+ # 7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4,
272
+ # 14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5,
273
+ # 18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14,
274
+ # 28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}
275
+ # ignore_classes_idx = [k for k, v in raw_idx_map_in_y_transform.items() if v == ignore_label]
276
+ # ignore_classes_idx = sorted(ignore_classes_idx)
277
+
278
+ # for class_dir in os.listdir(target_dir):
279
+ # if int(class_dir) in ignore_classes_idx:
280
+ # continue
281
+ # shutil.move(
282
+ # os.path.join(target_dir, class_dir),
283
+ # os.path.join(ignore_target_dir, class_dir)
284
+ # )
285
+ # else:
286
+ # shutil.move(
287
+ # os.path.join(target_dir, class_dir),
288
+ # os.path.join(target_dir, str(raw_idx_map_in_y_transform[int(class_dir)]))
289
+ # )
290
+ # continue
291
+ # print(class_dir)
292
+ # exit()
293
+
294
+
295
+
296
+ # baidu person
297
+ # root_dir = '/data/zql/datasets/baidu_person/clean_images/'
298
+
299
+ # images_path, labels_path = [], []
300
+ # for p in os.listdir(os.path.join(root_dir, 'images')):
301
+ # images_path += [os.path.join(root_dir, 'images', p)]
302
+ # labels_path += [os.path.join(root_dir, 'profiles', p.split('.')[0] + '-profile.jpg')]
303
+
304
+ # target_dir = '/data/zql/datasets/baiduperson_for_cls_task'
305
+ # # if os.path.exists(target_dir):
306
+ # # shutil.rmtree(target_dir)
307
+
308
+ # def label_after_hook(x):
309
+ # x[x > 1] = 1
310
+ # return x
311
+
312
+ # convert_seg_dataset_to_cls(
313
+ # seg_imgs_path=images_path,
314
+ # seg_labels_path=labels_path,
315
+ # target_cls_data_dir=target_dir,
316
+ # ignore_classes_idx=[1],
317
+ # # num_threads=8
318
+ # thread_i=1,
319
+ # min_img_size=224,
320
+ # label_after_hook=label_after_hook
321
+ # )
322
+
323
+
324
+
data/convert_seg_dataset_to_det.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data import ABDataset
2
+ from utils.common.data_record import read_json, write_json
3
+ from PIL import Image
4
+ import os
5
+ from utils.common.file import ensure_dir
6
+ import numpy as np
7
+ from itertools import groupby
8
+ from skimage import morphology, measure
9
+ from PIL import Image
10
+ from scipy import misc
11
+ import tqdm
12
+ from PIL import ImageFile
13
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
+ import shutil
15
+
16
+
17
+ def convert_seg_dataset_to_det(seg_imgs_path, seg_labels_path, root_dir, target_coco_ann_path, ignore_classes_idx, thread_i, min_img_size=224, label_after_hook=lambda x: x):
18
+ """
19
+ Reference: https://blog.csdn.net/lizaijinsheng/article/details/119889946
20
+
21
+ NOTE:
22
+ Background class should not be considered.
23
+ However, if a seg dataset has only one valid class, so that the generated cls dataset also has only one class and
24
+ the cls accuracy will be 100% forever. But we do not use the generated cls dataset alone, so it is ok.
25
+ """
26
+ assert len(seg_imgs_path) == len(seg_labels_path)
27
+
28
+ classes_imgs_id_map = {}
29
+
30
+ coco_ann = {
31
+ 'categories': [],
32
+ "type": "instances",
33
+ 'images': [],
34
+ 'annotations': []
35
+ }
36
+
37
+ image_id = 0
38
+ ann_id = 0
39
+
40
+ pbar = tqdm.tqdm(zip(seg_imgs_path, seg_labels_path), total=len(seg_imgs_path),
41
+ dynamic_ncols=True, leave=False, desc=f'thread {thread_i}')
42
+ for seg_img_path, seg_label_path in pbar:
43
+
44
+ try:
45
+ seg_img = Image.open(seg_img_path)
46
+ seg_label = Image.open(seg_label_path).convert('L')
47
+ seg_label = np.array(seg_label)
48
+ seg_label = label_after_hook(seg_label)
49
+ except Exception as e:
50
+ print(e)
51
+ print(f'file {seg_img_path} error, skip')
52
+ exit()
53
+ # seg_img = Image.open(seg_img_path)
54
+ # seg_label = Image.open(seg_label_path).convert('L')
55
+ # seg_label = np.array(seg_label)
56
+
57
+ image_coco_info = {'file_name': os.path.relpath(seg_img_path, root_dir), 'height': seg_img.height, 'width': seg_img.width,
58
+ 'id':image_id}
59
+ image_id += 1
60
+ coco_ann['images'] += [image_coco_info]
61
+
62
+ this_img_classes = set(seg_label.reshape(-1).tolist())
63
+ # print(this_img_classes)
64
+
65
+ for class_idx in this_img_classes:
66
+ if class_idx in ignore_classes_idx:
67
+ continue
68
+
69
+ if class_idx not in classes_imgs_id_map.keys():
70
+ classes_imgs_id_map[class_idx] = 0
71
+
72
+ mask = np.zeros((seg_label.shape[0], seg_label.shape[1]), dtype=np.uint8)
73
+ mask[seg_label == class_idx] = 1
74
+ mask_without_small = morphology.remove_small_objects(mask, min_size=10, connectivity=2)
75
+ label_image = measure.label(mask_without_small)
76
+
77
+ for region in measure.regionprops(label_image):
78
+ bbox = region.bbox # (top, left, bottom, right)
79
+ bbox = [bbox[1], bbox[0], bbox[3], bbox[2]] # (left, top, right, bottom)
80
+
81
+ width, height = bbox[2] - bbox[0], bbox[3] - bbox[1]
82
+ if width < min_img_size or height < min_img_size:
83
+ continue
84
+
85
+ # target_cropped_img_path = os.path.join(target_cls_data_dir, str(class_idx),
86
+ # f'{classes_imgs_id_map[class_idx]}.{seg_img_path.split(".")[-1]}')
87
+ # ensure_dir(target_cropped_img_path)
88
+ # seg_img.crop(bbox).save(target_cropped_img_path)
89
+ # print(target_cropped_img_path)
90
+ # exit()
91
+
92
+ ann_coco_info = {'area': width*height, 'iscrowd': 0, 'image_id':
93
+ image_id - 1, 'bbox': [bbox[0], bbox[1], width, height],
94
+ 'category_id': class_idx,
95
+ 'id': ann_id, 'ignore': 0,
96
+ 'segmentation': []}
97
+ ann_id += 1
98
+
99
+ coco_ann['annotations'] += [ann_coco_info]
100
+
101
+ classes_imgs_id_map[class_idx] += 1
102
+
103
+ pbar.set_description(f'# ann: {ann_id}')
104
+
105
+ coco_ann['categories'] = [
106
+ {'id': ci, 'name': f'class_{c}_in_seg'} for ci, c in enumerate(classes_imgs_id_map.keys())
107
+ ]
108
+ c_to_ci_map = {c: ci for ci, c in enumerate(classes_imgs_id_map.keys())}
109
+ for ann in coco_ann['annotations']:
110
+ ann['category_id'] = c_to_ci_map[
111
+ ann['category_id']
112
+ ]
113
+
114
+ write_json(target_coco_ann_path, coco_ann, indent=0, backup=True)
115
+ write_json(os.path.join(root_dir, 'coco_ann.json'), coco_ann, indent=0, backup=True)
116
+
117
+ num_cls_imgs = 0
118
+ for k, v in classes_imgs_id_map.items():
119
+ # print(f'# class {k}: {v + 1}')
120
+ num_cls_imgs += v
121
+ # print(f'total: {num_cls_imgs}')
122
+
123
+ return classes_imgs_id_map
124
+
125
+
126
+ from concurrent.futures import ThreadPoolExecutor
127
+
128
+
129
+
130
+ # def convert_seg_dataset_to_cls_multi_thread(seg_imgs_path, seg_labels_path, target_cls_data_dir, ignore_classes_idx, num_threads):
131
+ # if os.path.exists(target_cls_data_dir):
132
+ # shutil.rmtree(target_cls_data_dir)
133
+
134
+ # assert len(seg_imgs_path) == len(seg_labels_path)
135
+ # n = len(seg_imgs_path) // num_threads
136
+
137
+ # pool = ThreadPoolExecutor(max_workers=num_threads)
138
+ # # threads = []
139
+ # futures = []
140
+ # for thread_i in range(num_threads):
141
+ # # thread = threading.Thread(target=convert_seg_dataset_to_cls,
142
+ # # args=(seg_imgs_path[thread_i * n: (thread_i + 1) * n],
143
+ # # seg_labels_path[thread_i * n: (thread_i + 1) * n],
144
+ # # target_cls_data_dir, ignore_classes_idx))
145
+ # # threads += [thread]
146
+ # future = pool.submit(convert_seg_dataset_to_cls, *(seg_imgs_path[thread_i * n: (thread_i + 1) * n],
147
+ # seg_labels_path[thread_i * n: (thread_i + 1) * n],
148
+ # target_cls_data_dir, ignore_classes_idx, thread_i))
149
+ # futures += [future]
150
+
151
+ # futures += [
152
+ # pool.submit(convert_seg_dataset_to_cls, *(seg_imgs_path[(thread_i + 1) * n: ],
153
+ # seg_labels_path[(thread_i + 1) * n: ],
154
+ # target_cls_data_dir, ignore_classes_idx, thread_i))
155
+ # ]
156
+
157
+ # for f in futures:
158
+ # f.done()
159
+
160
+ # res = []
161
+ # for f in futures:
162
+ # res += [f.result()]
163
+ # print(res[-1])
164
+
165
+ # res_dist = {}
166
+ # for r in res:
167
+ # for k, v in r.items():
168
+ # if k in res_dist.keys():
169
+ # res_dist[k] += v
170
+ # else:
171
+ # res_dist[k] = v
172
+
173
+ # print('results:')
174
+ # print(res_dist)
175
+
176
+ # pool.shutdown()
177
+
178
+
179
+
180
+ # import random
181
+ # def random_crop_aug(target_dir):
182
+ # for class_dir in os.listdir(target_dir):
183
+ # class_dir = os.path.join(target_dir, class_dir)
184
+
185
+ # for img_path in os.listdir(class_dir):
186
+ # img_path = os.path.join(class_dir, img_path)
187
+
188
+ # img = Image.open(img_path)
189
+
190
+ # w, h = img.width, img.height
191
+
192
+ # for ri in range(5):
193
+ # img.crop(
194
+ # [
195
+ # random.randint(0, w // 5),
196
+ # random.randint(0, h // 5),
197
+ # random.randint(w // 5 * 4, w),
198
+ # random.randint(h // 5 * 4, h)
199
+ # ]
200
+ # ).save(
201
+ # os.path.join(os.path.dirname(img_path), f'randaug_{ri}_' + os.path.basename(img_path))
202
+ # )
203
+ # # print(img_path)
204
+ # # exit()
205
+
206
+
207
+ def post_ignore_classes(coco_ann_json_path):
208
+ # from data.datasets.object_detection.yolox_data_util.api import remap_dataset
209
+ # remap_dataset(coco_ann_json_path, [], {})
210
+ pass
211
+
212
+
213
+
214
+ if __name__ == '__main__':
215
+ # SuperviselyPerson
216
+ # root_dir = '/data/zql/datasets/supervisely_person_full_20230635/Supervisely Person Dataset'
217
+
218
+ # images_path, labels_path = [], []
219
+ # for p in os.listdir(root_dir):
220
+ # if p.startswith('ds'):
221
+ # p1 = os.path.join(root_dir, p, 'img')
222
+ # images_path += [(p, os.path.join(p1, n)) for n in os.listdir(p1)]
223
+ # for dsi, img_p in images_path:
224
+ # target_p = os.path.join(root_dir, p, dsi, img_p.split('/')[-1])
225
+ # labels_path += [target_p]
226
+ # images_path = [i[1] for i in images_path]
227
+
228
+ # target_coco_ann_path = '/data/zql/datasets/supervisely_person_for_det_task/coco_ann.json'
229
+ # if os.path.exists(target_coco_ann_path):
230
+ # os.remove(target_coco_ann_path)
231
+ # convert_seg_dataset_to_det(
232
+ # seg_imgs_path=images_path,
233
+ # seg_labels_path=labels_path,
234
+ # root_dir=root_dir,
235
+ # target_coco_ann_path=target_coco_ann_path,
236
+ # ignore_classes_idx=[0, 2],
237
+ # # num_threads=8
238
+ # thread_i=0
239
+ # )
240
+
241
+ # random_crop_aug('/data/zql/datasets/supervisely_person_for_cls_task')
242
+
243
+
244
+ # GTA5
245
+ # root_dir = '/data/zql/datasets/GTA-ls-copy/GTA5'
246
+ # images_path, labels_path = [], []
247
+ # for p in os.listdir(os.path.join(root_dir, 'images')):
248
+ # p = os.path.join(root_dir, 'images', p)
249
+ # if not p.endswith('png'):
250
+ # continue
251
+ # images_path += [p]
252
+ # labels_path += [p.replace('images', 'labels_gt')]
253
+
254
+ # target_coco_ann_path = '/data/zql/datasets/gta5_for_det_task/coco_ann.json'
255
+ # if os.path.exists(target_coco_ann_path):
256
+ # os.remove(target_coco_ann_path)
257
+
258
+ # """
259
+ # [
260
+ # 'road', 'sidewalk', 'building', 'wall',
261
+ # 'fence', 'pole', 'light', 'sign',
262
+ # 'vegetation', 'terrain', 'sky', 'people', # person
263
+ # 'rider', 'car', 'truck', 'bus', 'train',
264
+ # 'motocycle', 'bicycle'
265
+ # ]
266
+ # """
267
+ # need_classes_idx = [13, 15]
268
+ # convert_seg_dataset_to_det(
269
+ # seg_imgs_path=images_path,
270
+ # seg_labels_path=labels_path,
271
+ # root_dir=root_dir,
272
+ # target_coco_ann_path=target_coco_ann_path,
273
+ # ignore_classes_idx=[i for i in range(20) if i not in need_classes_idx],
274
+ # thread_i=0
275
+ # )
276
+
277
+ # from data.datasets.object_detection.yolox_data_util.api import remap_dataset
278
+ # new_coco_ann_json_path = remap_dataset('/data/zql/datasets/GTA-ls-copy/GTA5/coco_ann.json', [-1], {0: 0, 1:-1, 2:-1, 3: 1, 4:-1, 5:-1})
279
+ # print(new_coco_ann_json_path)
280
+
281
+ # cityscapes
282
+ # root_dir = '/data/zql/datasets/cityscape/'
283
+
284
+ # def _get_target_suffix(mode: str, target_type: str) -> str:
285
+ # if target_type == 'instance':
286
+ # return '{}_instanceIds.png'.format(mode)
287
+ # elif target_type == 'semantic':
288
+ # return '{}_labelIds.png'.format(mode)
289
+ # elif target_type == 'color':
290
+ # return '{}_color.png'.format(mode)
291
+ # else:
292
+ # return '{}_polygons.json'.format(mode)
293
+
294
+
295
+ # images_path, labels_path = [], []
296
+ # split = 'train'
297
+ # images_dir = os.path.join(root_dir, 'leftImg8bit', split)
298
+ # targets_dir = os.path.join(root_dir, 'gtFine', split)
299
+ # for city in os.listdir(images_dir):
300
+ # img_dir = os.path.join(images_dir, city)
301
+ # target_dir = os.path.join(targets_dir, city)
302
+ # for file_name in os.listdir(img_dir):
303
+ # target_types = []
304
+ # for t in ['semantic']:
305
+ # target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
306
+ # _get_target_suffix('gtFine', t))
307
+ # target_types.append(os.path.join(target_dir, target_name))
308
+
309
+ # images_path.append(os.path.join(img_dir, file_name))
310
+ # labels_path.append(target_types[0])
311
+
312
+ # # print(images_path[0: 5], '\n', labels_path[0: 5])
313
+
314
+ # target_coco_ann_path = '/data/zql/datasets/cityscape/coco_ann.json'
315
+ # # if os.path.exists(target_dir):
316
+ # # shutil.rmtree(target_dir)
317
+
318
+ # need_classes_idx = [26, 28]
319
+ # convert_seg_dataset_to_det(
320
+ # seg_imgs_path=images_path,
321
+ # seg_labels_path=labels_path,
322
+ # root_dir=root_dir,
323
+ # target_coco_ann_path=target_coco_ann_path,
324
+ # ignore_classes_idx=[i for i in range(80) if i not in need_classes_idx],
325
+ # # num_threads=8
326
+ # thread_i=0
327
+ # )
328
+
329
+ # import shutil
330
+
331
+ # ignore_target_dir = '/data/zql/datasets/cityscapes_for_cls_task_ignored'
332
+
333
+ # ignore_label = 255
334
+ # raw_idx_map_in_y_transform = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,
335
+ # 3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label,
336
+ # 7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4,
337
+ # 14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5,
338
+ # 18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14,
339
+ # 28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}
340
+ # ignore_classes_idx = [k for k, v in raw_idx_map_in_y_transform.items() if v == ignore_label]
341
+ # ignore_classes_idx = sorted(ignore_classes_idx)
342
+
343
+ # for class_dir in os.listdir(target_dir):
344
+ # if int(class_dir) in ignore_classes_idx:
345
+ # continue
346
+ # shutil.move(
347
+ # os.path.join(target_dir, class_dir),
348
+ # os.path.join(ignore_target_dir, class_dir)
349
+ # )
350
+ # else:
351
+ # shutil.move(
352
+ # os.path.join(target_dir, class_dir),
353
+ # os.path.join(target_dir, str(raw_idx_map_in_y_transform[int(class_dir)]))
354
+ # )
355
+ # continue
356
+ # print(class_dir)
357
+ # exit()
358
+
359
+
360
+
361
+ # baidu person
362
+ # root_dir = '/data/zql/datasets/baidu_person/clean_images/'
363
+
364
+ # images_path, labels_path = [], []
365
+ # for p in os.listdir(os.path.join(root_dir, 'images')):
366
+ # images_path += [os.path.join(root_dir, 'images', p)]
367
+ # labels_path += [os.path.join(root_dir, 'profiles', p.split('.')[0] + '-profile.jpg')]
368
+
369
+ # target_dir = '/data/zql/datasets/baiduperson_for_cls_task'
370
+ # # if os.path.exists(target_dir):
371
+ # # shutil.rmtree(target_dir)
372
+
373
+ # def label_after_hook(x):
374
+ # x[x > 1] = 1
375
+ # return x
376
+
377
+ # convert_seg_dataset_to_det(
378
+ # seg_imgs_path=images_path,
379
+ # seg_labels_path=labels_path,
380
+ # root_dir=root_dir,
381
+ # target_coco_ann_path='/data/zql/datasets/baidu_person/clean_images/coco_ann_zql.json',
382
+ # ignore_classes_idx=[1],
383
+ # # num_threads=8
384
+ # thread_i=1,
385
+ # min_img_size=224,
386
+ # label_after_hook=label_after_hook
387
+ # )
388
+
389
+
390
+ # from data.visualize import visualize_classes_in_object_detection
391
+ # from data import get_dataset
392
+ # d = get_dataset('CityscapesDet', '/data/zql/datasets/cityscape/', 'val', None, [], None)
393
+ # visualize_classes_in_object_detection(d, {'car': 0, 'bus': 1}, {}, 'debug.png')
394
+
395
+ # d = get_dataset('GTA5Det', '/data/zql/datasets/GTA-ls-copy/GTA5', 'val', None, [], None)
396
+ # visualize_classes_in_object_detection(d, {'car': 0, 'bus': 1}, {}, 'debug.png')
397
+
398
+ # d = get_dataset('BaiduPersonDet', '/data/zql/datasets/baidu_person/clean_images/', 'val', None, [], None)
399
+ # visualize_classes_in_object_detection(d, {'person': 0}, {}, 'debug.png')
data/dataloader.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ # domainbed/lib/fast_data_loader.py
3
+
4
+ import torch
5
+ from .datasets.ab_dataset import ABDataset
6
+
7
+
8
+ class _InfiniteSampler(torch.utils.data.Sampler):
9
+ """Wraps another Sampler to yield an infinite stream."""
10
+
11
+ def __init__(self, sampler):
12
+ self.sampler = sampler
13
+
14
+ def __iter__(self):
15
+ while True:
16
+ for batch in self.sampler:
17
+ yield batch
18
+
19
+
20
+ class InfiniteDataLoader:
21
+ def __init__(self, dataset, weights, batch_size, num_workers, collate_fn=None):
22
+ super().__init__()
23
+
24
+ if weights:
25
+ sampler = torch.utils.data.WeightedRandomSampler(
26
+ weights, replacement=True, num_samples=batch_size
27
+ )
28
+ else:
29
+ sampler = torch.utils.data.RandomSampler(dataset, replacement=True)
30
+
31
+ batch_sampler = torch.utils.data.BatchSampler(
32
+ sampler, batch_size=batch_size, drop_last=True
33
+ )
34
+
35
+ if collate_fn is not None:
36
+ self._infinite_iterator = iter(
37
+ torch.utils.data.DataLoader(
38
+ dataset,
39
+ num_workers=num_workers,
40
+ batch_sampler=_InfiniteSampler(batch_sampler),
41
+ pin_memory=False,
42
+ collate_fn=collate_fn
43
+ )
44
+ )
45
+ else:
46
+ self._infinite_iterator = iter(
47
+ torch.utils.data.DataLoader(
48
+ dataset,
49
+ num_workers=num_workers,
50
+ batch_sampler=_InfiniteSampler(batch_sampler),
51
+ pin_memory=False
52
+ )
53
+ )
54
+ self.dataset = dataset
55
+
56
+ def __iter__(self):
57
+ while True:
58
+ yield next(self._infinite_iterator)
59
+
60
+ def __len__(self):
61
+ raise ValueError
62
+
63
+
64
+ class FastDataLoader:
65
+ """
66
+ DataLoader wrapper with slightly improved speed by not respawning worker
67
+ processes at every epoch.
68
+ """
69
+
70
+ def __init__(self, dataset, batch_size, num_workers, shuffle=False, collate_fn=None):
71
+ super().__init__()
72
+
73
+ self.num_workers = num_workers
74
+
75
+ if shuffle:
76
+ sampler = torch.utils.data.RandomSampler(dataset, replacement=False)
77
+ else:
78
+ sampler = torch.utils.data.SequentialSampler(dataset)
79
+
80
+ batch_sampler = torch.utils.data.BatchSampler(
81
+ sampler,
82
+ batch_size=batch_size,
83
+ drop_last=False,
84
+ )
85
+ if collate_fn is not None:
86
+ self._infinite_iterator = iter(
87
+ torch.utils.data.DataLoader(
88
+ dataset,
89
+ num_workers=num_workers,
90
+ batch_sampler=_InfiniteSampler(batch_sampler),
91
+ pin_memory=False,
92
+ collate_fn=collate_fn
93
+ )
94
+ )
95
+ else:
96
+ self._infinite_iterator = iter(
97
+ torch.utils.data.DataLoader(
98
+ dataset,
99
+ num_workers=num_workers,
100
+ batch_sampler=_InfiniteSampler(batch_sampler),
101
+ pin_memory=False,
102
+ )
103
+ )
104
+
105
+ self.dataset = dataset
106
+ self.batch_size = batch_size
107
+ self._length = len(batch_sampler)
108
+
109
+ def __iter__(self):
110
+ for _ in range(len(self)):
111
+ yield next(self._infinite_iterator)
112
+
113
+ def __len__(self):
114
+ return self._length
115
+
116
+
117
+ def build_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool, collate_fn=None):
118
+ assert batch_size <= len(dataset), len(dataset)
119
+ if infinite:
120
+ dataloader = InfiniteDataLoader(
121
+ dataset, None, batch_size, num_workers=num_workers, collate_fn=collate_fn)
122
+ else:
123
+ dataloader = FastDataLoader(
124
+ dataset, batch_size, num_workers, shuffle=shuffle_when_finite, collate_fn=collate_fn)
125
+
126
+ return dataloader
127
+
128
+
129
+ def get_a_batch_dataloader(dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool):
130
+ pass
131
+
data/dataset.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from typing import Type
3
+ import torch
4
+ from torch.utils.data import TensorDataset
5
+ from torch.utils.data.dataloader import DataLoader
6
+
7
+ from .datasets.ab_dataset import ABDataset
8
+
9
+ from .datasets import * # import all datasets
10
+ from .datasets.registery import static_dataset_registery
11
+
12
+
13
+ def get_dataset(dataset_name, root_dir, split, transform=None, ignore_classes=[], idx_map=None) -> ABDataset:
14
+ dataset_cls = static_dataset_registery[dataset_name][0]
15
+ dataset = dataset_cls(root_dir, split, transform, ignore_classes, idx_map)
16
+
17
+ return dataset
18
+
19
+
20
+ def get_num_limited_dataset(dataset: ABDataset, num_samples: int, discard_label=True):
21
+ dataloader = iter(DataLoader(dataset, num_samples // 2, shuffle=True))
22
+ x, y = [], []
23
+ cur_num_samples = 0
24
+ while True:
25
+ batch = next(dataloader)
26
+ cur_x, cur_y = batch[0], batch[1]
27
+
28
+ x += [cur_x]
29
+ y += [cur_y]
30
+ cur_num_samples += cur_x.size(0)
31
+
32
+ if cur_num_samples >= num_samples:
33
+ break
34
+
35
+ x, y = torch.cat(x)[0: num_samples], torch.cat(y)[0: num_samples]
36
+ if discard_label:
37
+ new_dataset = TensorDataset(x)
38
+ else:
39
+ new_dataset = TensorDataset(x, y)
40
+
41
+ dataset.dataset = new_dataset
42
+
43
+ return dataset
data/datasets/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .image_classification import *
2
+ from .object_detection import *
3
+ from .semantic_segmentation import *
4
+ from .action_recognition import *
5
+
6
+ from .sentiment_classification import *
7
+ from .machine_translation import *
8
+ from .pos_tagging import *
9
+
10
+ from .mm_image_classification import *
11
+ from .visual_question_answering import *
data/datasets/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (481 Bytes). View file
 
data/datasets/__pycache__/ab_dataset.cpython-38.pyc ADDED
Binary file (2.2 kB). View file
 
data/datasets/__pycache__/data_aug.cpython-38.pyc ADDED
Binary file (3.24 kB). View file
 
data/datasets/__pycache__/dataset_cache.cpython-38.pyc ADDED
Binary file (1.68 kB). View file
 
data/datasets/__pycache__/dataset_split.cpython-38.pyc ADDED
Binary file (3.11 kB). View file
 
data/datasets/__pycache__/registery.cpython-38.pyc ADDED
Binary file (1.6 kB). View file
 
data/datasets/ab_dataset.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, List, Optional
3
+ from torchvision.transforms import Compose
4
+
5
+
6
+ class ABDataset(ABC):
7
+ def __init__(self, root_dir, split, transform=None, ignore_classes=[], idx_map=None):
8
+
9
+ self.root_dir = root_dir
10
+ self.split = split
11
+ self.transform = transform
12
+ self.ignore_classes = ignore_classes
13
+ self.idx_map = idx_map
14
+
15
+ self.dataset = None
16
+
17
+ # injected by @dataset_register
18
+ self.name = None
19
+ self.classes = None
20
+ self.raw_classes = None
21
+ self.class_aliases = None
22
+ self.shift_type = None
23
+ self.task_type = None # ['Image Classification', 'Object Detection', ...]
24
+ self.object_type = None # ['generic object', 'digit and letter', ...]
25
+
26
+ @abstractmethod
27
+ def create_dataset(self, root_dir: str, split: str, transform: Optional[Compose],
28
+ classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
29
+ raise NotImplementedError
30
+
31
+ def build(self):
32
+ if not hasattr(self, 'classes'):
33
+ raise AttributeError('attr `classes` is injected by `@dataset_register()`. '
34
+ 'Your dataset class should be wrapped with @dataset_register().')
35
+ self.dataset = self.create_dataset(self.root_dir, self.split, self.transform,
36
+ self.classes, self.ignore_classes, self.idx_map)
37
+ self.raw_classes = self.classes
38
+ self.classes = [i for i in self.classes if i not in self.ignore_classes]
39
+
40
+ def __getitem__(self, idx):
41
+ if self.dataset is None:
42
+ raise AttributeError('Real dataset is build in `@dataset_register()`. '
43
+ 'Your dataset class should be wrapped with @dataset_register().')
44
+ return self.dataset[idx]
45
+
46
+ def __len__(self):
47
+ return len(self.dataset)
48
+
data/datasets/action_recognition/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .ucf101 import UCF101
2
+ from .hmdb51 import HMDB51
3
+ # from .kinetics400 import Kinetics400
4
+ from .ixmas import IXMAS
data/datasets/action_recognition/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (319 Bytes). View file
 
data/datasets/action_recognition/__pycache__/common_dataset.cpython-38.pyc ADDED
Binary file (4.02 kB). View file
 
data/datasets/action_recognition/__pycache__/hmdb51.cpython-38.pyc ADDED
Binary file (2.49 kB). View file
 
data/datasets/action_recognition/__pycache__/ixmas.cpython-38.pyc ADDED
Binary file (2.09 kB). View file
 
data/datasets/action_recognition/__pycache__/ucf101.cpython-38.pyc ADDED
Binary file (3.54 kB). View file
 
data/datasets/action_recognition/common_dataset.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import random
5
+
6
+ import numpy as np
7
+ import pickle as pk
8
+ import cv2
9
+ from tqdm import tqdm
10
+ from PIL import Image
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch
14
+
15
+ # from prefetch_generator import BackgroundGenerator
16
+ from torch.utils.data import DataLoader, Dataset
17
+
18
+
19
+ class VideoDataset(Dataset):
20
+
21
+ def __init__(self, directory_list, local_rank=0, enable_GPUs_num=0, distributed_load=False, resize_shape=[224, 224] , mode='train', clip_len=32, crop_size = 168):
22
+
23
+ self.clip_len, self.crop_size, self.resize_shape = clip_len, crop_size, resize_shape
24
+ self.mode = mode
25
+
26
+ self.fnames, labels = [],[]
27
+ # get the directory of the specified split
28
+ for directory in directory_list:
29
+ folder = Path(directory)
30
+ print("Load dataset from folder : ", folder)
31
+ for label in sorted(os.listdir(folder)):
32
+ for fname in os.listdir(os.path.join(folder, label)) if mode=="train" else os.listdir(os.path.join(folder, label))[:10]:
33
+ self.fnames.append(os.path.join(folder, label, fname))
34
+ labels.append(label)
35
+ # print(labels)
36
+ random_list = list(zip(self.fnames, labels))
37
+ random.shuffle(random_list)
38
+ self.fnames[:], labels[:] = zip(*random_list)
39
+ self.labels = labels
40
+
41
+ # self.fnames = self.fnames[:240]
42
+
43
+ if mode == 'train' and distributed_load:
44
+ single_num_ = len(self.fnames)//enable_GPUs_num
45
+ self.fnames = self.fnames[local_rank*single_num_:((local_rank+1)*single_num_)]
46
+ labels = labels[local_rank*single_num_:((local_rank+1)*single_num_)]
47
+
48
+ # prepare a mapping between the label names (strings) and indices (ints)
49
+ self.label2index = {label:index for index, label in enumerate(sorted(set(labels)))}
50
+ # convert the list of label names into an array of label indices
51
+ self.label_array = np.array([self.label2index[label] for label in labels], dtype=int)
52
+
53
+ def __getitem__(self, index):
54
+ # loading and preprocessing. TODO move them to transform classess
55
+ buffer = self.loadvideo(self.fnames[index])
56
+
57
+ height_index = np.random.randint(buffer.shape[2] - self.crop_size)
58
+ width_index = np.random.randint(buffer.shape[3] - self.crop_size)
59
+
60
+ return buffer[:,:,height_index:height_index + self.crop_size, width_index:width_index + self.crop_size], self.label_array[index]
61
+
62
+
63
+ def __len__(self):
64
+ return len(self.fnames)
65
+
66
+
67
+ def loadvideo(self, fname):
68
+ # initialize a VideoCapture object to read video data into a numpy array
69
+ self.transform = transforms.Compose([
70
+ transforms.Resize([self.resize_shape[0], self.resize_shape[1]]),
71
+ transforms.ToTensor(),
72
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
73
+ ])
74
+
75
+ flip, flipCode = 1, random.choice([-1,0,1]) if np.random.random() < 0.5 and self.mode=="train" else 0
76
+
77
+ try:
78
+ video_stream = cv2.VideoCapture(fname)
79
+ frame_count = int(video_stream.get(cv2.CAP_PROP_FRAME_COUNT))
80
+ except RuntimeError:
81
+ index = np.random.randint(self.__len__())
82
+ video_stream = cv2.VideoCapture(self.fnames[index])
83
+ frame_count = int(video_stream.get(cv2.CAP_PROP_FRAME_COUNT))
84
+
85
+ while frame_count<self.clip_len+2:
86
+ index = np.random.randint(self.__len__())
87
+ video_stream = cv2.VideoCapture(self.fnames[index])
88
+ frame_count = int(video_stream.get(cv2.CAP_PROP_FRAME_COUNT))
89
+
90
+ speed_rate = np.random.randint(1, 3) if frame_count > self.clip_len*2+2 else 1
91
+ time_index = np.random.randint(frame_count - self.clip_len * speed_rate)
92
+
93
+ start_idx, end_idx, final_idx = time_index, time_index+(self.clip_len*speed_rate), frame_count-1
94
+ count, sample_count, retaining = 0, 0, True
95
+
96
+ # create a buffer. Must have dtype float, so it gets converted to a FloatTensor by Pytorch later
97
+ buffer = np.empty((self.clip_len, 3, self.resize_shape[0], self.resize_shape[1]), np.dtype('float32'))
98
+
99
+ while (count <= end_idx and retaining):
100
+ retaining, frame = video_stream.read()
101
+ if count < start_idx:
102
+ count += 1
103
+ continue
104
+ if count % speed_rate == speed_rate-1 and count >= start_idx and sample_count < self.clip_len:
105
+ if flip:
106
+ frame = cv2.flip(frame, flipCode=flipCode)
107
+ try:
108
+ buffer[sample_count] = self.transform(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
109
+ except cv2.error as err:
110
+ continue
111
+ sample_count += 1
112
+ count += 1
113
+ video_stream.release()
114
+
115
+ return buffer.transpose((1, 0, 2, 3))
116
+
117
+
118
+ if __name__ == '__main__':
119
+
120
+ datapath = ['/data/datasets/ucf101/videos']
121
+
122
+ dataset = VideoDataset(datapath,
123
+ resize_shape=[224, 224],
124
+ mode='validation')
125
+ x, y = dataset[0]
126
+ # x: (3, num_frames, w, h)
127
+ print(x.shape, y.shape, y)
128
+
129
+ # dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=24, pin_memory=True)
130
+
131
+ # bar = tqdm(total=len(dataloader), ncols=80)
132
+
133
+ # prefetcher = DataPrefetcher(BackgroundGenerator(dataloader), 0)
134
+ # batch = prefetcher.next()
135
+ # iter_id = 0
136
+ # while batch is not None:
137
+ # iter_id += 1
138
+ # bar.update(1)
139
+ # if iter_id >= len(dataloader):
140
+ # break
141
+
142
+ # batch = prefetcher.next()
143
+ # print(batch[0].shape)
144
+ # print("label: ", batch[1])
145
+
146
+ # '''
147
+ # for step, (buffer, labels) in enumerate(BackgroundGenerator(dataloader)):
148
+ # print(buffer.shape)
149
+ # print("label: ", labels)
150
+ # bar.update(1)
151
+ # '''
152
+
data/datasets/action_recognition/hmdb51.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..data_aug import cityscapes_like_image_train_aug, cityscapes_like_image_test_aug, cityscapes_like_label_aug
2
+ # from torchvision.datasets import Cityscapes as RawCityscapes
3
+ from ..ab_dataset import ABDataset
4
+ from ..dataset_split import train_val_test_split
5
+ import numpy as np
6
+ from typing import Dict, List, Optional
7
+ from torchvision.transforms import Compose, Lambda
8
+ import os
9
+
10
+ from .common_dataset import VideoDataset
11
+ from ..registery import dataset_register
12
+
13
+
14
+ @dataset_register(
15
+ name='HMDB51',
16
+ classes=['brush_hair', 'cartwheel', 'catch', 'chew', 'clap', 'climb', 'climb_stairs', 'dive', 'draw_sword', 'dribble', 'drink', 'eat', 'fall_floor', 'fencing', 'flic_flac', 'golf', 'handstand', 'hit', 'hug', 'jump', 'kick', 'kick_ball', 'kiss', 'laugh', 'pick', 'pour', 'pullup', 'punch', 'push', 'pushup', 'ride_bike', 'ride_horse', 'run', 'shake_hands', 'shoot_ball', 'shoot_bow', 'shoot_gun', 'sit', 'situp', 'smile', 'smoke', 'somersault', 'stand', 'swing_baseball', 'sword', 'sword_exercise', 'talk', 'throw', 'turn', 'walk', 'wave'],
17
+ task_type='Action Recognition',
18
+ object_type='Web Video',
19
+ # class_aliases=[['automobile', 'car']],
20
+ class_aliases=[],
21
+ shift_type=None
22
+ )
23
+ class HMDB51(ABDataset): # just for demo now
24
+ def create_dataset(self, root_dir: str, split: str, transform: Optional[Compose],
25
+ classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
26
+ # if transform is None:
27
+ # x_transform = cityscapes_like_image_train_aug() if split == 'train' else cityscapes_like_image_test_aug()
28
+ # y_transform = cityscapes_like_label_aug()
29
+ # self.transform = x_transform
30
+ # else:
31
+ # x_transform, y_transform = transform
32
+
33
+ dataset = VideoDataset([root_dir], mode='train')
34
+
35
+ if len(ignore_classes) > 0:
36
+ for ignore_class in ignore_classes:
37
+ ci = classes.index(ignore_class)
38
+ dataset.fnames = [img for img, label in zip(dataset.fnames, dataset.label_array) if label != ci]
39
+ dataset.label_array = [label for label in dataset.label_array if label != ci]
40
+
41
+ if idx_map is not None:
42
+ dataset.label_array = [idx_map[label] for label in dataset.label_array]
43
+
44
+ dataset = train_val_test_split(dataset, split)
45
+ return dataset
data/datasets/action_recognition/ixmas.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..data_aug import cityscapes_like_image_train_aug, cityscapes_like_image_test_aug, cityscapes_like_label_aug
2
+ # from torchvision.datasets import Cityscapes as RawCityscapes
3
+ from ..ab_dataset import ABDataset
4
+ from ..dataset_split import train_val_test_split
5
+ import numpy as np
6
+ from typing import Dict, List, Optional
7
+ from torchvision.transforms import Compose, Lambda
8
+ import os
9
+
10
+ from .common_dataset import VideoDataset
11
+ from ..registery import dataset_register
12
+
13
+
14
+ @dataset_register(
15
+ name='IXMAS',
16
+ classes=['check_watch', 'cross_arms', 'get_up', 'kick', 'pick_up', 'point', 'punch', 'scratch_head', 'sit_down', 'turn_around', 'walk', 'wave'],
17
+ task_type='Action Recognition',
18
+ object_type='Web Video',
19
+ # class_aliases=[['automobile', 'car']],
20
+ class_aliases=[],
21
+ shift_type=None
22
+ )
23
+ class IXMAS(ABDataset): # just for demo now
24
+ def create_dataset(self, root_dir: str, split: str, transform: Optional[Compose],
25
+ classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
26
+ # if transform is None:
27
+ # x_transform = cityscapes_like_image_train_aug() if split == 'train' else cityscapes_like_image_test_aug()
28
+ # y_transform = cityscapes_like_label_aug()
29
+ # self.transform = x_transform
30
+ # else:
31
+ # x_transform, y_transform = transform
32
+
33
+ dataset = VideoDataset([root_dir], mode='train')
34
+
35
+ if len(ignore_classes) > 0:
36
+ for ignore_class in ignore_classes:
37
+ ci = classes.index(ignore_class)
38
+ dataset.fnames = [img for img, label in zip(dataset.fnames, dataset.label_array) if label != ci]
39
+ dataset.label_array = [label for label in dataset.label_array if label != ci]
40
+
41
+ if idx_map is not None:
42
+ dataset.label_array = [idx_map[label] for label in dataset.label_array]
43
+
44
+ dataset = train_val_test_split(dataset, split)
45
+ return dataset
data/datasets/action_recognition/kinetics400.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..data_aug import cityscapes_like_image_train_aug, cityscapes_like_image_test_aug, cityscapes_like_label_aug
2
+ # from torchvision.datasets import Cityscapes as RawCityscapes
3
+ from ..ab_dataset import ABDataset
4
+ from ..dataset_split import train_val_split, train_val_test_split
5
+ import numpy as np
6
+ from typing import Dict, List, Optional
7
+ from torchvision.transforms import Compose, Lambda
8
+ import os
9
+
10
+ from .common_dataset import VideoDataset
11
+ from ..registery import dataset_register
12
+
13
+
14
+ @dataset_register(
15
+ name='Kinetics400',
16
+ classes=['abseiling', 'air drumming', 'answering questions', 'applauding', 'applying cream', 'archery', 'arm wrestling', 'arranging flowers', 'assembling computer', 'auctioning', 'baby waking up', 'baking cookies', 'balloon blowing', 'bandaging', 'barbequing', 'bartending', 'beatboxing', 'bee keeping', 'belly dancing', 'bench pressing', 'bending back', 'bending metal', 'biking through snow', 'blasting sand', 'blowing glass', 'blowing leaves', 'blowing nose', 'blowing out candles', 'bobsledding', 'bookbinding', 'bouncing on trampoline', 'bowling', 'braiding hair', 'breading or breadcrumbing', 'breakdancing', 'brush painting', 'brushing hair', 'brushing teeth', 'building cabinet', 'building shed', 'bungee jumping', 'busking', 'canoeing or kayaking', 'capoeira', 'carrying baby', 'cartwheeling', 'carving pumpkin', 'catching fish', 'catching or throwing baseball', 'catching or throwing frisbee', 'catching or throwing softball', 'celebrating', 'changing oil', 'changing wheel', 'checking tires', 'cheerleading', 'chopping wood', 'clapping', 'clay pottery making', 'clean and jerk', 'cleaning floor', 'cleaning gutters', 'cleaning pool', 'cleaning shoes', 'cleaning toilet', 'cleaning windows', 'climbing a rope', 'climbing ladder', 'climbing tree', 'contact juggling', 'cooking chicken', 'cooking egg', 'cooking on campfire', 'cooking sausages', 'counting money', 'country line dancing', 'cracking neck', 'crawling baby', 'crossing river', 'crying', 'curling hair', 'cutting nails', 'cutting pineapple', 'cutting watermelon', 'dancing ballet', 'dancing charleston', 'dancing gangnam style', 'dancing macarena', 'deadlifting', 'decorating the christmas tree', 'digging', 'dining', 'disc golfing', 'diving cliff', 'dodgeball', 'doing aerobics', 'doing laundry', 'doing nails', 'drawing', 'dribbling basketball', 'drinking', 'drinking beer', 'drinking shots', 'driving car', 'driving tractor', 'drop kicking', 'drumming fingers', 'dunking basketball', 'dying hair', 'eating burger', 'eating cake', 'eating carrots', 'eating chips', 'eating doughnuts', 'eating hotdog', 'eating ice cream', 'eating spaghetti', 'eating watermelon', 'egg hunting', 'exercising arm', 'exercising with an exercise ball', 'extinguishing fire', 'faceplanting', 'feeding birds', 'feeding fish', 'feeding goats', 'filling eyebrows', 'finger snapping', 'fixing hair', 'flipping pancake', 'flying kite', 'folding clothes', 'folding napkins', 'folding paper', 'front raises', 'frying vegetables', 'garbage collecting', 'gargling', 'getting a haircut', 'getting a tattoo', 'giving or receiving award', 'golf chipping', 'golf driving', 'golf putting', 'grinding meat', 'grooming dog', 'grooming horse', 'gymnastics tumbling', 'hammer throw', 'headbanging', 'headbutting', 'high jump', 'high kick', 'hitting baseball', 'hockey stop', 'holding snake', 'hopscotch', 'hoverboarding', 'hugging', 'hula hooping', 'hurdling', 'hurling (sport)', 'ice climbing', 'ice fishing', 'ice skating', 'ironing', 'javelin throw', 'jetskiing', 'jogging', 'juggling balls', 'juggling fire', 'juggling soccer ball', 'jumping into pool', 'jumpstyle dancing', 'kicking field goal', 'kicking soccer ball', 'kissing', 'kitesurfing', 'knitting', 'krumping', 'laughing', 'laying bricks', 'long jump', 'lunge', 'making a cake', 'making a sandwich', 'making bed', 'making jewelry', 'making pizza', 'making snowman', 'making sushi', 'making tea', 'marching', 'massaging back', 'massaging feet', 'massaging legs', "massaging person's head", 'milking cow', 'mopping floor', 'motorcycling', 'moving furniture', 'mowing lawn', 'news anchoring', 'opening bottle', 'opening present', 'paragliding', 'parasailing', 'parkour', 'passing American football (in game)', 'passing American football (not in game)', 'peeling apples', 'peeling potatoes', 'petting animal (not cat)', 'petting cat', 'picking fruit', 'planting trees', 'plastering', 'playing accordion', 'playing badminton', 'playing bagpipes', 'playing basketball', 'playing bass guitar', 'playing cards', 'playing cello', 'playing chess', 'playing clarinet', 'playing controller', 'playing cricket', 'playing cymbals', 'playing didgeridoo', 'playing drums', 'playing flute', 'playing guitar', 'playing harmonica', 'playing harp', 'playing ice hockey', 'playing keyboard', 'playing kickball', 'playing monopoly', 'playing organ', 'playing paintball', 'playing piano', 'playing poker', 'playing recorder', 'playing saxophone', 'playing squash or racquetball', 'playing tennis', 'playing trombone', 'playing trumpet', 'playing ukulele', 'playing violin', 'playing volleyball', 'playing xylophone', 'pole vault', 'presenting weather forecast', 'pull ups', 'pumping fist', 'pumping gas', 'punching bag', 'punching person (boxing)', 'push up', 'pushing car', 'pushing cart', 'pushing wheelchair', 'reading book', 'reading newspaper', 'recording music', 'riding a bike', 'riding camel', 'riding elephant', 'riding mechanical bull', 'riding mountain bike', 'riding mule', 'riding or walking with horse', 'riding scooter', 'riding unicycle', 'ripping paper', 'robot dancing', 'rock climbing', 'rock scissors paper', 'roller skating', 'running on treadmill', 'sailing', 'salsa dancing', 'sanding floor', 'scrambling eggs', 'scuba diving', 'setting table', 'shaking hands', 'shaking head', 'sharpening knives', 'sharpening pencil', 'shaving head', 'shaving legs', 'shearing sheep', 'shining shoes', 'shooting basketball', 'shooting goal (soccer)', 'shot put', 'shoveling snow', 'shredding paper', 'shuffling cards', 'side kick', 'sign language interpreting', 'singing', 'situp', 'skateboarding', 'ski jumping', 'skiing (not slalom or crosscountry)', 'skiing crosscountry', 'skiing slalom', 'skipping rope', 'skydiving', 'slacklining', 'slapping', 'sled dog racing', 'smoking', 'smoking hookah', 'snatch weight lifting', 'sneezing', 'sniffing', 'snorkeling', 'snowboarding', 'snowkiting', 'snowmobiling', 'somersaulting', 'spinning poi', 'spray painting', 'spraying', 'springboard diving', 'squat', 'sticking tongue out', 'stomping grapes', 'stretching arm', 'stretching leg', 'strumming guitar', 'surfing crowd', 'surfing water', 'sweeping floor', 'swimming backstroke', 'swimming breast stroke', 'swimming butterfly stroke', 'swing dancing', 'swinging legs', 'swinging on something', 'sword fighting', 'tai chi', 'taking a shower', 'tango dancing', 'tap dancing', 'tapping guitar', 'tapping pen', 'tasting beer', 'tasting food', 'testifying', 'texting', 'throwing axe', 'throwing ball', 'throwing discus', 'tickling', 'tobogganing', 'tossing coin', 'tossing salad', 'training dog', 'trapezing', 'trimming or shaving beard', 'trimming trees', 'triple jump', 'tying bow tie', 'tying knot (not on a tie)', 'tying tie', 'unboxing', 'unloading truck', 'using computer', 'using remote controller (not gaming)', 'using segway', 'vault', 'waiting in line', 'walking the dog', 'washing dishes', 'washing feet', 'washing hair', 'washing hands', 'water skiing', 'water sliding', 'watering plants', 'waxing back', 'waxing chest', 'waxing eyebrows', 'waxing legs', 'weaving basket', 'welding', 'whistling', 'windsurfing', 'wrapping present', 'wrestling', 'writing', 'yawning', 'yoga', 'zumba'],
17
+ task_type='Action Recognition',
18
+ object_type='Web Video',
19
+ # class_aliases=[['automobile', 'car']],
20
+ class_aliases=[],
21
+ shift_type=None
22
+ )
23
+ class Kinetics400(ABDataset): # just for demo now
24
+ def create_dataset(self, root_dir: str, split: str, transform: Optional[Compose],
25
+ classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
26
+ # if transform is None:
27
+ # x_transform = cityscapes_like_image_train_aug() if split == 'train' else cityscapes_like_image_test_aug()
28
+ # y_transform = cityscapes_like_label_aug()
29
+ # self.transform = x_transform
30
+ # else:
31
+ # x_transform, y_transform = transform
32
+
33
+ if split == 'test':
34
+ root_dir = os.path.join(root_dir, 'videos_val')
35
+ else:
36
+ root_dir = os.path.join(root_dir, 'videos_train')
37
+ # print(root_dir)
38
+ dataset = VideoDataset([root_dir], mode='train')
39
+
40
+ if len(ignore_classes) > 0:
41
+ for ignore_class in ignore_classes:
42
+ ci = classes.index(ignore_class)
43
+ dataset.fnames = [img for img, label in zip(dataset.fnames, dataset.label_array) if label != ci]
44
+ dataset.label_array = [label for label in dataset.label_array if label != ci]
45
+
46
+ if idx_map is not None:
47
+ dataset.label_array = [idx_map[label] for label in dataset.label_array]
48
+
49
+ if split != 'test':
50
+ dataset = train_val_split(dataset, split)
51
+ return dataset
data/datasets/action_recognition/ucf101.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..data_aug import cityscapes_like_image_train_aug, cityscapes_like_image_test_aug, cityscapes_like_label_aug
2
+ # from torchvision.datasets import Cityscapes as RawCityscapes
3
+ from ..ab_dataset import ABDataset
4
+ from ..dataset_split import train_val_test_split
5
+ import numpy as np
6
+ from typing import Dict, List, Optional
7
+ from torchvision.transforms import Compose, Lambda
8
+ import os
9
+
10
+ from .common_dataset import VideoDataset
11
+ from ..registery import dataset_register
12
+
13
+
14
+ @dataset_register(
15
+ name='UCF101',
16
+ classes=['apply_eye_makeup', 'apply_lipstick', 'archery', 'baby_crawling', 'balance_beam', 'band_marching', 'baseball_pitch', 'basketball', 'basketball_dunk', 'bench_press', 'biking', 'billiards', 'blow_dry_hair', 'blowing_candles', 'body_weight_squats', 'bowling', 'boxing_punching_bag', 'boxing_speed_bag', 'breast_stroke', 'brushing_teeth', 'clean_and_jerk', 'cliff_diving', 'cricket_bowling', 'cricket_shot', 'cutting_in_kitchen', 'diving', 'drumming', 'fencing', 'field_hockey_penalty', 'floor_gymnastics', 'frisbee_catch', 'front_crawl', 'golf_swing', 'haircut', 'hammer_throw', 'hammering', 'handstand_pushups', 'handstand_walking', 'head_massage', 'high_jump', 'horse_race', 'horse_riding', 'hula_hoop', 'ice_dancing', 'javelin_throw', 'juggling_balls', 'jump_rope', 'jumping_jack', 'kayaking', 'knitting', 'long_jump', 'lunges', 'military_parade', 'mixing', 'mopping_floor', 'nunchucks', 'parallel_bars', 'pizza_tossing', 'playing_cello', 'playing_daf', 'playing_dhol', 'playing_flute', 'playing_guitar', 'playing_piano', 'playing_sitar', 'playing_tabla', 'playing_violin', 'pole_vault', 'pommel_horse', 'pull_ups', 'punch', 'push_ups', 'rafting', 'rock_climbing_indoor', 'rope_climbing', 'rowing', 'salsa_spin', 'shaving_beard', 'shotput', 'skate_boarding', 'skiing', 'skijet', 'sky_diving', 'soccer_juggling', 'soccer_penalty', 'still_rings', 'sumo_wrestling', 'surfing', 'swing', 'table_tennis_shot', 'tai_chi', 'tennis_swing', 'throw_discus', 'trampoline_jumping', 'typing', 'uneven_bars', 'volleyball_spiking', 'walking_with_dog', 'wall_pushups', 'writing_on_board', 'yo_yo'],
17
+ task_type='Action Recognition',
18
+ object_type='Web Video',
19
+ # class_aliases=[['automobile', 'car']],
20
+ class_aliases=[],
21
+ shift_type=None
22
+ )
23
+ class UCF101(ABDataset): # just for demo now
24
+ def create_dataset(self, root_dir: str, split: str, transform: Optional[Compose],
25
+ classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
26
+ # if transform is None:
27
+ # x_transform = cityscapes_like_image_train_aug() if split == 'train' else cityscapes_like_image_test_aug()
28
+ # y_transform = cityscapes_like_label_aug()
29
+ # self.transform = x_transform
30
+ # else:
31
+ # x_transform, y_transform = transform
32
+
33
+ dataset = VideoDataset([root_dir], mode='train')
34
+
35
+ if len(ignore_classes) > 0:
36
+ for ignore_class in ignore_classes:
37
+ ci = classes.index(ignore_class)
38
+ dataset.fnames = [img for img, label in zip(dataset.fnames, dataset.label_array) if label != ci]
39
+ dataset.label_array = [label for label in dataset.label_array if label != ci]
40
+
41
+ if idx_map is not None:
42
+ dataset.label_array = [idx_map[label] for label in dataset.label_array]
43
+
44
+ dataset = train_val_test_split(dataset, split)
45
+ return dataset
data/datasets/data_aug.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ import torch
3
+
4
+
5
+ def one_d_image_train_aug(to_3_channels=False):
6
+ mean, std = (0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081)
7
+ return transforms.Compose([
8
+ transforms.Resize(32),
9
+ # transforms.RandomCrop(32, padding=4),
10
+ transforms.ToTensor(),
11
+ transforms.Lambda((lambda x: torch.cat([x] * 3)) if to_3_channels else (lambda x: x)),
12
+ transforms.Normalize(mean, std)
13
+ ])
14
+
15
+
16
+ def one_d_image_test_aug(to_3_channels=False):
17
+ mean, std = (0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081)
18
+ return transforms.Compose([
19
+ transforms.Resize(32),
20
+ transforms.ToTensor(),
21
+ transforms.Lambda((lambda x: torch.cat([x] * 3)) if to_3_channels else (lambda x: x)),
22
+ transforms.Normalize(mean, std)
23
+ ])
24
+
25
+
26
+ def cifar_like_image_train_aug(mean=None, std=None):
27
+ if mean is None:
28
+ mean, std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
29
+ return transforms.Compose([
30
+ transforms.Resize(40), # NOTE: this is critical!!! or you may crop a small part of an image
31
+ transforms.RandomCrop(32, padding=4),
32
+ transforms.RandomHorizontalFlip(),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(mean, std)
35
+ ])
36
+
37
+
38
+ def cifar_like_image_test_aug(mean=None, std=None):
39
+ if mean is None:
40
+ mean, std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
41
+ return transforms.Compose([
42
+ transforms.Resize(32),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(mean, std)
45
+ ])
46
+
47
+ def imagenet_like_image_train_aug():
48
+ mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
49
+ return transforms.Compose([
50
+ transforms.Resize((256, 256)),
51
+ transforms.RandomCrop((224, 224), padding=4),
52
+ transforms.RandomHorizontalFlip(),
53
+ transforms.ToTensor(),
54
+ transforms.Normalize(mean, std)
55
+ ])
56
+
57
+
58
+ def imagenet_like_image_test_aug():
59
+ mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
60
+ return transforms.Compose([
61
+ transforms.Resize((224, 224)),
62
+ transforms.ToTensor(),
63
+ transforms.Normalize(mean, std)
64
+ ])
65
+
66
+
67
+ def cityscapes_like_image_train_aug():
68
+ return transforms.Compose([
69
+ transforms.Resize((224, 224)),
70
+ transforms.ToTensor(),
71
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
72
+ ])
73
+
74
+ def cityscapes_like_image_test_aug():
75
+ return transforms.Compose([
76
+ transforms.Resize((224, 224)),
77
+ transforms.ToTensor(),
78
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
79
+ ])
80
+
81
+ def cityscapes_like_label_aug():
82
+ import numpy as np
83
+ return transforms.Compose([
84
+ transforms.Resize((224, 224)),
85
+ transforms.Lambda(lambda x: torch.from_numpy(np.array(x)).long())
86
+ ])
87
+
88
+
89
+ def pil_image_to_tensor(img_size=224):
90
+ return transforms.Compose([
91
+ transforms.Resize((img_size, img_size)),
92
+ transforms.ToTensor()
93
+ ])
data/datasets/dataset_cache.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Dict
2
+ import os
3
+ import torch
4
+ from utils.common.log import logger
5
+ import hashlib
6
+
7
+
8
+ def get_dataset_cache_path(root_dir: str,
9
+ classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]):
10
+
11
+ def _hash(o):
12
+ if isinstance(o, list):
13
+ o = sorted(o)
14
+ elif isinstance(o, dict):
15
+ o = {k: o[k] for k in sorted(o)}
16
+ elif isinstance(o, set):
17
+ o = sorted(list(o))
18
+ # else:
19
+ # print(type(o))
20
+
21
+ obj = hashlib.md5()
22
+ obj.update(str(o).encode('utf-8'))
23
+ return obj.hexdigest()
24
+
25
+ cache_key = _hash(f'zql_data_{_hash(root_dir)}_{_hash(classes)}_{_hash(ignore_classes)}_{_hash(idx_map)}.cache')
26
+
27
+ # print(root_dir, classes, ignore_classes, idx_map)
28
+ # print('cache key', cache_key)
29
+
30
+ cache_file_path = os.path.join('/tmp', f'./zql_data_cache_{cache_key}.cache')
31
+ return cache_file_path
32
+
33
+
34
+ def cache_dataset_status(status, cache_file_path, dataset_name):
35
+ logger.info(f'cache dataset status: {dataset_name}')
36
+ torch.save(status, cache_file_path)
37
+
38
+ def read_cached_dataset_status(cache_file_path, dataset_name):
39
+ logger.info(f'read dataset cache: {dataset_name}')
40
+ return torch.load(cache_file_path)
data/datasets/dataset_split.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import numpy as np
4
+ from .ab_dataset import ABDataset
5
+
6
+
7
+ class _SplitDataset(torch.utils.data.Dataset):
8
+ """Used by split_dataset"""
9
+
10
+ def __init__(self, underlying_dataset, keys):
11
+ super(_SplitDataset, self).__init__()
12
+ self.underlying_dataset = underlying_dataset
13
+ self.keys = keys
14
+
15
+ def __getitem__(self, key):
16
+ return self.underlying_dataset[self.keys[key]]
17
+
18
+ def __len__(self):
19
+ return len(self.keys)
20
+
21
+
22
+ def split_dataset(dataset, n, seed=0, transform=None):
23
+
24
+ if isinstance(dataset, ABDataset):
25
+ if dataset.task_type == 'Object Detection':
26
+ return split_dataset_det(dataset, n, seed)
27
+ if dataset.task_type == 'MM Object Detection':
28
+ return split_dataset_det_mm(dataset, n, seed, transform=transform)
29
+
30
+ """
31
+ Return a pair of datasets corresponding to a random split of the given
32
+ dataset, with n datapoints in the first dataset and the rest in the last,
33
+ using the given random seed
34
+ """
35
+ assert n <= len(dataset), f'{n}_{len(dataset)}'
36
+
37
+ cache_p = f'{n}_{seed}_{len(dataset)}'
38
+ cache_p = os.path.join(os.path.expanduser(
39
+ '~'), '.domain_benchmark_split_dataset_cache_' + str(cache_p))
40
+ if os.path.exists(cache_p):
41
+ keys_1, keys_2 = torch.load(cache_p)
42
+ else:
43
+ keys = list(range(len(dataset)))
44
+ np.random.RandomState(seed).shuffle(keys)
45
+ keys_1 = keys[:n]
46
+ keys_2 = keys[n:]
47
+ torch.save((keys_1, keys_2), cache_p)
48
+
49
+ return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2)
50
+
51
+
52
+ def train_val_split(dataset, split):
53
+ assert split in ['train', 'val']
54
+ if split == 'train':
55
+ return split_dataset(dataset, int(len(dataset) * 0.8))[0]
56
+ else:
57
+ return split_dataset(dataset, int(len(dataset) * 0.8))[1]
58
+
59
+
60
+ def train_val_test_split(dataset, split):
61
+ assert split in ['train', 'val', 'test']
62
+
63
+ train_set, test_set = split_dataset(dataset, int(len(dataset) * 0.8))
64
+ train_set, val_set = split_dataset(train_set, int(len(train_set) * 0.8))
65
+
66
+ return {'train': train_set, 'val': val_set, 'test': test_set}[split]
67
+
68
+
69
+ def split_dataset_det(dataset: ABDataset, n, seed=0):
70
+ coco_ann_json_path = dataset.ann_json_file_path_for_split
71
+ from .object_detection.yolox_data_util.api import coco_split, get_default_yolox_coco_dataset
72
+ split_coco_ann_json_path = coco_split(coco_ann_json_path, ratio=n / len(dataset))[0]
73
+ # print(n, len(dataset))
74
+ return get_default_yolox_coco_dataset(dataset.root_dir, split_coco_ann_json_path, train=dataset.split == 'train'), None
75
+
76
+ def split_dataset_det_mm(dataset: ABDataset, n, seed=0, transform=None):
77
+ coco_ann_json_path = dataset.ann_json_file_path_for_split
78
+ from .object_detection.yolox_data_util.api import coco_split, get_yolox_coco_dataset_with_caption
79
+ split_coco_ann_json_path = coco_split(coco_ann_json_path, ratio=n / len(dataset))[0]
80
+ # print(n, len(dataset))
81
+ return get_yolox_coco_dataset_with_caption(dataset.root_dir, split_coco_ann_json_path, transform=transform, train=dataset.split == 'train', classes=dataset.classes), None
data/datasets/image_classification/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mnist import MNIST
2
+ from .usps import USPS
3
+ from .svhn import SVHN
4
+ from .emnist import EMNIST
5
+ from .cifar10 import CIFAR10
6
+ from .stl10 import STL10
7
+ from .imagenet import ImageNet
8
+ from .imagenet_a import ImageNetA
9
+ from .caltech256 import Caltech256
10
+ from .domainnet_real import DomainNetReal
11
+ from .synsigns import SYNSIGNS
12
+ from .gtsrb import GTSRB
13
+
14
+ from .cifar10_single import CIFAR10Single
15
+ from .stl10_single import STL10Single
16
+ from .mnist_single import MNISTSingle
17
+ from .usps_single import USPSSingle
18
+ from .svhn_single import SVHNSingle
19
+
20
+ from .baidu_person_cls import BaiduPersonCls
21
+ from .cityscapes_cls import CityscapesCls
22
+ from .gta5_cls import GTA5Cls
23
+ from .supervisely_person_cls import SuperviselyPersonCls
24
+ from .coco_cls import COCOCls
data/datasets/image_classification/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.18 kB). View file
 
data/datasets/image_classification/__pycache__/baidu_person_cls.cpython-38.pyc ADDED
Binary file (1.99 kB). View file