Spaces:
Sleeping
Sleeping
File size: 1,691 Bytes
e45d058 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from typing import Any, Dict, Optional
import torch
from torch import Tensor
from torchmetrics import Metric
class NumTokens(Metric):
"""Keep track of how many tokens we've seen.
"""
# TODO: how do we prevent the reset between the epochs? The reset happens on the 1st batch
# of the next epoch.
# Right now the hack is that we override reset(), which would mess up the forward method.
# We then override forward to do the right thing.
is_differentiable = False
higher_is_better = False
full_state_update = False
count: Tensor
def __init__(self, **kwargs: Dict[str, Any]):
super().__init__(**kwargs)
self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum",
persistent=True) # We want the count to be saved to state-dict
def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore
self.count += target.numel()
def compute(self) -> Tensor:
return self.count
def reset(self):
count = self.count
super().reset()
self.count = count
# Adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/metric.py
def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:
"""forward computation using single call to `update` to calculate the metric value on the current batch and
accumulate global state.
This can be done when the global metric state is a sinple reduction of batch states.
"""
self.update(*args, **kwargs)
return self.compute()
|