File size: 4,663 Bytes
8c6b5ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import torch.nn as nn
from torch.nn import functional as F


class OptimalTransport(nn.Module):

    @staticmethod
    def distance(batch1, batch2, dist_metric="cosine"):
        if dist_metric == "cosine":
            batch1 = F.normalize(batch1, p=2, dim=1)
            batch2 = F.normalize(batch2, p=2, dim=1)
            dist_mat = 1 - torch.mm(batch1, batch2.t())
        elif dist_metric == "euclidean":
            m, n = batch1.size(0), batch2.size(0)
            dist_mat = (
                torch.pow(batch1, 2).sum(dim=1, keepdim=True).expand(m, n) +
                torch.pow(batch2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
            )
            dist_mat.addmm_(
                1, -2, batch1, batch2.t()
            )  # squared euclidean distance
        elif dist_metric == "fast_euclidean":
            batch1 = batch1.unsqueeze(-2)
            batch2 = batch2.unsqueeze(-3)
            dist_mat = torch.sum((torch.abs(batch1 - batch2))**2, -1)
        else:
            raise ValueError(
                "Unknown cost function: {}. Expected to "
                "be one of [cosine | euclidean]".format(dist_metric)
            )
        return dist_mat


class SinkhornDivergence(OptimalTransport):
    thre = 1e-3

    def __init__(
        self,
        dist_metric="cosine",
        eps=0.01,
        max_iter=5,
        bp_to_sinkhorn=False
    ):
        super().__init__()
        self.dist_metric = dist_metric
        self.eps = eps
        self.max_iter = max_iter
        self.bp_to_sinkhorn = bp_to_sinkhorn

    def forward(self, x, y):
        # x, y: two batches of data with shape (batch, dim)
        W_xy = self.transport_cost(x, y)
        W_xx = self.transport_cost(x, x)
        W_yy = self.transport_cost(y, y)
        return 2*W_xy - W_xx - W_yy

    def transport_cost(self, x, y, return_pi=False):
        C = self.distance(x, y, dist_metric=self.dist_metric)
        pi = self.sinkhorn_iterate(C, self.eps, self.max_iter, self.thre)
        if not self.bp_to_sinkhorn:
            pi = pi.detach()
        cost = torch.sum(pi * C)
        if return_pi:
            return cost, pi
        return cost

    @staticmethod
    def sinkhorn_iterate(C, eps, max_iter, thre):
        nx, ny = C.shape
        mu = torch.ones(nx, dtype=C.dtype, device=C.device) * (1.0/nx)
        nu = torch.ones(ny, dtype=C.dtype, device=C.device) * (1.0/ny)
        u = torch.zeros_like(mu)
        v = torch.zeros_like(nu)

        def M(_C, _u, _v):
            """Modified cost for logarithmic updates.
            Eq: M_{ij} = (-c_{ij} + u_i + v_j) / epsilon
            """
            return (-_C + _u.unsqueeze(-1) + _v.unsqueeze(-2)) / eps

        real_iter = 0  # check if algorithm terminates before max_iter
        # Sinkhorn iterations
        for i in range(max_iter):
            u0 = u
            u = eps * (
                torch.log(mu + 1e-8) - torch.logsumexp(M(C, u, v), dim=1)
            ) + u
            v = (
                eps * (
                    torch.log(nu + 1e-8) -
                    torch.logsumexp(M(C, u, v).permute(1, 0), dim=1)
                ) + v
            )
            err = (u - u0).abs().sum()
            real_iter += 1
            if err.item() < thre:
                break
        # Transport plan pi = diag(a)*K*diag(b)
        return torch.exp(M(C, u, v))


class MinibatchEnergyDistance(SinkhornDivergence):

    def __init__(
        self,
        dist_metric="cosine",
        eps=0.01,
        max_iter=5,
        bp_to_sinkhorn=False
    ):
        super().__init__(
            dist_metric=dist_metric,
            eps=eps,
            max_iter=max_iter,
            bp_to_sinkhorn=bp_to_sinkhorn,
        )

    def forward(self, x, y):
        x1, x2 = torch.split(x, x.size(0) // 2, dim=0)
        y1, y2 = torch.split(y, y.size(0) // 2, dim=0)
        cost = 0
        cost += self.transport_cost(x1, y1)
        cost += self.transport_cost(x1, y2)
        cost += self.transport_cost(x2, y1)
        cost += self.transport_cost(x2, y2)
        cost -= 2 * self.transport_cost(x1, x2)
        cost -= 2 * self.transport_cost(y1, y2)
        return cost


if __name__ == "__main__":
    # example: https://dfdazac.github.io/sinkhorn.html
    import numpy as np

    n_points = 5
    a = np.array([[i, 0] for i in range(n_points)])
    b = np.array([[i, 1] for i in range(n_points)])
    x = torch.tensor(a, dtype=torch.float)
    y = torch.tensor(b, dtype=torch.float)
    sinkhorn = SinkhornDivergence(
        dist_metric="euclidean", eps=0.01, max_iter=5
    )
    dist, pi = sinkhorn.transport_cost(x, y, True)
    import pdb

    pdb.set_trace()