|
import random |
|
|
|
import pytest |
|
import torch |
|
|
|
from ding.torch_utils.metric import levenshtein_distance, hamming_distance |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestMetric(): |
|
|
|
def test_levenshtein_distance(self): |
|
r''' |
|
Overview: |
|
Test the Levenshtein Distance |
|
''' |
|
pred = torch.LongTensor([1, 4, 6, 4, 1]) |
|
target1 = torch.LongTensor([1, 6, 4, 4, 1]) |
|
distance = levenshtein_distance(pred, target1) |
|
assert (distance.item() == 2) |
|
|
|
target2 = torch.LongTensor([]) |
|
distance = levenshtein_distance(pred, target2) |
|
assert (distance.item() == 5) |
|
|
|
target3 = torch.LongTensor([6, 4, 1]) |
|
distance = levenshtein_distance(pred, target3) |
|
assert (distance.item() == 2) |
|
target3 = torch.LongTensor([6, 4, 1]) |
|
distance = levenshtein_distance(pred, target3, pred, target3, extra_fn=lambda x, y: x + y) |
|
assert distance.item() == 13 |
|
target4 = torch.LongTensor([1, 4, 1]) |
|
distance = levenshtein_distance(pred, target4, pred, target4, extra_fn=lambda x, y: x + y) |
|
assert distance.item() == 14 |
|
|
|
def test_hamming_distance(self): |
|
r''' |
|
Overview: |
|
Test the Hamming Distance |
|
''' |
|
base = torch.zeros(8).long() |
|
index = [i for i in range(8)] |
|
for i in range(2): |
|
pred_idx = random.sample(index, 4) |
|
target_idx = random.sample(index, 4) |
|
pred = base.clone() |
|
pred[pred_idx] = 1 |
|
target = base.clone() |
|
target[target_idx] = 1 |
|
pred = pred.unsqueeze(0) |
|
target = target.unsqueeze(0) |
|
distance = hamming_distance(pred, target) |
|
diff = len(set(pred_idx).union(set(target_idx)) - set(pred_idx).intersection(set(target_idx))) |
|
assert (distance.item() == diff) |
|
|