# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This test is meant to be run in on an instance with TPUs like this: # # python examples/pytorch/xla_spawn.py --num_cores=8 tests/test_trainer_tpu.py # # Replace 8 with the number of TPU cores you have. # import sys from typing import Dict from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available from transformers.utils import logging logger = logging.get_logger(__name__) if is_torch_available(): import torch from torch import nn from torch.utils.data import Dataset from transformers import Trainer class DummyDataset(Dataset): def __init__(self, length: int = 101): self.length = length def __len__(self): return self.length def __getitem__(self, i) -> int: return i class DummyDataCollator: def __call__(self, features): return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)} class DummyModel(nn.Module): def __init__(self): super().__init__() # Add some (unused) params otherwise DDP will complain. self.fc = nn.Linear(120, 80) def forward(self, input_ids, labels=None): if labels is not None: return torch.tensor(0.0, device=input_ids.device), input_ids else: return input_ids def main(): parser = HfArgumentParser((TrainingArguments,)) sys.argv += ["--output_dir", "./examples"] training_args = parser.parse_args_into_dataclasses()[0] logger.warning( f"Process rank: {training_args.local_rank}, device: {training_args.device}, " f"tpu_num_cores: {training_args.tpu_num_cores}", ) # Essentially, what we want to verify in the distributed case is # that we get all samples back, in the right order. # (this is crucial for prediction for instance) for dataset_length in [1001, 256, 15]: dataset = DummyDataset(dataset_length) def compute_metrics(p: EvalPrediction) -> Dict: sequential = list(range(len(dataset))) success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential return {"success": success} trainer = Trainer( model=DummyModel(), args=training_args, data_collator=DummyDataCollator(), eval_dataset=dataset, compute_metrics=compute_metrics, ) metrics = trainer.evaluate() logger.info(metrics) if metrics["eval_success"] is not True: logger.error(metrics) exit(1) p = trainer.predict(dataset) logger.info(p.metrics) if p.metrics["test_success"] is not True: logger.error(p.metrics) exit(1) trainer.args.eval_accumulation_steps = 2 metrics = trainer.evaluate() logger.info(metrics) if metrics["eval_success"] is not True: logger.error(metrics) exit(1) p = trainer.predict(dataset) logger.info(p.metrics) if p.metrics["test_success"] is not True: logger.error(p.metrics) exit(1) trainer.args.eval_accumulation_steps = None logger.info("🔥 All distributed tests successful") def _mp_fn(index): # For xla_spawn (TPUs) main() if __name__ == "__main__": main()