File size: 5,332 Bytes
f7009b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from .diffusion import DDIMSampler, DDPMSampler, GaussianDiffusionTrainer
from .denoiser import OneDimCNN
from torch.nn import functional as F
from abc import abstractmethod
from torch import nn
import torch




class PDiff(nn.Module):
    config = {}

    def __init__(self, sequence_length):
        super().__init__()
        self.sequence_length = sequence_length
        self.net = OneDimCNN(
            layer_channels=self.config["layer_channels"],
            model_dim=self.config["model_dim"],
            kernel_size=self.config["kernel_size"],
        )
        self.diffusion_trainer = GaussianDiffusionTrainer(
            model=self.net,
            beta=self.config["beta"],
            T=self.config["T"]
        )
        self.diffusion_sampler = self.config["sample_mode"](
            model=self.net,
            beta=self.config["beta"],
            T=self.config["T"]
        )

    def forward(self, x=None, c=0., **kwargs):
        if kwargs.get("sample"):
            del kwargs["sample"]
            return self.sample(x, c, **kwargs)
        x = x.view(-1, x.size(-1))
        loss = self.diffusion_trainer(x, c, **kwargs)
        return loss

    @torch.no_grad()
    def sample(self, x=None, c=0., **kwargs):
        if x is None:
            x = torch.randn((1, self.config["model_dim"]), device=self.device)
        x_shape = x.shape
        x = x.view(-1, x.size(-1))
        return self.diffusion_sampler(x, c, **kwargs).view(x_shape)

    @property
    def device(self):
        return next(self.parameters()).device




class OneDimVAE(nn.Module):
    def __init__(self, d_model, d_latent, sequence_length, kernel_size=7, divide_slice_length=64):
        super(OneDimVAE, self).__init__()
        self.d_model = d_model.copy()
        self.d_latent = d_latent
        # confirm self.last_length
        sequence_length = (sequence_length // divide_slice_length + 1) * divide_slice_length \
                if sequence_length % divide_slice_length != 0 else sequence_length
        assert sequence_length % int(2 ** len(d_model)) == 0, \
                f"Please set divide_slice_length to {int(2 ** len(d_model))}."
        self.last_length = sequence_length // int(2 ** len(d_model))

        # Build Encoder
        modules = []
        in_dim = 1
        for h_dim in d_model:
            modules.append(nn.Sequential(
                nn.Conv1d(in_dim, h_dim, kernel_size, 2, kernel_size//2),
                nn.BatchNorm1d(h_dim),
                nn.LeakyReLU()
            ))
            in_dim = h_dim
        self.encoder = nn.Sequential(*modules)
        self.to_latent = nn.Linear(self.last_length * d_model[-1], d_latent)
        self.fc_mu = nn.Linear(d_latent, d_latent)
        self.fc_var = nn.Linear(d_latent, d_latent)

        # Build Decoder
        modules = []
        self.to_decode = nn.Linear(d_latent, self.last_length * d_model[-1])
        d_model.reverse()
        for i in range(len(d_model) - 1):
            modules.append(nn.Sequential(
                nn.ConvTranspose1d(d_model[i], d_model[i+1], kernel_size, 2, kernel_size//2, output_padding=1),
                nn.BatchNorm1d(d_model[i + 1]),
                nn.ELU(),
            ))
        self.decoder = nn.Sequential(*modules)
        self.final_layer = nn.Sequential(
            nn.ConvTranspose1d(d_model[-1], d_model[-1], kernel_size, 2, kernel_size//2, output_padding=1),
            nn.BatchNorm1d(d_model[-1]),
            nn.ELU(),
            nn.Conv1d(d_model[-1], 1, kernel_size, 1, kernel_size//2),
        )

    def encode(self, input, **kwargs):
        # print(input.shape)
        # assert input.shape == [batch_size, num_parameters]
        input = input[:, None, :]
        result = self.encoder(input)
        # print(result.shape)
        result = torch.flatten(result, start_dim=1)
        result = self.to_latent(result)
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)
        return mu, log_var

    def decode(self, z, **kwargs):
        # z.shape == [batch_size, d_latent]
        result = self.to_decode(z)
        result = result.view(-1, self.d_model[-1], self.last_length)
        result = self.decoder(result)
        result = self.final_layer(result)
        assert result.shape[1] == 1, f"{result.shape}"
        return result[:, 0, :]

    def reparameterize(self, mu, log_var, **kwargs):
        if kwargs.get("use_var"):
            std = torch.exp(0.5 * log_var)
            eps = torch.randn_like(std)
            if kwargs.get("manual_std") is not None:
                std = kwargs.get("manual_std")
            return eps * std + mu
        else:  # not use var
            return mu

    def encode_decode(self, input, **kwargs):
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var, **kwargs)
        recons = self.decode(z)
        return recons, input, mu, log_var

    def forward(self, x, **kwargs):
        recons, input, mu, log_var = self.encode_decode(input=x, **kwargs)
        recons_loss = F.mse_loss(recons, input)
        if kwargs.get("use_var"):
            kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
            loss = recons_loss + kwargs['kld_weight'] * kld_loss
        else:  # not use var
            loss = recons_loss
        return loss