Spaces:
Runtime error
Runtime error
import transformers | |
import torch | |
import torch.nn as nn | |
from torch.utils.data.sampler import RandomSampler | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.utils.data.dataloader import DataLoader | |
from transformers.data.data_collator import DataCollator | |
from transformers.data.data_collator import DataCollatorWithPadding, InputDataClass | |
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union | |
from transformers import is_torch_tpu_available | |
import numpy as np | |
class MultitaskModel(transformers.PreTrainedModel): | |
def __init__(self, encoder, taskmodels_dict): | |
""" | |
Setting MultitaskModel up as a PretrainedModel allows us | |
to take better advantage of Trainer features | |
""" | |
super().__init__(transformers.PretrainedConfig()) | |
self.encoder = encoder | |
self.taskmodels_dict = nn.ModuleDict(taskmodels_dict) | |
def create(cls, model_name, model_type_dict, model_config_dict): | |
""" | |
This creates a MultitaskModel using the model class and config objects | |
from single-task models. | |
We do this by creating each single-task model, and having them share | |
the same encoder transformer. | |
""" | |
shared_encoder = None | |
taskmodels_dict = {} | |
do = nn.Dropout(p=0.2) | |
for task_name, model_type in model_type_dict.items(): | |
model = model_type.from_pretrained( | |
model_name, | |
config=model_config_dict[task_name], | |
) | |
if shared_encoder is None: | |
shared_encoder = getattr( | |
model, cls.get_encoder_attr_name(model)) | |
else: | |
setattr(model, cls.get_encoder_attr_name( | |
model), shared_encoder) | |
taskmodels_dict[task_name] = model | |
return cls(encoder=shared_encoder, taskmodels_dict=taskmodels_dict) | |
def get_encoder_attr_name(cls, model): | |
""" | |
The encoder transformer is named differently in each model "architecture". | |
This method lets us get the name of the encoder attribute | |
""" | |
model_class_name = model.__class__.__name__ | |
if model_class_name.startswith("Bert"): | |
return "bert" | |
elif model_class_name.startswith("Roberta"): | |
return "roberta" | |
elif model_class_name.startswith("Albert"): | |
return "albert" | |
else: | |
raise KeyError(f"Add support for new model {model_class_name}") | |
def forward(self, task_name, **kwargs): | |
return self.taskmodels_dict[task_name](**kwargs) | |
def get_model(self, task_name): | |
return self.taskmodels_dict[task_name] | |
class NLPDataCollator(DataCollatorWithPadding): # DataCollatorWithPadding | |
""" | |
Extending the existing DataCollator to work with NLP dataset batches | |
""" | |
def collate_batch(self, features: List[Union[InputDataClass, Dict]]) -> Dict[str, torch.Tensor]: | |
first = features[0] | |
batch = None | |
if isinstance(first, dict): | |
# NLP data sets current works presents features as lists of dictionary | |
# (one per example), so we will adapt the collate_batch logic for that | |
if "labels" in first and first["labels"] is not None: | |
if first["labels"].dtype == torch.int64: | |
labels = torch.tensor([f["labels"] | |
for f in features], dtype=torch.long) | |
else: | |
labels = torch.tensor([f["labels"] | |
for f in features], dtype=torch.float) | |
batch = {"labels": labels} | |
for k, v in first.items(): | |
if k != "labels" and v is not None and not isinstance(v, str): | |
batch[k] = torch.stack([f[k] for f in features]) | |
return batch | |
else: | |
# otherwise, revert to using the default collate_batch | |
return DataCollatorWithPadding().collate_batch(features) | |
class StrIgnoreDevice(str): | |
""" | |
This is a hack. The Trainer is going call .to(device) on every input | |
value, but we need to pass in an additional `task_name` string. | |
This prevents it from throwing an error | |
""" | |
def to(self, device): | |
return self | |
class DataLoaderWithTaskname: | |
""" | |
Wrapper around a DataLoader to also yield a task name | |
""" | |
def __init__(self, task_name, data_loader): | |
self.task_name = task_name | |
self.data_loader = data_loader | |
self.batch_size = data_loader.batch_size | |
self.dataset = data_loader.dataset | |
def __len__(self): | |
return len(self.data_loader) | |
def __iter__(self): | |
for batch in self.data_loader: | |
batch["task_name"] = StrIgnoreDevice(self.task_name) | |
yield batch | |
class MultitaskDataloader: | |
""" | |
Data loader that combines and samples from multiple single-task | |
data loaders. | |
""" | |
def __init__(self, dataloader_dict): | |
self.dataloader_dict = dataloader_dict | |
self.num_batches_dict = { | |
task_name: len(dataloader) | |
for task_name, dataloader in self.dataloader_dict.items() | |
} | |
self.task_name_list = list(self.dataloader_dict) | |
self.dataset = [None] * sum( | |
len(dataloader.dataset) | |
for dataloader in self.dataloader_dict.values() | |
) | |
def __len__(self): | |
return sum(self.num_batches_dict.values()) | |
def __iter__(self): | |
""" | |
For each batch, sample a task, and yield a batch from the respective | |
task Dataloader. | |
We use size-proportional sampling, but you could easily modify this | |
to sample from some-other distribution. | |
""" | |
task_choice_list = [] | |
for i, task_name in enumerate(self.task_name_list): | |
task_choice_list += [i] * self.num_batches_dict[task_name] | |
task_choice_list = np.array(task_choice_list) | |
np.random.shuffle(task_choice_list) | |
dataloader_iter_dict = { | |
task_name: iter(dataloader) | |
for task_name, dataloader in self.dataloader_dict.items() | |
} | |
for task_choice in task_choice_list: | |
task_name = self.task_name_list[task_choice] | |
yield next(dataloader_iter_dict[task_name]) | |
class MultitaskTrainer(transformers.Trainer): | |
def get_single_train_dataloader(self, task_name, train_dataset): | |
""" | |
Create a single-task data loader that also yields task names | |
""" | |
if self.train_dataset is None: | |
raise ValueError("Trainer: training requires a train_dataset.") | |
if False and is_torch_tpu_available(): | |
train_sampler = get_tpu_sampler(train_dataset) | |
else: | |
train_sampler = ( | |
RandomSampler(train_dataset) | |
if self.args.local_rank == -1 | |
else DistributedSampler(train_dataset) | |
) | |
data_loader = DataLoaderWithTaskname( | |
task_name=task_name, | |
data_loader=DataLoader( | |
train_dataset, | |
batch_size=self.args.train_batch_size, | |
sampler=train_sampler, | |
collate_fn=self.data_collator.collate_batch, | |
), | |
) | |
return data_loader | |
def get_train_dataloader(self): | |
""" | |
Returns a MultitaskDataloader, which is not actually a Dataloader | |
but an iterable that returns a generator that samples from each | |
task Dataloader | |
""" | |
return MultitaskDataloader({ | |
task_name: self.get_single_train_dataloader( | |
task_name, task_dataset) | |
for task_name, task_dataset in self.train_dataset.items() | |
}) | |