ohayonguy commited on
Commit
1b8b226
1 Parent(s): 2ef4159

first commit

Browse files
Files changed (3) hide show
  1. app.py +170 -0
  2. arch/__init__.py +2 -0
  3. lightning_models/mmse_rectified_flow.py +317 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import gradio as gr
5
+ import torch
6
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
7
+ from basicsr.utils import img2tensor, tensor2img
8
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
9
+ from realesrgan.utils import RealESRGANer
10
+ import spaces
11
+
12
+ from lightning_models.mmse_rectified_flow import MMSERectifiedFlow
13
+
14
+ torch.set_grad_enabled(False)
15
+
16
+ if os.getenv('SPACES_ZERO_GPU') == "true":
17
+ os.environ['SPACES_ZERO_GPU'] = "1"
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ if not os.path.exists('pretrained_models'):
21
+ os.makedirs('pretrained_models')
22
+ realesr_model_path = 'pretrained_models/RealESRGAN_x4plus.pth'
23
+ if not os.path.exists(realesr_model_path):
24
+ os.system(
25
+ "wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -O experiments/pretrained_models/RealESRGAN_x4plus.pth")
26
+
27
+ pmrf_model_path = 'blind_face_restoration_pmrf.ckpt'
28
+
29
+ # background enhancer with RealESRGAN
30
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
31
+ half = True if torch.cuda.is_available() else False
32
+ upsampler = RealESRGANer(scale=4, model_path=realesr_model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
33
+
34
+ pmrf = MMSERectifiedFlow.load_from_checkpoint('./blind_face_restoration_pmrf.ckpt',
35
+ mmse_model_arch='swinir_L',
36
+ mmse_model_ckpt_path=None,
37
+ map_location='cpu').to(device)
38
+
39
+ os.makedirs('output', exist_ok=True)
40
+
41
+
42
+ @torch.inference_mode()
43
+ @spaces.GPU()
44
+ def enhance_face(img, face_helper, has_aligned, only_center_face=False, paste_back=True, scale=2):
45
+ face_helper.clean_all()
46
+
47
+ if has_aligned: # the inputs are already aligned
48
+ img = cv2.resize(img, (512, 512))
49
+ face_helper.cropped_faces = [img]
50
+ else:
51
+ face_helper.read_image(img)
52
+ face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
53
+ # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
54
+ # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
55
+ # align and warp each face
56
+ face_helper.align_warp_face()
57
+
58
+ # face restoration
59
+ for cropped_face in face_helper.cropped_faces:
60
+ # prepare data
61
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
62
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
63
+
64
+ try:
65
+ dummy_x = torch.zeros_like(cropped_face_t)
66
+ output = pmrf.generate_reconstructions(dummy_x, cropped_face_t, None, 25, device)
67
+ restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(0, 1))
68
+ except RuntimeError as error:
69
+ print(f'\tFailed inference for RestoreFormer: {error}.')
70
+ restored_face = cropped_face
71
+
72
+ restored_face = restored_face.astype('uint8')
73
+ face_helper.add_restored_face(restored_face)
74
+
75
+ if not has_aligned and paste_back:
76
+ # upsample the background
77
+ if upsampler is not None:
78
+ # Now only support RealESRGAN for upsampling background
79
+ bg_img = upsampler.enhance(img, outscale=scale)[0]
80
+ else:
81
+ bg_img = None
82
+
83
+ face_helper.get_inverse_affine(None)
84
+ # paste each restored face to the input image
85
+ restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img)
86
+ return face_helper.cropped_faces, face_helper.restored_faces, restored_img
87
+ else:
88
+ return face_helper.cropped_faces, face_helper.restored_faces, None
89
+
90
+
91
+ @torch.inference_mode()
92
+ @spaces.GPU()
93
+ def inference(img, aligned, scale):
94
+ if scale > 4:
95
+ scale = 4 # avoid too large scale value
96
+ try:
97
+
98
+ extension = os.path.splitext(os.path.basename(str(img)))[1]
99
+ img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
100
+ if len(img.shape) == 3 and img.shape[2] == 4:
101
+ img_mode = 'RGBA'
102
+ elif len(img.shape) == 2: # for gray inputs
103
+ img_mode = None
104
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
105
+ else:
106
+ img_mode = None
107
+
108
+ h, w = img.shape[0:2]
109
+ if h > 3500 or w > 3500:
110
+ print('Image size too large.')
111
+ return None, None
112
+
113
+ if h < 300:
114
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
115
+
116
+ face_helper = FaceRestoreHelper(
117
+ scale,
118
+ face_size=512,
119
+ crop_ratio=(1, 1),
120
+ det_model='retinaface_resnet50',
121
+ save_ext='png',
122
+ use_parse=True,
123
+ device=device,
124
+ model_rootpath=None)
125
+
126
+ try:
127
+ has_aligned = True if aligned == 'aligned' else False
128
+ _, restored_aligned, restored_img = enhance_face(img, face_helper, has_aligned, only_center_face=False,
129
+ paste_back=True)
130
+ if has_aligned:
131
+ output = restored_aligned[0]
132
+ else:
133
+ output = restored_img
134
+ except RuntimeError as error:
135
+ print('Error', error)
136
+
137
+ try:
138
+ if scale != 2:
139
+ interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
140
+ h, w = img.shape[0:2]
141
+ output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
142
+ except Exception as error:
143
+ print('wrong scale input.', error)
144
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
145
+ extension = 'png'
146
+ else:
147
+ extension = 'jpg'
148
+ save_path = f'output/out.{extension}'
149
+ cv2.imwrite(save_path, output)
150
+
151
+ output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
152
+ return output, save_path
153
+ except Exception as error:
154
+ print('global exception', error)
155
+ return None, None
156
+
157
+
158
+ css = r"""
159
+ """
160
+
161
+ demo = gr.Interface(
162
+ inference, [
163
+ gr.Image(type="filepath", label="Input"),
164
+ gr.Radio(['aligned', 'unaligned'], type="value", value='unaligned', label='Image Alignment'),
165
+ gr.Number(label="Rescaling factor", value=2),
166
+ ], [
167
+ gr.Image(type="numpy", label="Output (The whole image)"),
168
+ gr.File(label="Download the output image")
169
+ ],
170
+ )
arch/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from arch.hourglass.image_transformer_v2 import ImageTransformerDenoiserModelV2
2
+ from arch.swinir.swinir import SwinIR
lightning_models/mmse_rectified_flow.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from contextlib import contextmanager, nullcontext
3
+
4
+ import torch
5
+ import wandb
6
+ from pytorch_lightning import LightningModule
7
+ from torch.nn.functional import mse_loss
8
+ from torch.nn.functional import sigmoid
9
+ from torch.optim import AdamW
10
+ from torch_ema import ExponentialMovingAverage as EMA
11
+ from torchmetrics.image import FrechetInceptionDistance, InceptionScore
12
+ from torchvision.transforms.functional import to_pil_image
13
+ from torchvision.utils import save_image
14
+
15
+ from utils.create_arch import create_arch
16
+ from utils.img_utils import create_grid
17
+ from huggingface_hub import PyTorchModelHubMixin
18
+
19
+
20
+
21
+ class MMSERectifiedFlow(LightningModule,
22
+ PyTorchModelHubMixin,
23
+ pipeline_tag="image-to-image",
24
+ license="mit",
25
+ ):
26
+ def __init__(self,
27
+ stage,
28
+ arch,
29
+ conditional=False,
30
+ mmse_model_ckpt_path=None,
31
+ mmse_model_arch=None,
32
+ lr=5e-4,
33
+ weight_decay=1e-3,
34
+ betas=(0.9, 0.95),
35
+ mmse_noise_std=0.1,
36
+ num_flow_steps=50,
37
+ ema_decay=0.9999,
38
+ eps=0.0,
39
+ t_schedule='stratified_uniform',
40
+ *args,
41
+ **kwargs
42
+ ):
43
+ super().__init__()
44
+ self.save_hyperparameters(logger=False)
45
+
46
+ if stage == 'flow':
47
+ if conditional:
48
+ condition_channels = 3
49
+ else:
50
+ condition_channels = 0
51
+ if mmse_model_arch is None and 'colorization' in kwargs and kwargs['colorization']:
52
+ condition_channels //= 3
53
+ self.model = create_arch(arch, condition_channels)
54
+ self.mmse_model = create_arch(mmse_model_arch, 0) if mmse_model_arch is not None else None
55
+ if mmse_model_ckpt_path is not None:
56
+ ckpt = torch.load(mmse_model_ckpt_path, map_location="cpu")
57
+ if mmse_model_arch is None:
58
+ mmse_model_arch = ckpt['hyper_parameters']['arch']
59
+ self.mmse_model = create_arch(mmse_model_arch, 0)
60
+ if 'ema' in ckpt:
61
+ # ema_decay doesn't affect anything here, because we are doing load_state_dict
62
+ mmse_ema = EMA(self.mmse_model.parameters(), decay=ema_decay)
63
+ mmse_ema.load_state_dict(ckpt['ema'])
64
+ mmse_ema.copy_to()
65
+ elif 'params_ema' in ckpt:
66
+ self.mmse_model.load_state_dict(ckpt['params_ema'])
67
+ else:
68
+ state_dict = ckpt['state_dict']
69
+ state_dict = {layer_name.replace('model.', ''): weights for layer_name, weights in
70
+ state_dict.items()}
71
+ state_dict = {layer_name.replace('module.', ''): weights for layer_name, weights in
72
+ state_dict.items()}
73
+ self.mmse_model.load_state_dict(state_dict)
74
+ for param in self.mmse_model.parameters():
75
+ param.requires_grad = False
76
+ self.mmse_model.eval()
77
+ else:
78
+ assert stage == 'mmse' or stage == 'naive_flow'
79
+ assert not conditional
80
+ self.model = create_arch(arch, 0)
81
+ self.mmse_model = None
82
+ if 'flow' in stage:
83
+ self.fid = FrechetInceptionDistance(reset_real_features=True, normalize=True)
84
+ self.inception_score = InceptionScore(normalize=True)
85
+
86
+ self.ema = EMA(self.model.parameters(), decay=ema_decay) if self.ema_wanted else None
87
+ self.test_results_path = None
88
+
89
+ @property
90
+ def ema_wanted(self):
91
+ return self.hparams.ema_decay != -1
92
+
93
+ def on_save_checkpoint(self, checkpoint: dict) -> None:
94
+ if self.ema_wanted:
95
+ checkpoint['ema'] = self.ema.state_dict()
96
+ return super().on_save_checkpoint(checkpoint)
97
+
98
+ def on_load_checkpoint(self, checkpoint: dict) -> None:
99
+ if self.ema_wanted:
100
+ self.ema.load_state_dict(checkpoint['ema'])
101
+ return super().on_load_checkpoint(checkpoint)
102
+
103
+ def on_before_zero_grad(self, optimizer) -> None:
104
+ if self.ema_wanted:
105
+ self.ema.update(self.model.parameters())
106
+ return super().on_before_zero_grad(optimizer)
107
+
108
+ def to(self, *args, **kwargs):
109
+ if self.ema_wanted:
110
+ self.ema.to(*args, **kwargs)
111
+ return super().to(*args, **kwargs)
112
+
113
+ # This will use the contextmanager of ema, to copy the EMA weights to the flow model during validation, and then restore them for training.
114
+ @contextmanager
115
+ def maybe_ema(self):
116
+ ema = self.ema
117
+ ctx = nullcontext if ema is None else ema.average_parameters
118
+ yield ctx
119
+
120
+ def forward_mmse(self, y):
121
+ return self.model(y).clip(0, 1)
122
+
123
+ def forward_flow(self, x_t, t, y=None):
124
+ if self.hparams.conditional:
125
+ if self.mmse_model is not None:
126
+ with torch.no_grad():
127
+ self.mmse_model.eval()
128
+ condition = self.mmse_model(y).clip(0, 1)
129
+ else:
130
+ condition = y
131
+ x_t = torch.cat((x_t, condition), dim=1)
132
+ return self.model(x_t, t)
133
+
134
+ def forward(self, x_t, t, y):
135
+ if 'flow' in self.hparams.stage:
136
+ return self.forward_flow(x_t, t, y)
137
+ else:
138
+ return self.forward_mmse(y)
139
+
140
+ @torch.no_grad()
141
+ def create_source_distribution_samples(self, x, y, non_noisy_z0):
142
+ with torch.no_grad():
143
+ if self.hparams.conditional:
144
+ source_dist_samples = torch.randn_like(x)
145
+ else:
146
+ if self.hparams.stage == 'flow':
147
+ if non_noisy_z0 is None:
148
+ self.mmse_model.eval()
149
+ non_noisy_z0 = self.mmse_model(y).clip(0, 1)
150
+ source_dist_samples = non_noisy_z0 + torch.randn_like(non_noisy_z0) * self.hparams.mmse_noise_std
151
+ else:
152
+ assert self.hparams.stage == 'naive_flow'
153
+ if non_noisy_z0 is not None:
154
+ source_dist_samples = non_noisy_z0
155
+ else:
156
+ source_dist_samples = y
157
+ if source_dist_samples.shape[1] != x.shape[1]:
158
+ assert source_dist_samples.shape[1] == 1 # Colorization
159
+ source_dist_samples = source_dist_samples.expand(-1, x.shape[1], -1, -1)
160
+ if self.hparams.mmse_noise_std is not None:
161
+ source_dist_samples = source_dist_samples + torch.randn_like(source_dist_samples) * self.hparams.mmse_noise_std
162
+ return source_dist_samples
163
+
164
+ @staticmethod
165
+ def stratified_uniform(bs, group=0, groups=1, dtype=None, device=None):
166
+ if groups <= 0:
167
+ raise ValueError(f"groups must be positive, got {groups}")
168
+ if group < 0 or group >= groups:
169
+ raise ValueError(f"group must be in [0, {groups})")
170
+ n = bs * groups
171
+ offsets = torch.arange(group, n, groups, dtype=dtype, device=device)
172
+ u = torch.rand(bs, dtype=dtype, device=device)
173
+ return ((offsets + u) / n).view(bs, 1, 1, 1)
174
+
175
+ def generate_random_t(self, bs, dtype=None):
176
+ if self.hparams.t_schedule == 'logit-normal':
177
+ return sigmoid(torch.randn(bs, 1, 1, 1, device=self.device)) * (1.0 - self.hparams.eps) + self.hparams.eps
178
+ elif self.hparams.t_schedule == 'uniform':
179
+ return torch.rand(bs, 1, 1, 1, device=self.device) * (1.0 - self.hparams.eps) + self.hparams.eps
180
+ elif self.hparams.t_schedule == 'stratified_uniform':
181
+ return self.stratified_uniform(bs, self.trainer.global_rank, self.trainer.world_size, dtype=dtype,
182
+ device=self.device) * (1.0 - self.hparams.eps) + self.hparams.eps
183
+ else:
184
+ raise NotImplementedError()
185
+
186
+ def training_step(self, batch, batch_idx):
187
+ x = batch['x']
188
+ y = batch['y']
189
+ non_noisy_z0 = batch['non_noisy_z0'] if 'non_noisy_z0' in batch else None
190
+ if 'flow' in self.hparams.stage:
191
+ with torch.no_grad():
192
+ t = self.generate_random_t(x.shape[0], dtype=x.dtype)
193
+ source_dist_samples = self.create_source_distribution_samples(x, y, non_noisy_z0)
194
+ x_t = t * x + (1.0 - t) * source_dist_samples
195
+ v_t = self(x_t, t.squeeze(), y)
196
+ loss = mse_loss(v_t, x - source_dist_samples)
197
+ else:
198
+ xhat = self(x_t=None, t=None, y=y)
199
+ loss = mse_loss(xhat, x)
200
+ self.log("train/loss", loss)
201
+ return loss
202
+
203
+ @torch.no_grad()
204
+ def generate_reconstructions(self, x, y, non_noisy_z0, num_flow_steps, result_device):
205
+ with self.maybe_ema():
206
+ if 'flow' in self.hparams.stage:
207
+ source_dist_samples = self.create_source_distribution_samples(x, y, non_noisy_z0)
208
+
209
+ dt = (1.0 / num_flow_steps) * (1.0 - self.hparams.eps)
210
+ x_t_next = source_dist_samples.clone()
211
+ x_t_seq = [x_t_next]
212
+ t_one = torch.ones(x.shape[0], device=self.device)
213
+ for i in range(num_flow_steps):
214
+ num_t = (i / num_flow_steps) * (1.0 - self.hparams.eps) + self.hparams.eps
215
+ v_t_next = self(x_t=x_t_next, t=t_one * num_t, y=y).to(x_t_next.dtype)
216
+ x_t_next = x_t_next.clone() + v_t_next * dt
217
+ x_t_seq.append(x_t_next.to(result_device))
218
+
219
+ xhat = x_t_seq[-1].clip(0, 1).to(torch.float32)
220
+ source_dist_samples = source_dist_samples.to(result_device)
221
+ else:
222
+ xhat = self(x_t=None, t=None, y=y).to(torch.float32)
223
+ x_t_seq = None
224
+ source_dist_samples = None
225
+ return xhat.to(result_device), x_t_seq, source_dist_samples
226
+
227
+ def validation_step(self, batch, batch_idx):
228
+ x = batch['x']
229
+ y = batch['y']
230
+ non_noisy_z0 = batch['non_noisy_z0'] if 'non_noisy_z0' in batch else None
231
+ xhat, x_t_seq, source_dist_samples = self.generate_reconstructions(x, y, non_noisy_z0, self.hparams.num_flow_steps,
232
+ self.device)
233
+ x = x.to(torch.float32)
234
+ y = y.to(torch.float32)
235
+ self.log_dict({"val_metrics/mse": ((x - xhat) ** 2).mean()}, on_step=False, on_epoch=True, sync_dist=True,
236
+ batch_size=x.shape[0])
237
+
238
+ if 'flow' in self.hparams.stage:
239
+ self.fid.update(x, real=True)
240
+ self.fid.update(xhat, real=False)
241
+ self.inception_score.update(xhat)
242
+
243
+ if batch_idx == 0:
244
+ wandb_logger = self.logger.experiment
245
+ wandb_logger.log({'val_images/x': [wandb.Image(to_pil_image(create_grid(x)))],
246
+ 'val_images/y': [wandb.Image(to_pil_image(create_grid(y.clip(0, 1))))],
247
+ 'val_images/xhat': [wandb.Image(to_pil_image(create_grid(xhat)))], })
248
+ if 'flow' in self.hparams.stage:
249
+ wandb_logger.log({'val_images/x_t_seq': [wandb.Image(to_pil_image(create_grid(
250
+ torch.cat([elem[0].unsqueeze(0).to(torch.float32) for elem in x_t_seq], dim=0).clip(0, 1),
251
+ num_images=len(x_t_seq))))], 'val_images/source_distribution_samples': [
252
+ wandb.Image(to_pil_image(create_grid(source_dist_samples.clip(0, 1).to(torch.float32))))]})
253
+ if self.mmse_model is not None:
254
+ xhat_mmse = self.mmse_model(y).clip(0, 1)
255
+ wandb_logger.log({'val_images/xhat_mmse': [
256
+ wandb.Image(to_pil_image(create_grid(xhat_mmse.to(torch.float32))))]})
257
+
258
+ def on_validation_epoch_end(self):
259
+ if 'flow' in self.hparams.stage:
260
+ inception_score_mean, inception_score_std = self.inception_score.compute()
261
+ self.log_dict(
262
+ {'val_metrics/fid': self.fid.compute(),
263
+ 'val_metrics/inception_score_mean': inception_score_mean,
264
+ 'val_metrics/inception_score_std': inception_score_std},
265
+ on_epoch=True, on_step=False, sync_dist=True,
266
+ batch_size=1)
267
+ self.fid.reset()
268
+ self.inception_score.reset()
269
+
270
+ def test_step(self, batch, batch_idx):
271
+ assert self.test_results_path is not None, "Please set test_results_path before testing."
272
+ assert os.path.isdir(self.test_results_path), 'Please make sure the test_result_path dir exists.'
273
+
274
+ def save_image_batch(images, folder, image_file_names):
275
+ os.makedirs(folder, exist_ok=True)
276
+ for i, img in enumerate(images):
277
+ save_image(images[i].clip(0, 1), os.path.join(folder, image_file_names[i]))
278
+
279
+ os.makedirs(self.test_results_path, exist_ok=True)
280
+ x = batch['x']
281
+ y = batch['y']
282
+ non_noisy_z0 = batch['non_noisy_z0'] if 'non_noisy_z0' in batch else None
283
+ y_path = os.path.join(self.test_results_path, 'y')
284
+ save_image_batch(y, y_path, batch['img_file_name'])
285
+
286
+ if 'flow' in self.hparams.stage:
287
+ source_dist_samples_to_save = None
288
+
289
+ for num_flow_steps in self.num_test_flow_steps:
290
+ xhat, x_t_seq, source_dist_samples = self.generate_reconstructions(x, y, non_noisy_z0, num_flow_steps,
291
+ torch.device("cpu"))
292
+ xhat_path = os.path.join(self.test_results_path, f"num_flow_steps={num_flow_steps}", 'xhat')
293
+ save_image_batch(xhat, xhat_path, batch['img_file_name'])
294
+ if source_dist_samples_to_save is None:
295
+ source_dist_samples_to_save = source_dist_samples
296
+
297
+ source_distribution_samples_path = os.path.join(self.test_results_path, 'source_distribution_samples')
298
+ save_image_batch(source_dist_samples_to_save, source_distribution_samples_path, batch['img_file_name'])
299
+ if self.mmse_model is not None:
300
+ mmse_estimates = self.mmse_model(y).clip(0, 1)
301
+ mmse_samples_path = os.path.join(self.test_results_path, 'mmse_samples')
302
+ save_image_batch(mmse_estimates, mmse_samples_path, batch['img_file_name'])
303
+
304
+
305
+ else:
306
+ xhat, _, _ = self.generate_reconstructions(x, y, non_noisy_z0, None, torch.device('cpu'))
307
+ xhat_path = os.path.join(self.test_results_path, 'xhat')
308
+ save_image_batch(xhat, xhat_path, batch['img_file_name'])
309
+
310
+ def configure_optimizers(self):
311
+ # Add here a learning rate scheduler if you wish to do so.
312
+ optimizer = AdamW(self.model.parameters(),
313
+ betas=self.hparams.betas,
314
+ eps=1e-8,
315
+ lr=self.hparams.lr,
316
+ weight_decay=self.hparams.weight_decay)
317
+ return optimizer