import transformers import torch import torch.nn as nn from torch.utils.data.sampler import RandomSampler from torch.utils.data.distributed import DistributedSampler from torch.utils.data.dataloader import DataLoader from transformers.data.data_collator import DataCollator from transformers.data.data_collator import DataCollatorWithPadding, InputDataClass from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from transformers import is_torch_tpu_available import numpy as np class MultitaskModel(transformers.PreTrainedModel): def __init__(self, encoder, taskmodels_dict): """ Setting MultitaskModel up as a PretrainedModel allows us to take better advantage of Trainer features """ super().__init__(transformers.PretrainedConfig()) self.encoder = encoder self.taskmodels_dict = nn.ModuleDict(taskmodels_dict) @classmethod def create(cls, model_name, model_type_dict, model_config_dict): """ This creates a MultitaskModel using the model class and config objects from single-task models. We do this by creating each single-task model, and having them share the same encoder transformer. """ shared_encoder = None taskmodels_dict = {} do = nn.Dropout(p=0.2) for task_name, model_type in model_type_dict.items(): model = model_type.from_pretrained( model_name, config=model_config_dict[task_name], ) if shared_encoder is None: shared_encoder = getattr( model, cls.get_encoder_attr_name(model)) else: setattr(model, cls.get_encoder_attr_name( model), shared_encoder) taskmodels_dict[task_name] = model return cls(encoder=shared_encoder, taskmodels_dict=taskmodels_dict) @classmethod def get_encoder_attr_name(cls, model): """ The encoder transformer is named differently in each model "architecture". This method lets us get the name of the encoder attribute """ model_class_name = model.__class__.__name__ if model_class_name.startswith("Bert"): return "bert" elif model_class_name.startswith("Roberta"): return "roberta" elif model_class_name.startswith("Albert"): return "albert" else: raise KeyError(f"Add support for new model {model_class_name}") def forward(self, task_name, **kwargs): return self.taskmodels_dict[task_name](**kwargs) def get_model(self, task_name): return self.taskmodels_dict[task_name] class NLPDataCollator(DataCollatorWithPadding): # DataCollatorWithPadding """ Extending the existing DataCollator to work with NLP dataset batches """ def collate_batch(self, features: List[Union[InputDataClass, Dict]]) -> Dict[str, torch.Tensor]: first = features[0] batch = None if isinstance(first, dict): # NLP data sets current works presents features as lists of dictionary # (one per example), so we will adapt the collate_batch logic for that if "labels" in first and first["labels"] is not None: if first["labels"].dtype == torch.int64: labels = torch.tensor([f["labels"] for f in features], dtype=torch.long) else: labels = torch.tensor([f["labels"] for f in features], dtype=torch.float) batch = {"labels": labels} for k, v in first.items(): if k != "labels" and v is not None and not isinstance(v, str): batch[k] = torch.stack([f[k] for f in features]) return batch else: # otherwise, revert to using the default collate_batch return DataCollatorWithPadding().collate_batch(features) class StrIgnoreDevice(str): """ This is a hack. The Trainer is going call .to(device) on every input value, but we need to pass in an additional `task_name` string. This prevents it from throwing an error """ def to(self, device): return self class DataLoaderWithTaskname: """ Wrapper around a DataLoader to also yield a task name """ def __init__(self, task_name, data_loader): self.task_name = task_name self.data_loader = data_loader self.batch_size = data_loader.batch_size self.dataset = data_loader.dataset def __len__(self): return len(self.data_loader) def __iter__(self): for batch in self.data_loader: batch["task_name"] = StrIgnoreDevice(self.task_name) yield batch class MultitaskDataloader: """ Data loader that combines and samples from multiple single-task data loaders. """ def __init__(self, dataloader_dict): self.dataloader_dict = dataloader_dict self.num_batches_dict = { task_name: len(dataloader) for task_name, dataloader in self.dataloader_dict.items() } self.task_name_list = list(self.dataloader_dict) self.dataset = [None] * sum( len(dataloader.dataset) for dataloader in self.dataloader_dict.values() ) def __len__(self): return sum(self.num_batches_dict.values()) def __iter__(self): """ For each batch, sample a task, and yield a batch from the respective task Dataloader. We use size-proportional sampling, but you could easily modify this to sample from some-other distribution. """ task_choice_list = [] for i, task_name in enumerate(self.task_name_list): task_choice_list += [i] * self.num_batches_dict[task_name] task_choice_list = np.array(task_choice_list) np.random.shuffle(task_choice_list) dataloader_iter_dict = { task_name: iter(dataloader) for task_name, dataloader in self.dataloader_dict.items() } for task_choice in task_choice_list: task_name = self.task_name_list[task_choice] yield next(dataloader_iter_dict[task_name]) class MultitaskTrainer(transformers.Trainer): def get_single_train_dataloader(self, task_name, train_dataset): """ Create a single-task data loader that also yields task names """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") if False and is_torch_tpu_available(): train_sampler = get_tpu_sampler(train_dataset) else: train_sampler = ( RandomSampler(train_dataset) if self.args.local_rank == -1 else DistributedSampler(train_dataset) ) data_loader = DataLoaderWithTaskname( task_name=task_name, data_loader=DataLoader( train_dataset, batch_size=self.args.train_batch_size, sampler=train_sampler, collate_fn=self.data_collator.collate_batch, ), ) return data_loader def get_train_dataloader(self): """ Returns a MultitaskDataloader, which is not actually a Dataloader but an iterable that returns a generator that samples from each task Dataloader """ return MultitaskDataloader({ task_name: self.get_single_train_dataloader( task_name, task_dataset) for task_name, task_dataset in self.train_dataset.items() })