File size: 2,861 Bytes
09b4d4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import datasets
import evaluate
from typing import List
import torch


_DESCRIPTION = """
Quantifying encoder feature distribution properties, Alignment and Uniformity on the Hypersphere.
(https://github.com/ssnl/align_uniform)
"""

_KWARGS_DESCRIPTION = """
Args:
    xs (`list` of a list of `int`): a group of embeddings
    ys (`list` of `int`): the other group of embeddings paired with the ys
    
Returns:
    "align_loss": float(align_loss_val),
    "x_unif_loss": float(x_unif_loss_v),
    "y_unif_loss": float(y_unif_loss_v),
    "unif_loss": float(unif_loss)
            
Examples:

    Example 1-A simple example
        >>> metrics = evaluate.load("ahnyeonchan/Alignment-and-Uniformity")
        >>> results = metrics.compute(xs=[[1.0, 1.0], [0.0, 1.0]], ys=[[1.0, 1.0], [0.0, 1.0]])
        >>> print(results)
        {'align_loss': 0.0, 'x_unif_loss': -2.0, 'y_unif_loss': -2.0, 'unif_loss': -2.0}
"""

_CITATION = """"""


def align_loss(x, y, alpha=2):
    return (x - y).norm(p=2, dim=1).pow(alpha).mean()


def uniform_loss(x, t=2):
    return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class AlignUniform(evaluate.Metric):
    def __init__(self, align_alpha: float = 2.0, unif_t: float = 2.0, *args, **kwargs):
        super(AlignUniform, self).__init__(*args, **kwargs)
        self.align_alpha = align_alpha
        self.unif_t = unif_t
        
    def _info(self):
        return evaluate.MetricInfo(
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            features=datasets.Features(
                {
                    "xs": datasets.Sequence(datasets.Value("float32")),
                    "ys": datasets.Sequence(datasets.Value("float32")),
                }
            ),
            reference_urls=[],
        )

    def _compute(self, xs: List[List], ys: List[List]):
        
        if isinstance(xs, torch.Tensor):
            xs = torch.Tensor(xs)
        elif isinstance(ys, list):
            xs = torch.Tensor(xs)
        else:
            raise NotImplementedError()
        
        if isinstance(ys, torch.Tensor):
            ys = torch.Tensor(ys)
        elif isinstance(ys, list):
            ys = torch.Tensor(ys)
        else:
            raise NotImplementedError()
        
        align_loss_val = align_loss(xs, ys, self.align_alpha)
        x_unif_loss_v = uniform_loss(xs, t=self.unif_t)
        y_unif_loss_v = uniform_loss(ys, t=self.unif_t)
        unif_loss = (x_unif_loss_v + y_unif_loss_v) / 2
                
        return {
            "align_loss": float(align_loss_val),
            "x_unif_loss": float(x_unif_loss_v),
            "y_unif_loss": float(y_unif_loss_v),
            "unif_loss": float(unif_loss)
        }