thak123's picture
Create mtm.py
5266581
raw
history blame
7.82 kB
import transformers
import torch
import torch.nn as nn
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.dataloader import DataLoader
from transformers.data.data_collator import DataCollator
from transformers.data.data_collator import DataCollatorWithPadding, InputDataClass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers import is_torch_tpu_available
import numpy as np
class MultitaskModel(transformers.PreTrainedModel):
def __init__(self, encoder, taskmodels_dict):
"""
Setting MultitaskModel up as a PretrainedModel allows us
to take better advantage of Trainer features
"""
super().__init__(transformers.PretrainedConfig())
self.encoder = encoder
self.taskmodels_dict = nn.ModuleDict(taskmodels_dict)
@classmethod
def create(cls, model_name, model_type_dict, model_config_dict):
"""
This creates a MultitaskModel using the model class and config objects
from single-task models.
We do this by creating each single-task model, and having them share
the same encoder transformer.
"""
shared_encoder = None
taskmodels_dict = {}
do = nn.Dropout(p=0.2)
for task_name, model_type in model_type_dict.items():
model = model_type.from_pretrained(
model_name,
config=model_config_dict[task_name],
)
if shared_encoder is None:
shared_encoder = getattr(
model, cls.get_encoder_attr_name(model))
else:
setattr(model, cls.get_encoder_attr_name(
model), shared_encoder)
taskmodels_dict[task_name] = model
return cls(encoder=shared_encoder, taskmodels_dict=taskmodels_dict)
@classmethod
def get_encoder_attr_name(cls, model):
"""
The encoder transformer is named differently in each model "architecture".
This method lets us get the name of the encoder attribute
"""
model_class_name = model.__class__.__name__
if model_class_name.startswith("Bert"):
return "bert"
elif model_class_name.startswith("Roberta"):
return "roberta"
elif model_class_name.startswith("Albert"):
return "albert"
else:
raise KeyError(f"Add support for new model {model_class_name}")
def forward(self, task_name, **kwargs):
return self.taskmodels_dict[task_name](**kwargs)
def get_model(self, task_name):
return self.taskmodels_dict[task_name]
class NLPDataCollator(DataCollatorWithPadding): # DataCollatorWithPadding
"""
Extending the existing DataCollator to work with NLP dataset batches
"""
def collate_batch(self, features: List[Union[InputDataClass, Dict]]) -> Dict[str, torch.Tensor]:
first = features[0]
batch = None
if isinstance(first, dict):
# NLP data sets current works presents features as lists of dictionary
# (one per example), so we will adapt the collate_batch logic for that
if "labels" in first and first["labels"] is not None:
if first["labels"].dtype == torch.int64:
labels = torch.tensor([f["labels"]
for f in features], dtype=torch.long)
else:
labels = torch.tensor([f["labels"]
for f in features], dtype=torch.float)
batch = {"labels": labels}
for k, v in first.items():
if k != "labels" and v is not None and not isinstance(v, str):
batch[k] = torch.stack([f[k] for f in features])
return batch
else:
# otherwise, revert to using the default collate_batch
return DataCollatorWithPadding().collate_batch(features)
class StrIgnoreDevice(str):
"""
This is a hack. The Trainer is going call .to(device) on every input
value, but we need to pass in an additional `task_name` string.
This prevents it from throwing an error
"""
def to(self, device):
return self
class DataLoaderWithTaskname:
"""
Wrapper around a DataLoader to also yield a task name
"""
def __init__(self, task_name, data_loader):
self.task_name = task_name
self.data_loader = data_loader
self.batch_size = data_loader.batch_size
self.dataset = data_loader.dataset
def __len__(self):
return len(self.data_loader)
def __iter__(self):
for batch in self.data_loader:
batch["task_name"] = StrIgnoreDevice(self.task_name)
yield batch
class MultitaskDataloader:
"""
Data loader that combines and samples from multiple single-task
data loaders.
"""
def __init__(self, dataloader_dict):
self.dataloader_dict = dataloader_dict
self.num_batches_dict = {
task_name: len(dataloader)
for task_name, dataloader in self.dataloader_dict.items()
}
self.task_name_list = list(self.dataloader_dict)
self.dataset = [None] * sum(
len(dataloader.dataset)
for dataloader in self.dataloader_dict.values()
)
def __len__(self):
return sum(self.num_batches_dict.values())
def __iter__(self):
"""
For each batch, sample a task, and yield a batch from the respective
task Dataloader.
We use size-proportional sampling, but you could easily modify this
to sample from some-other distribution.
"""
task_choice_list = []
for i, task_name in enumerate(self.task_name_list):
task_choice_list += [i] * self.num_batches_dict[task_name]
task_choice_list = np.array(task_choice_list)
np.random.shuffle(task_choice_list)
dataloader_iter_dict = {
task_name: iter(dataloader)
for task_name, dataloader in self.dataloader_dict.items()
}
for task_choice in task_choice_list:
task_name = self.task_name_list[task_choice]
yield next(dataloader_iter_dict[task_name])
class MultitaskTrainer(transformers.Trainer):
def get_single_train_dataloader(self, task_name, train_dataset):
"""
Create a single-task data loader that also yields task names
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
if False and is_torch_tpu_available():
train_sampler = get_tpu_sampler(train_dataset)
else:
train_sampler = (
RandomSampler(train_dataset)
if self.args.local_rank == -1
else DistributedSampler(train_dataset)
)
data_loader = DataLoaderWithTaskname(
task_name=task_name,
data_loader=DataLoader(
train_dataset,
batch_size=self.args.train_batch_size,
sampler=train_sampler,
collate_fn=self.data_collator.collate_batch,
),
)
return data_loader
def get_train_dataloader(self):
"""
Returns a MultitaskDataloader, which is not actually a Dataloader
but an iterable that returns a generator that samples from each
task Dataloader
"""
return MultitaskDataloader({
task_name: self.get_single_train_dataloader(
task_name, task_dataset)
for task_name, task_dataset in self.train_dataset.items()
})