Spaces:
Running
on
Zero
Running
on
Zero
ohayonguy
commited on
Commit
•
1b8b226
1
Parent(s):
2ef4159
first commit
Browse files- app.py +170 -0
- arch/__init__.py +2 -0
- lightning_models/mmse_rectified_flow.py +317 -0
app.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import gradio as gr
|
5 |
+
import torch
|
6 |
+
from basicsr.archs.srvgg_arch import SRVGGNetCompact
|
7 |
+
from basicsr.utils import img2tensor, tensor2img
|
8 |
+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
9 |
+
from realesrgan.utils import RealESRGANer
|
10 |
+
import spaces
|
11 |
+
|
12 |
+
from lightning_models.mmse_rectified_flow import MMSERectifiedFlow
|
13 |
+
|
14 |
+
torch.set_grad_enabled(False)
|
15 |
+
|
16 |
+
if os.getenv('SPACES_ZERO_GPU') == "true":
|
17 |
+
os.environ['SPACES_ZERO_GPU'] = "1"
|
18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
|
20 |
+
if not os.path.exists('pretrained_models'):
|
21 |
+
os.makedirs('pretrained_models')
|
22 |
+
realesr_model_path = 'pretrained_models/RealESRGAN_x4plus.pth'
|
23 |
+
if not os.path.exists(realesr_model_path):
|
24 |
+
os.system(
|
25 |
+
"wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -O experiments/pretrained_models/RealESRGAN_x4plus.pth")
|
26 |
+
|
27 |
+
pmrf_model_path = 'blind_face_restoration_pmrf.ckpt'
|
28 |
+
|
29 |
+
# background enhancer with RealESRGAN
|
30 |
+
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
31 |
+
half = True if torch.cuda.is_available() else False
|
32 |
+
upsampler = RealESRGANer(scale=4, model_path=realesr_model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
33 |
+
|
34 |
+
pmrf = MMSERectifiedFlow.load_from_checkpoint('./blind_face_restoration_pmrf.ckpt',
|
35 |
+
mmse_model_arch='swinir_L',
|
36 |
+
mmse_model_ckpt_path=None,
|
37 |
+
map_location='cpu').to(device)
|
38 |
+
|
39 |
+
os.makedirs('output', exist_ok=True)
|
40 |
+
|
41 |
+
|
42 |
+
@torch.inference_mode()
|
43 |
+
@spaces.GPU()
|
44 |
+
def enhance_face(img, face_helper, has_aligned, only_center_face=False, paste_back=True, scale=2):
|
45 |
+
face_helper.clean_all()
|
46 |
+
|
47 |
+
if has_aligned: # the inputs are already aligned
|
48 |
+
img = cv2.resize(img, (512, 512))
|
49 |
+
face_helper.cropped_faces = [img]
|
50 |
+
else:
|
51 |
+
face_helper.read_image(img)
|
52 |
+
face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
|
53 |
+
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
|
54 |
+
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
|
55 |
+
# align and warp each face
|
56 |
+
face_helper.align_warp_face()
|
57 |
+
|
58 |
+
# face restoration
|
59 |
+
for cropped_face in face_helper.cropped_faces:
|
60 |
+
# prepare data
|
61 |
+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
62 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
63 |
+
|
64 |
+
try:
|
65 |
+
dummy_x = torch.zeros_like(cropped_face_t)
|
66 |
+
output = pmrf.generate_reconstructions(dummy_x, cropped_face_t, None, 25, device)
|
67 |
+
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(0, 1))
|
68 |
+
except RuntimeError as error:
|
69 |
+
print(f'\tFailed inference for RestoreFormer: {error}.')
|
70 |
+
restored_face = cropped_face
|
71 |
+
|
72 |
+
restored_face = restored_face.astype('uint8')
|
73 |
+
face_helper.add_restored_face(restored_face)
|
74 |
+
|
75 |
+
if not has_aligned and paste_back:
|
76 |
+
# upsample the background
|
77 |
+
if upsampler is not None:
|
78 |
+
# Now only support RealESRGAN for upsampling background
|
79 |
+
bg_img = upsampler.enhance(img, outscale=scale)[0]
|
80 |
+
else:
|
81 |
+
bg_img = None
|
82 |
+
|
83 |
+
face_helper.get_inverse_affine(None)
|
84 |
+
# paste each restored face to the input image
|
85 |
+
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img)
|
86 |
+
return face_helper.cropped_faces, face_helper.restored_faces, restored_img
|
87 |
+
else:
|
88 |
+
return face_helper.cropped_faces, face_helper.restored_faces, None
|
89 |
+
|
90 |
+
|
91 |
+
@torch.inference_mode()
|
92 |
+
@spaces.GPU()
|
93 |
+
def inference(img, aligned, scale):
|
94 |
+
if scale > 4:
|
95 |
+
scale = 4 # avoid too large scale value
|
96 |
+
try:
|
97 |
+
|
98 |
+
extension = os.path.splitext(os.path.basename(str(img)))[1]
|
99 |
+
img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
|
100 |
+
if len(img.shape) == 3 and img.shape[2] == 4:
|
101 |
+
img_mode = 'RGBA'
|
102 |
+
elif len(img.shape) == 2: # for gray inputs
|
103 |
+
img_mode = None
|
104 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
105 |
+
else:
|
106 |
+
img_mode = None
|
107 |
+
|
108 |
+
h, w = img.shape[0:2]
|
109 |
+
if h > 3500 or w > 3500:
|
110 |
+
print('Image size too large.')
|
111 |
+
return None, None
|
112 |
+
|
113 |
+
if h < 300:
|
114 |
+
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
|
115 |
+
|
116 |
+
face_helper = FaceRestoreHelper(
|
117 |
+
scale,
|
118 |
+
face_size=512,
|
119 |
+
crop_ratio=(1, 1),
|
120 |
+
det_model='retinaface_resnet50',
|
121 |
+
save_ext='png',
|
122 |
+
use_parse=True,
|
123 |
+
device=device,
|
124 |
+
model_rootpath=None)
|
125 |
+
|
126 |
+
try:
|
127 |
+
has_aligned = True if aligned == 'aligned' else False
|
128 |
+
_, restored_aligned, restored_img = enhance_face(img, face_helper, has_aligned, only_center_face=False,
|
129 |
+
paste_back=True)
|
130 |
+
if has_aligned:
|
131 |
+
output = restored_aligned[0]
|
132 |
+
else:
|
133 |
+
output = restored_img
|
134 |
+
except RuntimeError as error:
|
135 |
+
print('Error', error)
|
136 |
+
|
137 |
+
try:
|
138 |
+
if scale != 2:
|
139 |
+
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
|
140 |
+
h, w = img.shape[0:2]
|
141 |
+
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
|
142 |
+
except Exception as error:
|
143 |
+
print('wrong scale input.', error)
|
144 |
+
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
145 |
+
extension = 'png'
|
146 |
+
else:
|
147 |
+
extension = 'jpg'
|
148 |
+
save_path = f'output/out.{extension}'
|
149 |
+
cv2.imwrite(save_path, output)
|
150 |
+
|
151 |
+
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
|
152 |
+
return output, save_path
|
153 |
+
except Exception as error:
|
154 |
+
print('global exception', error)
|
155 |
+
return None, None
|
156 |
+
|
157 |
+
|
158 |
+
css = r"""
|
159 |
+
"""
|
160 |
+
|
161 |
+
demo = gr.Interface(
|
162 |
+
inference, [
|
163 |
+
gr.Image(type="filepath", label="Input"),
|
164 |
+
gr.Radio(['aligned', 'unaligned'], type="value", value='unaligned', label='Image Alignment'),
|
165 |
+
gr.Number(label="Rescaling factor", value=2),
|
166 |
+
], [
|
167 |
+
gr.Image(type="numpy", label="Output (The whole image)"),
|
168 |
+
gr.File(label="Download the output image")
|
169 |
+
],
|
170 |
+
)
|
arch/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from arch.hourglass.image_transformer_v2 import ImageTransformerDenoiserModelV2
|
2 |
+
from arch.swinir.swinir import SwinIR
|
lightning_models/mmse_rectified_flow.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from contextlib import contextmanager, nullcontext
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import wandb
|
6 |
+
from pytorch_lightning import LightningModule
|
7 |
+
from torch.nn.functional import mse_loss
|
8 |
+
from torch.nn.functional import sigmoid
|
9 |
+
from torch.optim import AdamW
|
10 |
+
from torch_ema import ExponentialMovingAverage as EMA
|
11 |
+
from torchmetrics.image import FrechetInceptionDistance, InceptionScore
|
12 |
+
from torchvision.transforms.functional import to_pil_image
|
13 |
+
from torchvision.utils import save_image
|
14 |
+
|
15 |
+
from utils.create_arch import create_arch
|
16 |
+
from utils.img_utils import create_grid
|
17 |
+
from huggingface_hub import PyTorchModelHubMixin
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
class MMSERectifiedFlow(LightningModule,
|
22 |
+
PyTorchModelHubMixin,
|
23 |
+
pipeline_tag="image-to-image",
|
24 |
+
license="mit",
|
25 |
+
):
|
26 |
+
def __init__(self,
|
27 |
+
stage,
|
28 |
+
arch,
|
29 |
+
conditional=False,
|
30 |
+
mmse_model_ckpt_path=None,
|
31 |
+
mmse_model_arch=None,
|
32 |
+
lr=5e-4,
|
33 |
+
weight_decay=1e-3,
|
34 |
+
betas=(0.9, 0.95),
|
35 |
+
mmse_noise_std=0.1,
|
36 |
+
num_flow_steps=50,
|
37 |
+
ema_decay=0.9999,
|
38 |
+
eps=0.0,
|
39 |
+
t_schedule='stratified_uniform',
|
40 |
+
*args,
|
41 |
+
**kwargs
|
42 |
+
):
|
43 |
+
super().__init__()
|
44 |
+
self.save_hyperparameters(logger=False)
|
45 |
+
|
46 |
+
if stage == 'flow':
|
47 |
+
if conditional:
|
48 |
+
condition_channels = 3
|
49 |
+
else:
|
50 |
+
condition_channels = 0
|
51 |
+
if mmse_model_arch is None and 'colorization' in kwargs and kwargs['colorization']:
|
52 |
+
condition_channels //= 3
|
53 |
+
self.model = create_arch(arch, condition_channels)
|
54 |
+
self.mmse_model = create_arch(mmse_model_arch, 0) if mmse_model_arch is not None else None
|
55 |
+
if mmse_model_ckpt_path is not None:
|
56 |
+
ckpt = torch.load(mmse_model_ckpt_path, map_location="cpu")
|
57 |
+
if mmse_model_arch is None:
|
58 |
+
mmse_model_arch = ckpt['hyper_parameters']['arch']
|
59 |
+
self.mmse_model = create_arch(mmse_model_arch, 0)
|
60 |
+
if 'ema' in ckpt:
|
61 |
+
# ema_decay doesn't affect anything here, because we are doing load_state_dict
|
62 |
+
mmse_ema = EMA(self.mmse_model.parameters(), decay=ema_decay)
|
63 |
+
mmse_ema.load_state_dict(ckpt['ema'])
|
64 |
+
mmse_ema.copy_to()
|
65 |
+
elif 'params_ema' in ckpt:
|
66 |
+
self.mmse_model.load_state_dict(ckpt['params_ema'])
|
67 |
+
else:
|
68 |
+
state_dict = ckpt['state_dict']
|
69 |
+
state_dict = {layer_name.replace('model.', ''): weights for layer_name, weights in
|
70 |
+
state_dict.items()}
|
71 |
+
state_dict = {layer_name.replace('module.', ''): weights for layer_name, weights in
|
72 |
+
state_dict.items()}
|
73 |
+
self.mmse_model.load_state_dict(state_dict)
|
74 |
+
for param in self.mmse_model.parameters():
|
75 |
+
param.requires_grad = False
|
76 |
+
self.mmse_model.eval()
|
77 |
+
else:
|
78 |
+
assert stage == 'mmse' or stage == 'naive_flow'
|
79 |
+
assert not conditional
|
80 |
+
self.model = create_arch(arch, 0)
|
81 |
+
self.mmse_model = None
|
82 |
+
if 'flow' in stage:
|
83 |
+
self.fid = FrechetInceptionDistance(reset_real_features=True, normalize=True)
|
84 |
+
self.inception_score = InceptionScore(normalize=True)
|
85 |
+
|
86 |
+
self.ema = EMA(self.model.parameters(), decay=ema_decay) if self.ema_wanted else None
|
87 |
+
self.test_results_path = None
|
88 |
+
|
89 |
+
@property
|
90 |
+
def ema_wanted(self):
|
91 |
+
return self.hparams.ema_decay != -1
|
92 |
+
|
93 |
+
def on_save_checkpoint(self, checkpoint: dict) -> None:
|
94 |
+
if self.ema_wanted:
|
95 |
+
checkpoint['ema'] = self.ema.state_dict()
|
96 |
+
return super().on_save_checkpoint(checkpoint)
|
97 |
+
|
98 |
+
def on_load_checkpoint(self, checkpoint: dict) -> None:
|
99 |
+
if self.ema_wanted:
|
100 |
+
self.ema.load_state_dict(checkpoint['ema'])
|
101 |
+
return super().on_load_checkpoint(checkpoint)
|
102 |
+
|
103 |
+
def on_before_zero_grad(self, optimizer) -> None:
|
104 |
+
if self.ema_wanted:
|
105 |
+
self.ema.update(self.model.parameters())
|
106 |
+
return super().on_before_zero_grad(optimizer)
|
107 |
+
|
108 |
+
def to(self, *args, **kwargs):
|
109 |
+
if self.ema_wanted:
|
110 |
+
self.ema.to(*args, **kwargs)
|
111 |
+
return super().to(*args, **kwargs)
|
112 |
+
|
113 |
+
# This will use the contextmanager of ema, to copy the EMA weights to the flow model during validation, and then restore them for training.
|
114 |
+
@contextmanager
|
115 |
+
def maybe_ema(self):
|
116 |
+
ema = self.ema
|
117 |
+
ctx = nullcontext if ema is None else ema.average_parameters
|
118 |
+
yield ctx
|
119 |
+
|
120 |
+
def forward_mmse(self, y):
|
121 |
+
return self.model(y).clip(0, 1)
|
122 |
+
|
123 |
+
def forward_flow(self, x_t, t, y=None):
|
124 |
+
if self.hparams.conditional:
|
125 |
+
if self.mmse_model is not None:
|
126 |
+
with torch.no_grad():
|
127 |
+
self.mmse_model.eval()
|
128 |
+
condition = self.mmse_model(y).clip(0, 1)
|
129 |
+
else:
|
130 |
+
condition = y
|
131 |
+
x_t = torch.cat((x_t, condition), dim=1)
|
132 |
+
return self.model(x_t, t)
|
133 |
+
|
134 |
+
def forward(self, x_t, t, y):
|
135 |
+
if 'flow' in self.hparams.stage:
|
136 |
+
return self.forward_flow(x_t, t, y)
|
137 |
+
else:
|
138 |
+
return self.forward_mmse(y)
|
139 |
+
|
140 |
+
@torch.no_grad()
|
141 |
+
def create_source_distribution_samples(self, x, y, non_noisy_z0):
|
142 |
+
with torch.no_grad():
|
143 |
+
if self.hparams.conditional:
|
144 |
+
source_dist_samples = torch.randn_like(x)
|
145 |
+
else:
|
146 |
+
if self.hparams.stage == 'flow':
|
147 |
+
if non_noisy_z0 is None:
|
148 |
+
self.mmse_model.eval()
|
149 |
+
non_noisy_z0 = self.mmse_model(y).clip(0, 1)
|
150 |
+
source_dist_samples = non_noisy_z0 + torch.randn_like(non_noisy_z0) * self.hparams.mmse_noise_std
|
151 |
+
else:
|
152 |
+
assert self.hparams.stage == 'naive_flow'
|
153 |
+
if non_noisy_z0 is not None:
|
154 |
+
source_dist_samples = non_noisy_z0
|
155 |
+
else:
|
156 |
+
source_dist_samples = y
|
157 |
+
if source_dist_samples.shape[1] != x.shape[1]:
|
158 |
+
assert source_dist_samples.shape[1] == 1 # Colorization
|
159 |
+
source_dist_samples = source_dist_samples.expand(-1, x.shape[1], -1, -1)
|
160 |
+
if self.hparams.mmse_noise_std is not None:
|
161 |
+
source_dist_samples = source_dist_samples + torch.randn_like(source_dist_samples) * self.hparams.mmse_noise_std
|
162 |
+
return source_dist_samples
|
163 |
+
|
164 |
+
@staticmethod
|
165 |
+
def stratified_uniform(bs, group=0, groups=1, dtype=None, device=None):
|
166 |
+
if groups <= 0:
|
167 |
+
raise ValueError(f"groups must be positive, got {groups}")
|
168 |
+
if group < 0 or group >= groups:
|
169 |
+
raise ValueError(f"group must be in [0, {groups})")
|
170 |
+
n = bs * groups
|
171 |
+
offsets = torch.arange(group, n, groups, dtype=dtype, device=device)
|
172 |
+
u = torch.rand(bs, dtype=dtype, device=device)
|
173 |
+
return ((offsets + u) / n).view(bs, 1, 1, 1)
|
174 |
+
|
175 |
+
def generate_random_t(self, bs, dtype=None):
|
176 |
+
if self.hparams.t_schedule == 'logit-normal':
|
177 |
+
return sigmoid(torch.randn(bs, 1, 1, 1, device=self.device)) * (1.0 - self.hparams.eps) + self.hparams.eps
|
178 |
+
elif self.hparams.t_schedule == 'uniform':
|
179 |
+
return torch.rand(bs, 1, 1, 1, device=self.device) * (1.0 - self.hparams.eps) + self.hparams.eps
|
180 |
+
elif self.hparams.t_schedule == 'stratified_uniform':
|
181 |
+
return self.stratified_uniform(bs, self.trainer.global_rank, self.trainer.world_size, dtype=dtype,
|
182 |
+
device=self.device) * (1.0 - self.hparams.eps) + self.hparams.eps
|
183 |
+
else:
|
184 |
+
raise NotImplementedError()
|
185 |
+
|
186 |
+
def training_step(self, batch, batch_idx):
|
187 |
+
x = batch['x']
|
188 |
+
y = batch['y']
|
189 |
+
non_noisy_z0 = batch['non_noisy_z0'] if 'non_noisy_z0' in batch else None
|
190 |
+
if 'flow' in self.hparams.stage:
|
191 |
+
with torch.no_grad():
|
192 |
+
t = self.generate_random_t(x.shape[0], dtype=x.dtype)
|
193 |
+
source_dist_samples = self.create_source_distribution_samples(x, y, non_noisy_z0)
|
194 |
+
x_t = t * x + (1.0 - t) * source_dist_samples
|
195 |
+
v_t = self(x_t, t.squeeze(), y)
|
196 |
+
loss = mse_loss(v_t, x - source_dist_samples)
|
197 |
+
else:
|
198 |
+
xhat = self(x_t=None, t=None, y=y)
|
199 |
+
loss = mse_loss(xhat, x)
|
200 |
+
self.log("train/loss", loss)
|
201 |
+
return loss
|
202 |
+
|
203 |
+
@torch.no_grad()
|
204 |
+
def generate_reconstructions(self, x, y, non_noisy_z0, num_flow_steps, result_device):
|
205 |
+
with self.maybe_ema():
|
206 |
+
if 'flow' in self.hparams.stage:
|
207 |
+
source_dist_samples = self.create_source_distribution_samples(x, y, non_noisy_z0)
|
208 |
+
|
209 |
+
dt = (1.0 / num_flow_steps) * (1.0 - self.hparams.eps)
|
210 |
+
x_t_next = source_dist_samples.clone()
|
211 |
+
x_t_seq = [x_t_next]
|
212 |
+
t_one = torch.ones(x.shape[0], device=self.device)
|
213 |
+
for i in range(num_flow_steps):
|
214 |
+
num_t = (i / num_flow_steps) * (1.0 - self.hparams.eps) + self.hparams.eps
|
215 |
+
v_t_next = self(x_t=x_t_next, t=t_one * num_t, y=y).to(x_t_next.dtype)
|
216 |
+
x_t_next = x_t_next.clone() + v_t_next * dt
|
217 |
+
x_t_seq.append(x_t_next.to(result_device))
|
218 |
+
|
219 |
+
xhat = x_t_seq[-1].clip(0, 1).to(torch.float32)
|
220 |
+
source_dist_samples = source_dist_samples.to(result_device)
|
221 |
+
else:
|
222 |
+
xhat = self(x_t=None, t=None, y=y).to(torch.float32)
|
223 |
+
x_t_seq = None
|
224 |
+
source_dist_samples = None
|
225 |
+
return xhat.to(result_device), x_t_seq, source_dist_samples
|
226 |
+
|
227 |
+
def validation_step(self, batch, batch_idx):
|
228 |
+
x = batch['x']
|
229 |
+
y = batch['y']
|
230 |
+
non_noisy_z0 = batch['non_noisy_z0'] if 'non_noisy_z0' in batch else None
|
231 |
+
xhat, x_t_seq, source_dist_samples = self.generate_reconstructions(x, y, non_noisy_z0, self.hparams.num_flow_steps,
|
232 |
+
self.device)
|
233 |
+
x = x.to(torch.float32)
|
234 |
+
y = y.to(torch.float32)
|
235 |
+
self.log_dict({"val_metrics/mse": ((x - xhat) ** 2).mean()}, on_step=False, on_epoch=True, sync_dist=True,
|
236 |
+
batch_size=x.shape[0])
|
237 |
+
|
238 |
+
if 'flow' in self.hparams.stage:
|
239 |
+
self.fid.update(x, real=True)
|
240 |
+
self.fid.update(xhat, real=False)
|
241 |
+
self.inception_score.update(xhat)
|
242 |
+
|
243 |
+
if batch_idx == 0:
|
244 |
+
wandb_logger = self.logger.experiment
|
245 |
+
wandb_logger.log({'val_images/x': [wandb.Image(to_pil_image(create_grid(x)))],
|
246 |
+
'val_images/y': [wandb.Image(to_pil_image(create_grid(y.clip(0, 1))))],
|
247 |
+
'val_images/xhat': [wandb.Image(to_pil_image(create_grid(xhat)))], })
|
248 |
+
if 'flow' in self.hparams.stage:
|
249 |
+
wandb_logger.log({'val_images/x_t_seq': [wandb.Image(to_pil_image(create_grid(
|
250 |
+
torch.cat([elem[0].unsqueeze(0).to(torch.float32) for elem in x_t_seq], dim=0).clip(0, 1),
|
251 |
+
num_images=len(x_t_seq))))], 'val_images/source_distribution_samples': [
|
252 |
+
wandb.Image(to_pil_image(create_grid(source_dist_samples.clip(0, 1).to(torch.float32))))]})
|
253 |
+
if self.mmse_model is not None:
|
254 |
+
xhat_mmse = self.mmse_model(y).clip(0, 1)
|
255 |
+
wandb_logger.log({'val_images/xhat_mmse': [
|
256 |
+
wandb.Image(to_pil_image(create_grid(xhat_mmse.to(torch.float32))))]})
|
257 |
+
|
258 |
+
def on_validation_epoch_end(self):
|
259 |
+
if 'flow' in self.hparams.stage:
|
260 |
+
inception_score_mean, inception_score_std = self.inception_score.compute()
|
261 |
+
self.log_dict(
|
262 |
+
{'val_metrics/fid': self.fid.compute(),
|
263 |
+
'val_metrics/inception_score_mean': inception_score_mean,
|
264 |
+
'val_metrics/inception_score_std': inception_score_std},
|
265 |
+
on_epoch=True, on_step=False, sync_dist=True,
|
266 |
+
batch_size=1)
|
267 |
+
self.fid.reset()
|
268 |
+
self.inception_score.reset()
|
269 |
+
|
270 |
+
def test_step(self, batch, batch_idx):
|
271 |
+
assert self.test_results_path is not None, "Please set test_results_path before testing."
|
272 |
+
assert os.path.isdir(self.test_results_path), 'Please make sure the test_result_path dir exists.'
|
273 |
+
|
274 |
+
def save_image_batch(images, folder, image_file_names):
|
275 |
+
os.makedirs(folder, exist_ok=True)
|
276 |
+
for i, img in enumerate(images):
|
277 |
+
save_image(images[i].clip(0, 1), os.path.join(folder, image_file_names[i]))
|
278 |
+
|
279 |
+
os.makedirs(self.test_results_path, exist_ok=True)
|
280 |
+
x = batch['x']
|
281 |
+
y = batch['y']
|
282 |
+
non_noisy_z0 = batch['non_noisy_z0'] if 'non_noisy_z0' in batch else None
|
283 |
+
y_path = os.path.join(self.test_results_path, 'y')
|
284 |
+
save_image_batch(y, y_path, batch['img_file_name'])
|
285 |
+
|
286 |
+
if 'flow' in self.hparams.stage:
|
287 |
+
source_dist_samples_to_save = None
|
288 |
+
|
289 |
+
for num_flow_steps in self.num_test_flow_steps:
|
290 |
+
xhat, x_t_seq, source_dist_samples = self.generate_reconstructions(x, y, non_noisy_z0, num_flow_steps,
|
291 |
+
torch.device("cpu"))
|
292 |
+
xhat_path = os.path.join(self.test_results_path, f"num_flow_steps={num_flow_steps}", 'xhat')
|
293 |
+
save_image_batch(xhat, xhat_path, batch['img_file_name'])
|
294 |
+
if source_dist_samples_to_save is None:
|
295 |
+
source_dist_samples_to_save = source_dist_samples
|
296 |
+
|
297 |
+
source_distribution_samples_path = os.path.join(self.test_results_path, 'source_distribution_samples')
|
298 |
+
save_image_batch(source_dist_samples_to_save, source_distribution_samples_path, batch['img_file_name'])
|
299 |
+
if self.mmse_model is not None:
|
300 |
+
mmse_estimates = self.mmse_model(y).clip(0, 1)
|
301 |
+
mmse_samples_path = os.path.join(self.test_results_path, 'mmse_samples')
|
302 |
+
save_image_batch(mmse_estimates, mmse_samples_path, batch['img_file_name'])
|
303 |
+
|
304 |
+
|
305 |
+
else:
|
306 |
+
xhat, _, _ = self.generate_reconstructions(x, y, non_noisy_z0, None, torch.device('cpu'))
|
307 |
+
xhat_path = os.path.join(self.test_results_path, 'xhat')
|
308 |
+
save_image_batch(xhat, xhat_path, batch['img_file_name'])
|
309 |
+
|
310 |
+
def configure_optimizers(self):
|
311 |
+
# Add here a learning rate scheduler if you wish to do so.
|
312 |
+
optimizer = AdamW(self.model.parameters(),
|
313 |
+
betas=self.hparams.betas,
|
314 |
+
eps=1e-8,
|
315 |
+
lr=self.hparams.lr,
|
316 |
+
weight_decay=self.hparams.weight_decay)
|
317 |
+
return optimizer
|