| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import math |
| import os |
| from copy import deepcopy |
|
|
| import datasets |
| import evaluate |
| import torch |
| import transformers |
| from datasets import load_dataset |
| from torch.utils.data import DataLoader, IterableDataset |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
| from accelerate import Accelerator, DataLoaderConfiguration, DistributedType |
| from accelerate.data_loader import DataLoaderDispatcher |
| from accelerate.test_utils import RegressionDataset, RegressionModel, torch_device |
| from accelerate.utils import is_torch_xla_available, set_seed |
|
|
|
|
| os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" |
|
|
|
|
| class ListHandler(logging.Handler): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.logs = [] |
|
|
| def emit(self, record): |
| self.logs.append(record) |
|
|
|
|
| def get_basic_setup(accelerator, num_samples=82, batch_size=16): |
| "Returns everything needed to perform basic training" |
| set_seed(42) |
| model = RegressionModel() |
| ddp_model = deepcopy(model) |
| dset = RegressionDataset(length=num_samples) |
| dataloader = DataLoader(dset, batch_size=batch_size) |
| model.to(accelerator.device) |
| ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader) |
| return model, ddp_model, dataloader |
|
|
|
|
| def get_dataloader(accelerator: Accelerator, use_longest=False): |
| tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/mrpc-bert-base-cased") |
| dataset = load_dataset("glue", "mrpc", split="validation") |
|
|
| def tokenize_function(examples): |
| outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) |
| return outputs |
|
|
| with accelerator.main_process_first(): |
| tokenized_datasets = dataset.map( |
| tokenize_function, |
| batched=True, |
| remove_columns=["idx", "sentence1", "sentence2"], |
| ) |
|
|
| tokenized_datasets = tokenized_datasets.rename_column("label", "labels") |
|
|
| def collate_fn(examples): |
| if use_longest: |
| return tokenizer.pad(examples, padding="longest", return_tensors="pt") |
| return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt") |
|
|
| return DataLoader(tokenized_datasets, shuffle=False, collate_fn=collate_fn, batch_size=16) |
|
|
|
|
| def get_mrpc_setup(dispatch_batches, split_batches): |
| dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, split_batches=split_batches) |
| accelerator = Accelerator(dataloader_config=dataloader_config) |
| dataloader = get_dataloader(accelerator, not dispatch_batches) |
| model = AutoModelForSequenceClassification.from_pretrained( |
| "hf-internal-testing/mrpc-bert-base-cased", return_dict=True |
| ) |
| ddp_model, ddp_dataloader = accelerator.prepare(model, dataloader) |
| return { |
| "ddp": [ddp_model, ddp_dataloader, torch_device], |
| "no": [model, dataloader, accelerator.device], |
| }, accelerator |
|
|
|
|
| def generate_predictions(model, dataloader, accelerator): |
| logits_and_targets = [] |
| for batch in dataloader: |
| input, target = batch.values() |
| with torch.no_grad(): |
| logit = model(input) |
| logit, target = accelerator.gather_for_metrics((logit, target)) |
| logits_and_targets.append((logit, target)) |
| logits, targs = [], [] |
| for logit, targ in logits_and_targets: |
| logits.append(logit) |
| targs.append(targ) |
| logits, targs = torch.cat(logits), torch.cat(targs) |
| return logits, targs |
|
|
|
|
| def test_torch_metrics( |
| accelerator: Accelerator, num_samples=82, dispatch_batches=False, split_batches=False, batch_size=16 |
| ): |
| _, ddp_model, dataloader = get_basic_setup(accelerator, num_samples, batch_size) |
| logits, _ = generate_predictions(ddp_model, dataloader, accelerator) |
| assert len(logits) == num_samples, ( |
| f"Unexpected number of inputs:\n Expected: {num_samples}\n Actual: {len(logits)}" |
| ) |
|
|
|
|
| def test_mrpc(dispatch_batches: bool = False, split_batches: bool = False): |
| metric = evaluate.load("glue", "mrpc") |
| setup, accelerator = get_mrpc_setup(dispatch_batches, split_batches) |
| |
| model, dataloader, device = setup["no"] |
| model.to(device) |
| model.eval() |
| for batch in dataloader: |
| batch.to(device) |
| with torch.inference_mode(): |
| outputs = model(**batch) |
| preds = outputs.logits.argmax(dim=-1) |
| metric.add_batch(predictions=preds, references=batch["labels"]) |
| baseline = metric.compute() |
|
|
| |
| model, dataloader, device = setup["ddp"] |
| model.eval() |
| for batch in dataloader: |
| with torch.inference_mode(): |
| outputs = model(**batch) |
| preds = outputs.logits.argmax(dim=-1) |
| references = batch["labels"] |
| preds, references = accelerator.gather_for_metrics((preds, references)) |
| metric.add_batch(predictions=preds, references=references) |
| distributed = metric.compute() |
|
|
| for key in "accuracy f1".split(): |
| assert math.isclose(baseline[key], distributed[key]), ( |
| f"Baseline and Distributed are not the same for key {key}:\n\tBaseline: {baseline[key]}\n\tDistributed: {distributed[key]}\n" |
| ) |
|
|
|
|
| def test_gather_for_metrics_with_non_tensor_objects_iterable_dataset(): |
| class DummyIterableDataset(IterableDataset): |
| def __init__(self, data): |
| self.data = data |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __iter__(self): |
| yield from self.data |
|
|
| iterable_dataset = DummyIterableDataset([n for n in range(30)]) |
| dataloader = DataLoader(iterable_dataset, batch_size=4) |
| accelerator = Accelerator() |
| prepared_dataloader = accelerator.prepare(dataloader) |
|
|
| if accelerator.is_main_process: |
| logger = logging.root.manager.loggerDict["accelerate.accelerator"] |
| list_handler = ListHandler() |
| logger.addHandler(list_handler) |
|
|
| batches_for_metrics = [] |
| for batch in prepared_dataloader: |
| batches_for_metrics.append(accelerator.gather_for_metrics(batch)) |
|
|
| assert torch.cat(batches_for_metrics).size(0) == 30 |
|
|
| if accelerator.is_main_process: |
| assert len(list_handler.logs) == 0 |
| logger.removeHandler(list_handler) |
|
|
|
|
| def test_gather_for_metrics_with_iterable_dataset(): |
| class DummyIterableDataset(IterableDataset): |
| def __init__(self, data): |
| self.data = data |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __iter__(self): |
| yield from self.data |
|
|
| iterable_dataset = DummyIterableDataset(torch.as_tensor(range(30))) |
| dataloader = DataLoader(iterable_dataset, batch_size=4) |
|
|
| accelerator = Accelerator() |
| prepared_dataloader = accelerator.prepare(dataloader) |
|
|
| assert isinstance(prepared_dataloader, DataLoaderDispatcher) |
|
|
| if accelerator.is_main_process: |
| logger = logging.root.manager.loggerDict["accelerate.accelerator"] |
| list_handler = ListHandler() |
| logger.addHandler(list_handler) |
|
|
| batches_for_metrics = [] |
| for batch in prepared_dataloader: |
| batches_for_metrics.append(accelerator.gather_for_metrics(batch)) |
|
|
| assert torch.cat(batches_for_metrics).size(0) == 30 |
|
|
| if accelerator.is_main_process: |
| assert len(list_handler.logs) == 0 |
|
|
| logger.removeHandler(list_handler) |
|
|
|
|
| def test_gather_for_metrics_drop_last(): |
| accelerator = Accelerator() |
| per_device_batch_size = 5 |
| num_items = (10 * accelerator.num_processes) + 1 |
| dataloader = DataLoader(range(num_items), batch_size=per_device_batch_size, drop_last=True) |
| dataloader = accelerator.prepare(dataloader) |
|
|
| iterator = iter(dataloader) |
| next(iterator) |
| batch = next(iterator) |
| gathered_items = accelerator.gather_for_metrics(batch) |
|
|
| |
| num_expected_items = per_device_batch_size * accelerator.num_processes |
| assert gathered_items.size(0) == (num_expected_items), ( |
| f"Expected number of items: {num_expected_items}, Actual: {gathered_items.size(0)}" |
| ) |
|
|
|
|
| def main(): |
| dataloader_config = DataLoaderConfiguration(split_batches=False, dispatch_batches=False) |
| accelerator = Accelerator(dataloader_config=dataloader_config) |
| if accelerator.is_local_main_process: |
| datasets.utils.logging.set_verbosity_warning() |
| transformers.utils.logging.set_verbosity_warning() |
| else: |
| datasets.utils.logging.set_verbosity_error() |
| transformers.utils.logging.set_verbosity_error() |
| |
| |
| dispatch_batches_options = [False] if accelerator.state.distributed_type == DistributedType.XLA else [True, False] |
|
|
| |
| |
| |
| if accelerator.device.type != "cpu" and not is_torch_xla_available(): |
| if accelerator.is_local_main_process: |
| print("**Testing gather_for_metrics**") |
| for split_batches in [True, False]: |
| for dispatch_batches in dispatch_batches_options: |
| if accelerator.is_local_main_process: |
| print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`") |
| test_mrpc(dispatch_batches, split_batches) |
| accelerator.state._reset_state() |
| print("test_gather_for_metrics_with_iterable_dataset") |
| test_gather_for_metrics_with_iterable_dataset() |
| print("test gather_for_metrics_with_non_tensor_objects_iterable_dataset") |
| test_gather_for_metrics_with_non_tensor_objects_iterable_dataset() |
|
|
| |
| |
| |
| if accelerator.state.distributed_type != DistributedType.XLA: |
| if accelerator.is_local_main_process: |
| print("**Test torch metrics**") |
| for split_batches in [True, False]: |
| for dispatch_batches in dispatch_batches_options: |
| dataloader_config = DataLoaderConfiguration( |
| split_batches=split_batches, dispatch_batches=dispatch_batches |
| ) |
| accelerator = Accelerator(dataloader_config=dataloader_config) |
| if accelerator.is_local_main_process: |
| print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`, length=99") |
| test_torch_metrics(accelerator, 99) |
| accelerator.state._reset_state() |
| if accelerator.is_local_main_process: |
| print("**Test last batch is not dropped when perfectly divisible**") |
| accelerator = Accelerator() |
| test_torch_metrics(accelerator, 512) |
| accelerator.state._reset_state() |
| if accelerator.is_local_main_process: |
| print("**Test that `drop_last` is taken into account**") |
| test_gather_for_metrics_drop_last() |
| accelerator.end_training() |
| accelerator.state._reset_state() |
|
|
|
|
| def _mp_fn(index): |
| |
| main() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|