File size: 3,068 Bytes
ce13d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Enable curriculum learning by resuming with a different dataset.

This callback is currently experimental. The API may change without warning in
the future.
"""
import logging
from typing import Any, Dict
from streaming import StreamingDataset
from torch.utils.data import DataLoader
from .interfaces import CallbackWithConfig
from .warnings import experimental_class
log = logging.getLogger(__name__)

@experimental_class('CurriculumLearning callback')
class CurriculumLearning(CallbackWithConfig):
    """Starts an epoch with a different dataset when resuming from a checkpoint.

    Args:
        dataset_index (int): The index of the dataset currently being used.
        current_dataset_config (Dict): The configuration of the dataset currently
            being used.
    """

    def __init__(self, dataset_index: int, train_config: Dict):
        self.dataset_index = dataset_index
        self.saved_dataset_index = 0
        self.all_dataset_configs = []
        self.current_dataset_state = {}
        self.current_dataset_config = train_config['dataloader']

    def before_load(self, state: State, logger: Logger):
        del logger
        train_loader = state.train_dataloader
        if not isinstance(train_loader, DataLoader):
            raise ValueError(f'CurriculumLearning callback can only be used with a train ', f'dataloader of type DataLoader, but got {type(train_loader)}.')
        dataset = train_loader.dataset
        if not isinstance(dataset, StreamingDataset):
            raise ValueError(f'CurriculumLearning callback only supports StreamingDataset ', f'because it requires loading and saving dataset state. ', f'Instead, got a dataset of type {type(dataset)}')
        assert isinstance(dataset, StreamingDataset)
        self.current_dataset_state = dataset.state_dict(num_samples=0, from_beginning=False)

    def after_load(self, state: State, logger: Logger):
        del logger
        train_loader = state._train_dataloader
        assert isinstance(train_loader, DataLoader), 'CurriculumLearning callback requires a DataLoader.'
        dataset = train_loader.dataset
        assert isinstance(dataset, StreamingDataset), 'CurriculumLearning callback requires a StreamingDataset.'
        if self.saved_dataset_index < self.dataset_index:
            if self.current_dataset_state['epoch'] < 0:
                self.current_dataset_state['epoch'] = 0
            dataset.load_state_dict(self.current_dataset_state)
            state.timestamp = state.timestamp.to_next_epoch()
            self.all_dataset_configs.append(self.current_dataset_config)
        elif self.dataset_index == 0 and len(self.all_dataset_configs) == 0:
            self.all_dataset_configs.append(self.current_dataset_config)

    def state_dict(self):
        return {'dataset_index': self.dataset_index, 'all_dataset_configs': self.all_dataset_configs}

    def load_state_dict(self, state: Dict[str, Any]):
        self.saved_dataset_index = state.get('dataset_index', 0)
        self.all_dataset_configs = state.get('all_dataset_configs', [])