| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import pickle |
| | import sys |
| | from functools import partial |
| | from typing import Callable, Optional |
| |
|
| | import numpy as np |
| | import pytest |
| | import torch |
| | from scipy.stats import entropy |
| | from torch.distributions.utils import logits_to_probs |
| | from torch.multiprocessing import Pool, set_start_method |
| | from torchmetrics import Metric |
| |
|
| | from nemo.collections.common.metrics import GlobalAverageLossMetric, Perplexity |
| |
|
| | NUM_PROCESSES = 2 |
| | NUM_BATCHES = 10 |
| | BATCH_SIZE = 16 |
| | NUM_CLASSES = 5 |
| | EXTRA_DIM = 3 |
| | THRESHOLD = 0.5 |
| |
|
| |
|
| | def setup_ddp(rank, world_size): |
| | """ Setup ddp enviroment """ |
| | os.environ["MASTER_ADDR"] = 'localhost' |
| | os.environ['MASTER_PORT'] = '8088' |
| |
|
| | if torch.distributed.is_available() and sys.platform not in ['win32', 'cygwin']: |
| | torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) |
| |
|
| |
|
| | def _class_test( |
| | rank: int, |
| | worldsize: int, |
| | preds: torch.Tensor, |
| | target: torch.Tensor, |
| | metric_class: Metric, |
| | sk_metric: Callable, |
| | dist_sync_on_step: bool, |
| | metric_args: dict = {}, |
| | check_dist_sync_on_step: bool = True, |
| | check_batch: bool = True, |
| | atol: float = 1e-8, |
| | ): |
| | """ Utility function doing the actual comparison between lightning class metric |
| | and reference metric. |
| | Args: |
| | rank: rank of current process |
| | worldsize: number of processes |
| | preds: torch tensor with predictions |
| | target: torch tensor with targets |
| | metric_class: lightning metric class that should be tested |
| | sk_metric: callable function that is used for comparison |
| | dist_sync_on_step: bool, if true will synchronize metric state across |
| | processes at each ``forward()`` |
| | metric_args: dict with additional arguments used for class initialization |
| | check_dist_sync_on_step: bool, if true will check if the metric is also correctly |
| | calculated per batch per device (and not just at the end) |
| | check_batch: bool, if true will check if the metric is also correctly |
| | calculated across devices for each batch (and not just at the end) |
| | """ |
| | |
| | metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) |
| |
|
| | |
| | pickled_metric = pickle.dumps(metric) |
| | metric = pickle.loads(pickled_metric) |
| |
|
| | for i in range(rank, NUM_BATCHES, worldsize): |
| | batch_result = metric(preds[i], target[i]) |
| |
|
| | if metric.dist_sync_on_step: |
| | if rank == 0: |
| | ddp_preds = torch.stack([preds[i + r] for r in range(worldsize)]) |
| | ddp_target = torch.stack([target[i + r] for r in range(worldsize)]) |
| | sk_batch_result = sk_metric(ddp_preds, ddp_target) |
| | |
| | if check_dist_sync_on_step: |
| | assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) |
| | else: |
| | sk_batch_result = sk_metric(preds[i], target[i]) |
| | |
| | if check_batch: |
| | assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) |
| |
|
| | |
| | result = metric.compute() |
| | assert isinstance(result, torch.Tensor) |
| |
|
| | total_preds = torch.stack([preds[i] for i in range(NUM_BATCHES)]) |
| | total_target = torch.stack([target[i] for i in range(NUM_BATCHES)]) |
| | sk_result = sk_metric(total_preds, total_target) |
| |
|
| | |
| | assert np.allclose(result.numpy(), sk_result, atol=atol) |
| |
|
| |
|
| | def _functional_test( |
| | preds: torch.Tensor, |
| | target: torch.Tensor, |
| | metric_functional: Callable, |
| | sk_metric: Callable, |
| | metric_args: dict = {}, |
| | atol: float = 1e-8, |
| | ): |
| | """ Utility function doing the actual comparison between lightning functional metric |
| | and reference metric. |
| | Args: |
| | preds: torch tensor with predictions |
| | target: torch tensor with targets |
| | metric_functional: lightning metric functional that should be tested |
| | sk_metric: callable function that is used for comparison |
| | metric_args: dict with additional arguments used for class initialization |
| | """ |
| | metric = partial(metric_functional, **metric_args) |
| |
|
| | for i in range(NUM_BATCHES): |
| | lightning_result = metric(preds[i], target[i]) |
| | sk_result = sk_metric(preds[i], target[i]) |
| |
|
| | |
| | assert np.allclose(lightning_result.numpy(), sk_result, atol=atol) |
| |
|
| |
|
| | class MetricTester: |
| | """ Class used for efficiently run alot of parametrized tests in ddp mode. |
| | Makes sure that ddp is only setup once and that pool of processes are |
| | used for all tests. |
| | All tests should subclass from this and implement a new method called |
| | `test_metric_name` |
| | where the method `self.run_metric_test` is called inside. |
| | """ |
| |
|
| | atol = 1e-8 |
| |
|
| | def setup_class(self): |
| | """ Setup the metric class. This will spawn the pool of workers that are |
| | used for metric testing and setup_ddp |
| | """ |
| | try: |
| | set_start_method('spawn') |
| | except RuntimeError: |
| | pass |
| | self.poolSize = NUM_PROCESSES |
| | self.pool = Pool(processes=self.poolSize) |
| | self.pool.starmap(setup_ddp, [(rank, self.poolSize) for rank in range(self.poolSize)]) |
| |
|
| | def teardown_class(self): |
| | """ Close pool of workers """ |
| | self.pool.close() |
| | self.pool.join() |
| |
|
| | def run_functional_metric_test( |
| | self, |
| | preds: torch.Tensor, |
| | target: torch.Tensor, |
| | metric_functional: Callable, |
| | sk_metric: Callable, |
| | metric_args: dict = {}, |
| | ): |
| | """ Main method that should be used for testing functions. Call this inside |
| | testing method |
| | Args: |
| | preds: torch tensor with predictions |
| | target: torch tensor with targets |
| | metric_functional: lightning metric class that should be tested |
| | sk_metric: callable function that is used for comparison |
| | metric_args: dict with additional arguments used for class initialization |
| | """ |
| | _functional_test( |
| | preds=preds, |
| | target=target, |
| | metric_functional=metric_functional, |
| | sk_metric=sk_metric, |
| | metric_args=metric_args, |
| | atol=self.atol, |
| | ) |
| |
|
| | def run_class_metric_test( |
| | self, |
| | ddp: bool, |
| | preds: torch.Tensor, |
| | target: torch.Tensor, |
| | metric_class: Metric, |
| | sk_metric: Callable, |
| | dist_sync_on_step: bool, |
| | metric_args: dict = {}, |
| | check_dist_sync_on_step: bool = True, |
| | check_batch: bool = True, |
| | ): |
| | """ Main method that should be used for testing class. Call this inside testing |
| | methods. |
| | Args: |
| | ddp: bool, if running in ddp mode or not |
| | preds: torch tensor with predictions |
| | target: torch tensor with targets |
| | metric_class: lightning metric class that should be tested |
| | sk_metric: callable function that is used for comparison |
| | dist_sync_on_step: bool, if true will synchronize metric state across |
| | processes at each ``forward()`` |
| | metric_args: dict with additional arguments used for class initialization |
| | check_dist_sync_on_step: bool, if true will check if the metric is also correctly |
| | calculated per batch per device (and not just at the end) |
| | check_batch: bool, if true will check if the metric is also correctly |
| | calculated across devices for each batch (and not just at the end) |
| | """ |
| | if ddp: |
| | if sys.platform == "win32": |
| | pytest.skip("DDP not supported on windows") |
| |
|
| | self.pool.starmap( |
| | partial( |
| | _class_test, |
| | preds=preds, |
| | target=target, |
| | metric_class=metric_class, |
| | sk_metric=sk_metric, |
| | dist_sync_on_step=dist_sync_on_step, |
| | metric_args=metric_args, |
| | check_dist_sync_on_step=check_dist_sync_on_step, |
| | check_batch=check_batch, |
| | atol=self.atol, |
| | ), |
| | [(rank, self.poolSize) for rank in range(self.poolSize)], |
| | ) |
| | else: |
| | _class_test( |
| | 0, |
| | 1, |
| | preds=preds, |
| | target=target, |
| | metric_class=metric_class, |
| | sk_metric=sk_metric, |
| | dist_sync_on_step=dist_sync_on_step, |
| | metric_args=metric_args, |
| | check_dist_sync_on_step=check_dist_sync_on_step, |
| | check_batch=check_batch, |
| | atol=self.atol, |
| | ) |
| |
|
| |
|
| | def reference_perplexity_func(probs): |
| | ent = entropy(probs, axis=-1) |
| | ppl = np.exp(ent) |
| | return ppl.mean() |
| |
|
| |
|
| | def _perplexity_class_test( |
| | rank: int, |
| | worldsize: int, |
| | probs: Optional[torch.Tensor], |
| | logits: Optional[torch.Tensor], |
| | dist_sync_on_step: bool, |
| | metric_args: dict = {}, |
| | check_dist_sync_on_step: bool = True, |
| | check_batch: bool = True, |
| | atol: float = 1e-8, |
| | ): |
| | """ Utility function doing the actual comparison between lightning class metric |
| | and reference metric. |
| | Args: |
| | rank: rank of current process |
| | worldsize: number of processes |
| | probs: torch tensor with probabilities |
| | logits: torch tensor with logits. The function checks ``probs`` and ``logits are mutually exclusive for |
| | ``Perplexity`` metric. |
| | dist_sync_on_step: bool, if true will synchronize metric state across |
| | processes at each ``forward()`` |
| | metric_args: dict with additional arguments used for class initialization |
| | check_dist_sync_on_step: bool, if true will check if the metric is also correctly |
| | calculated per batch per device (and not just at the end) |
| | check_batch: bool, if true will check if the metric is also correctly |
| | calculated across devices for each batch (and not just at the end) |
| | """ |
| | |
| | perplexity = Perplexity(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) |
| | if (probs is None) == (logits is None): |
| | with pytest.raises(ValueError): |
| | perplexity(probs, logits) |
| | return |
| |
|
| | |
| | pickled_metric = pickle.dumps(perplexity) |
| | perplexity = pickle.loads(pickled_metric) |
| |
|
| | for i in range(rank, NUM_BATCHES, worldsize): |
| | batch_result = perplexity(None if probs is None else probs[i], None if logits is None else logits[i]) |
| |
|
| | if perplexity.dist_sync_on_step: |
| | if rank == 0: |
| | if probs is not None: |
| | ddp_probs = torch.stack([probs[i + r] for r in range(worldsize)]) |
| | else: |
| | ddp_logits = torch.stack([logits[i + r] for r in range(worldsize)]) |
| | ddp_probs = logits_to_probs(ddp_logits, is_binary=False) |
| | sk_batch_result = reference_perplexity_func(ddp_probs) |
| | |
| | if check_dist_sync_on_step: |
| | assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) |
| | else: |
| | if probs is None: |
| | p = logits_to_probs(logits[i], is_binary=False) |
| | else: |
| | p = probs[i] |
| | sk_batch_result = reference_perplexity_func(p) |
| | |
| | if check_batch: |
| | assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) |
| |
|
| | assert (probs is None) != (logits is None) |
| | |
| | result = perplexity.compute() |
| | assert isinstance(result, torch.Tensor) |
| |
|
| | if probs is None: |
| | probs = logits_to_probs(logits, is_binary=False) |
| | sk_result = reference_perplexity_func(probs) |
| |
|
| | |
| | assert np.allclose(result.numpy(), sk_result, atol=atol) |
| |
|
| |
|
| | class PerplexityTester(MetricTester): |
| | def run_class_perplexity_test( |
| | self, |
| | ddp: bool, |
| | probs: Optional[torch.Tensor], |
| | logits: Optional[torch.Tensor], |
| | dist_sync_on_step: bool, |
| | metric_args: dict = {}, |
| | check_dist_sync_on_step: bool = True, |
| | check_batch: bool = True, |
| | ): |
| | """ Main method that should be used for testing class. Call this inside testing |
| | methods. |
| | Args: |
| | ddp: bool, if running in ddp mode or not |
| | probs: torch tensor with probabilities. |
| | logits: torch tensor with logits. This test checks that probs and logits are mutually exclusive for |
| | ``Perplexity`` metric. |
| | dist_sync_on_step: bool, if true will synchronize metric state across |
| | processes at each ``forward()`` |
| | metric_args: dict with additional arguments used for class initialization |
| | check_dist_sync_on_step: bool, if true will check if the metric is also correctly |
| | calculated per batch per device (and not just at the end) |
| | check_batch: bool, if true will check if the metric is also correctly |
| | calculated across devices for each batch (and not just at the end) |
| | """ |
| | if ddp: |
| | if sys.platform == "win32": |
| | pytest.skip("DDP not supported on windows") |
| |
|
| | self.pool.starmap( |
| | partial( |
| | _perplexity_class_test, |
| | probs=probs, |
| | logits=logits, |
| | dist_sync_on_step=dist_sync_on_step, |
| | metric_args=metric_args, |
| | check_dist_sync_on_step=check_dist_sync_on_step, |
| | check_batch=check_batch, |
| | atol=self.atol, |
| | ), |
| | [(rank, self.poolSize) for rank in range(self.poolSize)], |
| | ) |
| | else: |
| | _perplexity_class_test( |
| | 0, |
| | 1, |
| | probs=probs, |
| | logits=logits, |
| | dist_sync_on_step=dist_sync_on_step, |
| | metric_args=metric_args, |
| | check_dist_sync_on_step=check_dist_sync_on_step, |
| | check_batch=check_batch, |
| | atol=self.atol, |
| | ) |
| |
|
| |
|
| | def reference_loss_func(loss_sum_or_avg: torch.Tensor, num_measurements: torch.Tensor, take_avg_loss: bool): |
| | """ |
| | Returns average loss for data from``loss_sum_or_avg``. This function sums all losses from ``loss_sum_or_avg`` and |
| | divides the sum by the sum of ``num_measurements`` elements. |
| | |
| | If ``take_avg_loss`` is ``True`` then ``loss_sum_or_avg[i]`` elements are mean values of ``num_measurements[i]`` |
| | losses. In that case before computing sum of losses each element of ``loss_sum_or_avg`` is multiplied by |
| | corresponding element of ``num_measurements``. |
| | |
| | If ``num_measurements`` sum is zero then the function returns NaN tensor. |
| | |
| | The function is used for testing ``nemo.collections.common.metrics.GlobalAverageLossMetric`` class. |
| | |
| | Args: |
| | loss_sum_or_avg: a one dimensional float ``torch.Tensor``. Sums or mean values of loss. |
| | num_measurements: a one dimensional integer ``torch.Tensor``. Number of values on which sums of means in |
| | ``loss_sum_or_avg`` are calculated. |
| | take_avg_loss: if ``True`` then ``loss_sum_or_avg`` contains mean losses else ``loss_sum_or_avg`` contains |
| | sums of losses. |
| | """ |
| | loss_sum_or_avg = loss_sum_or_avg.clone().detach() |
| | if take_avg_loss: |
| | loss_sum_or_avg *= num_measurements |
| | nm_sum = num_measurements.sum() |
| | if nm_sum.eq(0): |
| | return torch.tensor(float('nan')) |
| | return loss_sum_or_avg.sum() / nm_sum |
| |
|
| |
|
| | def _loss_class_test( |
| | rank: int, |
| | worldsize: int, |
| | loss_sum_or_avg: Optional[torch.Tensor], |
| | num_measurements: Optional[torch.Tensor], |
| | dist_sync_on_step: bool, |
| | take_avg_loss: bool, |
| | check_dist_sync_on_step: bool = True, |
| | check_batch: bool = True, |
| | atol: float = 1e-8, |
| | ): |
| | """ Utility function doing the actual comparison between lightning class metric |
| | and reference metric. |
| | Args: |
| | rank: rank of current process |
| | worldsize: number of processes |
| | loss_sum_or_avg: a one dimensional float torch tensor with loss sums or means. |
| | num_measurements: a one dimensional integer torch tensor with number of values on which sums or means from |
| | ``loss_sum_or_avg`` were computed. |
| | dist_sync_on_step: bool, if true will synchronize metric state across processes at each call of the |
| | method :meth:`forward()` |
| | take_avg_loss: dict with additional arguments used for class initialization |
| | check_dist_sync_on_step: bool, if true will check if the metric is also correctly |
| | calculated per batch per device (and not just at the end) |
| | check_batch: bool, if true will check if the metric is also correctly |
| | calculated across devices for each batch (and not just at the end) |
| | """ |
| | |
| | loss_metric = GlobalAverageLossMetric( |
| | compute_on_step=True, dist_sync_on_step=dist_sync_on_step, take_avg_loss=take_avg_loss |
| | ) |
| |
|
| | |
| | pickled_metric = pickle.dumps(loss_metric) |
| | loss_metric = pickle.loads(pickled_metric) |
| | for i in range(rank, NUM_BATCHES, worldsize): |
| | batch_result = loss_metric(loss_sum_or_avg[i], num_measurements[i]) |
| | if loss_metric.dist_sync_on_step: |
| | if rank == 0: |
| | ddp_loss_sum_or_avg = torch.stack([loss_sum_or_avg[i + r] for r in range(worldsize)]) |
| | ddp_num_measurements = torch.stack([num_measurements[i + r] for r in range(worldsize)]) |
| | sk_batch_result = reference_loss_func(ddp_loss_sum_or_avg, ddp_num_measurements, take_avg_loss) |
| | |
| | if check_dist_sync_on_step: |
| | if sk_batch_result.isnan(): |
| | assert batch_result.isnan() |
| | else: |
| | assert np.allclose( |
| | batch_result.numpy(), sk_batch_result, atol=atol |
| | ), f"batch_result = {batch_result.numpy()}, sk_batch_result = {sk_batch_result}, i = {i}" |
| | else: |
| | ls = loss_sum_or_avg[i : i + 1] |
| | nm = num_measurements[i : i + 1] |
| | sk_batch_result = reference_loss_func(ls, nm, take_avg_loss) |
| | |
| | if check_batch: |
| | if sk_batch_result.isnan(): |
| | assert batch_result.isnan() |
| | else: |
| | assert np.allclose( |
| | batch_result.numpy(), sk_batch_result, atol=atol |
| | ), f"batch_result = {batch_result.numpy()}, sk_batch_result = {sk_batch_result}, i = {i}" |
| | |
| | result = loss_metric.compute() |
| | assert isinstance(result, torch.Tensor) |
| | sk_result = reference_loss_func(loss_sum_or_avg, num_measurements, take_avg_loss) |
| |
|
| | |
| | if sk_result.isnan(): |
| | assert result.isnan() |
| | else: |
| | assert np.allclose(result.numpy(), sk_result, atol=atol), f"result = {result.numpy()}, sk_result = {sk_result}" |
| |
|
| |
|
| | class LossTester(MetricTester): |
| | def run_class_loss_test( |
| | self, |
| | ddp: bool, |
| | loss_sum_or_avg: torch.Tensor, |
| | num_measurements: torch.Tensor, |
| | dist_sync_on_step: bool, |
| | take_avg_loss: bool, |
| | check_dist_sync_on_step: bool = True, |
| | check_batch: bool = True, |
| | ): |
| | if ddp: |
| | if sys.platform == "win32": |
| | pytest.skip("DDP not supported on windows") |
| | self.pool.starmap( |
| | partial( |
| | _loss_class_test, |
| | loss_sum_or_avg=loss_sum_or_avg, |
| | num_measurements=num_measurements, |
| | dist_sync_on_step=dist_sync_on_step, |
| | take_avg_loss=take_avg_loss, |
| | check_dist_sync_on_step=check_dist_sync_on_step, |
| | check_batch=check_batch, |
| | atol=self.atol, |
| | ), |
| | [(rank, self.poolSize) for rank in range(self.poolSize)], |
| | ) |
| | else: |
| | _loss_class_test( |
| | 0, |
| | 1, |
| | loss_sum_or_avg=loss_sum_or_avg, |
| | num_measurements=num_measurements, |
| | dist_sync_on_step=dist_sync_on_step, |
| | take_avg_loss=take_avg_loss, |
| | check_dist_sync_on_step=check_dist_sync_on_step, |
| | check_batch=check_batch, |
| | atol=self.atol, |
| | ) |
| |
|