def build_loader_simmim(config): ############ single model ##################### # transform = SimMIMTransform(config) # dataset = ImageFolder(config.DATA.DATA_PATH, transform) # sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) # dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn) ############## multi model #################### datasets = [] ### 数据增强 ###### model_paths = config.DATA.TYPE_PATH[0] for i in model_paths.keys(): a = config.DATA.SCALE[0][i].split(',') scale_model = (float(a[0].split('(')[1]) ,float(a[1].split(')')[0])) transform = SimMIMTransform(config, config.DATA.NORM[0][i], scale_model) dataset = CachedImageFolder(model_paths[i], transform = transform, model = i) datasets.append(dataset) multi_task_train_dataset = MultiTaskDataset(datasets) print(len(datasets)) multi_task_batch_sampler = DistrubutedMultiTaskBatchSampler(datasets, batch_size=config.DATA.BATCH_SIZE, num_replicas=dist.get_world_size(), rank=dist.get_rank(), mix_opt=0, extra_task_ratio=0, drop_last=True ,shuffle =True) dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, collate_fn=collate_fn) # dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, pin_memory=True, collate_fn=collate_fn) print(len(dataloader)) return dataloader