"""Enable curriculum learning by resuming with a different dataset. This callback is currently experimental. The API may change without warning in the future. """ import logging from typing import Any, Dict from streaming import StreamingDataset from torch.utils.data import DataLoader from .interfaces import CallbackWithConfig from .warnings import experimental_class log = logging.getLogger(__name__) @experimental_class('CurriculumLearning callback') class CurriculumLearning(CallbackWithConfig): """Starts an epoch with a different dataset when resuming from a checkpoint. Args: dataset_index (int): The index of the dataset currently being used. current_dataset_config (Dict): The configuration of the dataset currently being used. """ def __init__(self, dataset_index: int, train_config: Dict): self.dataset_index = dataset_index self.saved_dataset_index = 0 self.all_dataset_configs = [] self.current_dataset_state = {} self.current_dataset_config = train_config['dataloader'] def before_load(self, state: State, logger: Logger): del logger train_loader = state.train_dataloader if not isinstance(train_loader, DataLoader): raise ValueError(f'CurriculumLearning callback can only be used with a train ', f'dataloader of type DataLoader, but got {type(train_loader)}.') dataset = train_loader.dataset if not isinstance(dataset, StreamingDataset): raise ValueError(f'CurriculumLearning callback only supports StreamingDataset ', f'because it requires loading and saving dataset state. ', f'Instead, got a dataset of type {type(dataset)}') assert isinstance(dataset, StreamingDataset) self.current_dataset_state = dataset.state_dict(num_samples=0, from_beginning=False) def after_load(self, state: State, logger: Logger): del logger train_loader = state._train_dataloader assert isinstance(train_loader, DataLoader), 'CurriculumLearning callback requires a DataLoader.' dataset = train_loader.dataset assert isinstance(dataset, StreamingDataset), 'CurriculumLearning callback requires a StreamingDataset.' if self.saved_dataset_index < self.dataset_index: if self.current_dataset_state['epoch'] < 0: self.current_dataset_state['epoch'] = 0 dataset.load_state_dict(self.current_dataset_state) state.timestamp = state.timestamp.to_next_epoch() self.all_dataset_configs.append(self.current_dataset_config) elif self.dataset_index == 0 and len(self.all_dataset_configs) == 0: self.all_dataset_configs.append(self.current_dataset_config) def state_dict(self): return {'dataset_index': self.dataset_index, 'all_dataset_configs': self.all_dataset_configs} def load_state_dict(self, state: Dict[str, Any]): self.saved_dataset_index = state.get('dataset_index', 0) self.all_dataset_configs = state.get('all_dataset_configs', [])