linguask / tests /test_metric.py
GitHub Action
refs/heads/ci-cd/hugging-face
8b414b0
raw
history blame contribute delete
No virus
452 Bytes
import numpy as np
import torch
from sklearn.metrics import mean_squared_error
from src.model_finetuning.losses import MCRMSELoss
def test_sklearn_metric_matches_torch():
a = torch.randn(10, 6)
b = torch.randn(10, 6)
sklearn_loss = []
for ii in range(6):
loss = mean_squared_error(a[:, ii], b[:, ii], squared=False)
sklearn_loss.append(loss)
assert np.isclose(np.mean(sklearn_loss), MCRMSELoss().forward(a, b))