from torch import Tensor, nn from os.path import join as pjoin from .mr import MRMetrics from .t2m import TM2TMetrics from .mm import MMMetrics from .m2t import M2TMetrics from .m2m import PredMetrics class BaseMetrics(nn.Module): def __init__(self, cfg, datamodule, debug, **kwargs) -> None: super().__init__() njoints = datamodule.njoints data_name = datamodule.name if data_name in ["humanml3d", "kit"]: self.TM2TMetrics = TM2TMetrics( cfg=cfg, dataname=data_name, diversity_times=30 if debug else cfg.METRIC.DIVERSITY_TIMES, dist_sync_on_step=cfg.METRIC.DIST_SYNC_ON_STEP, ) self.M2TMetrics = M2TMetrics( cfg=cfg, w_vectorizer=datamodule.hparams.w_vectorizer, diversity_times=30 if debug else cfg.METRIC.DIVERSITY_TIMES, dist_sync_on_step=cfg.METRIC.DIST_SYNC_ON_STEP) self.MMMetrics = MMMetrics( cfg=cfg, mm_num_times=cfg.METRIC.MM_NUM_TIMES, dist_sync_on_step=cfg.METRIC.DIST_SYNC_ON_STEP, ) self.MRMetrics = MRMetrics( njoints=njoints, jointstype=cfg.DATASET.JOINT_TYPE, dist_sync_on_step=cfg.METRIC.DIST_SYNC_ON_STEP, ) self.PredMetrics = PredMetrics( cfg=cfg, njoints=njoints, jointstype=cfg.DATASET.JOINT_TYPE, dist_sync_on_step=cfg.METRIC.DIST_SYNC_ON_STEP, task=cfg.model.params.task, )