File size: 1,646 Bytes
8d7921b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

def build_loader_simmim(config):
    ############ single model #####################
    # transform = SimMIMTransform(config)
    # dataset = ImageFolder(config.DATA.DATA_PATH, transform)
    # sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
    # dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)

    ############## multi model ####################
    datasets = []
    ### 数据增强 ######
    model_paths = config.DATA.TYPE_PATH[0]
    for i in model_paths.keys():
        a = config.DATA.SCALE[0][i].split(',')
        scale_model = (float(a[0].split('(')[1]) ,float(a[1].split(')')[0]))
        transform = SimMIMTransform(config, config.DATA.NORM[0][i], scale_model)
        dataset = CachedImageFolder(model_paths[i], transform = transform, model = i)
        datasets.append(dataset)
    multi_task_train_dataset = MultiTaskDataset(datasets)
    print(len(datasets))
    multi_task_batch_sampler = DistrubutedMultiTaskBatchSampler(datasets, batch_size=config.DATA.BATCH_SIZE, num_replicas=dist.get_world_size(), rank=dist.get_rank(), mix_opt=0, extra_task_ratio=0, drop_last=True ,shuffle =True)
    dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)
    # dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, pin_memory=True, collate_fn=collate_fn)
    print(len(dataloader))

    return dataloader