File size: 5,332 Bytes
f7009b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from .diffusion import DDIMSampler, DDPMSampler, GaussianDiffusionTrainer
from .denoiser import OneDimCNN
from torch.nn import functional as F
from abc import abstractmethod
from torch import nn
import torch
class PDiff(nn.Module):
config = {}
def __init__(self, sequence_length):
super().__init__()
self.sequence_length = sequence_length
self.net = OneDimCNN(
layer_channels=self.config["layer_channels"],
model_dim=self.config["model_dim"],
kernel_size=self.config["kernel_size"],
)
self.diffusion_trainer = GaussianDiffusionTrainer(
model=self.net,
beta=self.config["beta"],
T=self.config["T"]
)
self.diffusion_sampler = self.config["sample_mode"](
model=self.net,
beta=self.config["beta"],
T=self.config["T"]
)
def forward(self, x=None, c=0., **kwargs):
if kwargs.get("sample"):
del kwargs["sample"]
return self.sample(x, c, **kwargs)
x = x.view(-1, x.size(-1))
loss = self.diffusion_trainer(x, c, **kwargs)
return loss
@torch.no_grad()
def sample(self, x=None, c=0., **kwargs):
if x is None:
x = torch.randn((1, self.config["model_dim"]), device=self.device)
x_shape = x.shape
x = x.view(-1, x.size(-1))
return self.diffusion_sampler(x, c, **kwargs).view(x_shape)
@property
def device(self):
return next(self.parameters()).device
class OneDimVAE(nn.Module):
def __init__(self, d_model, d_latent, sequence_length, kernel_size=7, divide_slice_length=64):
super(OneDimVAE, self).__init__()
self.d_model = d_model.copy()
self.d_latent = d_latent
# confirm self.last_length
sequence_length = (sequence_length // divide_slice_length + 1) * divide_slice_length \
if sequence_length % divide_slice_length != 0 else sequence_length
assert sequence_length % int(2 ** len(d_model)) == 0, \
f"Please set divide_slice_length to {int(2 ** len(d_model))}."
self.last_length = sequence_length // int(2 ** len(d_model))
# Build Encoder
modules = []
in_dim = 1
for h_dim in d_model:
modules.append(nn.Sequential(
nn.Conv1d(in_dim, h_dim, kernel_size, 2, kernel_size//2),
nn.BatchNorm1d(h_dim),
nn.LeakyReLU()
))
in_dim = h_dim
self.encoder = nn.Sequential(*modules)
self.to_latent = nn.Linear(self.last_length * d_model[-1], d_latent)
self.fc_mu = nn.Linear(d_latent, d_latent)
self.fc_var = nn.Linear(d_latent, d_latent)
# Build Decoder
modules = []
self.to_decode = nn.Linear(d_latent, self.last_length * d_model[-1])
d_model.reverse()
for i in range(len(d_model) - 1):
modules.append(nn.Sequential(
nn.ConvTranspose1d(d_model[i], d_model[i+1], kernel_size, 2, kernel_size//2, output_padding=1),
nn.BatchNorm1d(d_model[i + 1]),
nn.ELU(),
))
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose1d(d_model[-1], d_model[-1], kernel_size, 2, kernel_size//2, output_padding=1),
nn.BatchNorm1d(d_model[-1]),
nn.ELU(),
nn.Conv1d(d_model[-1], 1, kernel_size, 1, kernel_size//2),
)
def encode(self, input, **kwargs):
# print(input.shape)
# assert input.shape == [batch_size, num_parameters]
input = input[:, None, :]
result = self.encoder(input)
# print(result.shape)
result = torch.flatten(result, start_dim=1)
result = self.to_latent(result)
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return mu, log_var
def decode(self, z, **kwargs):
# z.shape == [batch_size, d_latent]
result = self.to_decode(z)
result = result.view(-1, self.d_model[-1], self.last_length)
result = self.decoder(result)
result = self.final_layer(result)
assert result.shape[1] == 1, f"{result.shape}"
return result[:, 0, :]
def reparameterize(self, mu, log_var, **kwargs):
if kwargs.get("use_var"):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
if kwargs.get("manual_std") is not None:
std = kwargs.get("manual_std")
return eps * std + mu
else: # not use var
return mu
def encode_decode(self, input, **kwargs):
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var, **kwargs)
recons = self.decode(z)
return recons, input, mu, log_var
def forward(self, x, **kwargs):
recons, input, mu, log_var = self.encode_decode(input=x, **kwargs)
recons_loss = F.mse_loss(recons, input)
if kwargs.get("use_var"):
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
loss = recons_loss + kwargs['kld_weight'] * kld_loss
else: # not use var
loss = recons_loss
return loss
|