Upload 1905 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- cls_lora.png +0 -0
- cls_md_w_fbs_index.png +0 -0
- cls_md_wo_fbs.png +0 -0
- cls_online.png +0 -0
- data/__init__.py +2 -0
- data/__pycache__/__init__.cpython-38.pyc +0 -0
- data/__pycache__/dataloader.cpython-38.pyc +0 -0
- data/__pycache__/dataset.cpython-38.pyc +0 -0
- data/build/__pycache__/__init__.cpython-38.pyc +0 -0
- data/build/__pycache__/build.cpython-38.pyc +0 -0
- data/build/__pycache__/merge_alias.cpython-38.pyc +0 -0
- data/build/__pycache__/scenario.cpython-38.pyc +0 -0
- data/build_cl/__pycache__/build.cpython-38.pyc +0 -0
- data/build_cl/__pycache__/scenario.cpython-38.pyc +0 -0
- data/build_gen/__pycache__/build.cpython-38.pyc +0 -0
- data/build_gen/__pycache__/merge_alias.cpython-38.pyc +0 -0
- data/build_gen/__pycache__/scenario.cpython-38.pyc +0 -0
- data/build_gen/build.py +495 -0
- data/build_gen/merge_alias.py +106 -0
- data/build_gen/scenario.py +473 -0
- data/datasets/__init__.py +1 -0
- data/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
- data/datasets/__pycache__/ab_dataset.cpython-38.pyc +0 -0
- data/datasets/__pycache__/data_aug.cpython-38.pyc +0 -0
- data/datasets/__pycache__/dataset_cache.cpython-38.pyc +0 -0
- data/datasets/__pycache__/dataset_split.cpython-38.pyc +0 -0
- data/datasets/__pycache__/registery.cpython-38.pyc +0 -0
- data/datasets/action_recognition/__pycache__/__init__.cpython-38.pyc +0 -0
- data/datasets/action_recognition/__pycache__/common_dataset.cpython-38.pyc +0 -0
- data/datasets/action_recognition/__pycache__/hmdb51.cpython-38.pyc +0 -0
- data/datasets/action_recognition/__pycache__/ixmas.cpython-38.pyc +0 -0
- data/datasets/action_recognition/__pycache__/ucf101.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/__init__.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/baidu_person_cls.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/caltech256.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/cifar10.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/cifar10_single.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/cityscapes_cls.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/coco_cls.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/domainnet_real.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/emnist.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/gta5_cls.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/gtsrb.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/imagenet.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/imagenet_a.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/mnist.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/mnist_single.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/stl10.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/stl10_single.cpython-38.pyc +0 -0
- data/datasets/image_classification/__pycache__/supervisely_person_cls.cpython-38.pyc +0 -0
cls_lora.png
ADDED
cls_md_w_fbs_index.png
ADDED
cls_md_wo_fbs.png
ADDED
cls_online.png
ADDED
data/__init__.py
CHANGED
@@ -8,5 +8,7 @@ from .build.scenario import Scenario
|
|
8 |
from .build_cl.build import build_cl_scenario
|
9 |
from .build_cl.scenario import Scenario as CLScenario
|
10 |
|
|
|
|
|
11 |
|
12 |
from .datasets.dataset_split import split_dataset
|
|
|
8 |
from .build_cl.build import build_cl_scenario
|
9 |
from .build_cl.scenario import Scenario as CLScenario
|
10 |
|
11 |
+
from .build_gen.build import build_scenario_manually_v2 as build_gen_scenario
|
12 |
+
from .build_gen.scenario import Scenario as GenScenario
|
13 |
|
14 |
from .datasets.dataset_split import split_dataset
|
data/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/data/__pycache__/__init__.cpython-38.pyc and b/data/__pycache__/__init__.cpython-38.pyc differ
|
|
data/__pycache__/dataloader.cpython-38.pyc
CHANGED
Binary files a/data/__pycache__/dataloader.cpython-38.pyc and b/data/__pycache__/dataloader.cpython-38.pyc differ
|
|
data/__pycache__/dataset.cpython-38.pyc
CHANGED
Binary files a/data/__pycache__/dataset.cpython-38.pyc and b/data/__pycache__/dataset.cpython-38.pyc differ
|
|
data/build/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/data/build/__pycache__/__init__.cpython-38.pyc and b/data/build/__pycache__/__init__.cpython-38.pyc differ
|
|
data/build/__pycache__/build.cpython-38.pyc
CHANGED
Binary files a/data/build/__pycache__/build.cpython-38.pyc and b/data/build/__pycache__/build.cpython-38.pyc differ
|
|
data/build/__pycache__/merge_alias.cpython-38.pyc
CHANGED
Binary files a/data/build/__pycache__/merge_alias.cpython-38.pyc and b/data/build/__pycache__/merge_alias.cpython-38.pyc differ
|
|
data/build/__pycache__/scenario.cpython-38.pyc
CHANGED
Binary files a/data/build/__pycache__/scenario.cpython-38.pyc and b/data/build/__pycache__/scenario.cpython-38.pyc differ
|
|
data/build_cl/__pycache__/build.cpython-38.pyc
CHANGED
Binary files a/data/build_cl/__pycache__/build.cpython-38.pyc and b/data/build_cl/__pycache__/build.cpython-38.pyc differ
|
|
data/build_cl/__pycache__/scenario.cpython-38.pyc
CHANGED
Binary files a/data/build_cl/__pycache__/scenario.cpython-38.pyc and b/data/build_cl/__pycache__/scenario.cpython-38.pyc differ
|
|
data/build_gen/__pycache__/build.cpython-38.pyc
ADDED
Binary file (9.07 kB). View file
|
|
data/build_gen/__pycache__/merge_alias.cpython-38.pyc
ADDED
Binary file (2.5 kB). View file
|
|
data/build_gen/__pycache__/scenario.cpython-38.pyc
ADDED
Binary file (9.65 kB). View file
|
|
data/build_gen/build.py
ADDED
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Type, Union
|
2 |
+
from ..datasets.ab_dataset import ABDataset
|
3 |
+
# from benchmark.data.visualize import visualize_classes_in_object_detection
|
4 |
+
# from benchmark.scenario.val_domain_shift import get_val_domain_shift_transform
|
5 |
+
from ..dataset import get_dataset
|
6 |
+
import copy
|
7 |
+
from torchvision.transforms import Compose
|
8 |
+
|
9 |
+
from .merge_alias import merge_the_same_meaning_classes
|
10 |
+
from ..datasets.registery import static_dataset_registery
|
11 |
+
|
12 |
+
|
13 |
+
# some legacy aliases of variables:
|
14 |
+
# ignore_classes == discarded classes
|
15 |
+
# private_classes == unknown classes in partial / open-set / universal DA
|
16 |
+
|
17 |
+
|
18 |
+
def _merge_the_same_meaning_classes(classes_info_of_all_datasets):
|
19 |
+
final_classes_of_all_datasets, rename_map = merge_the_same_meaning_classes(classes_info_of_all_datasets)
|
20 |
+
return final_classes_of_all_datasets, rename_map
|
21 |
+
|
22 |
+
|
23 |
+
def _find_ignore_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode):
|
24 |
+
thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode]
|
25 |
+
|
26 |
+
from functools import reduce
|
27 |
+
a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set())
|
28 |
+
|
29 |
+
if set(a_classes) == set(b_classes):
|
30 |
+
# a is equal to b, normal
|
31 |
+
# 1. no ignore classes; 2. match class idx
|
32 |
+
a_ignore_classes, b_ignore_classes = [], []
|
33 |
+
|
34 |
+
elif set(a_classes) > set(b_classes):
|
35 |
+
# a contains b, partial
|
36 |
+
a_ignore_classes, b_ignore_classes = [], []
|
37 |
+
if thres == 3 or thres == 1: # ignore extra classes in a
|
38 |
+
a_ignore_classes = set(a_classes) - set(b_classes)
|
39 |
+
|
40 |
+
elif set(a_classes) < set(b_classes):
|
41 |
+
# a is contained by b, open set
|
42 |
+
a_ignore_classes, b_ignore_classes = [], []
|
43 |
+
if thres == 3 or thres == 2: # ignore extra classes in b
|
44 |
+
b_ignore_classes = set(b_classes) - set(a_classes)
|
45 |
+
|
46 |
+
elif len(set(a_classes) & set(b_classes)) > 0:
|
47 |
+
a_ignore_classes, b_ignore_classes = [], []
|
48 |
+
if thres == 3:
|
49 |
+
a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes))
|
50 |
+
b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes))
|
51 |
+
elif thres == 2:
|
52 |
+
b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes))
|
53 |
+
elif thres == 1:
|
54 |
+
a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes))
|
55 |
+
|
56 |
+
else:
|
57 |
+
return None # a has no intersection with b, none
|
58 |
+
|
59 |
+
as_ignore_classes = [list(set(a_classes) & set(a_ignore_classes)) for a_classes in as_classes]
|
60 |
+
|
61 |
+
return as_ignore_classes, list(b_ignore_classes)
|
62 |
+
|
63 |
+
|
64 |
+
def _find_private_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode):
|
65 |
+
thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode]
|
66 |
+
|
67 |
+
from functools import reduce
|
68 |
+
a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set())
|
69 |
+
|
70 |
+
if set(a_classes) == set(b_classes):
|
71 |
+
# a is equal to b, normal
|
72 |
+
# 1. no ignore classes; 2. match class idx
|
73 |
+
a_private_classes, b_private_classes = [], []
|
74 |
+
|
75 |
+
elif set(a_classes) > set(b_classes):
|
76 |
+
# a contains b, partial
|
77 |
+
a_private_classes, b_private_classes = [], []
|
78 |
+
# if thres == 2 or thres == 0: # ignore extra classes in a
|
79 |
+
# a_private_classes = set(a_classes) - set(b_classes)
|
80 |
+
# if thres == 0: # ignore extra classes in a
|
81 |
+
# a_private_classes = set(a_classes) - set(b_classes)
|
82 |
+
|
83 |
+
elif set(a_classes) < set(b_classes):
|
84 |
+
# a is contained by b, open set
|
85 |
+
a_private_classes, b_private_classes = [], []
|
86 |
+
if thres == 1 or thres == 0: # ignore extra classes in b
|
87 |
+
b_private_classes = set(b_classes) - set(a_classes)
|
88 |
+
|
89 |
+
elif len(set(a_classes) & set(b_classes)) > 0:
|
90 |
+
a_private_classes, b_private_classes = [], []
|
91 |
+
if thres == 0:
|
92 |
+
# a_private_classes = set(a_classes) - (set(a_classes) & set(b_classes))
|
93 |
+
|
94 |
+
b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes))
|
95 |
+
elif thres == 1:
|
96 |
+
b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes))
|
97 |
+
elif thres == 2:
|
98 |
+
# a_private_classes = set(a_classes) - (set(a_classes) & set(b_classes))
|
99 |
+
pass
|
100 |
+
|
101 |
+
else:
|
102 |
+
return None # a has no intersection with b, none
|
103 |
+
|
104 |
+
return list(b_private_classes)
|
105 |
+
|
106 |
+
|
107 |
+
class _ABDatasetMetaInfo:
|
108 |
+
def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type):
|
109 |
+
self.name = name
|
110 |
+
self.classes = classes
|
111 |
+
self.class_aliases = class_aliases
|
112 |
+
self.shift_type = shift_type
|
113 |
+
self.task_type = task_type
|
114 |
+
self.object_type = object_type
|
115 |
+
|
116 |
+
|
117 |
+
def _get_dist_shift_type_when_source_a_to_target_b(a: _ABDatasetMetaInfo, b: _ABDatasetMetaInfo):
|
118 |
+
if b.shift_type is None:
|
119 |
+
return 'Dataset Shifts'
|
120 |
+
|
121 |
+
if a.name in b.shift_type.keys():
|
122 |
+
return b.shift_type[a.name]
|
123 |
+
|
124 |
+
mid_dataset_name = list(b.shift_type.keys())[0]
|
125 |
+
mid_dataset_meta_info = _ABDatasetMetaInfo(mid_dataset_name, *static_dataset_registery[mid_dataset_name][1:])
|
126 |
+
|
127 |
+
return _get_dist_shift_type_when_source_a_to_target_b(a, mid_dataset_meta_info) + ' + ' + list(b.shift_type.values())[0]
|
128 |
+
|
129 |
+
|
130 |
+
def _handle_all_datasets_v2(source_datasets: List[_ABDatasetMetaInfo], target_datasets: List[_ABDatasetMetaInfo], da_mode):
|
131 |
+
|
132 |
+
# 1. merge the same meaning classes
|
133 |
+
classes_info_of_all_datasets = {
|
134 |
+
d.name: (d.classes, d.class_aliases)
|
135 |
+
for d in source_datasets + target_datasets
|
136 |
+
}
|
137 |
+
final_classes_of_all_datasets, rename_map = _merge_the_same_meaning_classes(classes_info_of_all_datasets)
|
138 |
+
all_datasets_classes = copy.deepcopy(final_classes_of_all_datasets)
|
139 |
+
|
140 |
+
# print(all_datasets_known_classes)
|
141 |
+
|
142 |
+
# 2. find ignored classes according to DA mode
|
143 |
+
# source_datasets_ignore_classes, target_datasets_ignore_classes = {d.name: [] for d in source_datasets}, \
|
144 |
+
# {d.name: [] for d in target_datasets}
|
145 |
+
# source_datasets_private_classes, target_datasets_private_classes = {d.name: [] for d in source_datasets}, \
|
146 |
+
# {d.name: [] for d in target_datasets}
|
147 |
+
target_source_relationship_map = {td.name: {} for td in target_datasets}
|
148 |
+
# source_target_relationship_map = {sd.name: [] for sd in source_datasets}
|
149 |
+
|
150 |
+
# 1. construct target_source_relationship_map
|
151 |
+
for sd in source_datasets:#sd和td使列表中每一个元素(类)的实例
|
152 |
+
for td in target_datasets:
|
153 |
+
sc = all_datasets_classes[sd.name]
|
154 |
+
tc = all_datasets_classes[td.name]
|
155 |
+
|
156 |
+
if len(set(sc) & set(tc)) == 0:#只保留有相似类别的源域和目标域
|
157 |
+
continue
|
158 |
+
|
159 |
+
target_source_relationship_map[td.name][sd.name] = _get_dist_shift_type_when_source_a_to_target_b(sd, td)
|
160 |
+
|
161 |
+
# print(target_source_relationship_map)
|
162 |
+
# exit()
|
163 |
+
|
164 |
+
source_datasets_ignore_classes = {}
|
165 |
+
for td_name, v1 in target_source_relationship_map.items():
|
166 |
+
for sd_name, v2 in v1.items():
|
167 |
+
source_datasets_ignore_classes[sd_name + '|' + td_name] = []
|
168 |
+
target_datasets_ignore_classes = {d.name: [] for d in target_datasets}
|
169 |
+
target_datasets_private_classes = {d.name: [] for d in target_datasets}
|
170 |
+
# 保证对于每个目标域上的DA都符合给定的label shift
|
171 |
+
# 所以不同目标域就算对应同一个源域,该源域也可能不相同
|
172 |
+
|
173 |
+
for td_name, v1 in target_source_relationship_map.items():
|
174 |
+
sd_names = list(v1.keys())
|
175 |
+
|
176 |
+
sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names]
|
177 |
+
td_classes = all_datasets_classes[td_name]
|
178 |
+
ss_ignore_classes, t_ignore_classes = _find_ignore_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode)#根据DA方式不同产生ignore_classes
|
179 |
+
t_private_classes = _find_private_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode)
|
180 |
+
|
181 |
+
for sd_name, s_ignore_classes in zip(sd_names, ss_ignore_classes):
|
182 |
+
source_datasets_ignore_classes[sd_name + '|' + td_name] = s_ignore_classes
|
183 |
+
target_datasets_ignore_classes[td_name] = t_ignore_classes
|
184 |
+
target_datasets_private_classes[td_name] = t_private_classes
|
185 |
+
|
186 |
+
source_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in source_datasets_ignore_classes.items()}
|
187 |
+
target_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_ignore_classes.items()}
|
188 |
+
target_datasets_private_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_private_classes.items()}
|
189 |
+
|
190 |
+
# for k, v in source_datasets_ignore_classes.items():
|
191 |
+
# print(k, len(v))
|
192 |
+
# print()
|
193 |
+
# for k, v in target_datasets_ignore_classes.items():
|
194 |
+
# print(k, len(v))
|
195 |
+
# print()
|
196 |
+
# for k, v in target_datasets_private_classes.items():
|
197 |
+
# print(k, len(v))
|
198 |
+
# print()
|
199 |
+
|
200 |
+
# print(source_datasets_private_classes, target_datasets_private_classes)
|
201 |
+
# 3. reparse classes idx
|
202 |
+
# 3.1. agg all used classes
|
203 |
+
# all_used_classes = []
|
204 |
+
# all_datasets_private_class_idx_map = {}
|
205 |
+
|
206 |
+
# source_datasets_classes_idx_map = {}
|
207 |
+
# for td_name, v1 in target_source_relationship_map.items():
|
208 |
+
# for sd_name, v2 in v1.items():
|
209 |
+
# source_datasets_classes_idx_map[sd_name + '|' + td_name] = []
|
210 |
+
# target_datasets_classes_idx_map = {}
|
211 |
+
|
212 |
+
global_idx = 0
|
213 |
+
all_used_classes_idx_map = {}
|
214 |
+
# all_datasets_known_classes = {d: [] for d in final_classes_of_all_datasets.keys()}
|
215 |
+
for dataset_name, classes in all_datasets_classes.items():
|
216 |
+
if dataset_name not in target_datasets_ignore_classes.keys():
|
217 |
+
ignore_classes = [0] * 100000
|
218 |
+
for sn, sic in source_datasets_ignore_classes.items():
|
219 |
+
if sn.startswith(dataset_name):
|
220 |
+
if len(sic) < len(ignore_classes):
|
221 |
+
ignore_classes = sic
|
222 |
+
else:
|
223 |
+
ignore_classes = target_datasets_ignore_classes[dataset_name]
|
224 |
+
private_classes = [] \
|
225 |
+
if dataset_name not in target_datasets_ignore_classes.keys() else target_datasets_private_classes[dataset_name]
|
226 |
+
|
227 |
+
for c in classes:
|
228 |
+
if c not in ignore_classes and c not in all_used_classes_idx_map.keys() and c not in private_classes:
|
229 |
+
all_used_classes_idx_map[c] = global_idx
|
230 |
+
global_idx += 1
|
231 |
+
|
232 |
+
# print(all_used_classes_idx_map)
|
233 |
+
|
234 |
+
# dataset_private_class_idx_offset = 0
|
235 |
+
target_private_class_idx = global_idx
|
236 |
+
target_datasets_private_class_idx = {d: None for d in target_datasets_private_classes.keys()}
|
237 |
+
|
238 |
+
for dataset_name, classes in final_classes_of_all_datasets.items():
|
239 |
+
if dataset_name not in target_datasets_private_classes.keys():
|
240 |
+
continue
|
241 |
+
|
242 |
+
# ignore_classes = target_datasets_ignore_classes[dataset_name]
|
243 |
+
private_classes = target_datasets_private_classes[dataset_name]
|
244 |
+
# private_classes = [] \
|
245 |
+
# if dataset_name in source_datasets_private_classes.keys() else target_datasets_private_classes[dataset_name]
|
246 |
+
# for c in classes:
|
247 |
+
# if c not in ignore_classes and c not in all_used_classes_idx_map.keys() and c in private_classes:
|
248 |
+
# all_used_classes_idx_map[c] = global_idx + dataset_private_class_idx_offset
|
249 |
+
|
250 |
+
if len(private_classes) > 0:
|
251 |
+
# all_datasets_private_class_idx[dataset_name] = global_idx + dataset_private_class_idx_offset
|
252 |
+
# dataset_private_class_idx_offset += 1
|
253 |
+
# if dataset_name in source_datasets_private_classes.keys():
|
254 |
+
# if source_private_class_idx is None:
|
255 |
+
# source_private_class_idx = global_idx if target_private_class_idx is None else target_private_class_idx + 1
|
256 |
+
# all_datasets_private_class_idx[dataset_name] = source_private_class_idx
|
257 |
+
# else:
|
258 |
+
# if target_private_class_idx is None:
|
259 |
+
# target_private_class_idx = global_idx if source_private_class_idx is None else source_private_class_idx + 1
|
260 |
+
# all_datasets_private_class_idx[dataset_name] = target_private_class_idx
|
261 |
+
target_datasets_private_class_idx[dataset_name] = target_private_class_idx
|
262 |
+
target_private_class_idx += 1
|
263 |
+
|
264 |
+
|
265 |
+
# all_used_classes = sorted(set(all_used_classes), key=all_used_classes.index)
|
266 |
+
# all_used_classes_idx_map = {c: i for i, c in enumerate(all_used_classes)}
|
267 |
+
|
268 |
+
# print('rename_map', rename_map)
|
269 |
+
|
270 |
+
# 3.2 raw_class -> rename_map[raw_classes] -> all_used_classes_idx_map
|
271 |
+
all_datasets_e2e_idx_map = {}
|
272 |
+
all_datasets_e2e_class_to_idx_map = {}
|
273 |
+
|
274 |
+
for td_name, v1 in target_source_relationship_map.items():
|
275 |
+
sd_names = list(v1.keys())
|
276 |
+
sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names]
|
277 |
+
td_classes = all_datasets_classes[td_name]
|
278 |
+
|
279 |
+
for sd_name, sd_classes in zip(sd_names, sds_classes):
|
280 |
+
cur_e2e_idx_map = {}
|
281 |
+
cur_e2e_class_to_idx_map = {}
|
282 |
+
|
283 |
+
for raw_ci, raw_c in enumerate(sd_classes):
|
284 |
+
renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c]
|
285 |
+
|
286 |
+
ignore_classes = source_datasets_ignore_classes[sd_name + '|' + td_name]
|
287 |
+
if renamed_c in ignore_classes:
|
288 |
+
continue
|
289 |
+
|
290 |
+
idx = all_used_classes_idx_map[renamed_c]
|
291 |
+
|
292 |
+
cur_e2e_idx_map[raw_ci] = idx
|
293 |
+
cur_e2e_class_to_idx_map[raw_c] = idx
|
294 |
+
|
295 |
+
all_datasets_e2e_idx_map[sd_name + '|' + td_name] = cur_e2e_idx_map
|
296 |
+
all_datasets_e2e_class_to_idx_map[sd_name + '|' + td_name] = cur_e2e_class_to_idx_map
|
297 |
+
cur_e2e_idx_map = {}
|
298 |
+
cur_e2e_class_to_idx_map = {}
|
299 |
+
for raw_ci, raw_c in enumerate(td_classes):
|
300 |
+
renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c]
|
301 |
+
|
302 |
+
ignore_classes = target_datasets_ignore_classes[td_name]
|
303 |
+
if renamed_c in ignore_classes:
|
304 |
+
continue
|
305 |
+
|
306 |
+
if renamed_c in target_datasets_private_classes[td_name]:
|
307 |
+
idx = target_datasets_private_class_idx[td_name]
|
308 |
+
else:
|
309 |
+
idx = all_used_classes_idx_map[renamed_c]
|
310 |
+
|
311 |
+
cur_e2e_idx_map[raw_ci] = idx
|
312 |
+
cur_e2e_class_to_idx_map[raw_c] = idx
|
313 |
+
|
314 |
+
all_datasets_e2e_idx_map[td_name] = cur_e2e_idx_map
|
315 |
+
all_datasets_e2e_class_to_idx_map[td_name] = cur_e2e_class_to_idx_map
|
316 |
+
|
317 |
+
all_datasets_ignore_classes = {**source_datasets_ignore_classes, **target_datasets_ignore_classes}
|
318 |
+
# all_datasets_private_classes = {**source_datasets_private_classes, **target_datasets_private_classes}
|
319 |
+
|
320 |
+
classes_idx_set = []
|
321 |
+
for d, m in all_datasets_e2e_class_to_idx_map.items():
|
322 |
+
classes_idx_set += list(m.values())
|
323 |
+
classes_idx_set = set(classes_idx_set)
|
324 |
+
num_classes = len(classes_idx_set)
|
325 |
+
|
326 |
+
return all_datasets_ignore_classes, target_datasets_private_classes, \
|
327 |
+
all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
|
328 |
+
target_source_relationship_map, rename_map, num_classes
|
329 |
+
|
330 |
+
|
331 |
+
def _build_scenario_info_v2(
|
332 |
+
source_datasets_name: List[str],
|
333 |
+
target_datasets_order: List[str],
|
334 |
+
da_mode: str
|
335 |
+
):
|
336 |
+
assert da_mode in ['close_set', 'partial', 'open_set', 'universal']
|
337 |
+
da_mode = {'close_set': 'da', 'partial': 'partial_da', 'open_set': 'open_set_da', 'universal': 'universal_da'}[da_mode]
|
338 |
+
|
339 |
+
source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name]#获知对应的名字和对应属性,要添加数据集时,直接register就行
|
340 |
+
target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))]
|
341 |
+
|
342 |
+
all_datasets_ignore_classes, target_datasets_private_classes, \
|
343 |
+
all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
|
344 |
+
target_source_relationship_map, rename_map, num_classes \
|
345 |
+
= _handle_all_datasets_v2(source_datasets_meta_info, target_datasets_meta_info, da_mode)
|
346 |
+
|
347 |
+
return all_datasets_ignore_classes, target_datasets_private_classes, \
|
348 |
+
all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
|
349 |
+
target_source_relationship_map, rename_map, num_classes
|
350 |
+
|
351 |
+
|
352 |
+
def build_scenario_manually_v2(
|
353 |
+
source_datasets_name: List[str],
|
354 |
+
target_datasets_order: List[str],
|
355 |
+
da_mode: str,
|
356 |
+
data_dirs: Dict[str, str],
|
357 |
+
# transforms: Optional[Dict[str, Compose]] = None
|
358 |
+
):
|
359 |
+
configs = copy.deepcopy(locals())#返回当前局部变量
|
360 |
+
|
361 |
+
source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name]
|
362 |
+
target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))]
|
363 |
+
|
364 |
+
all_datasets_ignore_classes, target_datasets_private_classes, \
|
365 |
+
all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \
|
366 |
+
target_source_relationship_map, rename_map, num_classes \
|
367 |
+
= _build_scenario_info_v2(source_datasets_name, target_datasets_order, da_mode)
|
368 |
+
# from rich.console import Console
|
369 |
+
# console = Console(width=10000)
|
370 |
+
|
371 |
+
# def print_obj(_o):
|
372 |
+
# # import pprint
|
373 |
+
# # s = pprint.pformat(_o, width=140, compact=True)
|
374 |
+
# console.print(_o)
|
375 |
+
|
376 |
+
# console.print('configs:', style='bold red')
|
377 |
+
# print_obj(configs)
|
378 |
+
# console.print('renamed classes:', style='bold red')
|
379 |
+
# print_obj(rename_map)
|
380 |
+
# console.print('discarded classes:', style='bold red')
|
381 |
+
# print_obj(all_datasets_ignore_classes)
|
382 |
+
# console.print('unknown classes:', style='bold red')
|
383 |
+
# print_obj(target_datasets_private_classes)
|
384 |
+
# console.print('class to index map:', style='bold red')
|
385 |
+
# print_obj(all_datasets_e2e_class_to_idx_map)
|
386 |
+
# console.print('index map:', style='bold red')
|
387 |
+
# print_obj(all_datasets_e2e_idx_map)
|
388 |
+
# console = Console()
|
389 |
+
# # console.print('class distribution:', style='bold red')
|
390 |
+
# # class_dist = {
|
391 |
+
# # k: {
|
392 |
+
# # '#known classes': len(all_datasets_known_classes[k]),
|
393 |
+
# # '#unknown classes': len(all_datasets_private_classes[k]),
|
394 |
+
# # '#discarded classes': len(all_datasets_ignore_classes[k])
|
395 |
+
# # } for k in all_datasets_ignore_classes.keys()
|
396 |
+
# # }
|
397 |
+
# # print_obj(class_dist)
|
398 |
+
# console.print('corresponding sources of each target:', style='bold red')
|
399 |
+
# print_obj(target_source_relationship_map)
|
400 |
+
|
401 |
+
# return
|
402 |
+
|
403 |
+
# res_source_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split, getattr(transforms, d, None),
|
404 |
+
# all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
|
405 |
+
# for split in ['train', 'val', 'test']}
|
406 |
+
# for d in source_datasets_name}
|
407 |
+
# res_target_datasets_map = {d: {'train': get_num_limited_dataset(get_dataset(d, data_dirs[d], 'test', getattr(transforms, d, None),
|
408 |
+
# all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d]),
|
409 |
+
# num_samples_in_each_target_domain),
|
410 |
+
# 'test': get_dataset(d, data_dirs[d], 'test', getattr(transforms, d, None),
|
411 |
+
# all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
|
412 |
+
# }
|
413 |
+
# for d in list(set(target_datasets_order))}
|
414 |
+
|
415 |
+
# res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split,
|
416 |
+
# getattr(transforms, d.split('|')[0], None),
|
417 |
+
# all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
|
418 |
+
# for split in ['train', 'val', 'test']}
|
419 |
+
# for d in all_datasets_ignore_classes.keys() if d.split('|')[0] in source_datasets_name}
|
420 |
+
|
421 |
+
# from functools import reduce
|
422 |
+
# res_offline_train_source_datasets_map = {}
|
423 |
+
# res_offline_train_source_datasets_map_names = {}
|
424 |
+
|
425 |
+
# for d in source_datasets_name:
|
426 |
+
# source_dataset_with_max_num_classes = None
|
427 |
+
|
428 |
+
# for ed_name, ed in res_source_datasets_map.items():
|
429 |
+
# if not ed_name.startswith(d):
|
430 |
+
# continue
|
431 |
+
|
432 |
+
# if source_dataset_with_max_num_classes is None:
|
433 |
+
# source_dataset_with_max_num_classes = ed
|
434 |
+
# res_offline_train_source_datasets_map_names[d] = ed_name
|
435 |
+
|
436 |
+
# if len(ed['train'].ignore_classes) < len(source_dataset_with_max_num_classes['train'].ignore_classes):
|
437 |
+
# source_dataset_with_max_num_classes = ed
|
438 |
+
# res_offline_train_source_datasets_map_names[d] = ed_name
|
439 |
+
|
440 |
+
# res_offline_train_source_datasets_map[d] = source_dataset_with_max_num_classes
|
441 |
+
|
442 |
+
# res_target_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split, getattr(transforms, d, None),
|
443 |
+
# all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d])
|
444 |
+
# for split in ['train', 'val', 'test']}
|
445 |
+
# for d in list(set(target_datasets_order))}
|
446 |
+
|
447 |
+
from .scenario import Scenario, DatasetMetaInfo
|
448 |
+
|
449 |
+
# test_scenario = Scenario(
|
450 |
+
# config=configs,
|
451 |
+
# offline_source_datasets_meta_info={
|
452 |
+
# d: DatasetMetaInfo(d,
|
453 |
+
# {k: v for k, v in all_datasets_e2e_class_to_idx_map[res_offline_train_source_datasets_map_names[d]].items()},
|
454 |
+
# None)
|
455 |
+
# for d in source_datasets_name
|
456 |
+
# },
|
457 |
+
# offline_source_datasets={d: res_offline_train_source_datasets_map[d] for d in source_datasets_name},
|
458 |
+
|
459 |
+
# online_datasets_meta_info=[
|
460 |
+
# (
|
461 |
+
# {sd + '|' + d: DatasetMetaInfo(d,
|
462 |
+
# {k: v for k, v in all_datasets_e2e_class_to_idx_map[sd + '|' + d].items()},
|
463 |
+
# None)
|
464 |
+
# for sd in target_source_relationship_map[d].keys()},
|
465 |
+
# DatasetMetaInfo(d,
|
466 |
+
# {k: v for k, v in all_datasets_e2e_class_to_idx_map[d].items() if k not in target_datasets_private_classes[d]},
|
467 |
+
# target_datasets_private_class_idx[d])
|
468 |
+
# )
|
469 |
+
# for d in target_datasets_order
|
470 |
+
# ],
|
471 |
+
# online_datasets={**res_source_datasets_map, **res_target_datasets_map},
|
472 |
+
# target_domains_order=target_datasets_order,
|
473 |
+
# target_source_map=target_source_relationship_map,
|
474 |
+
# num_classes=num_classes
|
475 |
+
# )
|
476 |
+
import os
|
477 |
+
os.environ['_ZQL_NUMC'] = str(num_classes)
|
478 |
+
|
479 |
+
test_scenario = Scenario(config=configs, all_datasets_ignore_classes_map=all_datasets_ignore_classes,
|
480 |
+
all_datasets_idx_map=all_datasets_e2e_idx_map,
|
481 |
+
target_domains_order=target_datasets_order,
|
482 |
+
target_source_map=target_source_relationship_map,
|
483 |
+
all_datasets_e2e_class_to_idx_map=all_datasets_e2e_class_to_idx_map,
|
484 |
+
num_classes=num_classes)
|
485 |
+
|
486 |
+
|
487 |
+
return test_scenario
|
488 |
+
|
489 |
+
|
490 |
+
if __name__ == '__main__':
|
491 |
+
test_scenario = build_scenario_manually_v2(['CIFAR10', 'SVHN'],
|
492 |
+
['STL10', 'MNIST', 'STL10', 'USPS', 'MNIST', 'STL10'],
|
493 |
+
'close_set')
|
494 |
+
print(test_scenario.num_classes)
|
495 |
+
|
data/build_gen/merge_alias.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from re import L
|
2 |
+
from typing import Dict, List
|
3 |
+
from collections import Counter
|
4 |
+
|
5 |
+
|
6 |
+
def grouping(bondlist):
|
7 |
+
# reference: https://blog.csdn.net/YnagShanwen/article/details/111344386
|
8 |
+
groups = []
|
9 |
+
break1 = False
|
10 |
+
while bondlist:
|
11 |
+
pair1 = bondlist.pop(0)
|
12 |
+
a = 11111
|
13 |
+
b = 10000
|
14 |
+
while b != a:
|
15 |
+
a = b
|
16 |
+
for atomid in pair1:
|
17 |
+
for i,pair2 in enumerate(bondlist):
|
18 |
+
if atomid in pair2:
|
19 |
+
pair1 = pair1 + pair2
|
20 |
+
bondlist.pop(i)
|
21 |
+
if not bondlist:
|
22 |
+
break1 = True
|
23 |
+
break
|
24 |
+
if break1:
|
25 |
+
break
|
26 |
+
b = len(pair1)
|
27 |
+
groups.append(pair1)
|
28 |
+
return groups
|
29 |
+
|
30 |
+
|
31 |
+
def build_semantic_class_info(classes: List[str], aliases: List[List[str]]):
|
32 |
+
res = []
|
33 |
+
for c in classes:
|
34 |
+
# print(res)
|
35 |
+
if len(aliases) == 0:
|
36 |
+
res += [[c]]
|
37 |
+
else:
|
38 |
+
find_alias = False
|
39 |
+
for alias in aliases:
|
40 |
+
if c in alias:
|
41 |
+
res += [alias]
|
42 |
+
find_alias = True
|
43 |
+
break
|
44 |
+
if not find_alias:
|
45 |
+
res += [[c]]
|
46 |
+
# print(classes, res)
|
47 |
+
return res
|
48 |
+
|
49 |
+
|
50 |
+
def merge_the_same_meaning_classes(classes_info_of_all_datasets):
|
51 |
+
# print(classes_info_of_all_datasets)
|
52 |
+
|
53 |
+
semantic_classes_of_all_datasets = []
|
54 |
+
all_aliases = []
|
55 |
+
for classes, aliases in classes_info_of_all_datasets.values():
|
56 |
+
all_aliases += aliases
|
57 |
+
for classes, aliases in classes_info_of_all_datasets.values():
|
58 |
+
semantic_classes_of_all_datasets += build_semantic_class_info(classes, all_aliases)
|
59 |
+
|
60 |
+
# print(semantic_classes_of_all_datasets)
|
61 |
+
|
62 |
+
grouped_classes_of_all_datasets = grouping(semantic_classes_of_all_datasets)#匹配过后的数据
|
63 |
+
|
64 |
+
# print(grouped_classes_of_all_datasets)
|
65 |
+
|
66 |
+
# final_grouped_classes_of_all_datasets = [Counter(c).most_common()[0][0] for c in grouped_classes_of_all_datasets]
|
67 |
+
# use most common class name; if the same common, use shortest class name!
|
68 |
+
final_grouped_classes_of_all_datasets = []
|
69 |
+
for c in grouped_classes_of_all_datasets:
|
70 |
+
counter = Counter(c).most_common()
|
71 |
+
max_times = counter[0][1]
|
72 |
+
candidate_class_names = []
|
73 |
+
for item, times in counter:
|
74 |
+
if times < max_times:
|
75 |
+
break
|
76 |
+
candidate_class_names += [item]
|
77 |
+
candidate_class_names.sort(key=lambda x: len(x))
|
78 |
+
|
79 |
+
final_grouped_classes_of_all_datasets += [candidate_class_names[0]]
|
80 |
+
res = {}
|
81 |
+
res_map = {d: {} for d in classes_info_of_all_datasets.keys()}
|
82 |
+
|
83 |
+
for dataset_name, (classes, _) in classes_info_of_all_datasets.items():
|
84 |
+
final_classes = []
|
85 |
+
for c in classes:
|
86 |
+
for grouped_names, final_name in zip(grouped_classes_of_all_datasets, final_grouped_classes_of_all_datasets):
|
87 |
+
if c in grouped_names:
|
88 |
+
final_classes += [final_name]
|
89 |
+
if final_name != c:
|
90 |
+
res_map[dataset_name][c] = final_name
|
91 |
+
break
|
92 |
+
res[dataset_name] = sorted(set(final_classes), key=final_classes.index)
|
93 |
+
return res, res_map
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == '__main__':
|
97 |
+
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
98 |
+
cifar10_aliases = [['automobile', 'car']]
|
99 |
+
stl10_classes = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck']
|
100 |
+
|
101 |
+
final_classes_of_all_datasets, rename_map = merge_the_same_meaning_classes({
|
102 |
+
'CIFAR10': (cifar10_classes, cifar10_aliases),
|
103 |
+
'STL10': (stl10_classes, [])
|
104 |
+
})
|
105 |
+
|
106 |
+
print(final_classes_of_all_datasets, rename_map)
|
data/build_gen/scenario.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
from functools import reduce
|
3 |
+
from typing import Dict, List, Tuple
|
4 |
+
import numpy as np
|
5 |
+
import copy
|
6 |
+
from utils.common.log import logger
|
7 |
+
from ..datasets.ab_dataset import ABDataset
|
8 |
+
from ..datasets.dataset_split import train_val_split
|
9 |
+
from ..dataloader import FastDataLoader, InfiniteDataLoader, build_dataloader
|
10 |
+
from data import get_dataset
|
11 |
+
|
12 |
+
|
13 |
+
class DatasetMetaInfo:
|
14 |
+
def __init__(self, name,
|
15 |
+
known_classes_name_idx_map, unknown_class_idx):
|
16 |
+
|
17 |
+
assert unknown_class_idx not in known_classes_name_idx_map.keys()
|
18 |
+
|
19 |
+
self.name = name
|
20 |
+
self.unknown_class_idx = unknown_class_idx
|
21 |
+
self.known_classes_name_idx_map = known_classes_name_idx_map
|
22 |
+
|
23 |
+
@property
|
24 |
+
def num_classes(self):
|
25 |
+
return len(self.known_classes_idx) + 1
|
26 |
+
|
27 |
+
|
28 |
+
class MergedDataset:
|
29 |
+
def __init__(self, datasets: List[ABDataset]):
|
30 |
+
self.datasets = datasets
|
31 |
+
self.datasets_len = [len(i) for i in self.datasets]
|
32 |
+
logger.info(f'create MergedDataset: len of datasets {self.datasets_len}')
|
33 |
+
self.datasets_cum_len = np.cumsum(self.datasets_len)
|
34 |
+
|
35 |
+
def __getitem__(self, idx):
|
36 |
+
for i, cum_len in enumerate(self.datasets_cum_len):
|
37 |
+
if idx < cum_len:
|
38 |
+
return self.datasets[i][idx - sum(self.datasets_len[0: i])]
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
return sum(self.datasets_len)
|
42 |
+
|
43 |
+
|
44 |
+
class IndexReturnedDataset:
|
45 |
+
def __init__(self, dataset: ABDataset):
|
46 |
+
self.dataset = dataset
|
47 |
+
|
48 |
+
def __getitem__(self, idx):
|
49 |
+
res = self.dataset[idx]
|
50 |
+
|
51 |
+
if isinstance(res, (tuple, list)):
|
52 |
+
return (*res, idx)
|
53 |
+
else:
|
54 |
+
return res, idx
|
55 |
+
|
56 |
+
def __len__(self):
|
57 |
+
return len(self.dataset)
|
58 |
+
|
59 |
+
|
60 |
+
# class Scenario:
|
61 |
+
# def __init__(self, config,
|
62 |
+
# source_datasets_meta_info: Dict[str, DatasetMetaInfo], target_datasets_meta_info: Dict[str, DatasetMetaInfo],
|
63 |
+
# target_source_map: Dict[str, Dict[str, str]],
|
64 |
+
# target_domains_order: List[str],
|
65 |
+
# source_datasets: Dict[str, Dict[str, ABDataset]], target_datasets: Dict[str, Dict[str, ABDataset]]):
|
66 |
+
|
67 |
+
# self.__config = config
|
68 |
+
# self.__source_datasets_meta_info = source_datasets_meta_info
|
69 |
+
# self.__target_datasets_meta_info = target_datasets_meta_info
|
70 |
+
# self.__target_source_map = target_source_map
|
71 |
+
# self.__target_domains_order = target_domains_order
|
72 |
+
# self.__source_datasets = source_datasets
|
73 |
+
# self.__target_datasets = target_datasets
|
74 |
+
|
75 |
+
# # 1. basic
|
76 |
+
# def get_config(self):
|
77 |
+
# return copy.deepcopy(self.__config)
|
78 |
+
|
79 |
+
# def get_task_type(self):
|
80 |
+
# return list(self.__source_datasets.values())[0]['train'].task_type
|
81 |
+
|
82 |
+
# def get_num_classes(self):
|
83 |
+
# known_classes_idx = []
|
84 |
+
# unknown_classes_idx = []
|
85 |
+
# for v in self.__source_datasets_meta_info.values():
|
86 |
+
# known_classes_idx += list(v.known_classes_name_idx_map.values())
|
87 |
+
# unknown_classes_idx += [v.unknown_class_idx]
|
88 |
+
# for v in self.__target_datasets_meta_info.values():
|
89 |
+
# known_classes_idx += list(v.known_classes_name_idx_map.values())
|
90 |
+
# unknown_classes_idx += [v.unknown_class_idx]
|
91 |
+
# unknown_classes_idx = [i for i in unknown_classes_idx if i is not None]
|
92 |
+
# # print(known_classes_idx, unknown_classes_idx)
|
93 |
+
# res = len(set(known_classes_idx)), len(set(unknown_classes_idx)), len(set(known_classes_idx + unknown_classes_idx))
|
94 |
+
# # print(res)
|
95 |
+
# assert res[0] + res[1] == res[2]
|
96 |
+
# return res
|
97 |
+
|
98 |
+
# def build_dataloader(self, dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool):
|
99 |
+
# if infinite:
|
100 |
+
# dataloader = InfiniteDataLoader(
|
101 |
+
# dataset, None, batch_size, num_workers=num_workers)
|
102 |
+
# else:
|
103 |
+
# dataloader = FastDataLoader(
|
104 |
+
# dataset, batch_size, num_workers, shuffle=shuffle_when_finite)
|
105 |
+
|
106 |
+
# return dataloader
|
107 |
+
|
108 |
+
# def build_sub_dataset(self, dataset: ABDataset, indexes: List[int]):
|
109 |
+
# from ..data.datasets.dataset_split import _SplitDataset
|
110 |
+
# dataset.dataset = _SplitDataset(dataset.dataset, indexes)
|
111 |
+
# return dataset
|
112 |
+
|
113 |
+
# def build_index_returned_dataset(self, dataset: ABDataset):
|
114 |
+
# return IndexReturnedDataset(dataset)
|
115 |
+
|
116 |
+
# # 2. source
|
117 |
+
# def get_source_datasets_meta_info(self):
|
118 |
+
# return self.__source_datasets_meta_info
|
119 |
+
|
120 |
+
# def get_source_datasets_name(self):
|
121 |
+
# return list(self.__source_datasets.keys())
|
122 |
+
|
123 |
+
# def get_merged_source_dataset(self, split):
|
124 |
+
# source_train_datasets = {n: d[split] for n, d in self.__source_datasets.items()}
|
125 |
+
# return MergedDataset(list(source_train_datasets.values()))
|
126 |
+
|
127 |
+
# def get_source_datasets(self, split):
|
128 |
+
# source_train_datasets = {n: d[split] for n, d in self.__source_datasets.items()}
|
129 |
+
# return source_train_datasets
|
130 |
+
|
131 |
+
# # 3. target **domain**
|
132 |
+
# # (do we need such API `get_ith_target_domain()`?)
|
133 |
+
# def get_target_domains_meta_info(self):
|
134 |
+
# return self.__source_datasets_meta_info
|
135 |
+
|
136 |
+
# def get_target_domains_order(self):
|
137 |
+
# return self.__target_domains_order
|
138 |
+
|
139 |
+
# def get_corr_source_datasets_name_of_target_domain(self, target_domain_name):
|
140 |
+
# return self.__target_source_map[target_domain_name]
|
141 |
+
|
142 |
+
# def get_limited_target_train_dataset(self):
|
143 |
+
# if len(self.__target_domains_order) > 1:
|
144 |
+
# raise RuntimeError('this API is only for pass-in scenario in user-defined online DA algorithm')
|
145 |
+
# return list(self.__target_datasets.values())[0]['train']
|
146 |
+
|
147 |
+
# def get_target_domains_iterator(self, split):
|
148 |
+
# for target_domain_index, target_domain_name in enumerate(self.__target_domains_order):
|
149 |
+
# target_dataset = self.__target_datasets[target_domain_name]
|
150 |
+
# target_domain_meta_info = self.__target_datasets_meta_info[target_domain_name]
|
151 |
+
|
152 |
+
# yield target_domain_index, target_domain_name, target_dataset[split], target_domain_meta_info
|
153 |
+
|
154 |
+
# # 4. permission management
|
155 |
+
# def get_sub_scenario(self, source_datasets_name, source_splits, target_domains_order, target_splits):
|
156 |
+
# def get_split(dataset, splits):
|
157 |
+
# res = {}
|
158 |
+
# for s, d in dataset.items():
|
159 |
+
# if s in splits:
|
160 |
+
# res[s] = d
|
161 |
+
# return res
|
162 |
+
|
163 |
+
# return Scenario(
|
164 |
+
# config=self.__config,
|
165 |
+
# source_datasets_meta_info={k: v for k, v in self.__source_datasets_meta_info.items() if k in source_datasets_name},
|
166 |
+
# target_datasets_meta_info={k: v for k, v in self.__target_datasets_meta_info.items() if k in target_domains_order},
|
167 |
+
# target_source_map={k: v for k, v in self.__target_source_map.items() if k in target_domains_order},
|
168 |
+
# target_domains_order=target_domains_order,
|
169 |
+
# source_datasets={k: get_split(v, source_splits) for k, v in self.__source_datasets.items() if k in source_datasets_name},
|
170 |
+
# target_datasets={k: get_split(v, target_splits) for k, v in self.__target_datasets.items() if k in target_domains_order}
|
171 |
+
# )
|
172 |
+
|
173 |
+
# def get_only_source_sub_scenario_for_exp_tracker(self):
|
174 |
+
# return self.get_sub_scenario(self.get_source_datasets_name(), ['train', 'val', 'test'], [], [])
|
175 |
+
|
176 |
+
# def get_only_source_sub_scenario_for_alg(self):
|
177 |
+
# return self.get_sub_scenario(self.get_source_datasets_name(), ['train'], [], [])
|
178 |
+
|
179 |
+
# def get_one_da_sub_scenario_for_alg(self, target_domain_name):
|
180 |
+
# return self.get_sub_scenario(self.get_corr_source_datasets_name_of_target_domain(target_domain_name),
|
181 |
+
# ['train', 'val'], [target_domain_name], ['train'])
|
182 |
+
|
183 |
+
|
184 |
+
# class Scenario:
|
185 |
+
# def __init__(self, config,
|
186 |
+
|
187 |
+
# offline_source_datasets_meta_info: Dict[str, DatasetMetaInfo],
|
188 |
+
# offline_source_datasets: Dict[str, ABDataset],
|
189 |
+
|
190 |
+
# online_datasets_meta_info: List[Tuple[Dict[str, DatasetMetaInfo], DatasetMetaInfo]],
|
191 |
+
# online_datasets: Dict[str, ABDataset],
|
192 |
+
# target_domains_order: List[str],
|
193 |
+
# target_source_map: Dict[str, Dict[str, str]],
|
194 |
+
|
195 |
+
# num_classes: int):
|
196 |
+
|
197 |
+
# self.config = config
|
198 |
+
|
199 |
+
# self.offline_source_datasets_meta_info = offline_source_datasets_meta_info
|
200 |
+
# self.offline_source_datasets = offline_source_datasets
|
201 |
+
|
202 |
+
# self.online_datasets_meta_info = online_datasets_meta_info
|
203 |
+
# self.online_datasets = online_datasets
|
204 |
+
|
205 |
+
# self.target_domains_order = target_domains_order
|
206 |
+
# self.target_source_map = target_source_map
|
207 |
+
|
208 |
+
# self.num_classes = num_classes
|
209 |
+
|
210 |
+
# def get_offline_source_datasets(self, split):
|
211 |
+
# return {n: d[split] for n, d in self.offline_source_datasets.items()}
|
212 |
+
|
213 |
+
# def get_offline_source_merged_dataset(self, split):
|
214 |
+
# return MergedDataset([d[split] for d in self.offline_source_datasets.values()])
|
215 |
+
|
216 |
+
# def get_online_current_corresponding_source_datasets(self, domain_index, split):
|
217 |
+
# cur_target_domain_name = self.target_domains_order[domain_index]
|
218 |
+
# cur_source_datasets_name = list(self.target_source_map[cur_target_domain_name].keys())
|
219 |
+
# cur_source_datasets = {n: self.online_datasets[n + '|' + cur_target_domain_name][split] for n in cur_source_datasets_name}
|
220 |
+
# return cur_source_datasets
|
221 |
+
|
222 |
+
# def get_online_current_corresponding_merged_source_dataset(self, domain_index, split):
|
223 |
+
# cur_target_domain_name = self.target_domains_order[domain_index]
|
224 |
+
# cur_source_datasets_name = list(self.target_source_map[cur_target_domain_name].keys())
|
225 |
+
# cur_source_datasets = {n: self.online_datasets[n + '|' + cur_target_domain_name][split] for n in cur_source_datasets_name}
|
226 |
+
# return MergedDataset([d for d in cur_source_datasets.values()])
|
227 |
+
|
228 |
+
# def get_online_current_target_dataset(self, domain_index, split):
|
229 |
+
# cur_target_domain_name = self.target_domains_order[domain_index]
|
230 |
+
# return self.online_datasets[cur_target_domain_name][split]
|
231 |
+
|
232 |
+
# def build_dataloader(self, dataset: ABDataset, batch_size: int, num_workers: int,
|
233 |
+
# infinite: bool, shuffle_when_finite: bool, to_iterator: bool):
|
234 |
+
# if infinite:
|
235 |
+
# dataloader = InfiniteDataLoader(
|
236 |
+
# dataset, None, batch_size, num_workers=num_workers)
|
237 |
+
# else:
|
238 |
+
# dataloader = FastDataLoader(
|
239 |
+
# dataset, batch_size, num_workers, shuffle=shuffle_when_finite)
|
240 |
+
|
241 |
+
# if to_iterator:
|
242 |
+
# dataloader = iter(dataloader)
|
243 |
+
|
244 |
+
# return dataloader
|
245 |
+
|
246 |
+
# def build_sub_dataset(self, dataset: ABDataset, indexes: List[int]):
|
247 |
+
# from data.datasets.dataset_split import _SplitDataset
|
248 |
+
# dataset.dataset = _SplitDataset(dataset.dataset, indexes)
|
249 |
+
# return dataset
|
250 |
+
|
251 |
+
# def build_index_returned_dataset(self, dataset: ABDataset):
|
252 |
+
# return IndexReturnedDataset(dataset)
|
253 |
+
|
254 |
+
# def get_config(self):
|
255 |
+
# return copy.deepcopy(self.config)
|
256 |
+
|
257 |
+
# def get_task_type(self):
|
258 |
+
# return list(self.online_datasets.values())[0]['train'].task_type
|
259 |
+
|
260 |
+
# def get_num_classes(self):
|
261 |
+
# return self.num_classes
|
262 |
+
|
263 |
+
|
264 |
+
class Scenario:
|
265 |
+
def __init__(self, config, all_datasets_ignore_classes_map, all_datasets_idx_map, target_domains_order, target_source_map,
|
266 |
+
all_datasets_e2e_class_to_idx_map,
|
267 |
+
num_classes):
|
268 |
+
self.config = config
|
269 |
+
self.all_datasets_ignore_classes_map = all_datasets_ignore_classes_map
|
270 |
+
self.all_datasets_idx_map = all_datasets_idx_map
|
271 |
+
self.target_domains_order = target_domains_order
|
272 |
+
self.target_source_map = target_source_map
|
273 |
+
self.all_datasets_e2e_class_to_idx_map = all_datasets_e2e_class_to_idx_map
|
274 |
+
self.num_classes = num_classes
|
275 |
+
self.cur_domain_index = 0
|
276 |
+
|
277 |
+
logger.info(f'[scenario build] # classes: {num_classes}')
|
278 |
+
logger.debug(f'[scenario build] idx map: {all_datasets_idx_map}')
|
279 |
+
|
280 |
+
def to_json(self):
|
281 |
+
return dict(
|
282 |
+
config=self.config, all_datasets_ignore_classes_map=self.all_datasets_ignore_classes_map,
|
283 |
+
all_datasets_idx_map=self.all_datasets_idx_map, target_domains_order=self.target_domains_order,
|
284 |
+
target_source_map=self.target_source_map,
|
285 |
+
all_datasets_e2e_class_to_idx_map=self.all_datasets_e2e_class_to_idx_map,
|
286 |
+
num_classes=self.num_classes
|
287 |
+
)
|
288 |
+
|
289 |
+
def __str__(self):
|
290 |
+
return f'Scenario({self.to_json()})'
|
291 |
+
|
292 |
+
def get_offline_datasets(self, transform=None):
|
293 |
+
# make source datasets which contains all unioned classes
|
294 |
+
res_offline_train_source_datasets_map = {}
|
295 |
+
|
296 |
+
from .. import get_dataset
|
297 |
+
data_dirs = self.config['data_dirs']
|
298 |
+
|
299 |
+
source_datasets_name = self.config['source_datasets_name']
|
300 |
+
|
301 |
+
# ori_datasets_map = {d: get_dataset(d, data_dirs[d], None, None, None, None) for d in source_datasets_name}
|
302 |
+
# res_source_datasets_map = {k: {split: train_val_split(copy.deepcopy(v), split, rate=0.97) for split in ['train', 'val']} for k, v in ori_datasets_map.items()}
|
303 |
+
# for ds in res_source_datasets_map.values():
|
304 |
+
# for k, v in ds.items():
|
305 |
+
# v.underlying_dataset.dataset.setSplit(k)
|
306 |
+
res_source_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split,
|
307 |
+
transform,
|
308 |
+
self.all_datasets_ignore_classes_map[d], self.all_datasets_idx_map[d])
|
309 |
+
for split in ['train', 'val', 'test']}
|
310 |
+
for d in self.all_datasets_ignore_classes_map.keys() if d in source_datasets_name}
|
311 |
+
|
312 |
+
# for source_dataset_name in self.config['source_datasets_name']:
|
313 |
+
# source_datasets = [v for k, v in res_source_datasets_map.items() if source_dataset_name in k]
|
314 |
+
|
315 |
+
# # how to merge idx map?
|
316 |
+
# # 35 79 97
|
317 |
+
# idx_maps = [d['train'].idx_map for d in source_datasets]
|
318 |
+
# ignore_classes_list = [d['train'].ignore_classes for d in source_datasets]
|
319 |
+
|
320 |
+
# union_idx_map = {}
|
321 |
+
# for idx_map in idx_maps:
|
322 |
+
# for k, v in idx_map.items():
|
323 |
+
# if k not in union_idx_map:
|
324 |
+
# union_idx_map[k] = v
|
325 |
+
# else:
|
326 |
+
# assert union_idx_map[k] == v
|
327 |
+
|
328 |
+
# union_ignore_classes = reduce(lambda res, cur: res & set(cur), ignore_classes_list, set(ignore_classes_list[0]))
|
329 |
+
# assert len(union_ignore_classes) + len(union_idx_map) == len(source_datasets[0]['train'].raw_classes)
|
330 |
+
|
331 |
+
# logger.info(f'[scenario build] {source_dataset_name} has {len(union_idx_map)} classes in offline training')
|
332 |
+
|
333 |
+
# d = source_dataset_name
|
334 |
+
# res_offline_train_source_datasets_map[d] = {split: get_dataset(d, data_dirs[d], split,
|
335 |
+
# transform,
|
336 |
+
# union_ignore_classes, union_idx_map)
|
337 |
+
# for split in ['train', 'val', 'test']}
|
338 |
+
|
339 |
+
return res_source_datasets_map
|
340 |
+
|
341 |
+
def get_offline_datasets_args(self):
|
342 |
+
# make source datasets which contains all unioned classes
|
343 |
+
res_offline_train_source_datasets_map = {}
|
344 |
+
|
345 |
+
from .. import get_dataset
|
346 |
+
data_dirs = self.config['data_dirs']
|
347 |
+
|
348 |
+
source_datasets_name = self.config['source_datasets_name']
|
349 |
+
res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split,
|
350 |
+
None,
|
351 |
+
self.all_datasets_ignore_classes_map[d], self.all_datasets_idx_map[d])
|
352 |
+
for split in ['train', 'val', 'test']}
|
353 |
+
for d in self.all_datasets_ignore_classes_map.keys() if d.split('|')[0] in source_datasets_name}
|
354 |
+
|
355 |
+
for source_dataset_name in self.config['source_datasets_name']:
|
356 |
+
source_datasets = [v for k, v in res_source_datasets_map.items() if source_dataset_name in k]
|
357 |
+
|
358 |
+
# how to merge idx map?
|
359 |
+
# 35 79 97
|
360 |
+
idx_maps = [d['train'].idx_map for d in source_datasets]
|
361 |
+
ignore_classes_list = [d['train'].ignore_classes for d in source_datasets]
|
362 |
+
|
363 |
+
union_idx_map = {}
|
364 |
+
for idx_map in idx_maps:
|
365 |
+
for k, v in idx_map.items():
|
366 |
+
if k not in union_idx_map:
|
367 |
+
union_idx_map[k] = v
|
368 |
+
else:
|
369 |
+
assert union_idx_map[k] == v
|
370 |
+
|
371 |
+
union_ignore_classes = reduce(lambda res, cur: res & set(cur), ignore_classes_list, set(ignore_classes_list[0]))
|
372 |
+
assert len(union_ignore_classes) + len(union_idx_map) == len(source_datasets[0]['train'].raw_classes)
|
373 |
+
|
374 |
+
logger.info(f'[scenario build] {source_dataset_name} has {len(union_idx_map)} classes in offline training')
|
375 |
+
|
376 |
+
d = source_dataset_name
|
377 |
+
res_offline_train_source_datasets_map[d] = {split: dict(d, data_dirs[d], split,
|
378 |
+
None,
|
379 |
+
union_ignore_classes, union_idx_map)
|
380 |
+
for split in ['train', 'val', 'test']}
|
381 |
+
|
382 |
+
return res_offline_train_source_datasets_map
|
383 |
+
|
384 |
+
# for d in source_datasets_name:
|
385 |
+
# source_dataset_with_max_num_classes = None
|
386 |
+
|
387 |
+
# for ed_name, ed in res_source_datasets_map.items():
|
388 |
+
# if not ed_name.startswith(d):
|
389 |
+
# continue
|
390 |
+
|
391 |
+
# if source_dataset_with_max_num_classes is None:
|
392 |
+
# source_dataset_with_max_num_classes = ed
|
393 |
+
# res_offline_train_source_datasets_map_names[d] = ed_name
|
394 |
+
|
395 |
+
# if len(ed['train'].ignore_classes) < len(source_dataset_with_max_num_classes['train'].ignore_classes):
|
396 |
+
# source_dataset_with_max_num_classes = ed
|
397 |
+
# res_offline_train_source_datasets_map_names[d] = ed_name
|
398 |
+
|
399 |
+
# res_offline_train_source_datasets_map[d] = source_dataset_with_max_num_classes
|
400 |
+
|
401 |
+
# return res_offline_train_source_datasets_map
|
402 |
+
|
403 |
+
def get_online_ith_domain_datasets_args_for_inference(self, domain_index):
|
404 |
+
target_dataset_name = self.target_domains_order[domain_index]
|
405 |
+
# dataset_name: Any, root_dir: Any, split: Any, transform: Any | None = None, ignore_classes: Any = [], idx_map: Any | None = None
|
406 |
+
|
407 |
+
if 'MM-CityscapesDet' in self.target_domains_order or 'CityscapesDet' in self.target_domains_order or 'BaiduPersonDet' in self.target_domains_order:
|
408 |
+
logger.info(f'use val split for inference test (only Det workload)')
|
409 |
+
split = 'test'
|
410 |
+
else:
|
411 |
+
split = 'train'
|
412 |
+
|
413 |
+
return dict(dataset_name=target_dataset_name,
|
414 |
+
root_dir=self.config['data_dirs'][target_dataset_name],
|
415 |
+
split=split,
|
416 |
+
transform=None,
|
417 |
+
ignore_classes=self.all_datasets_ignore_classes_map[target_dataset_name],
|
418 |
+
idx_map=self.all_datasets_idx_map[target_dataset_name])
|
419 |
+
|
420 |
+
def get_online_ith_domain_datasets_args_for_training(self, domain_index):
|
421 |
+
target_dataset_name = self.target_domains_order[domain_index]
|
422 |
+
source_datasets_name = list(self.target_source_map[target_dataset_name].keys())
|
423 |
+
|
424 |
+
res = {}
|
425 |
+
# dataset_name: Any, root_dir: Any, split: Any, transform: Any | None = None, ignore_classes: Any = [], idx_map: Any | None = None
|
426 |
+
res[target_dataset_name] = {split: dict(dataset_name=target_dataset_name,
|
427 |
+
root_dir=self.config['data_dirs'][target_dataset_name],
|
428 |
+
split=split,
|
429 |
+
transform=None,
|
430 |
+
ignore_classes=self.all_datasets_ignore_classes_map[target_dataset_name],
|
431 |
+
idx_map=self.all_datasets_idx_map[target_dataset_name]) for split in ['train', 'val']}
|
432 |
+
for d in source_datasets_name:
|
433 |
+
res[d] = {split: dict(dataset_name=d,
|
434 |
+
root_dir=self.config['data_dirs'][d],
|
435 |
+
split=split,
|
436 |
+
transform=None,
|
437 |
+
ignore_classes=self.all_datasets_ignore_classes_map[d + '|' + target_dataset_name],
|
438 |
+
idx_map=self.all_datasets_idx_map[d + '|' + target_dataset_name]) for split in ['train', 'val']}
|
439 |
+
|
440 |
+
return res
|
441 |
+
|
442 |
+
def get_online_cur_domain_datasets_args_for_inference(self):
|
443 |
+
return self.get_online_ith_domain_datasets_args_for_inference(self.cur_domain_index)
|
444 |
+
|
445 |
+
def get_online_cur_domain_datasets_args_for_training(self):
|
446 |
+
return self.get_online_ith_domain_datasets_args_for_training(self.cur_domain_index)
|
447 |
+
|
448 |
+
def get_online_cur_domain_datasets_for_training(self, transform=None):
|
449 |
+
res = {}
|
450 |
+
datasets_args = self.get_online_ith_domain_datasets_args_for_training(self.cur_domain_index)
|
451 |
+
for dataset_name, dataset_args in datasets_args.items():
|
452 |
+
res[dataset_name] = {}
|
453 |
+
for split, args in dataset_args.items():
|
454 |
+
if transform is not None:
|
455 |
+
args['transform'] = transform
|
456 |
+
dataset = get_dataset(**args)
|
457 |
+
res[dataset_name][split] = dataset
|
458 |
+
return res
|
459 |
+
|
460 |
+
def get_online_cur_domain_datasets_for_inference(self, transform=None):
|
461 |
+
datasets_args = self.get_online_ith_domain_datasets_args_for_inference(self.cur_domain_index)
|
462 |
+
if transform is not None:
|
463 |
+
datasets_args['transform'] = transform
|
464 |
+
return get_dataset(**datasets_args)
|
465 |
+
|
466 |
+
def get_online_cur_domain_samples_for_training(self, num_samples, transform=None, collate_fn=None):
|
467 |
+
dataset = self.get_online_cur_domain_datasets_for_training(transform=transform)
|
468 |
+
dataset = dataset[self.target_domains_order[self.cur_domain_index]]['train']
|
469 |
+
return next(iter(build_dataloader(dataset, num_samples, 0, True, None, collate_fn=collate_fn)))[0]
|
470 |
+
|
471 |
+
def next_domain(self):
|
472 |
+
self.cur_domain_index += 1
|
473 |
+
|
data/datasets/__init__.py
CHANGED
@@ -4,6 +4,7 @@ from .semantic_segmentation import *
|
|
4 |
from .action_recognition import *
|
5 |
|
6 |
from .sentiment_classification import *
|
|
|
7 |
from .machine_translation import *
|
8 |
from .pos_tagging import *
|
9 |
|
|
|
4 |
from .action_recognition import *
|
5 |
|
6 |
from .sentiment_classification import *
|
7 |
+
from .text_generation import *
|
8 |
from .machine_translation import *
|
9 |
from .pos_tagging import *
|
10 |
|
data/datasets/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/data/datasets/__pycache__/__init__.cpython-38.pyc and b/data/datasets/__pycache__/__init__.cpython-38.pyc differ
|
|
data/datasets/__pycache__/ab_dataset.cpython-38.pyc
CHANGED
Binary files a/data/datasets/__pycache__/ab_dataset.cpython-38.pyc and b/data/datasets/__pycache__/ab_dataset.cpython-38.pyc differ
|
|
data/datasets/__pycache__/data_aug.cpython-38.pyc
CHANGED
Binary files a/data/datasets/__pycache__/data_aug.cpython-38.pyc and b/data/datasets/__pycache__/data_aug.cpython-38.pyc differ
|
|
data/datasets/__pycache__/dataset_cache.cpython-38.pyc
CHANGED
Binary files a/data/datasets/__pycache__/dataset_cache.cpython-38.pyc and b/data/datasets/__pycache__/dataset_cache.cpython-38.pyc differ
|
|
data/datasets/__pycache__/dataset_split.cpython-38.pyc
CHANGED
Binary files a/data/datasets/__pycache__/dataset_split.cpython-38.pyc and b/data/datasets/__pycache__/dataset_split.cpython-38.pyc differ
|
|
data/datasets/__pycache__/registery.cpython-38.pyc
CHANGED
Binary files a/data/datasets/__pycache__/registery.cpython-38.pyc and b/data/datasets/__pycache__/registery.cpython-38.pyc differ
|
|
data/datasets/action_recognition/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/data/datasets/action_recognition/__pycache__/__init__.cpython-38.pyc and b/data/datasets/action_recognition/__pycache__/__init__.cpython-38.pyc differ
|
|
data/datasets/action_recognition/__pycache__/common_dataset.cpython-38.pyc
CHANGED
Binary files a/data/datasets/action_recognition/__pycache__/common_dataset.cpython-38.pyc and b/data/datasets/action_recognition/__pycache__/common_dataset.cpython-38.pyc differ
|
|
data/datasets/action_recognition/__pycache__/hmdb51.cpython-38.pyc
CHANGED
Binary files a/data/datasets/action_recognition/__pycache__/hmdb51.cpython-38.pyc and b/data/datasets/action_recognition/__pycache__/hmdb51.cpython-38.pyc differ
|
|
data/datasets/action_recognition/__pycache__/ixmas.cpython-38.pyc
CHANGED
Binary files a/data/datasets/action_recognition/__pycache__/ixmas.cpython-38.pyc and b/data/datasets/action_recognition/__pycache__/ixmas.cpython-38.pyc differ
|
|
data/datasets/action_recognition/__pycache__/ucf101.cpython-38.pyc
CHANGED
Binary files a/data/datasets/action_recognition/__pycache__/ucf101.cpython-38.pyc and b/data/datasets/action_recognition/__pycache__/ucf101.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/__init__.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/__init__.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/baidu_person_cls.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/baidu_person_cls.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/baidu_person_cls.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/caltech256.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/caltech256.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/caltech256.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/cifar10.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/cifar10.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/cifar10.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/cifar10_single.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/cifar10_single.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/cifar10_single.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/cityscapes_cls.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/cityscapes_cls.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/cityscapes_cls.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/coco_cls.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/coco_cls.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/coco_cls.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/domainnet_real.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/domainnet_real.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/domainnet_real.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/emnist.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/emnist.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/emnist.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/gta5_cls.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/gta5_cls.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/gta5_cls.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/gtsrb.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/gtsrb.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/gtsrb.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/imagenet.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/imagenet.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/imagenet.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/imagenet_a.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/imagenet_a.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/imagenet_a.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/mnist.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/mnist.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/mnist.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/mnist_single.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/mnist_single.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/mnist_single.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/stl10.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/stl10.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/stl10.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/stl10_single.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/stl10_single.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/stl10_single.cpython-38.pyc differ
|
|
data/datasets/image_classification/__pycache__/supervisely_person_cls.cpython-38.pyc
CHANGED
Binary files a/data/datasets/image_classification/__pycache__/supervisely_person_cls.cpython-38.pyc and b/data/datasets/image_classification/__pycache__/supervisely_person_cls.cpython-38.pyc differ
|
|