Upload 3 files
Browse files- app.py +75 -0
- diffusion.py +139 -0
- modules.py +243 -0
app.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from pathlib import Path
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
import requests
|
8 |
+
|
9 |
+
# Function to download the model from Google Drive
|
10 |
+
def download_file_from_google_drive(id, destination):
|
11 |
+
URL = "https://drive.google.com/uc?export=download"
|
12 |
+
session = requests.Session()
|
13 |
+
response = session.get(URL, params={'id': id}, stream=True)
|
14 |
+
token = get_confirm_token(response)
|
15 |
+
|
16 |
+
if token:
|
17 |
+
params = {'id': id, 'confirm': token}
|
18 |
+
response = session.get(URL, params=params, stream=True)
|
19 |
+
|
20 |
+
save_response_content(response, destination)
|
21 |
+
|
22 |
+
def get_confirm_token(response):
|
23 |
+
for key, value in response.cookies.items():
|
24 |
+
if key.startswith('download_warning'):
|
25 |
+
return value
|
26 |
+
return None
|
27 |
+
|
28 |
+
def save_response_content(response, destination):
|
29 |
+
CHUNK_SIZE = 32768
|
30 |
+
with open(destination, "wb") as f:
|
31 |
+
for chunk in response.iter_content(CHUNK_SIZE):
|
32 |
+
if chunk: # filter out keep-alive new chunks
|
33 |
+
f.write(chunk)
|
34 |
+
|
35 |
+
# Replace 'YOUR_FILE_ID' with your actual file ID from Google Drive
|
36 |
+
file_id = '1WJ33nys02XpPDsMO5uIZFiLqTuAT_iuV'
|
37 |
+
destination = 'ema_ckpt_cond.pt'
|
38 |
+
download_file_from_google_drive(file_id, destination)
|
39 |
+
|
40 |
+
# Preprocessing
|
41 |
+
from modules import PaletteModelV2
|
42 |
+
from diffusion import Diffusion_cond
|
43 |
+
|
44 |
+
device = 'cuda'
|
45 |
+
|
46 |
+
model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, true_img_size=64).to(device)
|
47 |
+
ckpt = torch.load(destination, map_location=device)
|
48 |
+
model.load_state_dict(ckpt)
|
49 |
+
|
50 |
+
diffusion = Diffusion_cond(noise_steps=1000, img_size=256, device=device)
|
51 |
+
model.eval()
|
52 |
+
|
53 |
+
transform_hmi = transforms.Compose([
|
54 |
+
transforms.ToTensor(),
|
55 |
+
transforms.Resize((256, 256)),
|
56 |
+
transforms.RandomVerticalFlip(p=1.0),
|
57 |
+
transforms.Normalize(mean=(0.5,), std=(0.5,))
|
58 |
+
])
|
59 |
+
|
60 |
+
def generate_image(seed_image):
|
61 |
+
seed_image_tensor = transform_hmi(Image.open(seed_image)).reshape(1, 1, 256, 256).to(device)
|
62 |
+
generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
|
63 |
+
generated_image_pil = transforms.ToPILImage()(generated_image.squeeze().cpu())
|
64 |
+
return generated_image_pil
|
65 |
+
|
66 |
+
# Create Gradio interface
|
67 |
+
iface = gr.Interface(
|
68 |
+
fn=generate_image,
|
69 |
+
inputs="file",
|
70 |
+
outputs="image",
|
71 |
+
title="Magnetogram-to-Magnetogram: Generative Forecasting of Solar Evolution",
|
72 |
+
description="Upload a LoS magnetogram and predict how it is going to be in 24 hours."
|
73 |
+
)
|
74 |
+
|
75 |
+
iface.launch()
|
diffusion.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Tue Apr 25 14:45:59 2023
|
4 |
+
|
5 |
+
@author: pio-r
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
import torch.nn as nn
|
11 |
+
import logging
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")
|
15 |
+
|
16 |
+
|
17 |
+
class Diffusion_cond:
|
18 |
+
def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, img_channel=1, device="cuda"):
|
19 |
+
self.noise_steps = noise_steps # timestesps
|
20 |
+
self.beta_start = beta_start
|
21 |
+
self.beta_end = beta_end
|
22 |
+
self.img_channel = img_channel
|
23 |
+
self.img_size = img_size
|
24 |
+
self.device = device
|
25 |
+
|
26 |
+
self.beta = self.prepare_noise_schedule().to(device)
|
27 |
+
self.alpha = 1. - self.beta
|
28 |
+
self.alphas_prev = torch.cat([torch.tensor([1.0]).to(device), self.alpha[:-1]], dim=0)
|
29 |
+
self.alpha_hat = torch.cumprod(self.alpha, dim=0)
|
30 |
+
self.alphas_cumprod_prev = torch.cat([torch.tensor([1.0]).to(device), self.alpha_hat[:-1]], dim=0)
|
31 |
+
# self.alphas_cumprod_prev = torch.from_numpy(np.append(1, self.alpha_hat[:-1].cpu().numpy())).to(device)
|
32 |
+
def prepare_noise_schedule(self):
|
33 |
+
return torch.linspace(self.beta_start, self.beta_end, self.noise_steps) # linear variance schedule as proposed by Ho et al 2020
|
34 |
+
|
35 |
+
def noise_images(self, x, t):
|
36 |
+
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
|
37 |
+
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
|
38 |
+
Ɛ = torch.randn_like(x)
|
39 |
+
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ # equation in the paper from Ho et al that describes the noise processs
|
40 |
+
|
41 |
+
def sample_timesteps(self, n):
|
42 |
+
return torch.randint(low=1, high=self.noise_steps, size=(n,))
|
43 |
+
|
44 |
+
def sample(self, model, n, y, labels, cfg_scale=3, eta=1, sampling_mode='ddpm'):
|
45 |
+
logging.info(f"Sampling {n} new images....")
|
46 |
+
model.eval() # evaluation mode
|
47 |
+
with torch.no_grad(): # algorithm 2 from DDPM
|
48 |
+
x = torch.randn((n, self.img_channel, self.img_size, self.img_size)).to(self.device)
|
49 |
+
for i in tqdm(reversed(range(1, self.noise_steps)), position=0): # reverse loop from T to 1
|
50 |
+
t = (torch.ones(n) * i).long().to(self.device) # create timesteps tensor of length n
|
51 |
+
predicted_noise = model(x, y, labels, t)
|
52 |
+
if cfg_scale > 0:
|
53 |
+
uncond_predicted_noise = model(x, y, None, t)
|
54 |
+
predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
|
55 |
+
|
56 |
+
|
57 |
+
alpha = self.alpha[t][:, None, None, None]
|
58 |
+
alpha_hat = self.alpha_hat[t][:, None, None, None] # this is noise, created in one
|
59 |
+
alpha_prev = self.alphas_cumprod_prev[t][:, None, None, None]
|
60 |
+
beta = self.beta[t][:, None, None, None]
|
61 |
+
# SAMPLING adjusted from Stable diffusion
|
62 |
+
sigma = (
|
63 |
+
eta
|
64 |
+
* torch.sqrt((1 - alpha_prev) / (1 - alpha_hat)
|
65 |
+
* (1 - alpha_hat / alpha_prev))
|
66 |
+
)
|
67 |
+
if i > 1:
|
68 |
+
noise = torch.randn_like(x)
|
69 |
+
else:
|
70 |
+
noise = torch.zeros_like(x)
|
71 |
+
# pred_x0 = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise)
|
72 |
+
pred_x0 = (x - torch.sqrt(1 - alpha_hat) * predicted_noise) / torch.sqrt(alpha_hat)
|
73 |
+
if sampling_mode == 'ddpm':
|
74 |
+
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
|
75 |
+
elif sampling_mode == 'ddim':
|
76 |
+
noise = torch.randn_like(x)
|
77 |
+
nonzero_mask = (
|
78 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
79 |
+
)
|
80 |
+
x = (
|
81 |
+
torch.sqrt(alpha_prev) * pred_x0 +
|
82 |
+
torch.sqrt(1 - alpha_prev - sigma ** 2) * predicted_noise +
|
83 |
+
nonzero_mask * sigma * noise
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
print('The sampler {} is not implemented'.format(sampling_mode))
|
87 |
+
break
|
88 |
+
model.train() # it goes back to training mode
|
89 |
+
# x = (x.clamp(-1, 1) + 1) / 2 # to be in [-1, 1], the plus 1 and the division by 2 is to bring back values to [0, 1]
|
90 |
+
# x = (x * 255).type(torch.uint8) # to bring in valid pixel range
|
91 |
+
return x
|
92 |
+
|
93 |
+
mse = nn.MSELoss()
|
94 |
+
|
95 |
+
def psnr(input: torch.Tensor, target: torch.Tensor, max_val: float) -> torch.Tensor:
|
96 |
+
r"""Create a function that calculates the PSNR between 2 images.
|
97 |
+
|
98 |
+
PSNR is Peek Signal to Noise Ratio, which is similar to mean squared error.
|
99 |
+
Given an m x n image, the PSNR is:
|
100 |
+
|
101 |
+
.. math::
|
102 |
+
|
103 |
+
\text{PSNR} = 10 \log_{10} \bigg(\frac{\text{MAX}_I^2}{MSE(I,T)}\bigg)
|
104 |
+
|
105 |
+
where
|
106 |
+
|
107 |
+
.. math::
|
108 |
+
|
109 |
+
\text{MSE}(I,T) = \frac{1}{mn}\sum_{i=0}^{m-1}\sum_{j=0}^{n-1} [I(i,j) - T(i,j)]^2
|
110 |
+
|
111 |
+
and :math:`\text{MAX}_I` is the maximum possible input value
|
112 |
+
(e.g for floating point images :math:`\text{MAX}_I=1`).
|
113 |
+
|
114 |
+
Args:
|
115 |
+
input: the input image with arbitrary shape :math:`(*)`.
|
116 |
+
labels: the labels image with arbitrary shape :math:`(*)`.
|
117 |
+
max_val: The maximum value in the input tensor.
|
118 |
+
|
119 |
+
Return:
|
120 |
+
the computed loss as a scalar.
|
121 |
+
|
122 |
+
Examples:
|
123 |
+
>>> ones = torch.ones(1)
|
124 |
+
>>> psnr(ones, 1.2 * ones, 2.) # 10 * log(4/((1.2-1)**2)) / log(10)
|
125 |
+
tensor(20.0000)
|
126 |
+
|
127 |
+
Reference:
|
128 |
+
https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio#Definition
|
129 |
+
"""
|
130 |
+
if not isinstance(input, torch.Tensor):
|
131 |
+
raise TypeError(f"Expected torch.Tensor but got {type(target)}.")
|
132 |
+
|
133 |
+
if not isinstance(target, torch.Tensor):
|
134 |
+
raise TypeError(f"Expected torch.Tensor but got {type(input)}.")
|
135 |
+
|
136 |
+
if input.shape != target.shape:
|
137 |
+
raise TypeError(f"Expected tensors of equal shapes, but got {input.shape} and {target.shape}")
|
138 |
+
|
139 |
+
return 10.0 * torch.log10(max_val**2 / mse(input, target))
|
modules.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Tue Apr 25 14:28:21 2023
|
4 |
+
|
5 |
+
@author: pio-r
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from torch.utils.checkpoint import checkpoint
|
12 |
+
|
13 |
+
class EMA:
|
14 |
+
def __init__(self, beta):
|
15 |
+
super().__init__()
|
16 |
+
self.beta = beta
|
17 |
+
self.step = 0
|
18 |
+
|
19 |
+
def update_model_average(self, ma_model, current_model):
|
20 |
+
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
21 |
+
old_weight, up_weight = ma_params.data, current_params.data
|
22 |
+
ma_params.data = self.update_average(old_weight, up_weight)
|
23 |
+
|
24 |
+
def update_average(self, old, new):
|
25 |
+
if old is None:
|
26 |
+
return new
|
27 |
+
return old * self.beta + (1 - self.beta) * new
|
28 |
+
|
29 |
+
def step_ema(self, ema_model, model, step_start_ema=2000):
|
30 |
+
if self.step < step_start_ema:
|
31 |
+
self.reset_parameters(ema_model, model)
|
32 |
+
self.step += 1
|
33 |
+
return
|
34 |
+
self.update_model_average(ema_model, model)
|
35 |
+
self.step += 1
|
36 |
+
|
37 |
+
def reset_parameters(self, ema_model, model):
|
38 |
+
ema_model.load_state_dict(model.state_dict())
|
39 |
+
|
40 |
+
class SelfAttention(nn.Module):
|
41 |
+
"""
|
42 |
+
Pre Layer norm -> multi-headed tension -> skip connections -> pass it to
|
43 |
+
the feed forward layer (layer-norm -> 2 multiheadattention)
|
44 |
+
"""
|
45 |
+
def __init__(self, channels, size):
|
46 |
+
super(SelfAttention, self).__init__()
|
47 |
+
self.channels = channels
|
48 |
+
self.size = size
|
49 |
+
self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
|
50 |
+
self.ln = nn.LayerNorm([channels])
|
51 |
+
self.ff_self = nn.Sequential(
|
52 |
+
nn.LayerNorm([channels]),
|
53 |
+
nn.Linear(channels, channels),
|
54 |
+
nn.GELU(),
|
55 |
+
nn.Linear(channels, channels),
|
56 |
+
)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
|
60 |
+
x_ln = self.ln(x)
|
61 |
+
attention_value, _ = self.mha(x_ln, x_ln, x_ln)
|
62 |
+
attention_value = attention_value + x
|
63 |
+
attention_value = self.ff_self(attention_value) + attention_value
|
64 |
+
return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
|
65 |
+
|
66 |
+
|
67 |
+
class DoubleConv(nn.Module):
|
68 |
+
"""
|
69 |
+
Normal convolution block, with 2d convolution -> Group Norm -> GeLU -> convolution -> Group Norm
|
70 |
+
Possibility to add residual connection providing residual=True
|
71 |
+
"""
|
72 |
+
def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
|
73 |
+
super().__init__()
|
74 |
+
self.residual = residual
|
75 |
+
if not mid_channels:
|
76 |
+
mid_channels = out_channels
|
77 |
+
self.double_conv = nn.Sequential(
|
78 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
79 |
+
nn.GroupNorm(1, mid_channels),
|
80 |
+
nn.GELU(),
|
81 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
82 |
+
nn.GroupNorm(1, out_channels),
|
83 |
+
)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
if self.residual:
|
87 |
+
return F.gelu(x + self.double_conv(x))
|
88 |
+
else:
|
89 |
+
return self.double_conv(x)
|
90 |
+
|
91 |
+
|
92 |
+
class Down(nn.Module):
|
93 |
+
"""
|
94 |
+
maxpool reduce size by half -> 2*DoubleConv -> Embedding layer
|
95 |
+
|
96 |
+
"""
|
97 |
+
def __init__(self, in_channels, out_channels, emb_dim=256):
|
98 |
+
super().__init__()
|
99 |
+
self.maxpool_conv = nn.Sequential(
|
100 |
+
nn.MaxPool2d(2),
|
101 |
+
DoubleConv(in_channels, in_channels, residual=True),
|
102 |
+
DoubleConv(in_channels, out_channels),
|
103 |
+
)
|
104 |
+
|
105 |
+
self.emb_layer = nn.Sequential(
|
106 |
+
nn.SiLU(),
|
107 |
+
nn.Linear( # linear projection to bring the time embedding to the proper dimension
|
108 |
+
emb_dim,
|
109 |
+
out_channels
|
110 |
+
),
|
111 |
+
)
|
112 |
+
|
113 |
+
def forward(self, x, t):
|
114 |
+
x = self.maxpool_conv(x)
|
115 |
+
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) # projection
|
116 |
+
return x + emb
|
117 |
+
|
118 |
+
|
119 |
+
class Up(nn.Module):
|
120 |
+
"""
|
121 |
+
We take the skip connection which comes from the encoder
|
122 |
+
"""
|
123 |
+
def __init__(self, in_channels, out_channels, emb_dim=256):
|
124 |
+
super().__init__()
|
125 |
+
|
126 |
+
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
127 |
+
self.conv = nn.Sequential(
|
128 |
+
DoubleConv(in_channels, in_channels, residual=True),
|
129 |
+
DoubleConv(in_channels, out_channels, in_channels // 2),
|
130 |
+
)
|
131 |
+
|
132 |
+
self.emb_layer = nn.Sequential(
|
133 |
+
nn.SiLU(),
|
134 |
+
nn.Linear(
|
135 |
+
emb_dim,
|
136 |
+
out_channels
|
137 |
+
),
|
138 |
+
)
|
139 |
+
|
140 |
+
def forward(self, x, skip_x, t):
|
141 |
+
x = self.up(x)
|
142 |
+
x = torch.cat([skip_x, x], dim=1)
|
143 |
+
x = self.conv(x)
|
144 |
+
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
|
145 |
+
return x + emb
|
146 |
+
|
147 |
+
class PaletteModelV2(nn.Module):
|
148 |
+
def __init__(self, c_in=1, c_out=1, image_size=64, time_dim=256, device='cuda', latent=False, true_img_size=64, num_classes=None):
|
149 |
+
super(PaletteModelV2, self).__init__()
|
150 |
+
|
151 |
+
# Encoder
|
152 |
+
self.true_img_size = true_img_size
|
153 |
+
self.image_size = image_size
|
154 |
+
self.time_dim = time_dim
|
155 |
+
self.device = device
|
156 |
+
self.inc = DoubleConv(c_in, self.image_size) # Wrap-up for 2 Conv Layers
|
157 |
+
self.down1 = Down(self.image_size, self.image_size*2) # input and output channels
|
158 |
+
# self.sa1 = SelfAttention(self.image_size*2,int( self.true_img_size/2)) # 1st is channel dim, 2nd current image resolution
|
159 |
+
self.down2 = Down(self.image_size*2, self.image_size*4)
|
160 |
+
# self.sa2 = SelfAttention(self.image_size*4, int(self.true_img_size/4))
|
161 |
+
self.down3 = Down(self.image_size*4, self.image_size*4)
|
162 |
+
# self.sa3 = SelfAttention(self.image_size*4, int(self.true_img_size/8))
|
163 |
+
|
164 |
+
# Bootleneck
|
165 |
+
self.bot1 = DoubleConv(self.image_size*4, self.image_size*8)
|
166 |
+
self.bot2 = DoubleConv(self.image_size*8, self.image_size*8)
|
167 |
+
self.bot3 = DoubleConv(self.image_size*8, self.image_size*4)
|
168 |
+
|
169 |
+
# Decoder: reverse of encoder
|
170 |
+
self.up1 = Up(self.image_size*8, self.image_size*2)
|
171 |
+
# self.sa4 = SelfAttention(self.image_size*2, int(self.true_img_size/4))
|
172 |
+
self.up2 = Up(self.image_size*4, self.image_size)
|
173 |
+
# self.sa5 = SelfAttention(self.image_size, int(self.true_img_size/2))
|
174 |
+
self.up3 = Up(self.image_size*2, self.image_size)
|
175 |
+
# self.sa6 = SelfAttention(self.image_size, self.true_img_size)
|
176 |
+
self.outc = nn.Conv2d(self.image_size, c_out, kernel_size=1) # projecting back to the output channel dimensions
|
177 |
+
|
178 |
+
if num_classes is not None:
|
179 |
+
self.label_emb = nn.Embedding(num_classes, time_dim)
|
180 |
+
|
181 |
+
if latent == True:
|
182 |
+
self.latent = nn.Sequential(
|
183 |
+
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
|
184 |
+
nn.LeakyReLU(0.2),
|
185 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
186 |
+
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
|
187 |
+
nn.LeakyReLU(0.2),
|
188 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
189 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
190 |
+
nn.LeakyReLU(0.2),
|
191 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
192 |
+
nn.Flatten(),
|
193 |
+
nn.Linear(64 * 8 * 8, 256)).to(device)
|
194 |
+
|
195 |
+
def pos_encoding(self, t, channels):
|
196 |
+
"""
|
197 |
+
Input noised images and the timesteps. The timesteps will only be
|
198 |
+
a tensor with the integer timesteps values in it
|
199 |
+
"""
|
200 |
+
inv_freq = 1.0 / (
|
201 |
+
10000
|
202 |
+
** (torch.arange(0, channels, 2, device=self.device).float() / channels)
|
203 |
+
)
|
204 |
+
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
|
205 |
+
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
|
206 |
+
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
|
207 |
+
return pos_enc
|
208 |
+
|
209 |
+
def forward(self, x, y, lab, t):
|
210 |
+
# Pass the source image through the encoder network
|
211 |
+
t = t.unsqueeze(-1).type(torch.float)
|
212 |
+
t = self.pos_encoding(t, self.time_dim) # Encoding timesteps is HERE, we provide the dimension we want to encode
|
213 |
+
|
214 |
+
|
215 |
+
if lab is not None:
|
216 |
+
t += self.label_emb(lab)
|
217 |
+
|
218 |
+
# t += self.latent(y)
|
219 |
+
|
220 |
+
# Concatenate the source image and reference image
|
221 |
+
x = torch.cat([x, y], dim=1)
|
222 |
+
|
223 |
+
x1 = self.inc(x)
|
224 |
+
x2 = self.down1(x1, t)
|
225 |
+
# x2 = self.sa1(x2)
|
226 |
+
x3 = self.down2(x2, t)
|
227 |
+
# x3 = self.sa2(x3)
|
228 |
+
x4 = self.down3(x3, t)
|
229 |
+
# x4 = self.sa3(x4)
|
230 |
+
|
231 |
+
x4 = self.bot1(x4)
|
232 |
+
x4 = self.bot2(x4)
|
233 |
+
x4 = self.bot3(x4)
|
234 |
+
|
235 |
+
x = self.up1(x4, x3, t) # We note that upsampling box that in the skip connections from encoder
|
236 |
+
# x = self.sa4(x)
|
237 |
+
x = self.up2(x, x2, t)
|
238 |
+
# x = self.sa5(x)
|
239 |
+
x = self.up3(x, x1, t)
|
240 |
+
# x = self.sa6(x)
|
241 |
+
output = self.outc(x)
|
242 |
+
|
243 |
+
return output
|