|
from typing import Dict, List, Optional, Type, Union |
|
from ..datasets.ab_dataset import ABDataset |
|
|
|
|
|
from ..dataset import get_dataset |
|
import copy |
|
from torchvision.transforms import Compose |
|
|
|
from .merge_alias import merge_the_same_meaning_classes |
|
from ..datasets.registery import static_dataset_registery |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _merge_the_same_meaning_classes(classes_info_of_all_datasets): |
|
final_classes_of_all_datasets, rename_map = merge_the_same_meaning_classes(classes_info_of_all_datasets) |
|
return final_classes_of_all_datasets, rename_map |
|
|
|
|
|
def _find_ignore_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode): |
|
thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode] |
|
|
|
from functools import reduce |
|
a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set()) |
|
|
|
if set(a_classes) == set(b_classes): |
|
|
|
|
|
a_ignore_classes, b_ignore_classes = [], [] |
|
|
|
elif set(a_classes) > set(b_classes): |
|
|
|
a_ignore_classes, b_ignore_classes = [], [] |
|
if thres == 3 or thres == 1: |
|
a_ignore_classes = set(a_classes) - set(b_classes) |
|
|
|
elif set(a_classes) < set(b_classes): |
|
|
|
a_ignore_classes, b_ignore_classes = [], [] |
|
if thres == 3 or thres == 2: |
|
b_ignore_classes = set(b_classes) - set(a_classes) |
|
|
|
elif len(set(a_classes) & set(b_classes)) > 0: |
|
a_ignore_classes, b_ignore_classes = [], [] |
|
if thres == 3: |
|
a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes)) |
|
b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes)) |
|
elif thres == 2: |
|
b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes)) |
|
elif thres == 1: |
|
a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes)) |
|
|
|
else: |
|
return None |
|
|
|
as_ignore_classes = [list(set(a_classes) & set(a_ignore_classes)) for a_classes in as_classes] |
|
|
|
return as_ignore_classes, list(b_ignore_classes) |
|
|
|
|
|
def _find_private_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode): |
|
thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode] |
|
|
|
from functools import reduce |
|
a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set()) |
|
|
|
if set(a_classes) == set(b_classes): |
|
|
|
|
|
a_private_classes, b_private_classes = [], [] |
|
|
|
elif set(a_classes) > set(b_classes): |
|
|
|
a_private_classes, b_private_classes = [], [] |
|
|
|
|
|
|
|
|
|
|
|
elif set(a_classes) < set(b_classes): |
|
|
|
a_private_classes, b_private_classes = [], [] |
|
if thres == 1 or thres == 0: |
|
b_private_classes = set(b_classes) - set(a_classes) |
|
|
|
elif len(set(a_classes) & set(b_classes)) > 0: |
|
a_private_classes, b_private_classes = [], [] |
|
if thres == 0: |
|
|
|
|
|
b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes)) |
|
elif thres == 1: |
|
b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes)) |
|
elif thres == 2: |
|
|
|
pass |
|
|
|
else: |
|
return None |
|
|
|
return list(b_private_classes) |
|
|
|
|
|
class _ABDatasetMetaInfo: |
|
def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type): |
|
self.name = name |
|
self.classes = classes |
|
self.class_aliases = class_aliases |
|
self.shift_type = shift_type |
|
self.task_type = task_type |
|
self.object_type = object_type |
|
|
|
|
|
def _get_dist_shift_type_when_source_a_to_target_b(a: _ABDatasetMetaInfo, b: _ABDatasetMetaInfo): |
|
if b.shift_type is None: |
|
return 'Dataset Shifts' |
|
|
|
if a.name in b.shift_type.keys(): |
|
return b.shift_type[a.name] |
|
|
|
mid_dataset_name = list(b.shift_type.keys())[0] |
|
mid_dataset_meta_info = _ABDatasetMetaInfo(mid_dataset_name, *static_dataset_registery[mid_dataset_name][1:]) |
|
|
|
return _get_dist_shift_type_when_source_a_to_target_b(a, mid_dataset_meta_info) + ' + ' + list(b.shift_type.values())[0] |
|
|
|
|
|
def _handle_all_datasets_v2(source_datasets: List[_ABDatasetMetaInfo], target_datasets: List[_ABDatasetMetaInfo], da_mode): |
|
|
|
|
|
classes_info_of_all_datasets = { |
|
d.name: (d.classes, d.class_aliases) |
|
for d in source_datasets + target_datasets |
|
} |
|
final_classes_of_all_datasets, rename_map = _merge_the_same_meaning_classes(classes_info_of_all_datasets) |
|
all_datasets_classes = copy.deepcopy(final_classes_of_all_datasets) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_source_relationship_map = {td.name: {} for td in target_datasets} |
|
|
|
|
|
|
|
for sd in source_datasets: |
|
for td in target_datasets: |
|
sc = all_datasets_classes[sd.name] |
|
tc = all_datasets_classes[td.name] |
|
|
|
if len(set(sc) & set(tc)) == 0: |
|
continue |
|
|
|
target_source_relationship_map[td.name][sd.name] = _get_dist_shift_type_when_source_a_to_target_b(sd, td) |
|
|
|
|
|
|
|
|
|
source_datasets_ignore_classes = {} |
|
for td_name, v1 in target_source_relationship_map.items(): |
|
for sd_name, v2 in v1.items(): |
|
source_datasets_ignore_classes[sd_name + '|' + td_name] = [] |
|
target_datasets_ignore_classes = {d.name: [] for d in target_datasets} |
|
target_datasets_private_classes = {d.name: [] for d in target_datasets} |
|
|
|
|
|
|
|
for td_name, v1 in target_source_relationship_map.items(): |
|
sd_names = list(v1.keys()) |
|
|
|
sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names] |
|
td_classes = all_datasets_classes[td_name] |
|
ss_ignore_classes, t_ignore_classes = _find_ignore_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode) |
|
t_private_classes = _find_private_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode) |
|
|
|
for sd_name, s_ignore_classes in zip(sd_names, ss_ignore_classes): |
|
source_datasets_ignore_classes[sd_name + '|' + td_name] = s_ignore_classes |
|
target_datasets_ignore_classes[td_name] = t_ignore_classes |
|
target_datasets_private_classes[td_name] = t_private_classes |
|
|
|
source_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in source_datasets_ignore_classes.items()} |
|
target_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_ignore_classes.items()} |
|
target_datasets_private_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_private_classes.items()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global_idx = 0 |
|
all_used_classes_idx_map = {} |
|
|
|
for dataset_name, classes in all_datasets_classes.items(): |
|
if dataset_name not in target_datasets_ignore_classes.keys(): |
|
ignore_classes = [0] * 100000 |
|
for sn, sic in source_datasets_ignore_classes.items(): |
|
if sn.startswith(dataset_name): |
|
if len(sic) < len(ignore_classes): |
|
ignore_classes = sic |
|
else: |
|
ignore_classes = target_datasets_ignore_classes[dataset_name] |
|
private_classes = [] \ |
|
if dataset_name not in target_datasets_ignore_classes.keys() else target_datasets_private_classes[dataset_name] |
|
|
|
for c in classes: |
|
if c not in ignore_classes and c not in all_used_classes_idx_map.keys() and c not in private_classes: |
|
all_used_classes_idx_map[c] = global_idx |
|
global_idx += 1 |
|
|
|
|
|
|
|
|
|
target_private_class_idx = global_idx |
|
target_datasets_private_class_idx = {d: None for d in target_datasets_private_classes.keys()} |
|
|
|
for dataset_name, classes in final_classes_of_all_datasets.items(): |
|
if dataset_name not in target_datasets_private_classes.keys(): |
|
continue |
|
|
|
|
|
private_classes = target_datasets_private_classes[dataset_name] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(private_classes) > 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_datasets_private_class_idx[dataset_name] = target_private_class_idx |
|
target_private_class_idx += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_datasets_e2e_idx_map = {} |
|
all_datasets_e2e_class_to_idx_map = {} |
|
|
|
for td_name, v1 in target_source_relationship_map.items(): |
|
sd_names = list(v1.keys()) |
|
sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names] |
|
td_classes = all_datasets_classes[td_name] |
|
|
|
for sd_name, sd_classes in zip(sd_names, sds_classes): |
|
cur_e2e_idx_map = {} |
|
cur_e2e_class_to_idx_map = {} |
|
|
|
for raw_ci, raw_c in enumerate(sd_classes): |
|
renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c] |
|
|
|
ignore_classes = source_datasets_ignore_classes[sd_name + '|' + td_name] |
|
if renamed_c in ignore_classes: |
|
continue |
|
|
|
idx = all_used_classes_idx_map[renamed_c] |
|
|
|
cur_e2e_idx_map[raw_ci] = idx |
|
cur_e2e_class_to_idx_map[raw_c] = idx |
|
|
|
all_datasets_e2e_idx_map[sd_name + '|' + td_name] = cur_e2e_idx_map |
|
all_datasets_e2e_class_to_idx_map[sd_name + '|' + td_name] = cur_e2e_class_to_idx_map |
|
cur_e2e_idx_map = {} |
|
cur_e2e_class_to_idx_map = {} |
|
for raw_ci, raw_c in enumerate(td_classes): |
|
renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c] |
|
|
|
ignore_classes = target_datasets_ignore_classes[td_name] |
|
if renamed_c in ignore_classes: |
|
continue |
|
|
|
if renamed_c in target_datasets_private_classes[td_name]: |
|
idx = target_datasets_private_class_idx[td_name] |
|
else: |
|
idx = all_used_classes_idx_map[renamed_c] |
|
|
|
cur_e2e_idx_map[raw_ci] = idx |
|
cur_e2e_class_to_idx_map[raw_c] = idx |
|
|
|
all_datasets_e2e_idx_map[td_name] = cur_e2e_idx_map |
|
all_datasets_e2e_class_to_idx_map[td_name] = cur_e2e_class_to_idx_map |
|
|
|
all_datasets_ignore_classes = {**source_datasets_ignore_classes, **target_datasets_ignore_classes} |
|
|
|
|
|
classes_idx_set = [] |
|
for d, m in all_datasets_e2e_class_to_idx_map.items(): |
|
classes_idx_set += list(m.values()) |
|
classes_idx_set = set(classes_idx_set) |
|
num_classes = len(classes_idx_set) |
|
|
|
return all_datasets_ignore_classes, target_datasets_private_classes, \ |
|
all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \ |
|
target_source_relationship_map, rename_map, num_classes |
|
|
|
|
|
def _build_scenario_info_v2( |
|
source_datasets_name: List[str], |
|
target_datasets_order: List[str], |
|
da_mode: str |
|
): |
|
assert da_mode in ['close_set', 'partial', 'open_set', 'universal'] |
|
da_mode = {'close_set': 'da', 'partial': 'partial_da', 'open_set': 'open_set_da', 'universal': 'universal_da'}[da_mode] |
|
|
|
source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name] |
|
target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))] |
|
|
|
all_datasets_ignore_classes, target_datasets_private_classes, \ |
|
all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \ |
|
target_source_relationship_map, rename_map, num_classes \ |
|
= _handle_all_datasets_v2(source_datasets_meta_info, target_datasets_meta_info, da_mode) |
|
|
|
return all_datasets_ignore_classes, target_datasets_private_classes, \ |
|
all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \ |
|
target_source_relationship_map, rename_map, num_classes |
|
|
|
|
|
def build_scenario_manually_v2( |
|
source_datasets_name: List[str], |
|
target_datasets_order: List[str], |
|
da_mode: str, |
|
data_dirs: Dict[str, str], |
|
|
|
): |
|
configs = copy.deepcopy(locals()) |
|
|
|
source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name] |
|
target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))] |
|
|
|
all_datasets_ignore_classes, target_datasets_private_classes, \ |
|
all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \ |
|
target_source_relationship_map, rename_map, num_classes \ |
|
= _build_scenario_info_v2(source_datasets_name, target_datasets_order, da_mode) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from .scenario import Scenario, DatasetMetaInfo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
os.environ['_ZQL_NUMC'] = str(num_classes) |
|
|
|
test_scenario = Scenario(config=configs, all_datasets_ignore_classes_map=all_datasets_ignore_classes, |
|
all_datasets_idx_map=all_datasets_e2e_idx_map, |
|
target_domains_order=target_datasets_order, |
|
target_source_map=target_source_relationship_map, |
|
all_datasets_e2e_class_to_idx_map=all_datasets_e2e_class_to_idx_map, |
|
num_classes=num_classes) |
|
|
|
|
|
return test_scenario |
|
|
|
|
|
if __name__ == '__main__': |
|
test_scenario = build_scenario_manually_v2(['CIFAR10', 'SVHN'], |
|
['STL10', 'MNIST', 'STL10', 'USPS', 'MNIST', 'STL10'], |
|
'close_set') |
|
print(test_scenario.num_classes) |
|
|