thak123 commited on
Commit
5266581
1 Parent(s): 9251dbc

Create mtm.py

Browse files
Files changed (1) hide show
  1. mtm.py +214 -0
mtm.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.data.sampler import RandomSampler
5
+ from torch.utils.data.distributed import DistributedSampler
6
+ from torch.utils.data.dataloader import DataLoader
7
+ from transformers.data.data_collator import DataCollator
8
+ from transformers.data.data_collator import DataCollatorWithPadding, InputDataClass
9
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
10
+ from transformers import is_torch_tpu_available
11
+ import numpy as np
12
+
13
+ class MultitaskModel(transformers.PreTrainedModel):
14
+ def __init__(self, encoder, taskmodels_dict):
15
+ """
16
+ Setting MultitaskModel up as a PretrainedModel allows us
17
+ to take better advantage of Trainer features
18
+ """
19
+ super().__init__(transformers.PretrainedConfig())
20
+
21
+ self.encoder = encoder
22
+ self.taskmodels_dict = nn.ModuleDict(taskmodels_dict)
23
+
24
+ @classmethod
25
+ def create(cls, model_name, model_type_dict, model_config_dict):
26
+ """
27
+ This creates a MultitaskModel using the model class and config objects
28
+ from single-task models.
29
+
30
+ We do this by creating each single-task model, and having them share
31
+ the same encoder transformer.
32
+ """
33
+ shared_encoder = None
34
+ taskmodels_dict = {}
35
+ do = nn.Dropout(p=0.2)
36
+ for task_name, model_type in model_type_dict.items():
37
+ model = model_type.from_pretrained(
38
+ model_name,
39
+ config=model_config_dict[task_name],
40
+ )
41
+ if shared_encoder is None:
42
+ shared_encoder = getattr(
43
+ model, cls.get_encoder_attr_name(model))
44
+ else:
45
+ setattr(model, cls.get_encoder_attr_name(
46
+ model), shared_encoder)
47
+ taskmodels_dict[task_name] = model
48
+ return cls(encoder=shared_encoder, taskmodels_dict=taskmodels_dict)
49
+
50
+ @classmethod
51
+ def get_encoder_attr_name(cls, model):
52
+ """
53
+ The encoder transformer is named differently in each model "architecture".
54
+ This method lets us get the name of the encoder attribute
55
+ """
56
+ model_class_name = model.__class__.__name__
57
+ if model_class_name.startswith("Bert"):
58
+ return "bert"
59
+ elif model_class_name.startswith("Roberta"):
60
+ return "roberta"
61
+ elif model_class_name.startswith("Albert"):
62
+ return "albert"
63
+ else:
64
+ raise KeyError(f"Add support for new model {model_class_name}")
65
+
66
+ def forward(self, task_name, **kwargs):
67
+ return self.taskmodels_dict[task_name](**kwargs)
68
+
69
+ def get_model(self, task_name):
70
+ return self.taskmodels_dict[task_name]
71
+
72
+ class NLPDataCollator(DataCollatorWithPadding): # DataCollatorWithPadding
73
+ """
74
+ Extending the existing DataCollator to work with NLP dataset batches
75
+ """
76
+
77
+ def collate_batch(self, features: List[Union[InputDataClass, Dict]]) -> Dict[str, torch.Tensor]:
78
+ first = features[0]
79
+ batch = None
80
+ if isinstance(first, dict):
81
+ # NLP data sets current works presents features as lists of dictionary
82
+ # (one per example), so we will adapt the collate_batch logic for that
83
+ if "labels" in first and first["labels"] is not None:
84
+ if first["labels"].dtype == torch.int64:
85
+ labels = torch.tensor([f["labels"]
86
+ for f in features], dtype=torch.long)
87
+ else:
88
+ labels = torch.tensor([f["labels"]
89
+ for f in features], dtype=torch.float)
90
+ batch = {"labels": labels}
91
+ for k, v in first.items():
92
+ if k != "labels" and v is not None and not isinstance(v, str):
93
+ batch[k] = torch.stack([f[k] for f in features])
94
+ return batch
95
+ else:
96
+ # otherwise, revert to using the default collate_batch
97
+ return DataCollatorWithPadding().collate_batch(features)
98
+
99
+
100
+ class StrIgnoreDevice(str):
101
+ """
102
+ This is a hack. The Trainer is going call .to(device) on every input
103
+ value, but we need to pass in an additional `task_name` string.
104
+ This prevents it from throwing an error
105
+ """
106
+
107
+ def to(self, device):
108
+ return self
109
+
110
+
111
+ class DataLoaderWithTaskname:
112
+ """
113
+ Wrapper around a DataLoader to also yield a task name
114
+ """
115
+
116
+ def __init__(self, task_name, data_loader):
117
+ self.task_name = task_name
118
+ self.data_loader = data_loader
119
+
120
+ self.batch_size = data_loader.batch_size
121
+ self.dataset = data_loader.dataset
122
+
123
+ def __len__(self):
124
+ return len(self.data_loader)
125
+
126
+ def __iter__(self):
127
+ for batch in self.data_loader:
128
+ batch["task_name"] = StrIgnoreDevice(self.task_name)
129
+ yield batch
130
+
131
+
132
+ class MultitaskDataloader:
133
+ """
134
+ Data loader that combines and samples from multiple single-task
135
+ data loaders.
136
+ """
137
+
138
+ def __init__(self, dataloader_dict):
139
+ self.dataloader_dict = dataloader_dict
140
+ self.num_batches_dict = {
141
+ task_name: len(dataloader)
142
+ for task_name, dataloader in self.dataloader_dict.items()
143
+ }
144
+ self.task_name_list = list(self.dataloader_dict)
145
+ self.dataset = [None] * sum(
146
+ len(dataloader.dataset)
147
+ for dataloader in self.dataloader_dict.values()
148
+ )
149
+
150
+ def __len__(self):
151
+ return sum(self.num_batches_dict.values())
152
+
153
+ def __iter__(self):
154
+ """
155
+ For each batch, sample a task, and yield a batch from the respective
156
+ task Dataloader.
157
+
158
+ We use size-proportional sampling, but you could easily modify this
159
+ to sample from some-other distribution.
160
+ """
161
+ task_choice_list = []
162
+ for i, task_name in enumerate(self.task_name_list):
163
+ task_choice_list += [i] * self.num_batches_dict[task_name]
164
+ task_choice_list = np.array(task_choice_list)
165
+ np.random.shuffle(task_choice_list)
166
+ dataloader_iter_dict = {
167
+ task_name: iter(dataloader)
168
+ for task_name, dataloader in self.dataloader_dict.items()
169
+ }
170
+ for task_choice in task_choice_list:
171
+ task_name = self.task_name_list[task_choice]
172
+ yield next(dataloader_iter_dict[task_name])
173
+
174
+
175
+ class MultitaskTrainer(transformers.Trainer):
176
+
177
+ def get_single_train_dataloader(self, task_name, train_dataset):
178
+ """
179
+ Create a single-task data loader that also yields task names
180
+ """
181
+ if self.train_dataset is None:
182
+ raise ValueError("Trainer: training requires a train_dataset.")
183
+ if False and is_torch_tpu_available():
184
+ train_sampler = get_tpu_sampler(train_dataset)
185
+ else:
186
+ train_sampler = (
187
+ RandomSampler(train_dataset)
188
+ if self.args.local_rank == -1
189
+ else DistributedSampler(train_dataset)
190
+ )
191
+
192
+ data_loader = DataLoaderWithTaskname(
193
+ task_name=task_name,
194
+ data_loader=DataLoader(
195
+ train_dataset,
196
+ batch_size=self.args.train_batch_size,
197
+ sampler=train_sampler,
198
+ collate_fn=self.data_collator.collate_batch,
199
+ ),
200
+ )
201
+ return data_loader
202
+
203
+ def get_train_dataloader(self):
204
+ """
205
+ Returns a MultitaskDataloader, which is not actually a Dataloader
206
+ but an iterable that returns a generator that samples from each
207
+ task Dataloader
208
+ """
209
+ return MultitaskDataloader({
210
+ task_name: self.get_single_train_dataloader(
211
+ task_name, task_dataset)
212
+ for task_name, task_dataset in self.train_dataset.items()
213
+ })
214
+