Spaces:
Runtime error
Runtime error
mueller-franzes
commited on
Commit
·
f85e212
1
Parent(s):
96a28c6
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +25 -0
- README.md +62 -2
- medical_diffusion/data/augmentation/__init__.py +0 -0
- medical_diffusion/data/augmentation/augmentations_2d.py +27 -0
- medical_diffusion/data/augmentation/augmentations_3d.py +38 -0
- medical_diffusion/data/datamodules/__init__.py +1 -0
- medical_diffusion/data/datamodules/datamodule_simple.py +79 -0
- medical_diffusion/data/datasets/__init__.py +2 -0
- medical_diffusion/data/datasets/dataset_simple_2d.py +198 -0
- medical_diffusion/data/datasets/dataset_simple_3d.py +58 -0
- medical_diffusion/external/diffusers/attention.py +347 -0
- medical_diffusion/external/diffusers/embeddings.py +89 -0
- medical_diffusion/external/diffusers/resnet.py +479 -0
- medical_diffusion/external/diffusers/taming_discriminator.py +57 -0
- medical_diffusion/external/diffusers/unet.py +257 -0
- medical_diffusion/external/diffusers/unet_blocks.py +1557 -0
- medical_diffusion/external/diffusers/vae.py +857 -0
- medical_diffusion/external/stable_diffusion/attention.py +261 -0
- medical_diffusion/external/stable_diffusion/lr_schedulers.py +33 -0
- medical_diffusion/external/stable_diffusion/unet_openai.py +962 -0
- medical_diffusion/external/stable_diffusion/util.py +284 -0
- medical_diffusion/external/stable_diffusion/util_attention.py +56 -0
- medical_diffusion/external/unet_lucidrains.py +332 -0
- medical_diffusion/loss/gan_losses.py +22 -0
- medical_diffusion/loss/perceivers.py +27 -0
- medical_diffusion/metrics/__init__.py +0 -0
- medical_diffusion/metrics/torchmetrics_pr_recall.py +170 -0
- medical_diffusion/models/__init__.py +1 -0
- medical_diffusion/models/embedders/__init__.py +2 -0
- medical_diffusion/models/embedders/cond_embedders.py +27 -0
- medical_diffusion/models/embedders/latent_embedders.py +1065 -0
- medical_diffusion/models/embedders/time_embedder.py +75 -0
- medical_diffusion/models/estimators/__init__.py +1 -0
- medical_diffusion/models/estimators/unet.py +186 -0
- medical_diffusion/models/estimators/unet2.py +279 -0
- medical_diffusion/models/model_base.py +114 -0
- medical_diffusion/models/noise_schedulers/__init__.py +2 -0
- medical_diffusion/models/noise_schedulers/gaussian_scheduler.py +154 -0
- medical_diffusion/models/noise_schedulers/scheduler_base.py +49 -0
- medical_diffusion/models/pipelines/__init__.py +1 -0
- medical_diffusion/models/pipelines/diffusion_pipeline.py +348 -0
- medical_diffusion/models/utils/__init__.py +2 -0
- medical_diffusion/models/utils/attention_blocks.py +335 -0
- medical_diffusion/models/utils/conv_blocks.py +528 -0
- medical_diffusion/utils/math_utils.py +6 -0
- medical_diffusion/utils/train_utils.py +88 -0
- requirements.txt +17 -0
- scripts/evaluate_images.py +129 -0
- scripts/evaluate_latent_embedder.py +98 -0
- scripts/helpers/dump_discrimnator.py +26 -0
.gitignore
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**/*
|
2 |
+
!/**/
|
3 |
+
/venv/
|
4 |
+
!*.ipynb
|
5 |
+
!*.gitignore
|
6 |
+
!*.md
|
7 |
+
!*.bat
|
8 |
+
!*.py
|
9 |
+
!*.yml
|
10 |
+
!*.ui
|
11 |
+
!*.yaml
|
12 |
+
|
13 |
+
!requirements.txt
|
14 |
+
!version.txt
|
15 |
+
|
16 |
+
/docs/build
|
17 |
+
!/docs/Makefile
|
18 |
+
|
19 |
+
/build/
|
20 |
+
|
21 |
+
|
22 |
+
/results
|
23 |
+
/scripts/local_trash
|
24 |
+
|
25 |
+
!/media/**
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Medfusion App
|
3 |
-
emoji:
|
4 |
colorFrom: pink
|
5 |
colorTo: gray
|
6 |
sdk: streamlit
|
@@ -10,4 +10,64 @@ pinned: false
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: Medfusion App
|
3 |
+
emoji: 🔬
|
4 |
colorFrom: pink
|
5 |
colorTo: gray
|
6 |
sdk: streamlit
|
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
+
Medfusion - Medical Denoising Diffusion Probabilistic Model
|
14 |
+
=============
|
15 |
+
|
16 |
+
Paper
|
17 |
+
=======
|
18 |
+
Please see: [**Diffusion Probabilistic Models beat GANs on Medical 2D Images**]()
|
19 |
+
|
20 |
+
![](media/Medfusion.png)
|
21 |
+
*Figure: Medfusion*
|
22 |
+
|
23 |
+
![](media/animation_eye.gif) ![](media/animation_histo.gif) ![](media/animation_chest.gif)\
|
24 |
+
*Figure: Eye fundus, chest X-ray and colon histology images generated with Medfusion (Warning color quality limited by .gif)*
|
25 |
+
|
26 |
+
Demo
|
27 |
+
=============
|
28 |
+
[Link]() to streamlit app.
|
29 |
+
|
30 |
+
Install
|
31 |
+
=============
|
32 |
+
|
33 |
+
Create virtual environment and install packages: \
|
34 |
+
`python -m venv venv` \
|
35 |
+
`source venv/bin/activate`\
|
36 |
+
`pip install -e .`
|
37 |
+
|
38 |
+
|
39 |
+
Get Started
|
40 |
+
=============
|
41 |
+
|
42 |
+
1 Prepare Data
|
43 |
+
-------------
|
44 |
+
|
45 |
+
* Go to [medical_diffusion/data/datasets/dataset_simple_2d.py](medical_diffusion/data/datasets/dataset_simple_2d.py) and create a new `SimpleDataset2D` or write your own Dataset.
|
46 |
+
|
47 |
+
|
48 |
+
2 Train Autoencoder
|
49 |
+
----------------
|
50 |
+
* Go to [scripts/train_latent_embedder_2d.py](scripts/train_latent_embedder_2d.py) and import your Dataset.
|
51 |
+
* Load your dataset with eg. `SimpleDataModule`
|
52 |
+
* Customize `VAE` to your needs
|
53 |
+
* (Optional): Train a `VAEGAN` instead or load a pre-trained `VAE` and set `start_gan_train_step=-1` to start training of GAN immediately.
|
54 |
+
|
55 |
+
2.1 Evaluate Autoencoder
|
56 |
+
----------------
|
57 |
+
* Use [scripts/evaluate_latent_embedder.py](scripts/evaluate_latent_embedder.py) to evaluate the performance of the Autoencoder.
|
58 |
+
|
59 |
+
3 Train Diffusion
|
60 |
+
----------------
|
61 |
+
* Go to [scripts/train_diffusion.py](scripts/train_diffusion.py) and import/load your Dataset as before.
|
62 |
+
* Load your pre-trained VAE or VAEGAN with `latent_embedder_checkpoint=...`
|
63 |
+
* Use `cond_embedder = LabelEmbedder` for conditional training, otherwise `cond_embedder = None`
|
64 |
+
|
65 |
+
3.1 Evaluate Diffusion
|
66 |
+
----------------
|
67 |
+
* Go to [scripts/sample.py](scripts/sample.py) to sample a test image.
|
68 |
+
* Go to [scripts/helpers/sample_dataset.py](scripts/helpers/sample_dataset.py) to sample a more reprensative sample size.
|
69 |
+
* Use [scripts/evaluate_images.py](scripts/evaluate_images.py) to evaluate performance of sample (FID, Precision, Recall)
|
70 |
+
|
71 |
+
Acknowledgment
|
72 |
+
=============
|
73 |
+
* Code builds upon https://github.com/lucidrains/denoising-diffusion-pytorch
|
medical_diffusion/data/augmentation/__init__.py
ADDED
File without changes
|
medical_diffusion/data/augmentation/augmentations_2d.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class ToTensor16bit(object):
|
6 |
+
"""PyTorch can not handle uint16 only int16. First transform to int32. Note, this function also adds a channel-dim"""
|
7 |
+
def __call__(self, image):
|
8 |
+
# return torch.as_tensor(np.array(image, dtype=np.int32)[None])
|
9 |
+
# return torch.from_numpy(np.array(image, np.int32, copy=True)[None])
|
10 |
+
image = np.array(image, np.int32, copy=True) # [H,W,C] or [H,W]
|
11 |
+
image = np.expand_dims(image, axis=-1) if image.ndim ==2 else image
|
12 |
+
return torch.from_numpy(np.moveaxis(image, -1, 0)) #[C, H, W]
|
13 |
+
|
14 |
+
class Normalize(object):
|
15 |
+
"""Rescale the image to [0,1] and ensure float32 dtype """
|
16 |
+
|
17 |
+
def __call__(self, image):
|
18 |
+
image = image.type(torch.FloatTensor)
|
19 |
+
return (image-image.min())/(image.max()-image.min())
|
20 |
+
|
21 |
+
|
22 |
+
class RandomBackground(object):
|
23 |
+
"""Fill Background (intensity ==0) with random values"""
|
24 |
+
|
25 |
+
def __call__(self, image):
|
26 |
+
image[image==0] = torch.rand(*image[image==0].shape) #(image.max()-image.min())
|
27 |
+
return image
|
medical_diffusion/data/augmentation/augmentations_3d.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchio as tio
|
2 |
+
from typing import Union, Optional, Sequence
|
3 |
+
from torchio.typing import TypeTripletInt
|
4 |
+
from torchio import Subject, Image
|
5 |
+
from torchio.utils import to_tuple
|
6 |
+
|
7 |
+
class CropOrPad_None(tio.CropOrPad):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
target_shape: Union[int, TypeTripletInt, None] = None,
|
11 |
+
padding_mode: Union[str, float] = 0,
|
12 |
+
mask_name: Optional[str] = None,
|
13 |
+
labels: Optional[Sequence[int]] = None,
|
14 |
+
**kwargs
|
15 |
+
):
|
16 |
+
|
17 |
+
# WARNING: Ugly workaround to allow None values
|
18 |
+
if target_shape is not None:
|
19 |
+
self.original_target_shape = to_tuple(target_shape, length=3)
|
20 |
+
target_shape = [1 if t_s is None else t_s for t_s in target_shape]
|
21 |
+
super().__init__(target_shape, padding_mode, mask_name, labels, **kwargs)
|
22 |
+
|
23 |
+
def apply_transform(self, subject: Subject):
|
24 |
+
# WARNING: This makes the transformation subject dependent - reverse transformation must be adapted
|
25 |
+
if self.target_shape is not None:
|
26 |
+
self.target_shape = [s_s if t_s is None else t_s for t_s, s_s in zip(self.original_target_shape, subject.spatial_shape)]
|
27 |
+
return super().apply_transform(subject=subject)
|
28 |
+
|
29 |
+
|
30 |
+
class SubjectToTensor(object):
|
31 |
+
"""Transforms TorchIO Subjects into a Python dict and changes axes order from TorchIO to Torch"""
|
32 |
+
def __call__(self, subject: Subject):
|
33 |
+
return {key: val.data.swapaxes(1,-1) if isinstance(val, Image) else val for key,val in subject.items()}
|
34 |
+
|
35 |
+
class ImageToTensor(object):
|
36 |
+
"""Transforms TorchIO Image into a Numpy/Torch Tensor and changes axes order from TorchIO [B, C, W, H, D] to Torch [B, C, D, H, W]"""
|
37 |
+
def __call__(self, image: Image):
|
38 |
+
return image.data.swapaxes(1,-1)
|
medical_diffusion/data/datamodules/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .datamodule_simple import SimpleDataModule
|
medical_diffusion/data/datamodules/datamodule_simple.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import pytorch_lightning as pl
|
3 |
+
import torch
|
4 |
+
from torch.utils.data.dataloader import DataLoader
|
5 |
+
import torch.multiprocessing as mp
|
6 |
+
from torch.utils.data.sampler import WeightedRandomSampler, RandomSampler
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class SimpleDataModule(pl.LightningDataModule):
|
11 |
+
|
12 |
+
def __init__(self,
|
13 |
+
ds_train: object,
|
14 |
+
ds_val:object =None,
|
15 |
+
ds_test:object =None,
|
16 |
+
batch_size: int = 1,
|
17 |
+
num_workers: int = mp.cpu_count(),
|
18 |
+
seed: int = 0,
|
19 |
+
pin_memory: bool = False,
|
20 |
+
weights: list = None
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
self.hyperparameters = {**locals()}
|
24 |
+
self.hyperparameters.pop('__class__')
|
25 |
+
self.hyperparameters.pop('self')
|
26 |
+
|
27 |
+
self.ds_train = ds_train
|
28 |
+
self.ds_val = ds_val
|
29 |
+
self.ds_test = ds_test
|
30 |
+
|
31 |
+
self.batch_size = batch_size
|
32 |
+
self.num_workers = num_workers
|
33 |
+
self.seed = seed
|
34 |
+
self.pin_memory = pin_memory
|
35 |
+
self.weights = weights
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
def train_dataloader(self):
|
40 |
+
generator = torch.Generator()
|
41 |
+
generator.manual_seed(self.seed)
|
42 |
+
|
43 |
+
if self.weights is not None:
|
44 |
+
sampler = WeightedRandomSampler(self.weights, len(self.weights), generator=generator)
|
45 |
+
else:
|
46 |
+
sampler = RandomSampler(self.ds_train, replacement=False, generator=generator)
|
47 |
+
return DataLoader(self.ds_train, batch_size=self.batch_size, num_workers=self.num_workers,
|
48 |
+
sampler=sampler, generator=generator, drop_last=True, pin_memory=self.pin_memory)
|
49 |
+
|
50 |
+
|
51 |
+
def val_dataloader(self):
|
52 |
+
generator = torch.Generator()
|
53 |
+
generator.manual_seed(self.seed)
|
54 |
+
if self.ds_val is not None:
|
55 |
+
return DataLoader(self.ds_val, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False,
|
56 |
+
generator=generator, drop_last=False, pin_memory=self.pin_memory)
|
57 |
+
else:
|
58 |
+
raise AssertionError("A validation set was not initialized.")
|
59 |
+
|
60 |
+
|
61 |
+
def test_dataloader(self):
|
62 |
+
generator = torch.Generator()
|
63 |
+
generator.manual_seed(self.seed)
|
64 |
+
if self.ds_test is not None:
|
65 |
+
return DataLoader(self.ds_test, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False,
|
66 |
+
generator = generator, drop_last=False, pin_memory=self.pin_memory)
|
67 |
+
else:
|
68 |
+
raise AssertionError("A test test set was not initialized.")
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
medical_diffusion/data/datasets/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .dataset_simple_2d import *
|
2 |
+
from .dataset_simple_3d import *
|
medical_diffusion/data/datasets/dataset_simple_2d.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch.utils.data as data
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from pathlib import Path
|
6 |
+
from torchvision import transforms as T
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
from medical_diffusion.data.augmentation.augmentations_2d import Normalize, ToTensor16bit
|
12 |
+
|
13 |
+
class SimpleDataset2D(data.Dataset):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
path_root,
|
17 |
+
item_pointers =[],
|
18 |
+
crawler_ext = 'tif', # other options are ['jpg', 'jpeg', 'png', 'tiff'],
|
19 |
+
transform = None,
|
20 |
+
image_resize = None,
|
21 |
+
augment_horizontal_flip = False,
|
22 |
+
augment_vertical_flip = False,
|
23 |
+
image_crop = None,
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
self.path_root = Path(path_root)
|
27 |
+
self.crawler_ext = crawler_ext
|
28 |
+
if len(item_pointers):
|
29 |
+
self.item_pointers = item_pointers
|
30 |
+
else:
|
31 |
+
self.item_pointers = self.run_item_crawler(self.path_root, self.crawler_ext)
|
32 |
+
|
33 |
+
if transform is None:
|
34 |
+
self.transform = T.Compose([
|
35 |
+
T.Resize(image_resize) if image_resize is not None else nn.Identity(),
|
36 |
+
T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
|
37 |
+
T.RandomVerticalFlip() if augment_vertical_flip else nn.Identity(),
|
38 |
+
T.CenterCrop(image_crop) if image_crop is not None else nn.Identity(),
|
39 |
+
T.ToTensor(),
|
40 |
+
# T.Lambda(lambda x: torch.cat([x]*3) if x.shape[0]==1 else x),
|
41 |
+
# ToTensor16bit(),
|
42 |
+
# Normalize(), # [0, 1.0]
|
43 |
+
# T.ConvertImageDtype(torch.float),
|
44 |
+
T.Normalize(mean=0.5, std=0.5) # WARNING: mean and std are not the target values but rather the values to subtract and divide by: [0, 1] -> [0-0.5, 1-0.5]/0.5 -> [-1, 1]
|
45 |
+
])
|
46 |
+
else:
|
47 |
+
self.transform = transform
|
48 |
+
|
49 |
+
def __len__(self):
|
50 |
+
return len(self.item_pointers)
|
51 |
+
|
52 |
+
def __getitem__(self, index):
|
53 |
+
rel_path_item = self.item_pointers[index]
|
54 |
+
path_item = self.path_root/rel_path_item
|
55 |
+
# img = Image.open(path_item)
|
56 |
+
img = self.load_item(path_item)
|
57 |
+
return {'uid':rel_path_item.stem, 'source': self.transform(img)}
|
58 |
+
|
59 |
+
def load_item(self, path_item):
|
60 |
+
return Image.open(path_item).convert('RGB')
|
61 |
+
# return cv2.imread(str(path_item), cv2.IMREAD_UNCHANGED) # NOTE: Only CV2 supports 16bit RGB images
|
62 |
+
|
63 |
+
@classmethod
|
64 |
+
def run_item_crawler(cls, path_root, extension, **kwargs):
|
65 |
+
return [path.relative_to(path_root) for path in Path(path_root).rglob(f'*.{extension}')]
|
66 |
+
|
67 |
+
def get_weights(self):
|
68 |
+
"""Return list of class-weights for WeightedSampling"""
|
69 |
+
return None
|
70 |
+
|
71 |
+
|
72 |
+
class AIROGSDataset(SimpleDataset2D):
|
73 |
+
def __init__(self, *args, **kwargs):
|
74 |
+
super().__init__(*args, **kwargs)
|
75 |
+
self.labels = pd.read_csv(self.path_root.parent/'train_labels.csv', index_col='challenge_id')
|
76 |
+
|
77 |
+
def __len__(self):
|
78 |
+
return len(self.labels)
|
79 |
+
|
80 |
+
def __getitem__(self, index):
|
81 |
+
uid = self.labels.index[index]
|
82 |
+
path_item = self.path_root/f'{uid}.jpg'
|
83 |
+
img = self.load_item(path_item)
|
84 |
+
str_2_int = {'NRG':0, 'RG':1} # RG = 3270, NRG = 98172
|
85 |
+
target = str_2_int[self.labels.loc[uid, 'class']]
|
86 |
+
# return {'uid':uid, 'source': self.transform(img), 'target':target}
|
87 |
+
return {'source': self.transform(img), 'target':target}
|
88 |
+
|
89 |
+
def get_weights(self):
|
90 |
+
n_samples = len(self)
|
91 |
+
weight_per_class = 1/self.labels['class'].value_counts(normalize=True) # {'NRG': 1.03, 'RG': 31.02}
|
92 |
+
weights = [0] * n_samples
|
93 |
+
for index in range(n_samples):
|
94 |
+
target = self.labels.iloc[index]['class']
|
95 |
+
weights[index] = weight_per_class[target]
|
96 |
+
return weights
|
97 |
+
|
98 |
+
@classmethod
|
99 |
+
def run_item_crawler(cls, path_root, extension, **kwargs):
|
100 |
+
"""Overwrite to speed up as paths are determined by .csv file anyway"""
|
101 |
+
return []
|
102 |
+
|
103 |
+
class MSIvsMSS_Dataset(SimpleDataset2D):
|
104 |
+
# https://doi.org/10.5281/zenodo.2530835
|
105 |
+
def __getitem__(self, index):
|
106 |
+
rel_path_item = self.item_pointers[index]
|
107 |
+
path_item = self.path_root/rel_path_item
|
108 |
+
img = self.load_item(path_item)
|
109 |
+
uid = rel_path_item.stem
|
110 |
+
str_2_int = {'MSIMUT':0, 'MSS':1}
|
111 |
+
target = str_2_int[path_item.parent.name] #
|
112 |
+
return {'uid':uid, 'source': self.transform(img), 'target':target}
|
113 |
+
|
114 |
+
|
115 |
+
class MSIvsMSS_2_Dataset(SimpleDataset2D):
|
116 |
+
# https://doi.org/10.5281/zenodo.3832231
|
117 |
+
def __getitem__(self, index):
|
118 |
+
rel_path_item = self.item_pointers[index]
|
119 |
+
path_item = self.path_root/rel_path_item
|
120 |
+
img = self.load_item(path_item)
|
121 |
+
uid = rel_path_item.stem
|
122 |
+
str_2_int = {'MSIH':0, 'nonMSIH':1} # patients with MSI-H = MSIH; patients with MSI-L and MSS = NonMSIH)
|
123 |
+
target = str_2_int[path_item.parent.name]
|
124 |
+
# return {'uid':uid, 'source': self.transform(img), 'target':target}
|
125 |
+
return {'source': self.transform(img), 'target':target}
|
126 |
+
|
127 |
+
|
128 |
+
class CheXpert_Dataset(SimpleDataset2D):
|
129 |
+
def __init__(self, *args, **kwargs):
|
130 |
+
super().__init__(*args, **kwargs)
|
131 |
+
mode = self.path_root.name
|
132 |
+
labels = pd.read_csv(self.path_root.parent/f'{mode}.csv', index_col='Path')
|
133 |
+
self.labels = labels.loc[labels['Frontal/Lateral'] == 'Frontal'].copy()
|
134 |
+
self.labels.index = self.labels.index.str[20:]
|
135 |
+
self.labels.loc[self.labels['Sex'] == 'Unknown', 'Sex'] = 'Female' # Affects 1 case, must be "female" to match stats in publication
|
136 |
+
self.labels.fillna(2, inplace=True) # TODO: Find better solution,
|
137 |
+
str_2_int = {'Sex': {'Male':0, 'Female':1}, 'Frontal/Lateral':{'Frontal':0, 'Lateral':1}, 'AP/PA':{'AP':0, 'PA':1}}
|
138 |
+
self.labels.replace(str_2_int, inplace=True)
|
139 |
+
|
140 |
+
def __len__(self):
|
141 |
+
return len(self.labels)
|
142 |
+
|
143 |
+
def __getitem__(self, index):
|
144 |
+
rel_path_item = self.labels.index[index]
|
145 |
+
path_item = self.path_root/rel_path_item
|
146 |
+
img = self.load_item(path_item)
|
147 |
+
uid = str(rel_path_item)
|
148 |
+
target = torch.tensor(self.labels.loc[uid, 'Cardiomegaly']+1, dtype=torch.long) # Note Labels are -1=uncertain, 0=negative, 1=positive, NA=not reported -> Map to [0, 2], NA=3
|
149 |
+
return {'uid':uid, 'source': self.transform(img), 'target':target}
|
150 |
+
|
151 |
+
|
152 |
+
@classmethod
|
153 |
+
def run_item_crawler(cls, path_root, extension, **kwargs):
|
154 |
+
"""Overwrite to speed up as paths are determined by .csv file anyway"""
|
155 |
+
return []
|
156 |
+
|
157 |
+
class CheXpert_2_Dataset(SimpleDataset2D):
|
158 |
+
def __init__(self, *args, **kwargs):
|
159 |
+
super().__init__(*args, **kwargs)
|
160 |
+
labels = pd.read_csv(self.path_root/'labels/cheXPert_label.csv', index_col=['Path', 'Image Index']) # Note: 1 and -1 (uncertain) cases count as positives (1), 0 and NA count as negatives (0)
|
161 |
+
labels = labels.loc[labels['fold']=='train'].copy()
|
162 |
+
labels = labels.drop(labels='fold', axis=1)
|
163 |
+
|
164 |
+
labels2 = pd.read_csv(self.path_root/'labels/train.csv', index_col='Path')
|
165 |
+
labels2 = labels2.loc[labels2['Frontal/Lateral'] == 'Frontal'].copy()
|
166 |
+
labels2 = labels2[['Cardiomegaly',]].copy()
|
167 |
+
labels2[ (labels2 <0) | labels2.isna()] = 2 # 0 = Negative, 1 = Positive, 2 = Uncertain
|
168 |
+
labels = labels.join(labels2['Cardiomegaly'], on=["Path",], rsuffix='_true')
|
169 |
+
# labels = labels[labels['Cardiomegaly_true']!=2]
|
170 |
+
|
171 |
+
self.labels = labels
|
172 |
+
|
173 |
+
def __len__(self):
|
174 |
+
return len(self.labels)
|
175 |
+
|
176 |
+
def __getitem__(self, index):
|
177 |
+
path_index, image_index = self.labels.index[index]
|
178 |
+
path_item = self.path_root/'data'/f'{image_index:06}.png'
|
179 |
+
img = self.load_item(path_item)
|
180 |
+
uid = image_index
|
181 |
+
target = int(self.labels.loc[(path_index, image_index), 'Cardiomegaly'])
|
182 |
+
# return {'uid':uid, 'source': self.transform(img), 'target':target}
|
183 |
+
return {'source': self.transform(img), 'target':target}
|
184 |
+
|
185 |
+
@classmethod
|
186 |
+
def run_item_crawler(cls, path_root, extension, **kwargs):
|
187 |
+
"""Overwrite to speed up as paths are determined by .csv file anyway"""
|
188 |
+
return []
|
189 |
+
|
190 |
+
def get_weights(self):
|
191 |
+
n_samples = len(self)
|
192 |
+
weight_per_class = 1/self.labels['Cardiomegaly'].value_counts(normalize=True)
|
193 |
+
# weight_per_class = {2.0: 1.2, 1.0: 8.2, 0.0: 24.3}
|
194 |
+
weights = [0] * n_samples
|
195 |
+
for index in range(n_samples):
|
196 |
+
target = self.labels.loc[self.labels.index[index], 'Cardiomegaly']
|
197 |
+
weights[index] = weight_per_class[target]
|
198 |
+
return weights
|
medical_diffusion/data/datasets/dataset_simple_3d.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch.utils.data as data
|
3 |
+
from pathlib import Path
|
4 |
+
from torchvision import transforms as T
|
5 |
+
|
6 |
+
|
7 |
+
import torchio as tio
|
8 |
+
|
9 |
+
from medical_diffusion.data.augmentation.augmentations_3d import ImageToTensor
|
10 |
+
|
11 |
+
|
12 |
+
class SimpleDataset3D(data.Dataset):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
path_root,
|
16 |
+
item_pointers =[],
|
17 |
+
crawler_ext = ['nii'], # other options are ['nii.gz'],
|
18 |
+
transform = None,
|
19 |
+
image_resize = None,
|
20 |
+
flip = False,
|
21 |
+
image_crop = None,
|
22 |
+
use_znorm=True, # Use z-Norm for MRI as scale is arbitrary, otherwise scale intensity to [-1, 1]
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
self.path_root = path_root
|
26 |
+
self.crawler_ext = crawler_ext
|
27 |
+
|
28 |
+
if transform is None:
|
29 |
+
self.transform = T.Compose([
|
30 |
+
tio.Resize(image_resize) if image_resize is not None else tio.Lambda(lambda x: x),
|
31 |
+
tio.RandomFlip((0,1,2)) if flip else tio.Lambda(lambda x: x),
|
32 |
+
tio.CropOrPad(image_crop) if image_crop is not None else tio.Lambda(lambda x: x),
|
33 |
+
tio.ZNormalization() if use_znorm else tio.RescaleIntensity((-1,1)),
|
34 |
+
ImageToTensor() # [C, W, H, D] -> [C, D, H, W]
|
35 |
+
])
|
36 |
+
else:
|
37 |
+
self.transform = transform
|
38 |
+
|
39 |
+
if len(item_pointers):
|
40 |
+
self.item_pointers = item_pointers
|
41 |
+
else:
|
42 |
+
self.item_pointers = self.run_item_crawler(self.path_root, self.crawler_ext)
|
43 |
+
|
44 |
+
def __len__(self):
|
45 |
+
return len(self.item_pointers)
|
46 |
+
|
47 |
+
def __getitem__(self, index):
|
48 |
+
rel_path_item = self.item_pointers[index]
|
49 |
+
path_item = self.path_root/rel_path_item
|
50 |
+
img = self.load_item(path_item)
|
51 |
+
return {'uid':rel_path_item.stem, 'source': self.transform(img)}
|
52 |
+
|
53 |
+
def load_item(self, path_item):
|
54 |
+
return tio.ScalarImage(path_item) # Consider to use this or tio.ScalarLabel over SimpleITK (sitk.ReadImage(str(path_item)))
|
55 |
+
|
56 |
+
@classmethod
|
57 |
+
def run_item_crawler(cls, path_root, extension, **kwargs):
|
58 |
+
return [path.relative_to(path_root) for path in Path(path_root).rglob(f'*.{extension}')]
|
medical_diffusion/external/diffusers/attention.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
|
9 |
+
class AttentionBlock(nn.Module):
|
10 |
+
"""
|
11 |
+
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
|
12 |
+
to the N-d case.
|
13 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
14 |
+
Uses three q, k, v linear layers to compute attention.
|
15 |
+
|
16 |
+
Parameters:
|
17 |
+
channels (:obj:`int`): The number of channels in the input and output.
|
18 |
+
num_head_channels (:obj:`int`, *optional*):
|
19 |
+
The number of channels in each head. If None, then `num_heads` = 1.
|
20 |
+
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
|
21 |
+
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
|
22 |
+
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
channels: int,
|
28 |
+
num_head_channels: Optional[int] = None,
|
29 |
+
num_groups: int = 32,
|
30 |
+
rescale_output_factor: float = 1.0,
|
31 |
+
eps: float = 1e-5,
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
self.channels = channels
|
35 |
+
|
36 |
+
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
|
37 |
+
self.num_head_size = num_head_channels
|
38 |
+
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
|
39 |
+
|
40 |
+
# define q,k,v as linear layers
|
41 |
+
self.query = nn.Linear(channels, channels)
|
42 |
+
self.key = nn.Linear(channels, channels)
|
43 |
+
self.value = nn.Linear(channels, channels)
|
44 |
+
|
45 |
+
self.rescale_output_factor = rescale_output_factor
|
46 |
+
self.proj_attn = nn.Linear(channels, channels, 1)
|
47 |
+
|
48 |
+
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
49 |
+
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
50 |
+
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
51 |
+
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
52 |
+
return new_projection
|
53 |
+
|
54 |
+
def forward(self, hidden_states):
|
55 |
+
residual = hidden_states
|
56 |
+
batch, channel, height, width = hidden_states.shape
|
57 |
+
|
58 |
+
# norm
|
59 |
+
hidden_states = self.group_norm(hidden_states)
|
60 |
+
|
61 |
+
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
62 |
+
|
63 |
+
# proj to q, k, v
|
64 |
+
query_proj = self.query(hidden_states)
|
65 |
+
key_proj = self.key(hidden_states)
|
66 |
+
value_proj = self.value(hidden_states)
|
67 |
+
|
68 |
+
# transpose
|
69 |
+
query_states = self.transpose_for_scores(query_proj)
|
70 |
+
key_states = self.transpose_for_scores(key_proj)
|
71 |
+
value_states = self.transpose_for_scores(value_proj)
|
72 |
+
|
73 |
+
# get scores
|
74 |
+
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
|
75 |
+
|
76 |
+
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
|
77 |
+
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
|
78 |
+
|
79 |
+
# compute attention output
|
80 |
+
hidden_states = torch.matmul(attention_probs, value_states)
|
81 |
+
|
82 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
83 |
+
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
84 |
+
hidden_states = hidden_states.view(new_hidden_states_shape)
|
85 |
+
|
86 |
+
# compute next hidden_states
|
87 |
+
hidden_states = self.proj_attn(hidden_states)
|
88 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
89 |
+
|
90 |
+
# res connect and rescale
|
91 |
+
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
92 |
+
return hidden_states
|
93 |
+
|
94 |
+
|
95 |
+
class SpatialTransformer(nn.Module):
|
96 |
+
"""
|
97 |
+
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
|
98 |
+
standard transformer action. Finally, reshape to image.
|
99 |
+
|
100 |
+
Parameters:
|
101 |
+
in_channels (:obj:`int`): The number of channels in the input and output.
|
102 |
+
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
|
103 |
+
d_head (:obj:`int`): The number of channels in each head.
|
104 |
+
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
105 |
+
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
106 |
+
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
in_channels: int,
|
112 |
+
n_heads: int,
|
113 |
+
d_head: int,
|
114 |
+
depth: int = 1,
|
115 |
+
dropout: float = 0.0,
|
116 |
+
num_groups: int = 32,
|
117 |
+
context_dim: Optional[int] = None,
|
118 |
+
):
|
119 |
+
super().__init__()
|
120 |
+
self.n_heads = n_heads
|
121 |
+
self.d_head = d_head
|
122 |
+
self.in_channels = in_channels
|
123 |
+
inner_dim = n_heads * d_head
|
124 |
+
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
125 |
+
|
126 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
127 |
+
|
128 |
+
self.transformer_blocks = nn.ModuleList(
|
129 |
+
[
|
130 |
+
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
131 |
+
for d in range(depth)
|
132 |
+
]
|
133 |
+
)
|
134 |
+
|
135 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
136 |
+
|
137 |
+
def _set_attention_slice(self, slice_size):
|
138 |
+
for block in self.transformer_blocks:
|
139 |
+
block._set_attention_slice(slice_size)
|
140 |
+
|
141 |
+
def forward(self, hidden_states, context=None):
|
142 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
143 |
+
batch, channel, height, weight = hidden_states.shape
|
144 |
+
residual = hidden_states
|
145 |
+
hidden_states = self.norm(hidden_states)
|
146 |
+
hidden_states = self.proj_in(hidden_states)
|
147 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
|
148 |
+
for block in self.transformer_blocks:
|
149 |
+
hidden_states = block(hidden_states, context=context)
|
150 |
+
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
|
151 |
+
hidden_states = self.proj_out(hidden_states)
|
152 |
+
return hidden_states + residual
|
153 |
+
|
154 |
+
|
155 |
+
class BasicTransformerBlock(nn.Module):
|
156 |
+
r"""
|
157 |
+
A basic Transformer block.
|
158 |
+
|
159 |
+
Parameters:
|
160 |
+
dim (:obj:`int`): The number of channels in the input and output.
|
161 |
+
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
|
162 |
+
d_head (:obj:`int`): The number of channels in each head.
|
163 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
164 |
+
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
|
165 |
+
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
|
166 |
+
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
|
167 |
+
"""
|
168 |
+
|
169 |
+
def __init__(
|
170 |
+
self,
|
171 |
+
dim: int,
|
172 |
+
n_heads: int,
|
173 |
+
d_head: int,
|
174 |
+
dropout=0.0,
|
175 |
+
context_dim: Optional[int] = None,
|
176 |
+
gated_ff: bool = True,
|
177 |
+
checkpoint: bool = True,
|
178 |
+
):
|
179 |
+
super().__init__()
|
180 |
+
self.attn1 = CrossAttention(
|
181 |
+
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
182 |
+
) # is a self-attention
|
183 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
184 |
+
self.attn2 = CrossAttention(
|
185 |
+
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
186 |
+
) # is self-attn if context is none
|
187 |
+
self.norm1 = nn.LayerNorm(dim)
|
188 |
+
self.norm2 = nn.LayerNorm(dim)
|
189 |
+
self.norm3 = nn.LayerNorm(dim)
|
190 |
+
self.checkpoint = checkpoint
|
191 |
+
|
192 |
+
def _set_attention_slice(self, slice_size):
|
193 |
+
self.attn1._slice_size = slice_size
|
194 |
+
self.attn2._slice_size = slice_size
|
195 |
+
|
196 |
+
def forward(self, hidden_states, context=None):
|
197 |
+
hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
|
198 |
+
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
|
199 |
+
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
|
200 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
201 |
+
return hidden_states
|
202 |
+
|
203 |
+
|
204 |
+
class CrossAttention(nn.Module):
|
205 |
+
r"""
|
206 |
+
A cross attention layer.
|
207 |
+
|
208 |
+
Parameters:
|
209 |
+
query_dim (:obj:`int`): The number of channels in the query.
|
210 |
+
context_dim (:obj:`int`, *optional*):
|
211 |
+
The number of channels in the context. If not given, defaults to `query_dim`.
|
212 |
+
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
213 |
+
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
|
214 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
215 |
+
"""
|
216 |
+
|
217 |
+
def __init__(
|
218 |
+
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
|
219 |
+
):
|
220 |
+
super().__init__()
|
221 |
+
inner_dim = dim_head * heads
|
222 |
+
context_dim = context_dim if context_dim is not None else query_dim
|
223 |
+
|
224 |
+
self.scale = dim_head**-0.5
|
225 |
+
self.heads = heads
|
226 |
+
# for slice_size > 0 the attention score computation
|
227 |
+
# is split across the batch axis to save memory
|
228 |
+
# You can set slice_size with `set_attention_slice`
|
229 |
+
self._slice_size = None
|
230 |
+
|
231 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
232 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
233 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
234 |
+
|
235 |
+
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
236 |
+
|
237 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
238 |
+
batch_size, seq_len, dim = tensor.shape
|
239 |
+
head_size = self.heads
|
240 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
241 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
242 |
+
return tensor
|
243 |
+
|
244 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
245 |
+
batch_size, seq_len, dim = tensor.shape
|
246 |
+
head_size = self.heads
|
247 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
248 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
249 |
+
return tensor
|
250 |
+
|
251 |
+
def forward(self, hidden_states, context=None, mask=None):
|
252 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
253 |
+
|
254 |
+
query = self.to_q(hidden_states)
|
255 |
+
context = context if context is not None else hidden_states
|
256 |
+
key = self.to_k(context)
|
257 |
+
value = self.to_v(context)
|
258 |
+
|
259 |
+
dim = query.shape[-1]
|
260 |
+
|
261 |
+
query = self.reshape_heads_to_batch_dim(query)
|
262 |
+
key = self.reshape_heads_to_batch_dim(key)
|
263 |
+
value = self.reshape_heads_to_batch_dim(value)
|
264 |
+
|
265 |
+
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
266 |
+
|
267 |
+
# attention, what we cannot get enough of
|
268 |
+
|
269 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
270 |
+
hidden_states = self._attention(query, key, value)
|
271 |
+
else:
|
272 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
|
273 |
+
|
274 |
+
return self.to_out(hidden_states)
|
275 |
+
|
276 |
+
def _attention(self, query, key, value):
|
277 |
+
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
|
278 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
279 |
+
# compute attention output
|
280 |
+
hidden_states = torch.matmul(attention_probs, value)
|
281 |
+
# reshape hidden_states
|
282 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
283 |
+
return hidden_states
|
284 |
+
|
285 |
+
def _sliced_attention(self, query, key, value, sequence_length, dim):
|
286 |
+
batch_size_attention = query.shape[0]
|
287 |
+
hidden_states = torch.zeros(
|
288 |
+
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
289 |
+
)
|
290 |
+
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
291 |
+
for i in range(hidden_states.shape[0] // slice_size):
|
292 |
+
start_idx = i * slice_size
|
293 |
+
end_idx = (i + 1) * slice_size
|
294 |
+
attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
|
295 |
+
attn_slice = attn_slice.softmax(dim=-1)
|
296 |
+
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
|
297 |
+
|
298 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
299 |
+
|
300 |
+
# reshape hidden_states
|
301 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
302 |
+
return hidden_states
|
303 |
+
|
304 |
+
|
305 |
+
class FeedForward(nn.Module):
|
306 |
+
r"""
|
307 |
+
A feed-forward layer.
|
308 |
+
|
309 |
+
Parameters:
|
310 |
+
dim (:obj:`int`): The number of channels in the input.
|
311 |
+
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
312 |
+
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
313 |
+
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
|
314 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
315 |
+
"""
|
316 |
+
|
317 |
+
def __init__(
|
318 |
+
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
|
319 |
+
):
|
320 |
+
super().__init__()
|
321 |
+
inner_dim = int(dim * mult)
|
322 |
+
dim_out = dim_out if dim_out is not None else dim
|
323 |
+
project_in = GEGLU(dim, inner_dim)
|
324 |
+
|
325 |
+
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
326 |
+
|
327 |
+
def forward(self, hidden_states):
|
328 |
+
return self.net(hidden_states)
|
329 |
+
|
330 |
+
|
331 |
+
# feedforward
|
332 |
+
class GEGLU(nn.Module):
|
333 |
+
r"""
|
334 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
335 |
+
|
336 |
+
Parameters:
|
337 |
+
dim_in (:obj:`int`): The number of channels in the input.
|
338 |
+
dim_out (:obj:`int`): The number of channels in the output.
|
339 |
+
"""
|
340 |
+
|
341 |
+
def __init__(self, dim_in: int, dim_out: int):
|
342 |
+
super().__init__()
|
343 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
344 |
+
|
345 |
+
def forward(self, hidden_states):
|
346 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
347 |
+
return hidden_states * F.gelu(gate)
|
medical_diffusion/external/diffusers/embeddings.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from pydoc import describe
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
|
9 |
+
def get_timestep_embedding(
|
10 |
+
timesteps: torch.Tensor,
|
11 |
+
embedding_dim: int,
|
12 |
+
flip_sin_to_cos: bool = False,
|
13 |
+
downscale_freq_shift: float = 1,
|
14 |
+
scale: float = 1,
|
15 |
+
max_period: int = 10000,
|
16 |
+
):
|
17 |
+
"""
|
18 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
19 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
20 |
+
These may be fractional.
|
21 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
22 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
23 |
+
"""
|
24 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
25 |
+
|
26 |
+
half_dim = embedding_dim // 2
|
27 |
+
exponent = -math.log(max_period) * torch.arange(
|
28 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
29 |
+
)
|
30 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
31 |
+
|
32 |
+
emb = torch.exp(exponent)
|
33 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
34 |
+
|
35 |
+
# scale embeddings
|
36 |
+
emb = scale * emb
|
37 |
+
|
38 |
+
# concat sine and cosine embeddings
|
39 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
40 |
+
|
41 |
+
# flip sine and cosine embeddings
|
42 |
+
if flip_sin_to_cos:
|
43 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
44 |
+
|
45 |
+
# zero pad
|
46 |
+
if embedding_dim % 2 == 1:
|
47 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
48 |
+
return emb
|
49 |
+
|
50 |
+
class Timesteps(nn.Module):
|
51 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
52 |
+
super().__init__()
|
53 |
+
self.num_channels = num_channels
|
54 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
55 |
+
self.downscale_freq_shift = downscale_freq_shift
|
56 |
+
|
57 |
+
def forward(self, timesteps):
|
58 |
+
t_emb = get_timestep_embedding(
|
59 |
+
timesteps,
|
60 |
+
self.num_channels,
|
61 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
62 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
63 |
+
)
|
64 |
+
return t_emb
|
65 |
+
|
66 |
+
class TimeEmbbeding(nn.Module):
|
67 |
+
def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
|
68 |
+
super().__init__()
|
69 |
+
|
70 |
+
self.temb = Timesteps(channel, flip_sin_to_cos=True, downscale_freq_shift=0)
|
71 |
+
|
72 |
+
self.linear_1 = nn.Linear(channel, time_embed_dim)
|
73 |
+
self.act = None
|
74 |
+
if act_fn == "silu":
|
75 |
+
self.act = nn.SiLU()
|
76 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
77 |
+
|
78 |
+
def forward(self, sample):
|
79 |
+
sample = self.temb(sample)
|
80 |
+
sample = self.linear_1(sample)
|
81 |
+
|
82 |
+
if self.act is not None:
|
83 |
+
sample = self.act(sample)
|
84 |
+
|
85 |
+
sample = self.linear_2(sample)
|
86 |
+
return sample
|
87 |
+
|
88 |
+
|
89 |
+
|
medical_diffusion/external/diffusers/resnet.py
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class Upsample2D(nn.Module):
|
9 |
+
"""
|
10 |
+
An upsampling layer with an optional convolution.
|
11 |
+
|
12 |
+
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
|
13 |
+
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
14 |
+
upsampling occurs in the inner-two dimensions.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
18 |
+
super().__init__()
|
19 |
+
self.channels = channels
|
20 |
+
self.out_channels = out_channels or channels
|
21 |
+
self.use_conv = use_conv
|
22 |
+
self.use_conv_transpose = use_conv_transpose
|
23 |
+
self.name = name
|
24 |
+
|
25 |
+
conv = None
|
26 |
+
if use_conv_transpose:
|
27 |
+
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
28 |
+
elif use_conv:
|
29 |
+
conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
30 |
+
|
31 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
32 |
+
if name == "conv":
|
33 |
+
self.conv = conv
|
34 |
+
else:
|
35 |
+
self.Conv2d_0 = conv
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
assert x.shape[1] == self.channels
|
39 |
+
if self.use_conv_transpose:
|
40 |
+
return self.conv(x)
|
41 |
+
|
42 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
43 |
+
|
44 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
45 |
+
if self.use_conv:
|
46 |
+
if self.name == "conv":
|
47 |
+
x = self.conv(x)
|
48 |
+
else:
|
49 |
+
x = self.Conv2d_0(x)
|
50 |
+
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
class Downsample2D(nn.Module):
|
55 |
+
"""
|
56 |
+
A downsampling layer with an optional convolution.
|
57 |
+
|
58 |
+
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
|
59 |
+
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
60 |
+
downsampling occurs in the inner-two dimensions.
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
64 |
+
super().__init__()
|
65 |
+
self.channels = channels
|
66 |
+
self.out_channels = out_channels or channels
|
67 |
+
self.use_conv = use_conv
|
68 |
+
self.padding = padding
|
69 |
+
stride = 2
|
70 |
+
self.name = name
|
71 |
+
|
72 |
+
if use_conv:
|
73 |
+
conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
74 |
+
else:
|
75 |
+
assert self.channels == self.out_channels
|
76 |
+
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
77 |
+
|
78 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
79 |
+
if name == "conv":
|
80 |
+
self.Conv2d_0 = conv
|
81 |
+
self.conv = conv
|
82 |
+
elif name == "Conv2d_0":
|
83 |
+
self.conv = conv
|
84 |
+
else:
|
85 |
+
self.conv = conv
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
assert x.shape[1] == self.channels
|
89 |
+
if self.use_conv and self.padding == 0:
|
90 |
+
pad = (0, 1, 0, 1)
|
91 |
+
x = F.pad(x, pad, mode="constant", value=0)
|
92 |
+
|
93 |
+
assert x.shape[1] == self.channels
|
94 |
+
x = self.conv(x)
|
95 |
+
|
96 |
+
return x
|
97 |
+
|
98 |
+
|
99 |
+
class FirUpsample2D(nn.Module):
|
100 |
+
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
101 |
+
super().__init__()
|
102 |
+
out_channels = out_channels if out_channels else channels
|
103 |
+
if use_conv:
|
104 |
+
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
105 |
+
self.use_conv = use_conv
|
106 |
+
self.fir_kernel = fir_kernel
|
107 |
+
self.out_channels = out_channels
|
108 |
+
|
109 |
+
def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
|
110 |
+
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
114 |
+
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
|
115 |
+
order.
|
116 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
117 |
+
C]`.
|
118 |
+
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
|
119 |
+
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
120 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
121 |
+
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
122 |
+
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
|
126 |
+
`x`.
|
127 |
+
"""
|
128 |
+
|
129 |
+
assert isinstance(factor, int) and factor >= 1
|
130 |
+
|
131 |
+
# Setup filter kernel.
|
132 |
+
if kernel is None:
|
133 |
+
kernel = [1] * factor
|
134 |
+
|
135 |
+
# setup kernel
|
136 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
137 |
+
if kernel.ndim == 1:
|
138 |
+
kernel = torch.outer(kernel, kernel)
|
139 |
+
kernel /= torch.sum(kernel)
|
140 |
+
|
141 |
+
kernel = kernel * (gain * (factor**2))
|
142 |
+
|
143 |
+
if self.use_conv:
|
144 |
+
convH = weight.shape[2]
|
145 |
+
convW = weight.shape[3]
|
146 |
+
inC = weight.shape[1]
|
147 |
+
|
148 |
+
p = (kernel.shape[0] - factor) - (convW - 1)
|
149 |
+
|
150 |
+
stride = (factor, factor)
|
151 |
+
# Determine data dimensions.
|
152 |
+
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
|
153 |
+
output_padding = (
|
154 |
+
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
|
155 |
+
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
|
156 |
+
)
|
157 |
+
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
158 |
+
inC = weight.shape[1]
|
159 |
+
num_groups = x.shape[1] // inC
|
160 |
+
|
161 |
+
# Transpose weights.
|
162 |
+
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
|
163 |
+
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
|
164 |
+
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
165 |
+
|
166 |
+
x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
|
167 |
+
|
168 |
+
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
|
169 |
+
else:
|
170 |
+
p = kernel.shape[0] - factor
|
171 |
+
x = upfirdn2d_native(
|
172 |
+
x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
|
173 |
+
)
|
174 |
+
|
175 |
+
return x
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
if self.use_conv:
|
179 |
+
height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
180 |
+
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
181 |
+
else:
|
182 |
+
height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
|
183 |
+
|
184 |
+
return height
|
185 |
+
|
186 |
+
|
187 |
+
class FirDownsample2D(nn.Module):
|
188 |
+
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
189 |
+
super().__init__()
|
190 |
+
out_channels = out_channels if out_channels else channels
|
191 |
+
if use_conv:
|
192 |
+
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
193 |
+
self.fir_kernel = fir_kernel
|
194 |
+
self.use_conv = use_conv
|
195 |
+
self.out_channels = out_channels
|
196 |
+
|
197 |
+
def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
|
198 |
+
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
202 |
+
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
|
203 |
+
order.
|
204 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
|
205 |
+
filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
|
206 |
+
numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
|
207 |
+
factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
|
208 |
+
Scaling factor for signal magnitude (default: 1.0).
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
|
212 |
+
datatype as `x`.
|
213 |
+
"""
|
214 |
+
|
215 |
+
assert isinstance(factor, int) and factor >= 1
|
216 |
+
if kernel is None:
|
217 |
+
kernel = [1] * factor
|
218 |
+
|
219 |
+
# setup kernel
|
220 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
221 |
+
if kernel.ndim == 1:
|
222 |
+
kernel = torch.outer(kernel, kernel)
|
223 |
+
kernel /= torch.sum(kernel)
|
224 |
+
|
225 |
+
kernel = kernel * gain
|
226 |
+
|
227 |
+
if self.use_conv:
|
228 |
+
_, _, convH, convW = weight.shape
|
229 |
+
p = (kernel.shape[0] - factor) + (convW - 1)
|
230 |
+
s = [factor, factor]
|
231 |
+
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
|
232 |
+
x = F.conv2d(x, weight, stride=s, padding=0)
|
233 |
+
else:
|
234 |
+
p = kernel.shape[0] - factor
|
235 |
+
x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
|
236 |
+
|
237 |
+
return x
|
238 |
+
|
239 |
+
def forward(self, x):
|
240 |
+
if self.use_conv:
|
241 |
+
x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
242 |
+
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
243 |
+
else:
|
244 |
+
x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
|
245 |
+
|
246 |
+
return x
|
247 |
+
|
248 |
+
|
249 |
+
class ResnetBlock2D(nn.Module):
|
250 |
+
def __init__(
|
251 |
+
self,
|
252 |
+
*,
|
253 |
+
in_channels,
|
254 |
+
out_channels=None,
|
255 |
+
conv_shortcut=False,
|
256 |
+
dropout=0.0,
|
257 |
+
temb_channels=512,
|
258 |
+
groups=32,
|
259 |
+
groups_out=None,
|
260 |
+
pre_norm=True,
|
261 |
+
eps=1e-6,
|
262 |
+
non_linearity="swish",
|
263 |
+
time_embedding_norm="default",
|
264 |
+
kernel=None,
|
265 |
+
output_scale_factor=1.0,
|
266 |
+
use_in_shortcut=None,
|
267 |
+
up=False,
|
268 |
+
down=False,
|
269 |
+
):
|
270 |
+
super().__init__()
|
271 |
+
self.pre_norm = pre_norm
|
272 |
+
self.pre_norm = True
|
273 |
+
self.in_channels = in_channels
|
274 |
+
out_channels = in_channels if out_channels is None else out_channels
|
275 |
+
self.out_channels = out_channels
|
276 |
+
self.use_conv_shortcut = conv_shortcut
|
277 |
+
self.time_embedding_norm = time_embedding_norm
|
278 |
+
self.up = up
|
279 |
+
self.down = down
|
280 |
+
self.output_scale_factor = output_scale_factor
|
281 |
+
|
282 |
+
if groups_out is None:
|
283 |
+
groups_out = groups
|
284 |
+
|
285 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
286 |
+
|
287 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
288 |
+
|
289 |
+
if temb_channels is not None:
|
290 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
291 |
+
else:
|
292 |
+
self.time_emb_proj = None
|
293 |
+
|
294 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
295 |
+
self.dropout = torch.nn.Dropout(dropout)
|
296 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
297 |
+
|
298 |
+
if non_linearity == "swish":
|
299 |
+
self.nonlinearity = lambda x: F.silu(x)
|
300 |
+
elif non_linearity == "mish":
|
301 |
+
self.nonlinearity = Mish()
|
302 |
+
elif non_linearity == "silu":
|
303 |
+
self.nonlinearity = nn.SiLU()
|
304 |
+
|
305 |
+
self.upsample = self.downsample = None
|
306 |
+
if self.up:
|
307 |
+
if kernel == "fir":
|
308 |
+
fir_kernel = (1, 3, 3, 1)
|
309 |
+
self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
|
310 |
+
elif kernel == "sde_vp":
|
311 |
+
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
312 |
+
else:
|
313 |
+
self.upsample = Upsample2D(in_channels, use_conv=False)
|
314 |
+
elif self.down:
|
315 |
+
if kernel == "fir":
|
316 |
+
fir_kernel = (1, 3, 3, 1)
|
317 |
+
self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
|
318 |
+
elif kernel == "sde_vp":
|
319 |
+
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
320 |
+
else:
|
321 |
+
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
|
322 |
+
|
323 |
+
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
324 |
+
|
325 |
+
self.conv_shortcut = None
|
326 |
+
if self.use_in_shortcut:
|
327 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
328 |
+
|
329 |
+
def forward(self, x, temb):
|
330 |
+
hidden_states = x
|
331 |
+
|
332 |
+
# make sure hidden states is in float32
|
333 |
+
# when running in half-precision
|
334 |
+
hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
|
335 |
+
hidden_states = self.nonlinearity(hidden_states)
|
336 |
+
|
337 |
+
if self.upsample is not None:
|
338 |
+
x = self.upsample(x)
|
339 |
+
hidden_states = self.upsample(hidden_states)
|
340 |
+
elif self.downsample is not None:
|
341 |
+
x = self.downsample(x)
|
342 |
+
hidden_states = self.downsample(hidden_states)
|
343 |
+
|
344 |
+
hidden_states = self.conv1(hidden_states)
|
345 |
+
|
346 |
+
if temb is not None:
|
347 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
348 |
+
hidden_states = hidden_states + temb
|
349 |
+
|
350 |
+
# make sure hidden states is in float32
|
351 |
+
# when running in half-precision
|
352 |
+
hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
|
353 |
+
hidden_states = self.nonlinearity(hidden_states)
|
354 |
+
|
355 |
+
hidden_states = self.dropout(hidden_states)
|
356 |
+
hidden_states = self.conv2(hidden_states)
|
357 |
+
|
358 |
+
if self.conv_shortcut is not None:
|
359 |
+
x = self.conv_shortcut(x)
|
360 |
+
|
361 |
+
out = (x + hidden_states) / self.output_scale_factor
|
362 |
+
|
363 |
+
return out
|
364 |
+
|
365 |
+
|
366 |
+
class Mish(torch.nn.Module):
|
367 |
+
def forward(self, x):
|
368 |
+
return x * torch.tanh(torch.nn.functional.softplus(x))
|
369 |
+
|
370 |
+
|
371 |
+
def upsample_2d(x, kernel=None, factor=2, gain=1):
|
372 |
+
r"""Upsample2D a batch of 2D images with the given filter.
|
373 |
+
|
374 |
+
Args:
|
375 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
376 |
+
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
377 |
+
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
|
378 |
+
multiple of the upsampling factor.
|
379 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
380 |
+
C]`.
|
381 |
+
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
382 |
+
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
383 |
+
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
384 |
+
|
385 |
+
Returns:
|
386 |
+
Tensor of the shape `[N, C, H * factor, W * factor]`
|
387 |
+
"""
|
388 |
+
assert isinstance(factor, int) and factor >= 1
|
389 |
+
if kernel is None:
|
390 |
+
kernel = [1] * factor
|
391 |
+
|
392 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
393 |
+
if kernel.ndim == 1:
|
394 |
+
kernel = torch.outer(kernel, kernel)
|
395 |
+
kernel /= torch.sum(kernel)
|
396 |
+
|
397 |
+
kernel = kernel * (gain * (factor**2))
|
398 |
+
p = kernel.shape[0] - factor
|
399 |
+
return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
|
400 |
+
|
401 |
+
|
402 |
+
def downsample_2d(x, kernel=None, factor=2, gain=1):
|
403 |
+
r"""Downsample2D a batch of 2D images with the given filter.
|
404 |
+
|
405 |
+
Args:
|
406 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
407 |
+
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
408 |
+
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
409 |
+
shape is a multiple of the downsampling factor.
|
410 |
+
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
411 |
+
C]`.
|
412 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
413 |
+
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
414 |
+
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
|
415 |
+
|
416 |
+
Returns:
|
417 |
+
Tensor of the shape `[N, C, H // factor, W // factor]`
|
418 |
+
"""
|
419 |
+
|
420 |
+
assert isinstance(factor, int) and factor >= 1
|
421 |
+
if kernel is None:
|
422 |
+
kernel = [1] * factor
|
423 |
+
|
424 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
425 |
+
if kernel.ndim == 1:
|
426 |
+
kernel = torch.outer(kernel, kernel)
|
427 |
+
kernel /= torch.sum(kernel)
|
428 |
+
|
429 |
+
kernel = kernel * gain
|
430 |
+
p = kernel.shape[0] - factor
|
431 |
+
return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
|
432 |
+
|
433 |
+
|
434 |
+
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
|
435 |
+
up_x = up_y = up
|
436 |
+
down_x = down_y = down
|
437 |
+
pad_x0 = pad_y0 = pad[0]
|
438 |
+
pad_x1 = pad_y1 = pad[1]
|
439 |
+
|
440 |
+
_, channel, in_h, in_w = input.shape
|
441 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
442 |
+
|
443 |
+
_, in_h, in_w, minor = input.shape
|
444 |
+
kernel_h, kernel_w = kernel.shape
|
445 |
+
|
446 |
+
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
447 |
+
|
448 |
+
# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
|
449 |
+
if input.device.type == "mps":
|
450 |
+
out = out.to("cpu")
|
451 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
452 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
453 |
+
|
454 |
+
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
455 |
+
out = out.to(input.device) # Move back to mps if necessary
|
456 |
+
out = out[
|
457 |
+
:,
|
458 |
+
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
459 |
+
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
460 |
+
:,
|
461 |
+
]
|
462 |
+
|
463 |
+
out = out.permute(0, 3, 1, 2)
|
464 |
+
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
465 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
466 |
+
out = F.conv2d(out, w)
|
467 |
+
out = out.reshape(
|
468 |
+
-1,
|
469 |
+
minor,
|
470 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
471 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
472 |
+
)
|
473 |
+
out = out.permute(0, 2, 3, 1)
|
474 |
+
out = out[:, ::down_y, ::down_x, :]
|
475 |
+
|
476 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
477 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
478 |
+
|
479 |
+
return out.view(-1, channel, out_h, out_w)
|
medical_diffusion/external/diffusers/taming_discriminator.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
class NLayerDiscriminator(nn.Module):
|
8 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
9 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
10 |
+
"""
|
11 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
12 |
+
"""Construct a PatchGAN discriminator
|
13 |
+
Parameters:
|
14 |
+
input_nc (int) -- the number of channels in input images
|
15 |
+
ndf (int) -- the number of filters in the last conv layer
|
16 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
17 |
+
norm_layer -- normalization layer
|
18 |
+
"""
|
19 |
+
super(NLayerDiscriminator, self).__init__()
|
20 |
+
if not use_actnorm:
|
21 |
+
norm_layer = nn.BatchNorm2d
|
22 |
+
else:
|
23 |
+
raise NotImplementedError
|
24 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
25 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
26 |
+
else:
|
27 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
28 |
+
|
29 |
+
kw = 4
|
30 |
+
padw = 1
|
31 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
32 |
+
nf_mult = 1
|
33 |
+
nf_mult_prev = 1
|
34 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
35 |
+
nf_mult_prev = nf_mult
|
36 |
+
nf_mult = min(2 ** n, 8)
|
37 |
+
sequence += [
|
38 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
39 |
+
norm_layer(ndf * nf_mult),
|
40 |
+
nn.LeakyReLU(0.2, True)
|
41 |
+
]
|
42 |
+
|
43 |
+
nf_mult_prev = nf_mult
|
44 |
+
nf_mult = min(2 ** n_layers, 8)
|
45 |
+
sequence += [
|
46 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
47 |
+
norm_layer(ndf * nf_mult),
|
48 |
+
nn.LeakyReLU(0.2, True)
|
49 |
+
]
|
50 |
+
|
51 |
+
sequence += [
|
52 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
53 |
+
self.main = nn.Sequential(*sequence)
|
54 |
+
|
55 |
+
def forward(self, input):
|
56 |
+
"""Standard forward."""
|
57 |
+
return self.main(input)
|
medical_diffusion/external/diffusers/unet.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from typing import Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.utils.checkpoint
|
8 |
+
|
9 |
+
|
10 |
+
from .embeddings import TimeEmbbeding
|
11 |
+
|
12 |
+
from .unet_blocks import (
|
13 |
+
CrossAttnDownBlock2D,
|
14 |
+
CrossAttnUpBlock2D,
|
15 |
+
DownBlock2D,
|
16 |
+
UNetMidBlock2DCrossAttn,
|
17 |
+
UpBlock2D,
|
18 |
+
get_down_block,
|
19 |
+
get_up_block,
|
20 |
+
)
|
21 |
+
|
22 |
+
class TimestepEmbedding(nn.Module):
|
23 |
+
def __init__(self, channel, time_embed_dim, act_fn="silu"):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.linear_1 = nn.Linear(channel, time_embed_dim)
|
27 |
+
self.act = None
|
28 |
+
if act_fn == "silu":
|
29 |
+
self.act = nn.SiLU()
|
30 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
31 |
+
|
32 |
+
def forward(self, sample):
|
33 |
+
sample = self.linear_1(sample)
|
34 |
+
|
35 |
+
if self.act is not None:
|
36 |
+
sample = self.act(sample)
|
37 |
+
|
38 |
+
sample = self.linear_2(sample)
|
39 |
+
return sample
|
40 |
+
|
41 |
+
|
42 |
+
class UNet2DConditionModel(nn.Module):
|
43 |
+
r"""
|
44 |
+
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
|
45 |
+
and returns sample shaped output.
|
46 |
+
|
47 |
+
|
48 |
+
Parameters:
|
49 |
+
sample_size (`int`, *optional*): The size of the input sample.
|
50 |
+
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
51 |
+
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
52 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
53 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
54 |
+
Whether to flip the sin to cos in the time embedding.
|
55 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
56 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
57 |
+
The tuple of downsample blocks to use.
|
58 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
|
59 |
+
The tuple of upsample blocks to use.
|
60 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
61 |
+
The tuple of output channels for each block.
|
62 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
63 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
64 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
65 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
66 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
67 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
68 |
+
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
69 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
70 |
+
"""
|
71 |
+
|
72 |
+
_supports_gradient_checkpointing = True
|
73 |
+
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
sample_size: Optional[int] = None,
|
78 |
+
in_channels: int = 4,
|
79 |
+
out_channels: int = 4,
|
80 |
+
center_input_sample: bool = False,
|
81 |
+
flip_sin_to_cos: bool = True,
|
82 |
+
freq_shift: int = 0,
|
83 |
+
down_block_types: Tuple[str] = (
|
84 |
+
"CrossAttnDownBlock2D",
|
85 |
+
"CrossAttnDownBlock2D",
|
86 |
+
"CrossAttnDownBlock2D",
|
87 |
+
"DownBlock2D",
|
88 |
+
),
|
89 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
90 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
91 |
+
layers_per_block: int = 2,
|
92 |
+
downsample_padding: int = 1,
|
93 |
+
mid_block_scale_factor: float = 1,
|
94 |
+
act_fn: str = "silu",
|
95 |
+
norm_num_groups: int = 32,
|
96 |
+
norm_eps: float = 1e-5,
|
97 |
+
cross_attention_dim: int = 768,
|
98 |
+
attention_head_dim: int = 8,
|
99 |
+
):
|
100 |
+
super().__init__()
|
101 |
+
|
102 |
+
self.sample_size = sample_size
|
103 |
+
time_embed_dim = block_out_channels[0] * 4
|
104 |
+
|
105 |
+
self.emb = nn.Embedding(2, cross_attention_dim)
|
106 |
+
|
107 |
+
# input
|
108 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
109 |
+
|
110 |
+
# time
|
111 |
+
self.time_embedding = TimeEmbbeding(block_out_channels[0], time_embed_dim)
|
112 |
+
|
113 |
+
self.down_blocks = nn.ModuleList([])
|
114 |
+
self.mid_block = None
|
115 |
+
self.up_blocks = nn.ModuleList([])
|
116 |
+
|
117 |
+
# down
|
118 |
+
output_channel = block_out_channels[0]
|
119 |
+
for i, down_block_type in enumerate(down_block_types):
|
120 |
+
input_channel = output_channel
|
121 |
+
output_channel = block_out_channels[i]
|
122 |
+
is_final_block = i == len(block_out_channels) - 1
|
123 |
+
|
124 |
+
down_block = get_down_block(
|
125 |
+
down_block_type,
|
126 |
+
num_layers=layers_per_block,
|
127 |
+
in_channels=input_channel,
|
128 |
+
out_channels=output_channel,
|
129 |
+
temb_channels=time_embed_dim,
|
130 |
+
add_downsample=not is_final_block,
|
131 |
+
resnet_eps=norm_eps,
|
132 |
+
resnet_act_fn=act_fn,
|
133 |
+
resnet_groups=norm_num_groups,
|
134 |
+
cross_attention_dim=cross_attention_dim,
|
135 |
+
attn_num_head_channels=attention_head_dim,
|
136 |
+
downsample_padding=downsample_padding,
|
137 |
+
)
|
138 |
+
self.down_blocks.append(down_block)
|
139 |
+
|
140 |
+
# mid
|
141 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
142 |
+
in_channels=block_out_channels[-1],
|
143 |
+
temb_channels=time_embed_dim,
|
144 |
+
resnet_eps=norm_eps,
|
145 |
+
resnet_act_fn=act_fn,
|
146 |
+
output_scale_factor=mid_block_scale_factor,
|
147 |
+
resnet_time_scale_shift="default",
|
148 |
+
cross_attention_dim=cross_attention_dim,
|
149 |
+
attn_num_head_channels=attention_head_dim,
|
150 |
+
resnet_groups=norm_num_groups,
|
151 |
+
)
|
152 |
+
|
153 |
+
# up
|
154 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
155 |
+
output_channel = reversed_block_out_channels[0]
|
156 |
+
for i, up_block_type in enumerate(up_block_types):
|
157 |
+
prev_output_channel = output_channel
|
158 |
+
output_channel = reversed_block_out_channels[i]
|
159 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
160 |
+
|
161 |
+
is_final_block = i == len(block_out_channels) - 1
|
162 |
+
|
163 |
+
up_block = get_up_block(
|
164 |
+
up_block_type,
|
165 |
+
num_layers=layers_per_block + 1,
|
166 |
+
in_channels=input_channel,
|
167 |
+
out_channels=output_channel,
|
168 |
+
prev_output_channel=prev_output_channel,
|
169 |
+
temb_channels=time_embed_dim,
|
170 |
+
add_upsample=not is_final_block,
|
171 |
+
resnet_eps=norm_eps,
|
172 |
+
resnet_act_fn=act_fn,
|
173 |
+
resnet_groups=norm_num_groups,
|
174 |
+
cross_attention_dim=cross_attention_dim,
|
175 |
+
attn_num_head_channels=attention_head_dim,
|
176 |
+
)
|
177 |
+
self.up_blocks.append(up_block)
|
178 |
+
prev_output_channel = output_channel
|
179 |
+
|
180 |
+
# out
|
181 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
182 |
+
self.conv_act = nn.SiLU()
|
183 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
184 |
+
|
185 |
+
|
186 |
+
|
187 |
+
def forward(
|
188 |
+
self,
|
189 |
+
sample: torch.FloatTensor,
|
190 |
+
t: torch.Tensor,
|
191 |
+
encoder_hidden_states: torch.Tensor = None,
|
192 |
+
self_cond: torch.Tensor = None
|
193 |
+
):
|
194 |
+
encoder_hidden_states = self.emb(encoder_hidden_states)
|
195 |
+
# encoder_hidden_states = None # ------------------------ WARNING Disabled ---------------------
|
196 |
+
"""r
|
197 |
+
Args:
|
198 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
199 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
200 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
204 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
205 |
+
returning a tuple, the first element is the sample tensor.
|
206 |
+
"""
|
207 |
+
# 0. center input if necessary
|
208 |
+
# if self.config.center_input_sample:
|
209 |
+
# sample = 2 * sample - 1.0
|
210 |
+
|
211 |
+
# 1. time
|
212 |
+
t_emb = self.time_embedding(t)
|
213 |
+
|
214 |
+
# 2. pre-process
|
215 |
+
sample = self.conv_in(sample)
|
216 |
+
|
217 |
+
# 3. down
|
218 |
+
down_block_res_samples = (sample,)
|
219 |
+
for downsample_block in self.down_blocks:
|
220 |
+
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
221 |
+
sample, res_samples = downsample_block(
|
222 |
+
hidden_states=sample,
|
223 |
+
temb=t_emb,
|
224 |
+
encoder_hidden_states=encoder_hidden_states,
|
225 |
+
)
|
226 |
+
else:
|
227 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=t_emb)
|
228 |
+
|
229 |
+
down_block_res_samples += res_samples
|
230 |
+
|
231 |
+
# 4. mid
|
232 |
+
sample = self.mid_block(sample, t_emb, encoder_hidden_states=encoder_hidden_states)
|
233 |
+
|
234 |
+
# 5. up
|
235 |
+
for upsample_block in self.up_blocks:
|
236 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
237 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
238 |
+
|
239 |
+
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
|
240 |
+
sample = upsample_block(
|
241 |
+
hidden_states=sample,
|
242 |
+
temb=t_emb,
|
243 |
+
res_hidden_states_tuple=res_samples,
|
244 |
+
encoder_hidden_states=encoder_hidden_states,
|
245 |
+
)
|
246 |
+
else:
|
247 |
+
sample = upsample_block(hidden_states=sample, temb=t_emb, res_hidden_states_tuple=res_samples)
|
248 |
+
|
249 |
+
# 6. post-process
|
250 |
+
# make sure hidden states is in float32
|
251 |
+
# when running in half-precision
|
252 |
+
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
|
253 |
+
sample = self.conv_act(sample)
|
254 |
+
sample = self.conv_out(sample)
|
255 |
+
|
256 |
+
|
257 |
+
return sample, []
|
medical_diffusion/external/diffusers/unet_blocks.py
ADDED
@@ -0,0 +1,1557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
# limitations under the License.
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from .attention import AttentionBlock, SpatialTransformer
|
21 |
+
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
22 |
+
|
23 |
+
|
24 |
+
def get_down_block(
|
25 |
+
down_block_type,
|
26 |
+
num_layers,
|
27 |
+
in_channels,
|
28 |
+
out_channels,
|
29 |
+
temb_channels,
|
30 |
+
add_downsample,
|
31 |
+
resnet_eps,
|
32 |
+
resnet_act_fn,
|
33 |
+
attn_num_head_channels,
|
34 |
+
resnet_groups=None,
|
35 |
+
cross_attention_dim=None,
|
36 |
+
downsample_padding=None,
|
37 |
+
):
|
38 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
39 |
+
if down_block_type == "DownBlock2D":
|
40 |
+
return DownBlock2D(
|
41 |
+
num_layers=num_layers,
|
42 |
+
in_channels=in_channels,
|
43 |
+
out_channels=out_channels,
|
44 |
+
temb_channels=temb_channels,
|
45 |
+
add_downsample=add_downsample,
|
46 |
+
resnet_eps=resnet_eps,
|
47 |
+
resnet_act_fn=resnet_act_fn,
|
48 |
+
resnet_groups=resnet_groups,
|
49 |
+
downsample_padding=downsample_padding,
|
50 |
+
)
|
51 |
+
elif down_block_type == "AttnDownBlock2D":
|
52 |
+
return AttnDownBlock2D(
|
53 |
+
num_layers=num_layers,
|
54 |
+
in_channels=in_channels,
|
55 |
+
out_channels=out_channels,
|
56 |
+
temb_channels=temb_channels,
|
57 |
+
add_downsample=add_downsample,
|
58 |
+
resnet_eps=resnet_eps,
|
59 |
+
resnet_act_fn=resnet_act_fn,
|
60 |
+
resnet_groups=resnet_groups,
|
61 |
+
downsample_padding=downsample_padding,
|
62 |
+
attn_num_head_channels=attn_num_head_channels,
|
63 |
+
)
|
64 |
+
elif down_block_type == "CrossAttnDownBlock2D":
|
65 |
+
if cross_attention_dim is None:
|
66 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
|
67 |
+
return CrossAttnDownBlock2D(
|
68 |
+
num_layers=num_layers,
|
69 |
+
in_channels=in_channels,
|
70 |
+
out_channels=out_channels,
|
71 |
+
temb_channels=temb_channels,
|
72 |
+
add_downsample=add_downsample,
|
73 |
+
resnet_eps=resnet_eps,
|
74 |
+
resnet_act_fn=resnet_act_fn,
|
75 |
+
resnet_groups=resnet_groups,
|
76 |
+
downsample_padding=downsample_padding,
|
77 |
+
cross_attention_dim=cross_attention_dim,
|
78 |
+
attn_num_head_channels=attn_num_head_channels,
|
79 |
+
)
|
80 |
+
elif down_block_type == "SkipDownBlock2D":
|
81 |
+
return SkipDownBlock2D(
|
82 |
+
num_layers=num_layers,
|
83 |
+
in_channels=in_channels,
|
84 |
+
out_channels=out_channels,
|
85 |
+
temb_channels=temb_channels,
|
86 |
+
add_downsample=add_downsample,
|
87 |
+
resnet_eps=resnet_eps,
|
88 |
+
resnet_act_fn=resnet_act_fn,
|
89 |
+
downsample_padding=downsample_padding,
|
90 |
+
)
|
91 |
+
elif down_block_type == "AttnSkipDownBlock2D":
|
92 |
+
return AttnSkipDownBlock2D(
|
93 |
+
num_layers=num_layers,
|
94 |
+
in_channels=in_channels,
|
95 |
+
out_channels=out_channels,
|
96 |
+
temb_channels=temb_channels,
|
97 |
+
add_downsample=add_downsample,
|
98 |
+
resnet_eps=resnet_eps,
|
99 |
+
resnet_act_fn=resnet_act_fn,
|
100 |
+
downsample_padding=downsample_padding,
|
101 |
+
attn_num_head_channels=attn_num_head_channels,
|
102 |
+
)
|
103 |
+
elif down_block_type == "DownEncoderBlock2D":
|
104 |
+
return DownEncoderBlock2D(
|
105 |
+
num_layers=num_layers,
|
106 |
+
in_channels=in_channels,
|
107 |
+
out_channels=out_channels,
|
108 |
+
add_downsample=add_downsample,
|
109 |
+
resnet_eps=resnet_eps,
|
110 |
+
resnet_act_fn=resnet_act_fn,
|
111 |
+
resnet_groups=resnet_groups,
|
112 |
+
downsample_padding=downsample_padding,
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
def get_up_block(
|
117 |
+
up_block_type,
|
118 |
+
num_layers,
|
119 |
+
in_channels,
|
120 |
+
out_channels,
|
121 |
+
prev_output_channel,
|
122 |
+
temb_channels,
|
123 |
+
add_upsample,
|
124 |
+
resnet_eps,
|
125 |
+
resnet_act_fn,
|
126 |
+
attn_num_head_channels,
|
127 |
+
resnet_groups=None,
|
128 |
+
cross_attention_dim=None,
|
129 |
+
):
|
130 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
131 |
+
if up_block_type == "UpBlock2D":
|
132 |
+
return UpBlock2D(
|
133 |
+
num_layers=num_layers,
|
134 |
+
in_channels=in_channels,
|
135 |
+
out_channels=out_channels,
|
136 |
+
prev_output_channel=prev_output_channel,
|
137 |
+
temb_channels=temb_channels,
|
138 |
+
add_upsample=add_upsample,
|
139 |
+
resnet_eps=resnet_eps,
|
140 |
+
resnet_act_fn=resnet_act_fn,
|
141 |
+
resnet_groups=resnet_groups,
|
142 |
+
)
|
143 |
+
elif up_block_type == "CrossAttnUpBlock2D":
|
144 |
+
if cross_attention_dim is None:
|
145 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
|
146 |
+
return CrossAttnUpBlock2D(
|
147 |
+
num_layers=num_layers,
|
148 |
+
in_channels=in_channels,
|
149 |
+
out_channels=out_channels,
|
150 |
+
prev_output_channel=prev_output_channel,
|
151 |
+
temb_channels=temb_channels,
|
152 |
+
add_upsample=add_upsample,
|
153 |
+
resnet_eps=resnet_eps,
|
154 |
+
resnet_act_fn=resnet_act_fn,
|
155 |
+
resnet_groups=resnet_groups,
|
156 |
+
cross_attention_dim=cross_attention_dim,
|
157 |
+
attn_num_head_channels=attn_num_head_channels,
|
158 |
+
)
|
159 |
+
elif up_block_type == "AttnUpBlock2D":
|
160 |
+
return AttnUpBlock2D(
|
161 |
+
num_layers=num_layers,
|
162 |
+
in_channels=in_channels,
|
163 |
+
out_channels=out_channels,
|
164 |
+
prev_output_channel=prev_output_channel,
|
165 |
+
temb_channels=temb_channels,
|
166 |
+
add_upsample=add_upsample,
|
167 |
+
resnet_eps=resnet_eps,
|
168 |
+
resnet_act_fn=resnet_act_fn,
|
169 |
+
resnet_groups=resnet_groups,
|
170 |
+
attn_num_head_channels=attn_num_head_channels,
|
171 |
+
)
|
172 |
+
elif up_block_type == "SkipUpBlock2D":
|
173 |
+
return SkipUpBlock2D(
|
174 |
+
num_layers=num_layers,
|
175 |
+
in_channels=in_channels,
|
176 |
+
out_channels=out_channels,
|
177 |
+
prev_output_channel=prev_output_channel,
|
178 |
+
temb_channels=temb_channels,
|
179 |
+
add_upsample=add_upsample,
|
180 |
+
resnet_eps=resnet_eps,
|
181 |
+
resnet_act_fn=resnet_act_fn,
|
182 |
+
)
|
183 |
+
elif up_block_type == "AttnSkipUpBlock2D":
|
184 |
+
return AttnSkipUpBlock2D(
|
185 |
+
num_layers=num_layers,
|
186 |
+
in_channels=in_channels,
|
187 |
+
out_channels=out_channels,
|
188 |
+
prev_output_channel=prev_output_channel,
|
189 |
+
temb_channels=temb_channels,
|
190 |
+
add_upsample=add_upsample,
|
191 |
+
resnet_eps=resnet_eps,
|
192 |
+
resnet_act_fn=resnet_act_fn,
|
193 |
+
attn_num_head_channels=attn_num_head_channels,
|
194 |
+
)
|
195 |
+
elif up_block_type == "UpDecoderBlock2D":
|
196 |
+
return UpDecoderBlock2D(
|
197 |
+
num_layers=num_layers,
|
198 |
+
in_channels=in_channels,
|
199 |
+
out_channels=out_channels,
|
200 |
+
add_upsample=add_upsample,
|
201 |
+
resnet_eps=resnet_eps,
|
202 |
+
resnet_act_fn=resnet_act_fn,
|
203 |
+
resnet_groups=resnet_groups,
|
204 |
+
)
|
205 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
206 |
+
|
207 |
+
|
208 |
+
class UNetMidBlock2D(nn.Module):
|
209 |
+
def __init__(
|
210 |
+
self,
|
211 |
+
in_channels: int,
|
212 |
+
temb_channels: int,
|
213 |
+
dropout: float = 0.0,
|
214 |
+
num_layers: int = 1,
|
215 |
+
resnet_eps: float = 1e-6,
|
216 |
+
resnet_time_scale_shift: str = "default",
|
217 |
+
resnet_act_fn: str = "swish",
|
218 |
+
resnet_groups: int = 32,
|
219 |
+
resnet_pre_norm: bool = True,
|
220 |
+
attn_num_head_channels=1,
|
221 |
+
attention_type="default",
|
222 |
+
output_scale_factor=1.0,
|
223 |
+
**kwargs,
|
224 |
+
):
|
225 |
+
super().__init__()
|
226 |
+
|
227 |
+
self.attention_type = attention_type
|
228 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
229 |
+
|
230 |
+
# there is always at least one resnet
|
231 |
+
resnets = [
|
232 |
+
ResnetBlock2D(
|
233 |
+
in_channels=in_channels,
|
234 |
+
out_channels=in_channels,
|
235 |
+
temb_channels=temb_channels,
|
236 |
+
eps=resnet_eps,
|
237 |
+
groups=resnet_groups,
|
238 |
+
dropout=dropout,
|
239 |
+
time_embedding_norm=resnet_time_scale_shift,
|
240 |
+
non_linearity=resnet_act_fn,
|
241 |
+
output_scale_factor=output_scale_factor,
|
242 |
+
pre_norm=resnet_pre_norm,
|
243 |
+
)
|
244 |
+
]
|
245 |
+
attentions = []
|
246 |
+
|
247 |
+
for _ in range(num_layers):
|
248 |
+
attentions.append(
|
249 |
+
AttentionBlock(
|
250 |
+
in_channels,
|
251 |
+
num_head_channels=attn_num_head_channels,
|
252 |
+
rescale_output_factor=output_scale_factor,
|
253 |
+
eps=resnet_eps,
|
254 |
+
num_groups=resnet_groups,
|
255 |
+
)
|
256 |
+
)
|
257 |
+
resnets.append(
|
258 |
+
ResnetBlock2D(
|
259 |
+
in_channels=in_channels,
|
260 |
+
out_channels=in_channels,
|
261 |
+
temb_channels=temb_channels,
|
262 |
+
eps=resnet_eps,
|
263 |
+
groups=resnet_groups,
|
264 |
+
dropout=dropout,
|
265 |
+
time_embedding_norm=resnet_time_scale_shift,
|
266 |
+
non_linearity=resnet_act_fn,
|
267 |
+
output_scale_factor=output_scale_factor,
|
268 |
+
pre_norm=resnet_pre_norm,
|
269 |
+
)
|
270 |
+
)
|
271 |
+
|
272 |
+
self.attentions = nn.ModuleList(attentions)
|
273 |
+
self.resnets = nn.ModuleList(resnets)
|
274 |
+
|
275 |
+
def forward(self, hidden_states, temb=None, encoder_states=None):
|
276 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
277 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
278 |
+
if self.attention_type == "default":
|
279 |
+
hidden_states = attn(hidden_states)
|
280 |
+
else:
|
281 |
+
hidden_states = attn(hidden_states, encoder_states)
|
282 |
+
hidden_states = resnet(hidden_states, temb)
|
283 |
+
|
284 |
+
return hidden_states
|
285 |
+
|
286 |
+
|
287 |
+
class UNetMidBlock2DCrossAttn(nn.Module):
|
288 |
+
def __init__(
|
289 |
+
self,
|
290 |
+
in_channels: int,
|
291 |
+
temb_channels: int,
|
292 |
+
dropout: float = 0.0,
|
293 |
+
num_layers: int = 1,
|
294 |
+
resnet_eps: float = 1e-6,
|
295 |
+
resnet_time_scale_shift: str = "default",
|
296 |
+
resnet_act_fn: str = "swish",
|
297 |
+
resnet_groups: int = 32,
|
298 |
+
resnet_pre_norm: bool = True,
|
299 |
+
attn_num_head_channels=1,
|
300 |
+
attention_type="default",
|
301 |
+
output_scale_factor=1.0,
|
302 |
+
cross_attention_dim=1280,
|
303 |
+
**kwargs,
|
304 |
+
):
|
305 |
+
super().__init__()
|
306 |
+
|
307 |
+
self.attention_type = attention_type
|
308 |
+
self.attn_num_head_channels = attn_num_head_channels
|
309 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
310 |
+
|
311 |
+
# there is always at least one resnet
|
312 |
+
resnets = [
|
313 |
+
ResnetBlock2D(
|
314 |
+
in_channels=in_channels,
|
315 |
+
out_channels=in_channels,
|
316 |
+
temb_channels=temb_channels,
|
317 |
+
eps=resnet_eps,
|
318 |
+
groups=resnet_groups,
|
319 |
+
dropout=dropout,
|
320 |
+
time_embedding_norm=resnet_time_scale_shift,
|
321 |
+
non_linearity=resnet_act_fn,
|
322 |
+
output_scale_factor=output_scale_factor,
|
323 |
+
pre_norm=resnet_pre_norm,
|
324 |
+
)
|
325 |
+
]
|
326 |
+
attentions = []
|
327 |
+
|
328 |
+
for _ in range(num_layers):
|
329 |
+
attentions.append(
|
330 |
+
SpatialTransformer(
|
331 |
+
in_channels,
|
332 |
+
attn_num_head_channels,
|
333 |
+
in_channels // attn_num_head_channels,
|
334 |
+
depth=1,
|
335 |
+
context_dim=cross_attention_dim,
|
336 |
+
num_groups=resnet_groups,
|
337 |
+
)
|
338 |
+
)
|
339 |
+
resnets.append(
|
340 |
+
ResnetBlock2D(
|
341 |
+
in_channels=in_channels,
|
342 |
+
out_channels=in_channels,
|
343 |
+
temb_channels=temb_channels,
|
344 |
+
eps=resnet_eps,
|
345 |
+
groups=resnet_groups,
|
346 |
+
dropout=dropout,
|
347 |
+
time_embedding_norm=resnet_time_scale_shift,
|
348 |
+
non_linearity=resnet_act_fn,
|
349 |
+
output_scale_factor=output_scale_factor,
|
350 |
+
pre_norm=resnet_pre_norm,
|
351 |
+
)
|
352 |
+
)
|
353 |
+
|
354 |
+
self.attentions = nn.ModuleList(attentions)
|
355 |
+
self.resnets = nn.ModuleList(resnets)
|
356 |
+
|
357 |
+
def set_attention_slice(self, slice_size):
|
358 |
+
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
359 |
+
raise ValueError(
|
360 |
+
f"Make sure slice_size {slice_size} is a divisor of "
|
361 |
+
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
362 |
+
)
|
363 |
+
if slice_size is not None and slice_size > self.attn_num_head_channels:
|
364 |
+
raise ValueError(
|
365 |
+
f"Chunk_size {slice_size} has to be smaller or equal to "
|
366 |
+
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
367 |
+
)
|
368 |
+
|
369 |
+
for attn in self.attentions:
|
370 |
+
attn._set_attention_slice(slice_size)
|
371 |
+
|
372 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
373 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
374 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
375 |
+
hidden_states = attn(hidden_states, encoder_hidden_states)
|
376 |
+
hidden_states = resnet(hidden_states, temb)
|
377 |
+
|
378 |
+
return hidden_states
|
379 |
+
|
380 |
+
|
381 |
+
class AttnDownBlock2D(nn.Module):
|
382 |
+
def __init__(
|
383 |
+
self,
|
384 |
+
in_channels: int,
|
385 |
+
out_channels: int,
|
386 |
+
temb_channels: int,
|
387 |
+
dropout: float = 0.0,
|
388 |
+
num_layers: int = 1,
|
389 |
+
resnet_eps: float = 1e-6,
|
390 |
+
resnet_time_scale_shift: str = "default",
|
391 |
+
resnet_act_fn: str = "swish",
|
392 |
+
resnet_groups: int = 32,
|
393 |
+
resnet_pre_norm: bool = True,
|
394 |
+
attn_num_head_channels=1,
|
395 |
+
attention_type="default",
|
396 |
+
output_scale_factor=1.0,
|
397 |
+
downsample_padding=1,
|
398 |
+
add_downsample=True,
|
399 |
+
):
|
400 |
+
super().__init__()
|
401 |
+
resnets = []
|
402 |
+
attentions = []
|
403 |
+
|
404 |
+
self.attention_type = attention_type
|
405 |
+
|
406 |
+
for i in range(num_layers):
|
407 |
+
in_channels = in_channels if i == 0 else out_channels
|
408 |
+
resnets.append(
|
409 |
+
ResnetBlock2D(
|
410 |
+
in_channels=in_channels,
|
411 |
+
out_channels=out_channels,
|
412 |
+
temb_channels=temb_channels,
|
413 |
+
eps=resnet_eps,
|
414 |
+
groups=resnet_groups,
|
415 |
+
dropout=dropout,
|
416 |
+
time_embedding_norm=resnet_time_scale_shift,
|
417 |
+
non_linearity=resnet_act_fn,
|
418 |
+
output_scale_factor=output_scale_factor,
|
419 |
+
pre_norm=resnet_pre_norm,
|
420 |
+
)
|
421 |
+
)
|
422 |
+
attentions.append(
|
423 |
+
AttentionBlock(
|
424 |
+
out_channels,
|
425 |
+
num_head_channels=attn_num_head_channels,
|
426 |
+
rescale_output_factor=output_scale_factor,
|
427 |
+
eps=resnet_eps,
|
428 |
+
num_groups=resnet_groups,
|
429 |
+
)
|
430 |
+
)
|
431 |
+
|
432 |
+
self.attentions = nn.ModuleList(attentions)
|
433 |
+
self.resnets = nn.ModuleList(resnets)
|
434 |
+
|
435 |
+
if add_downsample:
|
436 |
+
self.downsamplers = nn.ModuleList(
|
437 |
+
[
|
438 |
+
Downsample2D(
|
439 |
+
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
440 |
+
)
|
441 |
+
]
|
442 |
+
)
|
443 |
+
else:
|
444 |
+
self.downsamplers = None
|
445 |
+
|
446 |
+
def forward(self, hidden_states, temb=None):
|
447 |
+
output_states = ()
|
448 |
+
|
449 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
450 |
+
hidden_states = resnet(hidden_states, temb)
|
451 |
+
hidden_states = attn(hidden_states)
|
452 |
+
output_states += (hidden_states,)
|
453 |
+
|
454 |
+
if self.downsamplers is not None:
|
455 |
+
for downsampler in self.downsamplers:
|
456 |
+
hidden_states = downsampler(hidden_states)
|
457 |
+
|
458 |
+
output_states += (hidden_states,)
|
459 |
+
|
460 |
+
return hidden_states, output_states
|
461 |
+
|
462 |
+
|
463 |
+
class CrossAttnDownBlock2D(nn.Module):
|
464 |
+
def __init__(
|
465 |
+
self,
|
466 |
+
in_channels: int,
|
467 |
+
out_channels: int,
|
468 |
+
temb_channels: int,
|
469 |
+
dropout: float = 0.0,
|
470 |
+
num_layers: int = 1,
|
471 |
+
resnet_eps: float = 1e-6,
|
472 |
+
resnet_time_scale_shift: str = "default",
|
473 |
+
resnet_act_fn: str = "swish",
|
474 |
+
resnet_groups: int = 32,
|
475 |
+
resnet_pre_norm: bool = True,
|
476 |
+
attn_num_head_channels=1,
|
477 |
+
cross_attention_dim=1280,
|
478 |
+
attention_type="default",
|
479 |
+
output_scale_factor=1.0,
|
480 |
+
downsample_padding=1,
|
481 |
+
add_downsample=True,
|
482 |
+
):
|
483 |
+
super().__init__()
|
484 |
+
resnets = []
|
485 |
+
attentions = []
|
486 |
+
|
487 |
+
self.attention_type = attention_type
|
488 |
+
self.attn_num_head_channels = attn_num_head_channels
|
489 |
+
|
490 |
+
for i in range(num_layers):
|
491 |
+
in_channels = in_channels if i == 0 else out_channels
|
492 |
+
resnets.append(
|
493 |
+
ResnetBlock2D(
|
494 |
+
in_channels=in_channels,
|
495 |
+
out_channels=out_channels,
|
496 |
+
temb_channels=temb_channels,
|
497 |
+
eps=resnet_eps,
|
498 |
+
groups=resnet_groups,
|
499 |
+
dropout=dropout,
|
500 |
+
time_embedding_norm=resnet_time_scale_shift,
|
501 |
+
non_linearity=resnet_act_fn,
|
502 |
+
output_scale_factor=output_scale_factor,
|
503 |
+
pre_norm=resnet_pre_norm,
|
504 |
+
)
|
505 |
+
)
|
506 |
+
attentions.append(
|
507 |
+
SpatialTransformer(
|
508 |
+
out_channels,
|
509 |
+
attn_num_head_channels,
|
510 |
+
out_channels // attn_num_head_channels,
|
511 |
+
depth=1,
|
512 |
+
context_dim=cross_attention_dim,
|
513 |
+
num_groups=resnet_groups,
|
514 |
+
)
|
515 |
+
)
|
516 |
+
self.attentions = nn.ModuleList(attentions)
|
517 |
+
self.resnets = nn.ModuleList(resnets)
|
518 |
+
|
519 |
+
if add_downsample:
|
520 |
+
self.downsamplers = nn.ModuleList(
|
521 |
+
[
|
522 |
+
Downsample2D(
|
523 |
+
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
524 |
+
)
|
525 |
+
]
|
526 |
+
)
|
527 |
+
else:
|
528 |
+
self.downsamplers = None
|
529 |
+
|
530 |
+
self.gradient_checkpointing = False
|
531 |
+
|
532 |
+
def set_attention_slice(self, slice_size):
|
533 |
+
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
534 |
+
raise ValueError(
|
535 |
+
f"Make sure slice_size {slice_size} is a divisor of "
|
536 |
+
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
537 |
+
)
|
538 |
+
if slice_size is not None and slice_size > self.attn_num_head_channels:
|
539 |
+
raise ValueError(
|
540 |
+
f"Chunk_size {slice_size} has to be smaller or equal to "
|
541 |
+
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
542 |
+
)
|
543 |
+
|
544 |
+
for attn in self.attentions:
|
545 |
+
attn._set_attention_slice(slice_size)
|
546 |
+
|
547 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
548 |
+
output_states = ()
|
549 |
+
|
550 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
551 |
+
if self.training and self.gradient_checkpointing:
|
552 |
+
|
553 |
+
def create_custom_forward(module):
|
554 |
+
def custom_forward(*inputs):
|
555 |
+
return module(*inputs)
|
556 |
+
|
557 |
+
return custom_forward
|
558 |
+
|
559 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
560 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
561 |
+
create_custom_forward(attn), hidden_states, encoder_hidden_states
|
562 |
+
)
|
563 |
+
else:
|
564 |
+
hidden_states = resnet(hidden_states, temb)
|
565 |
+
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
566 |
+
|
567 |
+
output_states += (hidden_states,)
|
568 |
+
|
569 |
+
if self.downsamplers is not None:
|
570 |
+
for downsampler in self.downsamplers:
|
571 |
+
hidden_states = downsampler(hidden_states)
|
572 |
+
|
573 |
+
output_states += (hidden_states,)
|
574 |
+
|
575 |
+
return hidden_states, output_states
|
576 |
+
|
577 |
+
|
578 |
+
class DownBlock2D(nn.Module):
|
579 |
+
def __init__(
|
580 |
+
self,
|
581 |
+
in_channels: int,
|
582 |
+
out_channels: int,
|
583 |
+
temb_channels: int,
|
584 |
+
dropout: float = 0.0,
|
585 |
+
num_layers: int = 1,
|
586 |
+
resnet_eps: float = 1e-6,
|
587 |
+
resnet_time_scale_shift: str = "default",
|
588 |
+
resnet_act_fn: str = "swish",
|
589 |
+
resnet_groups: int = 32,
|
590 |
+
resnet_pre_norm: bool = True,
|
591 |
+
output_scale_factor=1.0,
|
592 |
+
add_downsample=True,
|
593 |
+
downsample_padding=1,
|
594 |
+
):
|
595 |
+
super().__init__()
|
596 |
+
resnets = []
|
597 |
+
|
598 |
+
for i in range(num_layers):
|
599 |
+
in_channels = in_channels if i == 0 else out_channels
|
600 |
+
resnets.append(
|
601 |
+
ResnetBlock2D(
|
602 |
+
in_channels=in_channels,
|
603 |
+
out_channels=out_channels,
|
604 |
+
temb_channels=temb_channels,
|
605 |
+
eps=resnet_eps,
|
606 |
+
groups=resnet_groups,
|
607 |
+
dropout=dropout,
|
608 |
+
time_embedding_norm=resnet_time_scale_shift,
|
609 |
+
non_linearity=resnet_act_fn,
|
610 |
+
output_scale_factor=output_scale_factor,
|
611 |
+
pre_norm=resnet_pre_norm,
|
612 |
+
)
|
613 |
+
)
|
614 |
+
|
615 |
+
self.resnets = nn.ModuleList(resnets)
|
616 |
+
|
617 |
+
if add_downsample:
|
618 |
+
self.downsamplers = nn.ModuleList(
|
619 |
+
[
|
620 |
+
Downsample2D(
|
621 |
+
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
622 |
+
)
|
623 |
+
]
|
624 |
+
)
|
625 |
+
else:
|
626 |
+
self.downsamplers = None
|
627 |
+
|
628 |
+
self.gradient_checkpointing = False
|
629 |
+
|
630 |
+
def forward(self, hidden_states, temb=None):
|
631 |
+
output_states = ()
|
632 |
+
|
633 |
+
for resnet in self.resnets:
|
634 |
+
if self.training and self.gradient_checkpointing:
|
635 |
+
|
636 |
+
def create_custom_forward(module):
|
637 |
+
def custom_forward(*inputs):
|
638 |
+
return module(*inputs)
|
639 |
+
|
640 |
+
return custom_forward
|
641 |
+
|
642 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
643 |
+
else:
|
644 |
+
hidden_states = resnet(hidden_states, temb)
|
645 |
+
|
646 |
+
output_states += (hidden_states,)
|
647 |
+
|
648 |
+
if self.downsamplers is not None:
|
649 |
+
for downsampler in self.downsamplers:
|
650 |
+
hidden_states = downsampler(hidden_states)
|
651 |
+
|
652 |
+
output_states += (hidden_states,)
|
653 |
+
|
654 |
+
return hidden_states, output_states
|
655 |
+
|
656 |
+
|
657 |
+
class DownEncoderBlock2D(nn.Module):
|
658 |
+
def __init__(
|
659 |
+
self,
|
660 |
+
in_channels: int,
|
661 |
+
out_channels: int,
|
662 |
+
dropout: float = 0.0,
|
663 |
+
num_layers: int = 1,
|
664 |
+
resnet_eps: float = 1e-6,
|
665 |
+
resnet_time_scale_shift: str = "default",
|
666 |
+
resnet_act_fn: str = "swish",
|
667 |
+
resnet_groups: int = 32,
|
668 |
+
resnet_pre_norm: bool = True,
|
669 |
+
output_scale_factor=1.0,
|
670 |
+
add_downsample=True,
|
671 |
+
downsample_padding=1,
|
672 |
+
):
|
673 |
+
super().__init__()
|
674 |
+
resnets = []
|
675 |
+
|
676 |
+
for i in range(num_layers):
|
677 |
+
in_channels = in_channels if i == 0 else out_channels
|
678 |
+
resnets.append(
|
679 |
+
ResnetBlock2D(
|
680 |
+
in_channels=in_channels,
|
681 |
+
out_channels=out_channels,
|
682 |
+
temb_channels=None,
|
683 |
+
eps=resnet_eps,
|
684 |
+
groups=resnet_groups,
|
685 |
+
dropout=dropout,
|
686 |
+
time_embedding_norm=resnet_time_scale_shift,
|
687 |
+
non_linearity=resnet_act_fn,
|
688 |
+
output_scale_factor=output_scale_factor,
|
689 |
+
pre_norm=resnet_pre_norm,
|
690 |
+
)
|
691 |
+
)
|
692 |
+
|
693 |
+
self.resnets = nn.ModuleList(resnets)
|
694 |
+
|
695 |
+
if add_downsample:
|
696 |
+
self.downsamplers = nn.ModuleList(
|
697 |
+
[
|
698 |
+
Downsample2D(
|
699 |
+
out_channels if len(resnets)>0 else in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
700 |
+
)
|
701 |
+
]
|
702 |
+
)
|
703 |
+
else:
|
704 |
+
self.downsamplers = None
|
705 |
+
|
706 |
+
def forward(self, hidden_states):
|
707 |
+
for resnet in self.resnets:
|
708 |
+
hidden_states = resnet(hidden_states, temb=None)
|
709 |
+
|
710 |
+
if self.downsamplers is not None:
|
711 |
+
for downsampler in self.downsamplers:
|
712 |
+
hidden_states = downsampler(hidden_states)
|
713 |
+
|
714 |
+
return hidden_states
|
715 |
+
|
716 |
+
|
717 |
+
class AttnDownEncoderBlock2D(nn.Module):
|
718 |
+
def __init__(
|
719 |
+
self,
|
720 |
+
in_channels: int,
|
721 |
+
out_channels: int,
|
722 |
+
dropout: float = 0.0,
|
723 |
+
num_layers: int = 1,
|
724 |
+
resnet_eps: float = 1e-6,
|
725 |
+
resnet_time_scale_shift: str = "default",
|
726 |
+
resnet_act_fn: str = "swish",
|
727 |
+
resnet_groups: int = 32,
|
728 |
+
resnet_pre_norm: bool = True,
|
729 |
+
attn_num_head_channels=1,
|
730 |
+
output_scale_factor=1.0,
|
731 |
+
add_downsample=True,
|
732 |
+
downsample_padding=1,
|
733 |
+
):
|
734 |
+
super().__init__()
|
735 |
+
resnets = []
|
736 |
+
attentions = []
|
737 |
+
|
738 |
+
for i in range(num_layers):
|
739 |
+
in_channels = in_channels if i == 0 else out_channels
|
740 |
+
resnets.append(
|
741 |
+
ResnetBlock2D(
|
742 |
+
in_channels=in_channels,
|
743 |
+
out_channels=out_channels,
|
744 |
+
temb_channels=None,
|
745 |
+
eps=resnet_eps,
|
746 |
+
groups=resnet_groups,
|
747 |
+
dropout=dropout,
|
748 |
+
time_embedding_norm=resnet_time_scale_shift,
|
749 |
+
non_linearity=resnet_act_fn,
|
750 |
+
output_scale_factor=output_scale_factor,
|
751 |
+
pre_norm=resnet_pre_norm,
|
752 |
+
)
|
753 |
+
)
|
754 |
+
attentions.append(
|
755 |
+
AttentionBlock(
|
756 |
+
out_channels,
|
757 |
+
num_head_channels=attn_num_head_channels,
|
758 |
+
rescale_output_factor=output_scale_factor,
|
759 |
+
eps=resnet_eps,
|
760 |
+
num_groups=resnet_groups,
|
761 |
+
)
|
762 |
+
)
|
763 |
+
|
764 |
+
self.attentions = nn.ModuleList(attentions)
|
765 |
+
self.resnets = nn.ModuleList(resnets)
|
766 |
+
|
767 |
+
if add_downsample:
|
768 |
+
self.downsamplers = nn.ModuleList(
|
769 |
+
[
|
770 |
+
Downsample2D(
|
771 |
+
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
772 |
+
)
|
773 |
+
]
|
774 |
+
)
|
775 |
+
else:
|
776 |
+
self.downsamplers = None
|
777 |
+
|
778 |
+
def forward(self, hidden_states):
|
779 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
780 |
+
hidden_states = resnet(hidden_states, temb=None)
|
781 |
+
hidden_states = attn(hidden_states)
|
782 |
+
|
783 |
+
if self.downsamplers is not None:
|
784 |
+
for downsampler in self.downsamplers:
|
785 |
+
hidden_states = downsampler(hidden_states)
|
786 |
+
|
787 |
+
return hidden_states
|
788 |
+
|
789 |
+
|
790 |
+
class AttnSkipDownBlock2D(nn.Module):
|
791 |
+
def __init__(
|
792 |
+
self,
|
793 |
+
in_channels: int,
|
794 |
+
out_channels: int,
|
795 |
+
temb_channels: int,
|
796 |
+
dropout: float = 0.0,
|
797 |
+
num_layers: int = 1,
|
798 |
+
resnet_eps: float = 1e-6,
|
799 |
+
resnet_time_scale_shift: str = "default",
|
800 |
+
resnet_act_fn: str = "swish",
|
801 |
+
resnet_pre_norm: bool = True,
|
802 |
+
attn_num_head_channels=1,
|
803 |
+
attention_type="default",
|
804 |
+
output_scale_factor=np.sqrt(2.0),
|
805 |
+
downsample_padding=1,
|
806 |
+
add_downsample=True,
|
807 |
+
):
|
808 |
+
super().__init__()
|
809 |
+
self.attentions = nn.ModuleList([])
|
810 |
+
self.resnets = nn.ModuleList([])
|
811 |
+
|
812 |
+
self.attention_type = attention_type
|
813 |
+
|
814 |
+
for i in range(num_layers):
|
815 |
+
in_channels = in_channels if i == 0 else out_channels
|
816 |
+
self.resnets.append(
|
817 |
+
ResnetBlock2D(
|
818 |
+
in_channels=in_channels,
|
819 |
+
out_channels=out_channels,
|
820 |
+
temb_channels=temb_channels,
|
821 |
+
eps=resnet_eps,
|
822 |
+
groups=min(in_channels // 4, 32),
|
823 |
+
groups_out=min(out_channels // 4, 32),
|
824 |
+
dropout=dropout,
|
825 |
+
time_embedding_norm=resnet_time_scale_shift,
|
826 |
+
non_linearity=resnet_act_fn,
|
827 |
+
output_scale_factor=output_scale_factor,
|
828 |
+
pre_norm=resnet_pre_norm,
|
829 |
+
)
|
830 |
+
)
|
831 |
+
self.attentions.append(
|
832 |
+
AttentionBlock(
|
833 |
+
out_channels,
|
834 |
+
num_head_channels=attn_num_head_channels,
|
835 |
+
rescale_output_factor=output_scale_factor,
|
836 |
+
eps=resnet_eps,
|
837 |
+
)
|
838 |
+
)
|
839 |
+
|
840 |
+
if add_downsample:
|
841 |
+
self.resnet_down = ResnetBlock2D(
|
842 |
+
in_channels=out_channels,
|
843 |
+
out_channels=out_channels,
|
844 |
+
temb_channels=temb_channels,
|
845 |
+
eps=resnet_eps,
|
846 |
+
groups=min(out_channels // 4, 32),
|
847 |
+
dropout=dropout,
|
848 |
+
time_embedding_norm=resnet_time_scale_shift,
|
849 |
+
non_linearity=resnet_act_fn,
|
850 |
+
output_scale_factor=output_scale_factor,
|
851 |
+
pre_norm=resnet_pre_norm,
|
852 |
+
use_in_shortcut=True,
|
853 |
+
down=True,
|
854 |
+
kernel="fir",
|
855 |
+
)
|
856 |
+
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
|
857 |
+
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
|
858 |
+
else:
|
859 |
+
self.resnet_down = None
|
860 |
+
self.downsamplers = None
|
861 |
+
self.skip_conv = None
|
862 |
+
|
863 |
+
def forward(self, hidden_states, temb=None, skip_sample=None):
|
864 |
+
output_states = ()
|
865 |
+
|
866 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
867 |
+
hidden_states = resnet(hidden_states, temb)
|
868 |
+
hidden_states = attn(hidden_states)
|
869 |
+
output_states += (hidden_states,)
|
870 |
+
|
871 |
+
if self.downsamplers is not None:
|
872 |
+
hidden_states = self.resnet_down(hidden_states, temb)
|
873 |
+
for downsampler in self.downsamplers:
|
874 |
+
skip_sample = downsampler(skip_sample)
|
875 |
+
|
876 |
+
hidden_states = self.skip_conv(skip_sample) + hidden_states
|
877 |
+
|
878 |
+
output_states += (hidden_states,)
|
879 |
+
|
880 |
+
return hidden_states, output_states, skip_sample
|
881 |
+
|
882 |
+
|
883 |
+
class SkipDownBlock2D(nn.Module):
|
884 |
+
def __init__(
|
885 |
+
self,
|
886 |
+
in_channels: int,
|
887 |
+
out_channels: int,
|
888 |
+
temb_channels: int,
|
889 |
+
dropout: float = 0.0,
|
890 |
+
num_layers: int = 1,
|
891 |
+
resnet_eps: float = 1e-6,
|
892 |
+
resnet_time_scale_shift: str = "default",
|
893 |
+
resnet_act_fn: str = "swish",
|
894 |
+
resnet_pre_norm: bool = True,
|
895 |
+
output_scale_factor=np.sqrt(2.0),
|
896 |
+
add_downsample=True,
|
897 |
+
downsample_padding=1,
|
898 |
+
):
|
899 |
+
super().__init__()
|
900 |
+
self.resnets = nn.ModuleList([])
|
901 |
+
|
902 |
+
for i in range(num_layers):
|
903 |
+
in_channels = in_channels if i == 0 else out_channels
|
904 |
+
self.resnets.append(
|
905 |
+
ResnetBlock2D(
|
906 |
+
in_channels=in_channels,
|
907 |
+
out_channels=out_channels,
|
908 |
+
temb_channels=temb_channels,
|
909 |
+
eps=resnet_eps,
|
910 |
+
groups=min(in_channels // 4, 32),
|
911 |
+
groups_out=min(out_channels // 4, 32),
|
912 |
+
dropout=dropout,
|
913 |
+
time_embedding_norm=resnet_time_scale_shift,
|
914 |
+
non_linearity=resnet_act_fn,
|
915 |
+
output_scale_factor=output_scale_factor,
|
916 |
+
pre_norm=resnet_pre_norm,
|
917 |
+
)
|
918 |
+
)
|
919 |
+
|
920 |
+
if add_downsample:
|
921 |
+
self.resnet_down = ResnetBlock2D(
|
922 |
+
in_channels=out_channels,
|
923 |
+
out_channels=out_channels,
|
924 |
+
temb_channels=temb_channels,
|
925 |
+
eps=resnet_eps,
|
926 |
+
groups=min(out_channels // 4, 32),
|
927 |
+
dropout=dropout,
|
928 |
+
time_embedding_norm=resnet_time_scale_shift,
|
929 |
+
non_linearity=resnet_act_fn,
|
930 |
+
output_scale_factor=output_scale_factor,
|
931 |
+
pre_norm=resnet_pre_norm,
|
932 |
+
use_in_shortcut=True,
|
933 |
+
down=True,
|
934 |
+
kernel="fir",
|
935 |
+
)
|
936 |
+
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
|
937 |
+
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
|
938 |
+
else:
|
939 |
+
self.resnet_down = None
|
940 |
+
self.downsamplers = None
|
941 |
+
self.skip_conv = None
|
942 |
+
|
943 |
+
def forward(self, hidden_states, temb=None, skip_sample=None):
|
944 |
+
output_states = ()
|
945 |
+
|
946 |
+
for resnet in self.resnets:
|
947 |
+
hidden_states = resnet(hidden_states, temb)
|
948 |
+
output_states += (hidden_states,)
|
949 |
+
|
950 |
+
if self.downsamplers is not None:
|
951 |
+
hidden_states = self.resnet_down(hidden_states, temb)
|
952 |
+
for downsampler in self.downsamplers:
|
953 |
+
skip_sample = downsampler(skip_sample)
|
954 |
+
|
955 |
+
hidden_states = self.skip_conv(skip_sample) + hidden_states
|
956 |
+
|
957 |
+
output_states += (hidden_states,)
|
958 |
+
|
959 |
+
return hidden_states, output_states, skip_sample
|
960 |
+
|
961 |
+
|
962 |
+
class AttnUpBlock2D(nn.Module):
|
963 |
+
def __init__(
|
964 |
+
self,
|
965 |
+
in_channels: int,
|
966 |
+
prev_output_channel: int,
|
967 |
+
out_channels: int,
|
968 |
+
temb_channels: int,
|
969 |
+
dropout: float = 0.0,
|
970 |
+
num_layers: int = 1,
|
971 |
+
resnet_eps: float = 1e-6,
|
972 |
+
resnet_time_scale_shift: str = "default",
|
973 |
+
resnet_act_fn: str = "swish",
|
974 |
+
resnet_groups: int = 32,
|
975 |
+
resnet_pre_norm: bool = True,
|
976 |
+
attention_type="default",
|
977 |
+
attn_num_head_channels=1,
|
978 |
+
output_scale_factor=1.0,
|
979 |
+
add_upsample=True,
|
980 |
+
):
|
981 |
+
super().__init__()
|
982 |
+
resnets = []
|
983 |
+
attentions = []
|
984 |
+
|
985 |
+
self.attention_type = attention_type
|
986 |
+
|
987 |
+
for i in range(num_layers):
|
988 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
989 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
990 |
+
|
991 |
+
resnets.append(
|
992 |
+
ResnetBlock2D(
|
993 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
994 |
+
out_channels=out_channels,
|
995 |
+
temb_channels=temb_channels,
|
996 |
+
eps=resnet_eps,
|
997 |
+
groups=resnet_groups,
|
998 |
+
dropout=dropout,
|
999 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1000 |
+
non_linearity=resnet_act_fn,
|
1001 |
+
output_scale_factor=output_scale_factor,
|
1002 |
+
pre_norm=resnet_pre_norm,
|
1003 |
+
)
|
1004 |
+
)
|
1005 |
+
attentions.append(
|
1006 |
+
AttentionBlock(
|
1007 |
+
out_channels,
|
1008 |
+
num_head_channels=attn_num_head_channels,
|
1009 |
+
rescale_output_factor=output_scale_factor,
|
1010 |
+
eps=resnet_eps,
|
1011 |
+
num_groups=resnet_groups,
|
1012 |
+
)
|
1013 |
+
)
|
1014 |
+
|
1015 |
+
self.attentions = nn.ModuleList(attentions)
|
1016 |
+
self.resnets = nn.ModuleList(resnets)
|
1017 |
+
|
1018 |
+
if add_upsample:
|
1019 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1020 |
+
else:
|
1021 |
+
self.upsamplers = None
|
1022 |
+
|
1023 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
1024 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1025 |
+
# pop res hidden states
|
1026 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1027 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1028 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1029 |
+
|
1030 |
+
hidden_states = resnet(hidden_states, temb)
|
1031 |
+
hidden_states = attn(hidden_states)
|
1032 |
+
|
1033 |
+
if self.upsamplers is not None:
|
1034 |
+
for upsampler in self.upsamplers:
|
1035 |
+
hidden_states = upsampler(hidden_states)
|
1036 |
+
|
1037 |
+
return hidden_states
|
1038 |
+
|
1039 |
+
|
1040 |
+
class CrossAttnUpBlock2D(nn.Module):
|
1041 |
+
def __init__(
|
1042 |
+
self,
|
1043 |
+
in_channels: int,
|
1044 |
+
out_channels: int,
|
1045 |
+
prev_output_channel: int,
|
1046 |
+
temb_channels: int,
|
1047 |
+
dropout: float = 0.0,
|
1048 |
+
num_layers: int = 1,
|
1049 |
+
resnet_eps: float = 1e-6,
|
1050 |
+
resnet_time_scale_shift: str = "default",
|
1051 |
+
resnet_act_fn: str = "swish",
|
1052 |
+
resnet_groups: int = 32,
|
1053 |
+
resnet_pre_norm: bool = True,
|
1054 |
+
attn_num_head_channels=1,
|
1055 |
+
cross_attention_dim=1280,
|
1056 |
+
attention_type="default",
|
1057 |
+
output_scale_factor=1.0,
|
1058 |
+
downsample_padding=1,
|
1059 |
+
add_upsample=True,
|
1060 |
+
):
|
1061 |
+
super().__init__()
|
1062 |
+
resnets = []
|
1063 |
+
attentions = []
|
1064 |
+
|
1065 |
+
self.attention_type = attention_type
|
1066 |
+
self.attn_num_head_channels = attn_num_head_channels
|
1067 |
+
|
1068 |
+
for i in range(num_layers):
|
1069 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1070 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1071 |
+
|
1072 |
+
resnets.append(
|
1073 |
+
ResnetBlock2D(
|
1074 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1075 |
+
out_channels=out_channels,
|
1076 |
+
temb_channels=temb_channels,
|
1077 |
+
eps=resnet_eps,
|
1078 |
+
groups=resnet_groups,
|
1079 |
+
dropout=dropout,
|
1080 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1081 |
+
non_linearity=resnet_act_fn,
|
1082 |
+
output_scale_factor=output_scale_factor,
|
1083 |
+
pre_norm=resnet_pre_norm,
|
1084 |
+
)
|
1085 |
+
)
|
1086 |
+
attentions.append(
|
1087 |
+
SpatialTransformer(
|
1088 |
+
out_channels,
|
1089 |
+
attn_num_head_channels,
|
1090 |
+
out_channels // attn_num_head_channels,
|
1091 |
+
depth=1,
|
1092 |
+
context_dim=cross_attention_dim,
|
1093 |
+
num_groups=resnet_groups,
|
1094 |
+
)
|
1095 |
+
)
|
1096 |
+
self.attentions = nn.ModuleList(attentions)
|
1097 |
+
self.resnets = nn.ModuleList(resnets)
|
1098 |
+
|
1099 |
+
if add_upsample:
|
1100 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1101 |
+
else:
|
1102 |
+
self.upsamplers = None
|
1103 |
+
|
1104 |
+
self.gradient_checkpointing = False
|
1105 |
+
|
1106 |
+
def set_attention_slice(self, slice_size):
|
1107 |
+
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
1108 |
+
raise ValueError(
|
1109 |
+
f"Make sure slice_size {slice_size} is a divisor of "
|
1110 |
+
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
1111 |
+
)
|
1112 |
+
if slice_size is not None and slice_size > self.attn_num_head_channels:
|
1113 |
+
raise ValueError(
|
1114 |
+
f"Chunk_size {slice_size} has to be smaller or equal to "
|
1115 |
+
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
1116 |
+
)
|
1117 |
+
|
1118 |
+
for attn in self.attentions:
|
1119 |
+
attn._set_attention_slice(slice_size)
|
1120 |
+
|
1121 |
+
self.gradient_checkpointing = False
|
1122 |
+
|
1123 |
+
def forward(
|
1124 |
+
self,
|
1125 |
+
hidden_states,
|
1126 |
+
res_hidden_states_tuple,
|
1127 |
+
temb=None,
|
1128 |
+
encoder_hidden_states=None,
|
1129 |
+
):
|
1130 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1131 |
+
# pop res hidden states
|
1132 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1133 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1134 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1135 |
+
|
1136 |
+
if self.training and self.gradient_checkpointing:
|
1137 |
+
|
1138 |
+
def create_custom_forward(module):
|
1139 |
+
def custom_forward(*inputs):
|
1140 |
+
return module(*inputs)
|
1141 |
+
|
1142 |
+
return custom_forward
|
1143 |
+
|
1144 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1145 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1146 |
+
create_custom_forward(attn), hidden_states, encoder_hidden_states
|
1147 |
+
)
|
1148 |
+
else:
|
1149 |
+
hidden_states = resnet(hidden_states, temb)
|
1150 |
+
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
1151 |
+
|
1152 |
+
if self.upsamplers is not None:
|
1153 |
+
for upsampler in self.upsamplers:
|
1154 |
+
hidden_states = upsampler(hidden_states)
|
1155 |
+
|
1156 |
+
return hidden_states
|
1157 |
+
|
1158 |
+
|
1159 |
+
class UpBlock2D(nn.Module):
|
1160 |
+
def __init__(
|
1161 |
+
self,
|
1162 |
+
in_channels: int,
|
1163 |
+
prev_output_channel: int,
|
1164 |
+
out_channels: int,
|
1165 |
+
temb_channels: int,
|
1166 |
+
dropout: float = 0.0,
|
1167 |
+
num_layers: int = 1,
|
1168 |
+
resnet_eps: float = 1e-6,
|
1169 |
+
resnet_time_scale_shift: str = "default",
|
1170 |
+
resnet_act_fn: str = "swish",
|
1171 |
+
resnet_groups: int = 32,
|
1172 |
+
resnet_pre_norm: bool = True,
|
1173 |
+
output_scale_factor=1.0,
|
1174 |
+
add_upsample=True,
|
1175 |
+
):
|
1176 |
+
super().__init__()
|
1177 |
+
resnets = []
|
1178 |
+
|
1179 |
+
for i in range(num_layers):
|
1180 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1181 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1182 |
+
|
1183 |
+
resnets.append(
|
1184 |
+
ResnetBlock2D(
|
1185 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1186 |
+
out_channels=out_channels,
|
1187 |
+
temb_channels=temb_channels,
|
1188 |
+
eps=resnet_eps,
|
1189 |
+
groups=resnet_groups,
|
1190 |
+
dropout=dropout,
|
1191 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1192 |
+
non_linearity=resnet_act_fn,
|
1193 |
+
output_scale_factor=output_scale_factor,
|
1194 |
+
pre_norm=resnet_pre_norm,
|
1195 |
+
)
|
1196 |
+
)
|
1197 |
+
|
1198 |
+
self.resnets = nn.ModuleList(resnets)
|
1199 |
+
|
1200 |
+
if add_upsample:
|
1201 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1202 |
+
else:
|
1203 |
+
self.upsamplers = None
|
1204 |
+
|
1205 |
+
self.gradient_checkpointing = False
|
1206 |
+
|
1207 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
1208 |
+
for resnet in self.resnets:
|
1209 |
+
# pop res hidden states
|
1210 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1211 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1212 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1213 |
+
|
1214 |
+
if self.training and self.gradient_checkpointing:
|
1215 |
+
|
1216 |
+
def create_custom_forward(module):
|
1217 |
+
def custom_forward(*inputs):
|
1218 |
+
return module(*inputs)
|
1219 |
+
|
1220 |
+
return custom_forward
|
1221 |
+
|
1222 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1223 |
+
else:
|
1224 |
+
hidden_states = resnet(hidden_states, temb)
|
1225 |
+
|
1226 |
+
if self.upsamplers is not None:
|
1227 |
+
for upsampler in self.upsamplers:
|
1228 |
+
hidden_states = upsampler(hidden_states)
|
1229 |
+
|
1230 |
+
return hidden_states
|
1231 |
+
|
1232 |
+
|
1233 |
+
class UpDecoderBlock2D(nn.Module):
|
1234 |
+
def __init__(
|
1235 |
+
self,
|
1236 |
+
in_channels: int,
|
1237 |
+
out_channels: int,
|
1238 |
+
dropout: float = 0.0,
|
1239 |
+
num_layers: int = 1,
|
1240 |
+
resnet_eps: float = 1e-6,
|
1241 |
+
resnet_time_scale_shift: str = "default",
|
1242 |
+
resnet_act_fn: str = "swish",
|
1243 |
+
resnet_groups: int = 32,
|
1244 |
+
resnet_pre_norm: bool = True,
|
1245 |
+
output_scale_factor=1.0,
|
1246 |
+
add_upsample=True,
|
1247 |
+
):
|
1248 |
+
super().__init__()
|
1249 |
+
resnets = []
|
1250 |
+
|
1251 |
+
for i in range(num_layers):
|
1252 |
+
input_channels = in_channels if i == 0 else out_channels
|
1253 |
+
|
1254 |
+
resnets.append(
|
1255 |
+
ResnetBlock2D(
|
1256 |
+
in_channels=input_channels,
|
1257 |
+
out_channels=out_channels,
|
1258 |
+
temb_channels=None,
|
1259 |
+
eps=resnet_eps,
|
1260 |
+
groups=resnet_groups,
|
1261 |
+
dropout=dropout,
|
1262 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1263 |
+
non_linearity=resnet_act_fn,
|
1264 |
+
output_scale_factor=output_scale_factor,
|
1265 |
+
pre_norm=resnet_pre_norm,
|
1266 |
+
)
|
1267 |
+
)
|
1268 |
+
|
1269 |
+
self.resnets = nn.ModuleList(resnets)
|
1270 |
+
|
1271 |
+
if add_upsample:
|
1272 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1273 |
+
else:
|
1274 |
+
self.upsamplers = None
|
1275 |
+
|
1276 |
+
def forward(self, hidden_states):
|
1277 |
+
for resnet in self.resnets:
|
1278 |
+
hidden_states = resnet(hidden_states, temb=None)
|
1279 |
+
|
1280 |
+
if self.upsamplers is not None:
|
1281 |
+
for upsampler in self.upsamplers:
|
1282 |
+
hidden_states = upsampler(hidden_states)
|
1283 |
+
|
1284 |
+
return hidden_states
|
1285 |
+
|
1286 |
+
|
1287 |
+
class AttnUpDecoderBlock2D(nn.Module):
|
1288 |
+
def __init__(
|
1289 |
+
self,
|
1290 |
+
in_channels: int,
|
1291 |
+
out_channels: int,
|
1292 |
+
dropout: float = 0.0,
|
1293 |
+
num_layers: int = 1,
|
1294 |
+
resnet_eps: float = 1e-6,
|
1295 |
+
resnet_time_scale_shift: str = "default",
|
1296 |
+
resnet_act_fn: str = "swish",
|
1297 |
+
resnet_groups: int = 32,
|
1298 |
+
resnet_pre_norm: bool = True,
|
1299 |
+
attn_num_head_channels=1,
|
1300 |
+
output_scale_factor=1.0,
|
1301 |
+
add_upsample=True,
|
1302 |
+
):
|
1303 |
+
super().__init__()
|
1304 |
+
resnets = []
|
1305 |
+
attentions = []
|
1306 |
+
|
1307 |
+
for i in range(num_layers):
|
1308 |
+
input_channels = in_channels if i == 0 else out_channels
|
1309 |
+
|
1310 |
+
resnets.append(
|
1311 |
+
ResnetBlock2D(
|
1312 |
+
in_channels=input_channels,
|
1313 |
+
out_channels=out_channels,
|
1314 |
+
temb_channels=None,
|
1315 |
+
eps=resnet_eps,
|
1316 |
+
groups=resnet_groups,
|
1317 |
+
dropout=dropout,
|
1318 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1319 |
+
non_linearity=resnet_act_fn,
|
1320 |
+
output_scale_factor=output_scale_factor,
|
1321 |
+
pre_norm=resnet_pre_norm,
|
1322 |
+
)
|
1323 |
+
)
|
1324 |
+
attentions.append(
|
1325 |
+
AttentionBlock(
|
1326 |
+
out_channels,
|
1327 |
+
num_head_channels=attn_num_head_channels,
|
1328 |
+
rescale_output_factor=output_scale_factor,
|
1329 |
+
eps=resnet_eps,
|
1330 |
+
num_groups=resnet_groups,
|
1331 |
+
)
|
1332 |
+
)
|
1333 |
+
|
1334 |
+
self.attentions = nn.ModuleList(attentions)
|
1335 |
+
self.resnets = nn.ModuleList(resnets)
|
1336 |
+
|
1337 |
+
if add_upsample:
|
1338 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
1339 |
+
else:
|
1340 |
+
self.upsamplers = None
|
1341 |
+
|
1342 |
+
def forward(self, hidden_states):
|
1343 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1344 |
+
hidden_states = resnet(hidden_states, temb=None)
|
1345 |
+
hidden_states = attn(hidden_states)
|
1346 |
+
|
1347 |
+
if self.upsamplers is not None:
|
1348 |
+
for upsampler in self.upsamplers:
|
1349 |
+
hidden_states = upsampler(hidden_states)
|
1350 |
+
|
1351 |
+
return hidden_states
|
1352 |
+
|
1353 |
+
|
1354 |
+
class AttnSkipUpBlock2D(nn.Module):
|
1355 |
+
def __init__(
|
1356 |
+
self,
|
1357 |
+
in_channels: int,
|
1358 |
+
prev_output_channel: int,
|
1359 |
+
out_channels: int,
|
1360 |
+
temb_channels: int,
|
1361 |
+
dropout: float = 0.0,
|
1362 |
+
num_layers: int = 1,
|
1363 |
+
resnet_eps: float = 1e-6,
|
1364 |
+
resnet_time_scale_shift: str = "default",
|
1365 |
+
resnet_act_fn: str = "swish",
|
1366 |
+
resnet_pre_norm: bool = True,
|
1367 |
+
attn_num_head_channels=1,
|
1368 |
+
attention_type="default",
|
1369 |
+
output_scale_factor=np.sqrt(2.0),
|
1370 |
+
upsample_padding=1,
|
1371 |
+
add_upsample=True,
|
1372 |
+
):
|
1373 |
+
super().__init__()
|
1374 |
+
self.attentions = nn.ModuleList([])
|
1375 |
+
self.resnets = nn.ModuleList([])
|
1376 |
+
|
1377 |
+
self.attention_type = attention_type
|
1378 |
+
|
1379 |
+
for i in range(num_layers):
|
1380 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1381 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1382 |
+
|
1383 |
+
self.resnets.append(
|
1384 |
+
ResnetBlock2D(
|
1385 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1386 |
+
out_channels=out_channels,
|
1387 |
+
temb_channels=temb_channels,
|
1388 |
+
eps=resnet_eps,
|
1389 |
+
groups=min(resnet_in_channels + res_skip_channels // 4, 32),
|
1390 |
+
groups_out=min(out_channels // 4, 32),
|
1391 |
+
dropout=dropout,
|
1392 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1393 |
+
non_linearity=resnet_act_fn,
|
1394 |
+
output_scale_factor=output_scale_factor,
|
1395 |
+
pre_norm=resnet_pre_norm,
|
1396 |
+
)
|
1397 |
+
)
|
1398 |
+
|
1399 |
+
self.attentions.append(
|
1400 |
+
AttentionBlock(
|
1401 |
+
out_channels,
|
1402 |
+
num_head_channels=attn_num_head_channels,
|
1403 |
+
rescale_output_factor=output_scale_factor,
|
1404 |
+
eps=resnet_eps,
|
1405 |
+
)
|
1406 |
+
)
|
1407 |
+
|
1408 |
+
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
|
1409 |
+
if add_upsample:
|
1410 |
+
self.resnet_up = ResnetBlock2D(
|
1411 |
+
in_channels=out_channels,
|
1412 |
+
out_channels=out_channels,
|
1413 |
+
temb_channels=temb_channels,
|
1414 |
+
eps=resnet_eps,
|
1415 |
+
groups=min(out_channels // 4, 32),
|
1416 |
+
groups_out=min(out_channels // 4, 32),
|
1417 |
+
dropout=dropout,
|
1418 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1419 |
+
non_linearity=resnet_act_fn,
|
1420 |
+
output_scale_factor=output_scale_factor,
|
1421 |
+
pre_norm=resnet_pre_norm,
|
1422 |
+
use_in_shortcut=True,
|
1423 |
+
up=True,
|
1424 |
+
kernel="fir",
|
1425 |
+
)
|
1426 |
+
self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
1427 |
+
self.skip_norm = torch.nn.GroupNorm(
|
1428 |
+
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
|
1429 |
+
)
|
1430 |
+
self.act = nn.SiLU()
|
1431 |
+
else:
|
1432 |
+
self.resnet_up = None
|
1433 |
+
self.skip_conv = None
|
1434 |
+
self.skip_norm = None
|
1435 |
+
self.act = None
|
1436 |
+
|
1437 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
|
1438 |
+
for resnet in self.resnets:
|
1439 |
+
# pop res hidden states
|
1440 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1441 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1442 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1443 |
+
|
1444 |
+
hidden_states = resnet(hidden_states, temb)
|
1445 |
+
|
1446 |
+
hidden_states = self.attentions[0](hidden_states)
|
1447 |
+
|
1448 |
+
if skip_sample is not None:
|
1449 |
+
skip_sample = self.upsampler(skip_sample)
|
1450 |
+
else:
|
1451 |
+
skip_sample = 0
|
1452 |
+
|
1453 |
+
if self.resnet_up is not None:
|
1454 |
+
skip_sample_states = self.skip_norm(hidden_states)
|
1455 |
+
skip_sample_states = self.act(skip_sample_states)
|
1456 |
+
skip_sample_states = self.skip_conv(skip_sample_states)
|
1457 |
+
|
1458 |
+
skip_sample = skip_sample + skip_sample_states
|
1459 |
+
|
1460 |
+
hidden_states = self.resnet_up(hidden_states, temb)
|
1461 |
+
|
1462 |
+
return hidden_states, skip_sample
|
1463 |
+
|
1464 |
+
|
1465 |
+
class SkipUpBlock2D(nn.Module):
|
1466 |
+
def __init__(
|
1467 |
+
self,
|
1468 |
+
in_channels: int,
|
1469 |
+
prev_output_channel: int,
|
1470 |
+
out_channels: int,
|
1471 |
+
temb_channels: int,
|
1472 |
+
dropout: float = 0.0,
|
1473 |
+
num_layers: int = 1,
|
1474 |
+
resnet_eps: float = 1e-6,
|
1475 |
+
resnet_time_scale_shift: str = "default",
|
1476 |
+
resnet_act_fn: str = "swish",
|
1477 |
+
resnet_pre_norm: bool = True,
|
1478 |
+
output_scale_factor=np.sqrt(2.0),
|
1479 |
+
add_upsample=True,
|
1480 |
+
upsample_padding=1,
|
1481 |
+
):
|
1482 |
+
super().__init__()
|
1483 |
+
self.resnets = nn.ModuleList([])
|
1484 |
+
|
1485 |
+
for i in range(num_layers):
|
1486 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
1487 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1488 |
+
|
1489 |
+
self.resnets.append(
|
1490 |
+
ResnetBlock2D(
|
1491 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1492 |
+
out_channels=out_channels,
|
1493 |
+
temb_channels=temb_channels,
|
1494 |
+
eps=resnet_eps,
|
1495 |
+
groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
|
1496 |
+
groups_out=min(out_channels // 4, 32),
|
1497 |
+
dropout=dropout,
|
1498 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1499 |
+
non_linearity=resnet_act_fn,
|
1500 |
+
output_scale_factor=output_scale_factor,
|
1501 |
+
pre_norm=resnet_pre_norm,
|
1502 |
+
)
|
1503 |
+
)
|
1504 |
+
|
1505 |
+
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
|
1506 |
+
if add_upsample:
|
1507 |
+
self.resnet_up = ResnetBlock2D(
|
1508 |
+
in_channels=out_channels,
|
1509 |
+
out_channels=out_channels,
|
1510 |
+
temb_channels=temb_channels,
|
1511 |
+
eps=resnet_eps,
|
1512 |
+
groups=min(out_channels // 4, 32),
|
1513 |
+
groups_out=min(out_channels // 4, 32),
|
1514 |
+
dropout=dropout,
|
1515 |
+
time_embedding_norm=resnet_time_scale_shift,
|
1516 |
+
non_linearity=resnet_act_fn,
|
1517 |
+
output_scale_factor=output_scale_factor,
|
1518 |
+
pre_norm=resnet_pre_norm,
|
1519 |
+
use_in_shortcut=True,
|
1520 |
+
up=True,
|
1521 |
+
kernel="fir",
|
1522 |
+
)
|
1523 |
+
self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
1524 |
+
self.skip_norm = torch.nn.GroupNorm(
|
1525 |
+
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
|
1526 |
+
)
|
1527 |
+
self.act = nn.SiLU()
|
1528 |
+
else:
|
1529 |
+
self.resnet_up = None
|
1530 |
+
self.skip_conv = None
|
1531 |
+
self.skip_norm = None
|
1532 |
+
self.act = None
|
1533 |
+
|
1534 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
|
1535 |
+
for resnet in self.resnets:
|
1536 |
+
# pop res hidden states
|
1537 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1538 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1539 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1540 |
+
|
1541 |
+
hidden_states = resnet(hidden_states, temb)
|
1542 |
+
|
1543 |
+
if skip_sample is not None:
|
1544 |
+
skip_sample = self.upsampler(skip_sample)
|
1545 |
+
else:
|
1546 |
+
skip_sample = 0
|
1547 |
+
|
1548 |
+
if self.resnet_up is not None:
|
1549 |
+
skip_sample_states = self.skip_norm(hidden_states)
|
1550 |
+
skip_sample_states = self.act(skip_sample_states)
|
1551 |
+
skip_sample_states = self.skip_conv(skip_sample_states)
|
1552 |
+
|
1553 |
+
skip_sample = skip_sample + skip_sample_states
|
1554 |
+
|
1555 |
+
hidden_states = self.resnet_up(hidden_states, temb)
|
1556 |
+
|
1557 |
+
return hidden_states, skip_sample
|
medical_diffusion/external/diffusers/vae.py
ADDED
@@ -0,0 +1,857 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from typing import Optional, Tuple, Union
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from itertools import chain
|
11 |
+
|
12 |
+
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
13 |
+
from .taming_discriminator import NLayerDiscriminator
|
14 |
+
from medical_diffusion.models import BasicModel
|
15 |
+
from torchvision.utils import save_image
|
16 |
+
|
17 |
+
from torch.distributions.normal import Normal
|
18 |
+
from torch.distributions import kl_divergence
|
19 |
+
|
20 |
+
class Encoder(nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
in_channels=3,
|
24 |
+
out_channels=3,
|
25 |
+
down_block_types=("DownEncoderBlock2D",),
|
26 |
+
block_out_channels=(64),
|
27 |
+
layers_per_block=2,
|
28 |
+
norm_num_groups=32,
|
29 |
+
act_fn="silu",
|
30 |
+
double_z=True,
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
self.layers_per_block = layers_per_block
|
34 |
+
|
35 |
+
self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
36 |
+
|
37 |
+
self.mid_block = None
|
38 |
+
self.down_blocks = nn.ModuleList([])
|
39 |
+
|
40 |
+
# down
|
41 |
+
output_channel = block_out_channels[0]
|
42 |
+
for i, down_block_type in enumerate(down_block_types):
|
43 |
+
input_channel = output_channel
|
44 |
+
output_channel = block_out_channels[i+1]
|
45 |
+
is_final_block = False #i == len(block_out_channels) - 1
|
46 |
+
|
47 |
+
down_block = get_down_block(
|
48 |
+
down_block_type,
|
49 |
+
num_layers=self.layers_per_block,
|
50 |
+
in_channels=input_channel,
|
51 |
+
out_channels=output_channel,
|
52 |
+
add_downsample=not is_final_block,
|
53 |
+
resnet_eps=1e-6,
|
54 |
+
downsample_padding=0,
|
55 |
+
resnet_act_fn=act_fn,
|
56 |
+
resnet_groups=norm_num_groups,
|
57 |
+
attn_num_head_channels=None,
|
58 |
+
temb_channels=None,
|
59 |
+
)
|
60 |
+
self.down_blocks.append(down_block)
|
61 |
+
|
62 |
+
# mid
|
63 |
+
self.mid_block = UNetMidBlock2D(
|
64 |
+
in_channels=block_out_channels[-1],
|
65 |
+
resnet_eps=1e-6,
|
66 |
+
resnet_act_fn=act_fn,
|
67 |
+
output_scale_factor=1,
|
68 |
+
resnet_time_scale_shift="default",
|
69 |
+
attn_num_head_channels=None,
|
70 |
+
resnet_groups=norm_num_groups,
|
71 |
+
temb_channels=None,
|
72 |
+
)
|
73 |
+
|
74 |
+
# out
|
75 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
76 |
+
self.conv_act = nn.SiLU()
|
77 |
+
|
78 |
+
conv_out_channels = 2 * out_channels if double_z else out_channels
|
79 |
+
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
sample = x
|
83 |
+
sample = self.conv_in(sample)
|
84 |
+
|
85 |
+
# down
|
86 |
+
for down_block in self.down_blocks:
|
87 |
+
sample = down_block(sample)
|
88 |
+
|
89 |
+
# middle
|
90 |
+
sample = self.mid_block(sample)
|
91 |
+
|
92 |
+
# post-process
|
93 |
+
sample = self.conv_norm_out(sample)
|
94 |
+
sample = self.conv_act(sample)
|
95 |
+
sample = self.conv_out(sample)
|
96 |
+
|
97 |
+
return sample
|
98 |
+
|
99 |
+
|
100 |
+
class Decoder(nn.Module):
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
in_channels=3,
|
104 |
+
out_channels=3,
|
105 |
+
up_block_types=("UpDecoderBlock2D",),
|
106 |
+
block_out_channels=(64,),
|
107 |
+
layers_per_block=2,
|
108 |
+
norm_num_groups=32,
|
109 |
+
act_fn="silu",
|
110 |
+
):
|
111 |
+
super().__init__()
|
112 |
+
self.layers_per_block = layers_per_block
|
113 |
+
|
114 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
115 |
+
|
116 |
+
self.mid_block = None
|
117 |
+
self.up_blocks = nn.ModuleList([])
|
118 |
+
|
119 |
+
# mid
|
120 |
+
self.mid_block = UNetMidBlock2D(
|
121 |
+
in_channels=block_out_channels[-1],
|
122 |
+
resnet_eps=1e-6,
|
123 |
+
resnet_act_fn=act_fn,
|
124 |
+
output_scale_factor=1,
|
125 |
+
resnet_time_scale_shift="default",
|
126 |
+
attn_num_head_channels=None,
|
127 |
+
resnet_groups=norm_num_groups,
|
128 |
+
temb_channels=None,
|
129 |
+
)
|
130 |
+
|
131 |
+
# up
|
132 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
133 |
+
output_channel = reversed_block_out_channels[0]
|
134 |
+
for i, up_block_type in enumerate(up_block_types):
|
135 |
+
prev_output_channel = output_channel
|
136 |
+
output_channel = reversed_block_out_channels[i+1]
|
137 |
+
|
138 |
+
is_final_block = False # i == len(block_out_channels) - 1
|
139 |
+
|
140 |
+
up_block = get_up_block(
|
141 |
+
up_block_type,
|
142 |
+
num_layers=self.layers_per_block + 1,
|
143 |
+
in_channels=prev_output_channel,
|
144 |
+
out_channels=output_channel,
|
145 |
+
prev_output_channel=None,
|
146 |
+
add_upsample=not is_final_block,
|
147 |
+
resnet_eps=1e-6,
|
148 |
+
resnet_act_fn=act_fn,
|
149 |
+
resnet_groups=norm_num_groups,
|
150 |
+
attn_num_head_channels=None,
|
151 |
+
temb_channels=None,
|
152 |
+
)
|
153 |
+
self.up_blocks.append(up_block)
|
154 |
+
prev_output_channel = output_channel
|
155 |
+
|
156 |
+
# out
|
157 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
158 |
+
self.conv_act = nn.SiLU()
|
159 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
160 |
+
|
161 |
+
def forward(self, z):
|
162 |
+
sample = z
|
163 |
+
sample = self.conv_in(sample)
|
164 |
+
|
165 |
+
# middle
|
166 |
+
sample = self.mid_block(sample)
|
167 |
+
|
168 |
+
# up
|
169 |
+
for up_block in self.up_blocks:
|
170 |
+
sample = up_block(sample)
|
171 |
+
|
172 |
+
# post-process
|
173 |
+
sample = self.conv_norm_out(sample)
|
174 |
+
sample = self.conv_act(sample)
|
175 |
+
sample = self.conv_out(sample)
|
176 |
+
|
177 |
+
return sample
|
178 |
+
|
179 |
+
|
180 |
+
class VectorQuantizer(nn.Module):
|
181 |
+
"""
|
182 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
|
183 |
+
multiplications and allows for post-hoc remapping of indices.
|
184 |
+
"""
|
185 |
+
|
186 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
187 |
+
# backwards compatibility we use the buggy version by default, but you can
|
188 |
+
# specify legacy=False to fix it.
|
189 |
+
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=False):
|
190 |
+
super().__init__()
|
191 |
+
self.n_e = n_e
|
192 |
+
self.e_dim = e_dim
|
193 |
+
self.beta = beta
|
194 |
+
self.legacy = legacy
|
195 |
+
|
196 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
197 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
198 |
+
|
199 |
+
self.remap = remap
|
200 |
+
if self.remap is not None:
|
201 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
202 |
+
self.re_embed = self.used.shape[0]
|
203 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
204 |
+
if self.unknown_index == "extra":
|
205 |
+
self.unknown_index = self.re_embed
|
206 |
+
self.re_embed = self.re_embed + 1
|
207 |
+
print(
|
208 |
+
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
209 |
+
f"Using {self.unknown_index} for unknown indices."
|
210 |
+
)
|
211 |
+
else:
|
212 |
+
self.re_embed = n_e
|
213 |
+
|
214 |
+
self.sane_index_shape = sane_index_shape
|
215 |
+
|
216 |
+
def remap_to_used(self, inds):
|
217 |
+
ishape = inds.shape
|
218 |
+
assert len(ishape) > 1
|
219 |
+
inds = inds.reshape(ishape[0], -1)
|
220 |
+
used = self.used.to(inds)
|
221 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
222 |
+
new = match.argmax(-1)
|
223 |
+
unknown = match.sum(2) < 1
|
224 |
+
if self.unknown_index == "random":
|
225 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
226 |
+
else:
|
227 |
+
new[unknown] = self.unknown_index
|
228 |
+
return new.reshape(ishape)
|
229 |
+
|
230 |
+
def unmap_to_all(self, inds):
|
231 |
+
ishape = inds.shape
|
232 |
+
assert len(ishape) > 1
|
233 |
+
inds = inds.reshape(ishape[0], -1)
|
234 |
+
used = self.used.to(inds)
|
235 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
236 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
237 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
238 |
+
return back.reshape(ishape)
|
239 |
+
|
240 |
+
def forward(self, z):
|
241 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
242 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
243 |
+
z_flattened = z.view(-1, self.e_dim)
|
244 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
245 |
+
|
246 |
+
d = (
|
247 |
+
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
248 |
+
+ torch.sum(self.embedding.weight**2, dim=1)
|
249 |
+
- 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
|
250 |
+
)
|
251 |
+
|
252 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
253 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
254 |
+
perplexity = None
|
255 |
+
min_encodings = None
|
256 |
+
|
257 |
+
# compute loss for embedding
|
258 |
+
if not self.legacy:
|
259 |
+
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
|
260 |
+
else:
|
261 |
+
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
262 |
+
|
263 |
+
# preserve gradients
|
264 |
+
z_q = z + (z_q - z).detach()
|
265 |
+
|
266 |
+
# reshape back to match original input shape
|
267 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
268 |
+
|
269 |
+
if self.remap is not None:
|
270 |
+
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
271 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
272 |
+
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
273 |
+
|
274 |
+
if self.sane_index_shape:
|
275 |
+
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
276 |
+
|
277 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
278 |
+
|
279 |
+
def get_codebook_entry(self, indices, shape):
|
280 |
+
# shape specifying (batch, height, width, channel)
|
281 |
+
if self.remap is not None:
|
282 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
283 |
+
indices = self.unmap_to_all(indices)
|
284 |
+
indices = indices.reshape(-1) # flatten again
|
285 |
+
|
286 |
+
# get quantized latent vectors
|
287 |
+
z_q = self.embedding(indices)
|
288 |
+
|
289 |
+
if shape is not None:
|
290 |
+
z_q = z_q.view(shape)
|
291 |
+
# reshape back to match original input shape
|
292 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
293 |
+
|
294 |
+
return z_q
|
295 |
+
|
296 |
+
|
297 |
+
class DiagonalGaussianDistribution(object):
|
298 |
+
def __init__(self, parameters, deterministic=False):
|
299 |
+
self.batch_size = parameters.shape[0]
|
300 |
+
self.parameters = parameters
|
301 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
302 |
+
# self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
303 |
+
self.deterministic = deterministic
|
304 |
+
self.std = torch.exp(0.5 * self.logvar)
|
305 |
+
self.var = torch.exp(self.logvar)
|
306 |
+
if self.deterministic:
|
307 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
308 |
+
|
309 |
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
310 |
+
device = self.parameters.device
|
311 |
+
sample_device = "cpu" if device.type == "mps" else device
|
312 |
+
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
|
313 |
+
x = self.mean + self.std * sample
|
314 |
+
return x
|
315 |
+
|
316 |
+
def kl(self, other=None):
|
317 |
+
if self.deterministic:
|
318 |
+
return torch.Tensor([0.0])
|
319 |
+
else:
|
320 |
+
if other is None:
|
321 |
+
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar)/self.batch_size
|
322 |
+
else:
|
323 |
+
return 0.5 * torch.sum(
|
324 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
325 |
+
+ self.var / other.var
|
326 |
+
- 1.0
|
327 |
+
- self.logvar
|
328 |
+
+ other.logvar,
|
329 |
+
)/self.batch_size
|
330 |
+
|
331 |
+
# q_z_x = Normal(self.mean, self.logvar.mul(.5).exp())
|
332 |
+
# p_z = Normal(torch.zeros_like(self.mean), torch.ones_like(self.logvar))
|
333 |
+
# kl_div = kl_divergence(q_z_x, p_z).sum(1).mean()
|
334 |
+
# return kl_div
|
335 |
+
|
336 |
+
def nll(self, sample, dims=[1, 2, 3]):
|
337 |
+
if self.deterministic:
|
338 |
+
return torch.Tensor([0.0])
|
339 |
+
logtwopi = np.log(2.0 * np.pi)
|
340 |
+
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
|
341 |
+
|
342 |
+
def mode(self):
|
343 |
+
return self.mean
|
344 |
+
|
345 |
+
|
346 |
+
class VQModel(nn.Module):
|
347 |
+
r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
|
348 |
+
Kavukcuoglu.
|
349 |
+
|
350 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
351 |
+
implements for all the model (such as downloading or saving, etc.)
|
352 |
+
|
353 |
+
Parameters:
|
354 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
355 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
356 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
357 |
+
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
|
358 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
359 |
+
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
|
360 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
361 |
+
obj:`(64,)`): Tuple of block output channels.
|
362 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
363 |
+
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
|
364 |
+
sample_size (`int`, *optional*, defaults to `32`): TODO
|
365 |
+
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
|
366 |
+
"""
|
367 |
+
|
368 |
+
|
369 |
+
def __init__(
|
370 |
+
self,
|
371 |
+
in_channels: int = 3,
|
372 |
+
out_channels: int = 3,
|
373 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
|
374 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
|
375 |
+
block_out_channels: Tuple[int] = (32, 64, 128, 256),
|
376 |
+
layers_per_block: int = 1,
|
377 |
+
act_fn: str = "silu",
|
378 |
+
latent_channels: int = 3,
|
379 |
+
sample_size: int = 32,
|
380 |
+
num_vq_embeddings: int = 256,
|
381 |
+
norm_num_groups: int = 32,
|
382 |
+
):
|
383 |
+
super().__init__()
|
384 |
+
|
385 |
+
# pass init params to Encoder
|
386 |
+
self.encoder = Encoder(
|
387 |
+
in_channels=in_channels,
|
388 |
+
out_channels=latent_channels,
|
389 |
+
down_block_types=down_block_types,
|
390 |
+
block_out_channels=block_out_channels,
|
391 |
+
layers_per_block=layers_per_block,
|
392 |
+
act_fn=act_fn,
|
393 |
+
norm_num_groups=norm_num_groups,
|
394 |
+
double_z=False,
|
395 |
+
)
|
396 |
+
|
397 |
+
self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
398 |
+
self.quantize = VectorQuantizer(
|
399 |
+
num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
|
400 |
+
)
|
401 |
+
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
402 |
+
|
403 |
+
# pass init params to Decoder
|
404 |
+
self.decoder = Decoder(
|
405 |
+
in_channels=latent_channels,
|
406 |
+
out_channels=out_channels,
|
407 |
+
up_block_types=up_block_types,
|
408 |
+
block_out_channels=block_out_channels,
|
409 |
+
layers_per_block=layers_per_block,
|
410 |
+
act_fn=act_fn,
|
411 |
+
norm_num_groups=norm_num_groups,
|
412 |
+
)
|
413 |
+
|
414 |
+
# def encode(self, x: torch.FloatTensor):
|
415 |
+
# z = self.encoder(x)
|
416 |
+
# z = self.quant_conv(z)
|
417 |
+
# return z
|
418 |
+
|
419 |
+
def encode(self, x, return_loss=True, force_quantize= True):
|
420 |
+
z = self.encoder(x)
|
421 |
+
z = self.quant_conv(z)
|
422 |
+
|
423 |
+
if force_quantize:
|
424 |
+
z_q, emb_loss, _ = self.quantize(z)
|
425 |
+
else:
|
426 |
+
z_q, emb_loss = z, None
|
427 |
+
|
428 |
+
if return_loss:
|
429 |
+
return z_q, emb_loss
|
430 |
+
else:
|
431 |
+
return z_q
|
432 |
+
|
433 |
+
def decode(self, z_q) -> torch.FloatTensor:
|
434 |
+
z_q = self.post_quant_conv(z_q)
|
435 |
+
x = self.decoder(z_q)
|
436 |
+
return x
|
437 |
+
|
438 |
+
# def decode(self, z: torch.FloatTensor, return_loss=True, force_quantize: bool = True) -> torch.FloatTensor:
|
439 |
+
# if force_quantize:
|
440 |
+
# z_q, emb_loss, _ = self.quantize(z)
|
441 |
+
# else:
|
442 |
+
# z_q, emb_loss = z, None
|
443 |
+
|
444 |
+
# z_q = self.post_quant_conv(z_q)
|
445 |
+
# x = self.decoder(z_q)
|
446 |
+
|
447 |
+
# if return_loss:
|
448 |
+
# return x, emb_loss
|
449 |
+
# else:
|
450 |
+
# return x
|
451 |
+
|
452 |
+
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
453 |
+
r"""
|
454 |
+
Args:
|
455 |
+
sample (`torch.FloatTensor`): Input sample.
|
456 |
+
"""
|
457 |
+
# h = self.encode(sample)
|
458 |
+
h, emb_loss = self.encode(sample)
|
459 |
+
dec = self.decode(h)
|
460 |
+
# dec, emb_loss = self.decode(h)
|
461 |
+
|
462 |
+
return dec, emb_loss
|
463 |
+
|
464 |
+
|
465 |
+
class AutoencoderKL(nn.Module):
|
466 |
+
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
|
467 |
+
and Max Welling.
|
468 |
+
|
469 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
470 |
+
implements for all the model (such as downloading or saving, etc.)
|
471 |
+
|
472 |
+
Parameters:
|
473 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
474 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
475 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
476 |
+
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
|
477 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
478 |
+
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
|
479 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
480 |
+
obj:`(64,)`): Tuple of block output channels.
|
481 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
482 |
+
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
|
483 |
+
sample_size (`int`, *optional*, defaults to `32`): TODO
|
484 |
+
"""
|
485 |
+
|
486 |
+
|
487 |
+
def __init__(
|
488 |
+
self,
|
489 |
+
in_channels: int = 3,
|
490 |
+
out_channels: int = 3,
|
491 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D","DownEncoderBlock2D",),
|
492 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D",),
|
493 |
+
block_out_channels: Tuple[int] = (32, 64, 128, 128),
|
494 |
+
layers_per_block: int = 1,
|
495 |
+
act_fn: str = "silu",
|
496 |
+
latent_channels: int = 3,
|
497 |
+
norm_num_groups: int = 32,
|
498 |
+
sample_size: int = 32,
|
499 |
+
):
|
500 |
+
super().__init__()
|
501 |
+
|
502 |
+
# pass init params to Encoder
|
503 |
+
self.encoder = Encoder(
|
504 |
+
in_channels=in_channels,
|
505 |
+
out_channels=latent_channels,
|
506 |
+
down_block_types=down_block_types,
|
507 |
+
block_out_channels=block_out_channels,
|
508 |
+
layers_per_block=layers_per_block,
|
509 |
+
act_fn=act_fn,
|
510 |
+
norm_num_groups=norm_num_groups,
|
511 |
+
double_z=True,
|
512 |
+
)
|
513 |
+
|
514 |
+
# pass init params to Decoder
|
515 |
+
self.decoder = Decoder(
|
516 |
+
in_channels=latent_channels,
|
517 |
+
out_channels=out_channels,
|
518 |
+
up_block_types=up_block_types,
|
519 |
+
block_out_channels=block_out_channels,
|
520 |
+
layers_per_block=layers_per_block,
|
521 |
+
norm_num_groups=norm_num_groups,
|
522 |
+
act_fn=act_fn,
|
523 |
+
)
|
524 |
+
|
525 |
+
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
526 |
+
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
527 |
+
|
528 |
+
def encode(self, x: torch.FloatTensor):
|
529 |
+
h = self.encoder(x)
|
530 |
+
moments = self.quant_conv(h)
|
531 |
+
posterior = DiagonalGaussianDistribution(moments)
|
532 |
+
return posterior
|
533 |
+
|
534 |
+
def decode(self, z: torch.FloatTensor) -> torch.FloatTensor:
|
535 |
+
z = self.post_quant_conv(z)
|
536 |
+
dec = self.decoder(z)
|
537 |
+
return dec
|
538 |
+
|
539 |
+
def forward(
|
540 |
+
self,
|
541 |
+
sample: torch.FloatTensor,
|
542 |
+
sample_posterior: bool = True,
|
543 |
+
generator: Optional[torch.Generator] = None,
|
544 |
+
) -> torch.FloatTensor:
|
545 |
+
r"""
|
546 |
+
Args:
|
547 |
+
sample (`torch.FloatTensor`): Input sample.
|
548 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
549 |
+
Whether to sample from the posterior.
|
550 |
+
"""
|
551 |
+
x = sample
|
552 |
+
posterior = self.encode(x)
|
553 |
+
if sample_posterior:
|
554 |
+
z = posterior.sample(generator=generator)
|
555 |
+
else:
|
556 |
+
z = posterior.mode()
|
557 |
+
kl_loss = posterior.kl()
|
558 |
+
dec = self.decode(z)
|
559 |
+
return dec, kl_loss
|
560 |
+
|
561 |
+
|
562 |
+
|
563 |
+
class VQVAEWrapper(BasicModel):
|
564 |
+
def __init__(
|
565 |
+
self,
|
566 |
+
in_ch: int = 3,
|
567 |
+
out_ch: int = 3,
|
568 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D",),
|
569 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D",),
|
570 |
+
block_out_channels: Tuple[int] = (32, 64, 128, 256, ),
|
571 |
+
layers_per_block: int = 1,
|
572 |
+
act_fn: str = "silu",
|
573 |
+
latent_channels: int = 3,
|
574 |
+
sample_size: int = 32,
|
575 |
+
num_vq_embeddings: int = 64,
|
576 |
+
norm_num_groups: int = 32,
|
577 |
+
|
578 |
+
optimizer=torch.optim.AdamW,
|
579 |
+
optimizer_kwargs={},
|
580 |
+
lr_scheduler=None,
|
581 |
+
lr_scheduler_kwargs={},
|
582 |
+
loss=torch.nn.MSELoss,
|
583 |
+
loss_kwargs={}
|
584 |
+
):
|
585 |
+
super().__init__(optimizer, optimizer_kwargs, lr_scheduler, lr_scheduler_kwargs, loss, loss_kwargs)
|
586 |
+
self.model = VQModel(in_ch, out_ch, down_block_types, up_block_types, block_out_channels,
|
587 |
+
layers_per_block, act_fn, latent_channels, sample_size, num_vq_embeddings, norm_num_groups)
|
588 |
+
|
589 |
+
def forward(self, sample):
|
590 |
+
return self.model(sample)
|
591 |
+
|
592 |
+
def encode(self, x):
|
593 |
+
z = self.model.encode(x, return_loss=False)
|
594 |
+
return z
|
595 |
+
|
596 |
+
def decode(self, z):
|
597 |
+
x = self.model.decode(z)
|
598 |
+
return x
|
599 |
+
|
600 |
+
def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
|
601 |
+
# ------------------------- Get Source/Target ---------------------------
|
602 |
+
x = batch['source']
|
603 |
+
target = x
|
604 |
+
|
605 |
+
# ------------------------- Run Model ---------------------------
|
606 |
+
pred, vq_loss = self(x)
|
607 |
+
|
608 |
+
# ------------------------- Compute Loss ---------------------------
|
609 |
+
loss = self.loss_fct(pred, target)
|
610 |
+
loss += vq_loss
|
611 |
+
|
612 |
+
# --------------------- Compute Metrics -------------------------------
|
613 |
+
results = {'loss':loss}
|
614 |
+
with torch.no_grad():
|
615 |
+
results['L2'] = torch.nn.functional.mse_loss(pred, target)
|
616 |
+
results['L1'] = torch.nn.functional.l1_loss(pred, target)
|
617 |
+
|
618 |
+
# ----------------- Log Scalars ----------------------
|
619 |
+
for metric_name, metric_val in results.items():
|
620 |
+
self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True)
|
621 |
+
|
622 |
+
# ----------------- Save Image ------------------------------
|
623 |
+
if self.global_step != 0 and self.global_step % self.trainer.log_every_n_steps == 0:
|
624 |
+
def norm(x):
|
625 |
+
return (x-x.min())/(x.max()-x.min())
|
626 |
+
|
627 |
+
images = [x, pred]
|
628 |
+
log_step = self.global_step // self.trainer.log_every_n_steps
|
629 |
+
path_out = Path(self.logger.log_dir)/'images'
|
630 |
+
path_out.mkdir(parents=True, exist_ok=True)
|
631 |
+
images = torch.cat([norm(img) for img in images])
|
632 |
+
save_image(images, path_out/f'sample_{log_step}.png')
|
633 |
+
|
634 |
+
return loss
|
635 |
+
|
636 |
+
def hinge_d_loss(logits_real, logits_fake):
|
637 |
+
loss_real = torch.mean(F.relu(1. - logits_real))
|
638 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
639 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
640 |
+
return d_loss
|
641 |
+
|
642 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
643 |
+
d_loss = 0.5 * (
|
644 |
+
torch.mean(F.softplus(-logits_real)) +
|
645 |
+
torch.mean(F.softplus(logits_fake)))
|
646 |
+
return d_loss
|
647 |
+
|
648 |
+
class VQGAN(BasicModel):
|
649 |
+
def __init__(
|
650 |
+
self,
|
651 |
+
in_ch: int = 3,
|
652 |
+
out_ch: int = 3,
|
653 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D",),
|
654 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D",),
|
655 |
+
block_out_channels: Tuple[int] = (32, 64, 128, 256, ),
|
656 |
+
layers_per_block: int = 1,
|
657 |
+
act_fn: str = "silu",
|
658 |
+
latent_channels: int = 3,
|
659 |
+
sample_size: int = 32,
|
660 |
+
num_vq_embeddings: int = 64,
|
661 |
+
norm_num_groups: int = 32,
|
662 |
+
|
663 |
+
start_gan_train_step = 50000, # NOTE step increase with each optimizer
|
664 |
+
gan_loss_weight: float = 1.0, # alias discriminator
|
665 |
+
perceptual_loss_weight: float = 1.0,
|
666 |
+
embedding_loss_weight: float = 1.0,
|
667 |
+
|
668 |
+
optimizer=torch.optim.AdamW,
|
669 |
+
optimizer_kwargs={},
|
670 |
+
lr_scheduler=None,
|
671 |
+
lr_scheduler_kwargs={},
|
672 |
+
loss=torch.nn.MSELoss,
|
673 |
+
loss_kwargs={}
|
674 |
+
):
|
675 |
+
super().__init__(optimizer, optimizer_kwargs, lr_scheduler, lr_scheduler_kwargs, loss, loss_kwargs)
|
676 |
+
self.vqvae = VQModel(in_ch, out_ch, down_block_types, up_block_types, block_out_channels, layers_per_block, act_fn,
|
677 |
+
latent_channels, sample_size, num_vq_embeddings, norm_num_groups)
|
678 |
+
self.discriminator = NLayerDiscriminator(in_ch)
|
679 |
+
# self.perceiver = ... # Currently not supported, would require another trained NN
|
680 |
+
|
681 |
+
self.start_gan_train_step = start_gan_train_step
|
682 |
+
self.perceptual_loss_weight = perceptual_loss_weight
|
683 |
+
self.gan_loss_weight = gan_loss_weight
|
684 |
+
self.embedding_loss_weight = embedding_loss_weight
|
685 |
+
|
686 |
+
def forward(self, x, condition=None):
|
687 |
+
return self.vqvae(x)
|
688 |
+
|
689 |
+
def encode(self, x):
|
690 |
+
z = self.vqvae.encode(x, return_loss=False)
|
691 |
+
return z
|
692 |
+
|
693 |
+
def decode(self, z):
|
694 |
+
x = self.vqvae.decode(z)
|
695 |
+
return x
|
696 |
+
|
697 |
+
|
698 |
+
def compute_lambda(self, rec_loss, gan_loss, eps=1e-4):
|
699 |
+
"""Computes adaptive weight as proposed in eq. 7 of https://arxiv.org/abs/2012.09841"""
|
700 |
+
last_layer = self.vqvae.decoder.conv_out.weight
|
701 |
+
rec_grads = torch.autograd.grad(rec_loss, last_layer, retain_graph=True)[0]
|
702 |
+
gan_grads = torch.autograd.grad(gan_loss, last_layer, retain_graph=True)[0]
|
703 |
+
d_weight = torch.norm(rec_grads) / (torch.norm(gan_grads) + eps)
|
704 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4)
|
705 |
+
return d_weight.detach()
|
706 |
+
|
707 |
+
|
708 |
+
|
709 |
+
def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
|
710 |
+
x = batch['source']
|
711 |
+
# condition = batch.get('target', None)
|
712 |
+
|
713 |
+
pred, vq_emb_loss = self.vqvae(x)
|
714 |
+
|
715 |
+
if optimizer_idx == 0:
|
716 |
+
# ------ VAE -------
|
717 |
+
vq_img_loss = F.mse_loss(pred, x)
|
718 |
+
vq_per_loss = 0.0 #self.perceiver(pred, x)
|
719 |
+
rec_loss = vq_img_loss+self.perceptual_loss_weight*vq_per_loss
|
720 |
+
|
721 |
+
# ------- GAN -----
|
722 |
+
if step > self.start_gan_train_step:
|
723 |
+
gan_loss = -torch.mean(self.discriminator(pred))
|
724 |
+
lambda_weight = self.compute_lambda(rec_loss, gan_loss)
|
725 |
+
gan_loss = gan_loss*lambda_weight
|
726 |
+
else:
|
727 |
+
gan_loss = torch.tensor([0.0], requires_grad=True, device=x.device)
|
728 |
+
|
729 |
+
loss = self.gan_loss_weight*gan_loss+rec_loss+self.embedding_loss_weight*vq_emb_loss
|
730 |
+
|
731 |
+
elif optimizer_idx == 1:
|
732 |
+
if step > self.start_gan_train_step//2:
|
733 |
+
logits_real = self.discriminator(x.detach())
|
734 |
+
logits_fake = self.discriminator(pred.detach())
|
735 |
+
loss = hinge_d_loss(logits_real, logits_fake)
|
736 |
+
else:
|
737 |
+
loss = torch.tensor([0.0], requires_grad=True, device=x.device)
|
738 |
+
|
739 |
+
# --------------------- Compute Metrics -------------------------------
|
740 |
+
results = {'loss':loss.detach(), f'loss_{optimizer_idx}':loss.detach()}
|
741 |
+
with torch.no_grad():
|
742 |
+
results[f'L2'] = torch.nn.functional.mse_loss(pred, x)
|
743 |
+
results[f'L1'] = torch.nn.functional.l1_loss(pred, x)
|
744 |
+
|
745 |
+
# ----------------- Log Scalars ----------------------
|
746 |
+
for metric_name, metric_val in results.items():
|
747 |
+
self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True)
|
748 |
+
|
749 |
+
# ----------------- Save Image ------------------------------
|
750 |
+
if self.global_step != 0 and self.global_step % self.trainer.log_every_n_steps == 0: # NOTE: step 1 (opt1) , step=2 (opt2), step=3 (opt1), ...
|
751 |
+
def norm(x):
|
752 |
+
return (x-x.min())/(x.max()-x.min())
|
753 |
+
|
754 |
+
images = torch.cat([x, pred])
|
755 |
+
log_step = self.global_step // self.trainer.log_every_n_steps
|
756 |
+
path_out = Path(self.logger.log_dir)/'images'
|
757 |
+
path_out.mkdir(parents=True, exist_ok=True)
|
758 |
+
images = torch.stack([norm(img) for img in images])
|
759 |
+
save_image(images, path_out/f'sample_{log_step}.png')
|
760 |
+
|
761 |
+
return loss
|
762 |
+
|
763 |
+
def configure_optimizers(self):
|
764 |
+
opt_vae = self.optimizer(self.vqvae.parameters(), **self.optimizer_kwargs)
|
765 |
+
opt_disc = self.optimizer(self.discriminator.parameters(), **self.optimizer_kwargs)
|
766 |
+
if self.lr_scheduler is not None:
|
767 |
+
scheduler = [
|
768 |
+
{
|
769 |
+
'scheduler': self.lr_scheduler(opt_vae, **self.lr_scheduler_kwargs),
|
770 |
+
'interval': 'step',
|
771 |
+
'frequency': 1
|
772 |
+
},
|
773 |
+
{
|
774 |
+
'scheduler': self.lr_scheduler(opt_disc, **self.lr_scheduler_kwargs),
|
775 |
+
'interval': 'step',
|
776 |
+
'frequency': 1
|
777 |
+
},
|
778 |
+
]
|
779 |
+
else:
|
780 |
+
scheduler = []
|
781 |
+
|
782 |
+
return [opt_vae, opt_disc], scheduler
|
783 |
+
|
784 |
+
class VAEWrapper(BasicModel):
|
785 |
+
def __init__(
|
786 |
+
self,
|
787 |
+
in_ch: int = 3,
|
788 |
+
out_ch: int = 3,
|
789 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"), # "DownEncoderBlock2D", "DownEncoderBlock2D",
|
790 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D", "UpDecoderBlock2D","UpDecoderBlock2D" ), # "UpDecoderBlock2D", "UpDecoderBlock2D",
|
791 |
+
block_out_channels: Tuple[int] = (32, 64, 128, 256), # 128, 256
|
792 |
+
layers_per_block: int = 1,
|
793 |
+
act_fn: str = "silu",
|
794 |
+
latent_channels: int = 3,
|
795 |
+
norm_num_groups: int = 32,
|
796 |
+
sample_size: int = 32,
|
797 |
+
|
798 |
+
optimizer=torch.optim.AdamW,
|
799 |
+
optimizer_kwargs={'lr':1e-4, 'weight_decay':1e-3, 'amsgrad':True},
|
800 |
+
lr_scheduler=None,
|
801 |
+
lr_scheduler_kwargs={},
|
802 |
+
# loss=torch.nn.MSELoss, # WARNING: No Effect
|
803 |
+
# loss_kwargs={'reduction': 'mean'}
|
804 |
+
):
|
805 |
+
super().__init__(optimizer, optimizer_kwargs, lr_scheduler, lr_scheduler_kwargs ) # loss, loss_kwargs
|
806 |
+
self.model = AutoencoderKL(in_ch, out_ch, down_block_types, up_block_types, block_out_channels,
|
807 |
+
layers_per_block, act_fn, latent_channels, norm_num_groups, sample_size)
|
808 |
+
|
809 |
+
self.logvar = nn.Parameter(torch.zeros(size=())) # Better weighting between KL and MSE, see (https://arxiv.org/abs/1903.05789), also used by Taming-Transfomer/Stable Diffusion
|
810 |
+
|
811 |
+
def forward(self, sample):
|
812 |
+
return self.model(sample)
|
813 |
+
|
814 |
+
def encode(self, x):
|
815 |
+
z = self.model.encode(x) # Latent space but not yet mapped to discrete embedding vectors
|
816 |
+
return z.sample(generator=None)
|
817 |
+
|
818 |
+
def decode(self, z):
|
819 |
+
x = self.model.decode(z)
|
820 |
+
return x
|
821 |
+
|
822 |
+
def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
|
823 |
+
# ------------------------- Get Source/Target ---------------------------
|
824 |
+
x = batch['source']
|
825 |
+
target = x
|
826 |
+
HALF_LOG_TWO_PI = 0.91893 # log(2pi)/2
|
827 |
+
|
828 |
+
# ------------------------- Run Model ---------------------------
|
829 |
+
pred, kl_loss = self(x)
|
830 |
+
|
831 |
+
# ------------------------- Compute Loss ---------------------------
|
832 |
+
loss = torch.sum( torch.square(pred-target))/x.shape[0] #torch.sum( torch.square((pred-target)/torch.exp(self.logvar))/2 + self.logvar + HALF_LOG_TWO_PI )/x.shape[0]
|
833 |
+
loss += kl_loss
|
834 |
+
|
835 |
+
# --------------------- Compute Metrics -------------------------------
|
836 |
+
results = {'loss':loss.detach()}
|
837 |
+
with torch.no_grad():
|
838 |
+
results['L2'] = torch.nn.functional.mse_loss(pred, target)
|
839 |
+
results['L1'] = torch.nn.functional.l1_loss(pred, target)
|
840 |
+
|
841 |
+
# ----------------- Log Scalars ----------------------
|
842 |
+
for metric_name, metric_val in results.items():
|
843 |
+
self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True)
|
844 |
+
|
845 |
+
# ----------------- Save Image ------------------------------
|
846 |
+
if self.global_step != 0 and self.global_step % self.trainer.log_every_n_steps == 0:
|
847 |
+
def norm(x):
|
848 |
+
return (x-x.min())/(x.max()-x.min())
|
849 |
+
|
850 |
+
images = torch.cat([x, pred])
|
851 |
+
log_step = self.global_step // self.trainer.log_every_n_steps
|
852 |
+
path_out = Path(self.logger.log_dir)/'images'
|
853 |
+
path_out.mkdir(parents=True, exist_ok=True)
|
854 |
+
images = torch.stack([norm(img) for img in images])
|
855 |
+
save_image(images, path_out/f'sample_{log_step}.png')
|
856 |
+
|
857 |
+
return loss
|
medical_diffusion/external/stable_diffusion/attention.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inspect import isfunction
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn, einsum
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
|
8 |
+
from .util_attention import checkpoint
|
9 |
+
|
10 |
+
|
11 |
+
def exists(val):
|
12 |
+
return val is not None
|
13 |
+
|
14 |
+
|
15 |
+
def uniq(arr):
|
16 |
+
return{el: True for el in arr}.keys()
|
17 |
+
|
18 |
+
|
19 |
+
def default(val, d):
|
20 |
+
if exists(val):
|
21 |
+
return val
|
22 |
+
return d() if isfunction(d) else d
|
23 |
+
|
24 |
+
|
25 |
+
def max_neg_value(t):
|
26 |
+
return -torch.finfo(t.dtype).max
|
27 |
+
|
28 |
+
|
29 |
+
def init_(tensor):
|
30 |
+
dim = tensor.shape[-1]
|
31 |
+
std = 1 / math.sqrt(dim)
|
32 |
+
tensor.uniform_(-std, std)
|
33 |
+
return tensor
|
34 |
+
|
35 |
+
|
36 |
+
# feedforward
|
37 |
+
class GEGLU(nn.Module):
|
38 |
+
def __init__(self, dim_in, dim_out):
|
39 |
+
super().__init__()
|
40 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
44 |
+
return x * F.gelu(gate)
|
45 |
+
|
46 |
+
|
47 |
+
class FeedForward(nn.Module):
|
48 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
49 |
+
super().__init__()
|
50 |
+
inner_dim = int(dim * mult)
|
51 |
+
dim_out = default(dim_out, dim)
|
52 |
+
project_in = nn.Sequential(
|
53 |
+
nn.Linear(dim, inner_dim),
|
54 |
+
nn.GELU()
|
55 |
+
) if not glu else GEGLU(dim, inner_dim)
|
56 |
+
|
57 |
+
self.net = nn.Sequential(
|
58 |
+
project_in,
|
59 |
+
nn.Dropout(dropout),
|
60 |
+
nn.Linear(inner_dim, dim_out)
|
61 |
+
)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
return self.net(x)
|
65 |
+
|
66 |
+
|
67 |
+
def zero_module(module):
|
68 |
+
"""
|
69 |
+
Zero out the parameters of a module and return it.
|
70 |
+
"""
|
71 |
+
for p in module.parameters():
|
72 |
+
p.detach().zero_()
|
73 |
+
return module
|
74 |
+
|
75 |
+
|
76 |
+
def Normalize(in_channels):
|
77 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
78 |
+
|
79 |
+
|
80 |
+
class LinearAttention(nn.Module):
|
81 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
82 |
+
super().__init__()
|
83 |
+
self.heads = heads
|
84 |
+
hidden_dim = dim_head * heads
|
85 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
86 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
b, c, h, w = x.shape
|
90 |
+
qkv = self.to_qkv(x)
|
91 |
+
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
92 |
+
k = k.softmax(dim=-1)
|
93 |
+
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
94 |
+
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
95 |
+
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
96 |
+
return self.to_out(out)
|
97 |
+
|
98 |
+
|
99 |
+
class SpatialSelfAttention(nn.Module):
|
100 |
+
def __init__(self, in_channels):
|
101 |
+
super().__init__()
|
102 |
+
self.in_channels = in_channels
|
103 |
+
|
104 |
+
self.norm = Normalize(in_channels)
|
105 |
+
self.q = torch.nn.Conv2d(in_channels,
|
106 |
+
in_channels,
|
107 |
+
kernel_size=1,
|
108 |
+
stride=1,
|
109 |
+
padding=0)
|
110 |
+
self.k = torch.nn.Conv2d(in_channels,
|
111 |
+
in_channels,
|
112 |
+
kernel_size=1,
|
113 |
+
stride=1,
|
114 |
+
padding=0)
|
115 |
+
self.v = torch.nn.Conv2d(in_channels,
|
116 |
+
in_channels,
|
117 |
+
kernel_size=1,
|
118 |
+
stride=1,
|
119 |
+
padding=0)
|
120 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
121 |
+
in_channels,
|
122 |
+
kernel_size=1,
|
123 |
+
stride=1,
|
124 |
+
padding=0)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
h_ = x
|
128 |
+
h_ = self.norm(h_)
|
129 |
+
q = self.q(h_)
|
130 |
+
k = self.k(h_)
|
131 |
+
v = self.v(h_)
|
132 |
+
|
133 |
+
# compute attention
|
134 |
+
b,c,h,w = q.shape
|
135 |
+
q = rearrange(q, 'b c h w -> b (h w) c')
|
136 |
+
k = rearrange(k, 'b c h w -> b c (h w)')
|
137 |
+
w_ = torch.einsum('bij,bjk->bik', q, k)
|
138 |
+
|
139 |
+
w_ = w_ * (int(c)**(-0.5))
|
140 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
141 |
+
|
142 |
+
# attend to values
|
143 |
+
v = rearrange(v, 'b c h w -> b c (h w)')
|
144 |
+
w_ = rearrange(w_, 'b i j -> b j i')
|
145 |
+
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
146 |
+
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
147 |
+
h_ = self.proj_out(h_)
|
148 |
+
|
149 |
+
return x+h_
|
150 |
+
|
151 |
+
|
152 |
+
class CrossAttention(nn.Module):
|
153 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
154 |
+
super().__init__()
|
155 |
+
inner_dim = dim_head * heads
|
156 |
+
context_dim = default(context_dim, query_dim)
|
157 |
+
|
158 |
+
self.scale = dim_head ** -0.5
|
159 |
+
self.heads = heads
|
160 |
+
|
161 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
162 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
163 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
164 |
+
|
165 |
+
self.to_out = nn.Sequential(
|
166 |
+
nn.Linear(inner_dim, query_dim),
|
167 |
+
nn.Dropout(dropout)
|
168 |
+
)
|
169 |
+
|
170 |
+
def forward(self, x, context=None, mask=None):
|
171 |
+
h = self.heads
|
172 |
+
|
173 |
+
q = self.to_q(x)
|
174 |
+
context = default(context, x)
|
175 |
+
k = self.to_k(context)
|
176 |
+
v = self.to_v(context)
|
177 |
+
|
178 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
179 |
+
|
180 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
181 |
+
|
182 |
+
if exists(mask):
|
183 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
184 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
185 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
186 |
+
sim.masked_fill_(~mask, max_neg_value)
|
187 |
+
|
188 |
+
# attention, what we cannot get enough of
|
189 |
+
attn = sim.softmax(dim=-1)
|
190 |
+
|
191 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
192 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
193 |
+
return self.to_out(out)
|
194 |
+
|
195 |
+
|
196 |
+
class BasicTransformerBlock(nn.Module):
|
197 |
+
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
|
198 |
+
super().__init__()
|
199 |
+
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
|
200 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
201 |
+
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
202 |
+
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
203 |
+
self.norm1 = nn.LayerNorm(dim)
|
204 |
+
self.norm2 = nn.LayerNorm(dim)
|
205 |
+
self.norm3 = nn.LayerNorm(dim)
|
206 |
+
self.checkpoint = checkpoint
|
207 |
+
|
208 |
+
def forward(self, x, context=None):
|
209 |
+
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
210 |
+
|
211 |
+
def _forward(self, x, context=None):
|
212 |
+
x = self.attn1(self.norm1(x)) + x
|
213 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
214 |
+
x = self.ff(self.norm3(x)) + x
|
215 |
+
return x
|
216 |
+
|
217 |
+
|
218 |
+
class SpatialTransformer(nn.Module):
|
219 |
+
"""
|
220 |
+
Transformer block for image-like data.
|
221 |
+
First, project the input (aka embedding)
|
222 |
+
and reshape to b, t, d.
|
223 |
+
Then apply standard transformer action.
|
224 |
+
Finally, reshape to image
|
225 |
+
"""
|
226 |
+
def __init__(self, in_channels, n_heads, d_head,
|
227 |
+
depth=1, dropout=0., context_dim=None):
|
228 |
+
super().__init__()
|
229 |
+
self.in_channels = in_channels
|
230 |
+
inner_dim = n_heads * d_head
|
231 |
+
self.norm = Normalize(in_channels)
|
232 |
+
|
233 |
+
self.proj_in = nn.Conv2d(in_channels,
|
234 |
+
inner_dim,
|
235 |
+
kernel_size=1,
|
236 |
+
stride=1,
|
237 |
+
padding=0)
|
238 |
+
|
239 |
+
self.transformer_blocks = nn.ModuleList(
|
240 |
+
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
241 |
+
for d in range(depth)]
|
242 |
+
)
|
243 |
+
|
244 |
+
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
245 |
+
in_channels,
|
246 |
+
kernel_size=1,
|
247 |
+
stride=1,
|
248 |
+
padding=0))
|
249 |
+
|
250 |
+
def forward(self, x, context=None):
|
251 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
252 |
+
b, c, h, w = x.shape
|
253 |
+
x_in = x
|
254 |
+
x = self.norm(x)
|
255 |
+
x = self.proj_in(x)
|
256 |
+
x = rearrange(x, 'b c h w -> b (h w) c')
|
257 |
+
for block in self.transformer_blocks:
|
258 |
+
x = block(x, context=context)
|
259 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
260 |
+
x = self.proj_out(x)
|
261 |
+
return x + x_in
|
medical_diffusion/external/stable_diffusion/lr_schedulers.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class LambdaLinearScheduler:
|
4 |
+
def __init__(self, warm_up_steps=[10000,], f_min=[1.0,], f_max=[1.0,], f_start=[1.e-6], cycle_lengths=[10000000000000], verbosity_interval=0):
|
5 |
+
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
6 |
+
self.lr_warm_up_steps = warm_up_steps
|
7 |
+
self.f_start = f_start
|
8 |
+
self.f_min = f_min
|
9 |
+
self.f_max = f_max
|
10 |
+
self.cycle_lengths = cycle_lengths
|
11 |
+
self.cum_cycles = torch.cumsum(torch.tensor([0] + list(self.cycle_lengths)), 0)
|
12 |
+
self.last_f = 0.
|
13 |
+
self.verbosity_interval = verbosity_interval
|
14 |
+
|
15 |
+
def find_in_interval(self, n):
|
16 |
+
interval = 0
|
17 |
+
for cl in self.cum_cycles[1:]:
|
18 |
+
if n <= cl:
|
19 |
+
return interval
|
20 |
+
interval += 1
|
21 |
+
|
22 |
+
def schedule(self, n, **kwargs):
|
23 |
+
cycle = self.find_in_interval(n)
|
24 |
+
n = n - self.cum_cycles[cycle]
|
25 |
+
|
26 |
+
if n < self.lr_warm_up_steps[cycle]:
|
27 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
28 |
+
self.last_f = f
|
29 |
+
return f
|
30 |
+
else:
|
31 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
32 |
+
self.last_f = f
|
33 |
+
return f
|
medical_diffusion/external/stable_diffusion/unet_openai.py
ADDED
@@ -0,0 +1,962 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from functools import partial
|
3 |
+
import math
|
4 |
+
from typing import Iterable
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch as th
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from .util import (
|
12 |
+
checkpoint,
|
13 |
+
conv_nd,
|
14 |
+
linear,
|
15 |
+
avg_pool_nd,
|
16 |
+
zero_module,
|
17 |
+
normalization,
|
18 |
+
timestep_embedding,
|
19 |
+
)
|
20 |
+
from .attention import SpatialTransformer
|
21 |
+
|
22 |
+
|
23 |
+
# dummy replace
|
24 |
+
def convert_module_to_f16(x):
|
25 |
+
pass
|
26 |
+
|
27 |
+
def convert_module_to_f32(x):
|
28 |
+
pass
|
29 |
+
|
30 |
+
|
31 |
+
## go
|
32 |
+
class AttentionPool2d(nn.Module):
|
33 |
+
"""
|
34 |
+
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
spacial_dim: int,
|
40 |
+
embed_dim: int,
|
41 |
+
num_heads_channels: int,
|
42 |
+
output_dim: int = None,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
|
46 |
+
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
47 |
+
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
48 |
+
self.num_heads = embed_dim // num_heads_channels
|
49 |
+
self.attention = QKVAttention(self.num_heads)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
b, c, *_spatial = x.shape
|
53 |
+
x = x.reshape(b, c, -1) # NC(HW)
|
54 |
+
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
55 |
+
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
56 |
+
x = self.qkv_proj(x)
|
57 |
+
x = self.attention(x)
|
58 |
+
x = self.c_proj(x)
|
59 |
+
return x[:, :, 0]
|
60 |
+
|
61 |
+
|
62 |
+
class TimestepBlock(nn.Module):
|
63 |
+
"""
|
64 |
+
Any module where forward() takes timestep embeddings as a second argument.
|
65 |
+
"""
|
66 |
+
|
67 |
+
@abstractmethod
|
68 |
+
def forward(self, x, emb):
|
69 |
+
"""
|
70 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
71 |
+
"""
|
72 |
+
|
73 |
+
|
74 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
75 |
+
"""
|
76 |
+
A sequential module that passes timestep embeddings to the children that
|
77 |
+
support it as an extra input.
|
78 |
+
"""
|
79 |
+
|
80 |
+
def forward(self, x, emb, context=None):
|
81 |
+
for layer in self:
|
82 |
+
if isinstance(layer, TimestepBlock):
|
83 |
+
x = layer(x, emb)
|
84 |
+
elif isinstance(layer, SpatialTransformer):
|
85 |
+
x = layer(x, context)
|
86 |
+
else:
|
87 |
+
x = layer(x)
|
88 |
+
return x
|
89 |
+
|
90 |
+
|
91 |
+
class Upsample(nn.Module):
|
92 |
+
"""
|
93 |
+
An upsampling layer with an optional convolution.
|
94 |
+
:param channels: channels in the inputs and outputs.
|
95 |
+
:param use_conv: a bool determining if a convolution is applied.
|
96 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
97 |
+
upsampling occurs in the inner-two dimensions.
|
98 |
+
"""
|
99 |
+
|
100 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
101 |
+
super().__init__()
|
102 |
+
self.channels = channels
|
103 |
+
self.out_channels = out_channels or channels
|
104 |
+
self.use_conv = use_conv
|
105 |
+
self.dims = dims
|
106 |
+
if use_conv:
|
107 |
+
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
assert x.shape[1] == self.channels
|
111 |
+
if self.dims == 3:
|
112 |
+
x = F.interpolate(
|
113 |
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
114 |
+
)
|
115 |
+
else:
|
116 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
117 |
+
if self.use_conv:
|
118 |
+
x = self.conv(x)
|
119 |
+
return x
|
120 |
+
|
121 |
+
class TransposedUpsample(nn.Module):
|
122 |
+
'Learned 2x upsampling without padding'
|
123 |
+
def __init__(self, channels, out_channels=None, ks=5):
|
124 |
+
super().__init__()
|
125 |
+
self.channels = channels
|
126 |
+
self.out_channels = out_channels or channels
|
127 |
+
|
128 |
+
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
|
129 |
+
|
130 |
+
def forward(self,x):
|
131 |
+
return self.up(x)
|
132 |
+
|
133 |
+
|
134 |
+
class Downsample(nn.Module):
|
135 |
+
"""
|
136 |
+
A downsampling layer with an optional convolution.
|
137 |
+
:param channels: channels in the inputs and outputs.
|
138 |
+
:param use_conv: a bool determining if a convolution is applied.
|
139 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
140 |
+
downsampling occurs in the inner-two dimensions.
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
144 |
+
super().__init__()
|
145 |
+
self.channels = channels
|
146 |
+
self.out_channels = out_channels or channels
|
147 |
+
self.use_conv = use_conv
|
148 |
+
self.dims = dims
|
149 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
150 |
+
if use_conv:
|
151 |
+
self.op = conv_nd(
|
152 |
+
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
assert self.channels == self.out_channels
|
156 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
157 |
+
|
158 |
+
def forward(self, x):
|
159 |
+
assert x.shape[1] == self.channels
|
160 |
+
return self.op(x)
|
161 |
+
|
162 |
+
|
163 |
+
class ResBlock(TimestepBlock):
|
164 |
+
"""
|
165 |
+
A residual block that can optionally change the number of channels.
|
166 |
+
:param channels: the number of input channels.
|
167 |
+
:param emb_channels: the number of timestep embedding channels.
|
168 |
+
:param dropout: the rate of dropout.
|
169 |
+
:param out_channels: if specified, the number of out channels.
|
170 |
+
:param use_conv: if True and out_channels is specified, use a spatial
|
171 |
+
convolution instead of a smaller 1x1 convolution to change the
|
172 |
+
channels in the skip connection.
|
173 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
174 |
+
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
175 |
+
:param up: if True, use this block for upsampling.
|
176 |
+
:param down: if True, use this block for downsampling.
|
177 |
+
"""
|
178 |
+
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
channels,
|
182 |
+
emb_channels,
|
183 |
+
dropout,
|
184 |
+
out_channels=None,
|
185 |
+
use_conv=False,
|
186 |
+
use_scale_shift_norm=False,
|
187 |
+
dims=2,
|
188 |
+
use_checkpoint=False,
|
189 |
+
up=False,
|
190 |
+
down=False,
|
191 |
+
):
|
192 |
+
super().__init__()
|
193 |
+
self.channels = channels
|
194 |
+
self.emb_channels = emb_channels
|
195 |
+
self.dropout = dropout
|
196 |
+
self.out_channels = out_channels or channels
|
197 |
+
self.use_conv = use_conv
|
198 |
+
self.use_checkpoint = use_checkpoint
|
199 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
200 |
+
|
201 |
+
self.in_layers = nn.Sequential(
|
202 |
+
normalization(channels),
|
203 |
+
nn.SiLU(),
|
204 |
+
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
205 |
+
)
|
206 |
+
|
207 |
+
self.updown = up or down
|
208 |
+
|
209 |
+
if up:
|
210 |
+
self.h_upd = Upsample(channels, False, dims)
|
211 |
+
self.x_upd = Upsample(channels, False, dims)
|
212 |
+
elif down:
|
213 |
+
self.h_upd = Downsample(channels, False, dims)
|
214 |
+
self.x_upd = Downsample(channels, False, dims)
|
215 |
+
else:
|
216 |
+
self.h_upd = self.x_upd = nn.Identity()
|
217 |
+
|
218 |
+
self.emb_layers = nn.Sequential(
|
219 |
+
nn.SiLU(),
|
220 |
+
linear(
|
221 |
+
emb_channels,
|
222 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
223 |
+
),
|
224 |
+
)
|
225 |
+
self.out_layers = nn.Sequential(
|
226 |
+
normalization(self.out_channels),
|
227 |
+
nn.SiLU(),
|
228 |
+
nn.Dropout(p=dropout),
|
229 |
+
zero_module(
|
230 |
+
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
231 |
+
),
|
232 |
+
)
|
233 |
+
|
234 |
+
if self.out_channels == channels:
|
235 |
+
self.skip_connection = nn.Identity()
|
236 |
+
elif use_conv:
|
237 |
+
self.skip_connection = conv_nd(
|
238 |
+
dims, channels, self.out_channels, 3, padding=1
|
239 |
+
)
|
240 |
+
else:
|
241 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
242 |
+
|
243 |
+
def forward(self, x, emb):
|
244 |
+
"""
|
245 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
246 |
+
:param x: an [N x C x ...] Tensor of features.
|
247 |
+
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
248 |
+
:return: an [N x C x ...] Tensor of outputs.
|
249 |
+
"""
|
250 |
+
return checkpoint(
|
251 |
+
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
252 |
+
)
|
253 |
+
|
254 |
+
|
255 |
+
def _forward(self, x, emb):
|
256 |
+
if self.updown:
|
257 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
258 |
+
h = in_rest(x)
|
259 |
+
h = self.h_upd(h)
|
260 |
+
x = self.x_upd(x)
|
261 |
+
h = in_conv(h)
|
262 |
+
else:
|
263 |
+
h = self.in_layers(x)
|
264 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
265 |
+
while len(emb_out.shape) < len(h.shape):
|
266 |
+
emb_out = emb_out[..., None]
|
267 |
+
if self.use_scale_shift_norm:
|
268 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
269 |
+
scale, shift = th.chunk(emb_out, 2, dim=1)
|
270 |
+
h = out_norm(h) * (1 + scale) + shift
|
271 |
+
h = out_rest(h)
|
272 |
+
else:
|
273 |
+
h = h + emb_out
|
274 |
+
h = self.out_layers(h)
|
275 |
+
return self.skip_connection(x) + h
|
276 |
+
|
277 |
+
|
278 |
+
class AttentionBlock(nn.Module):
|
279 |
+
"""
|
280 |
+
An attention block that allows spatial positions to attend to each other.
|
281 |
+
Originally ported from here, but adapted to the N-d case.
|
282 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
283 |
+
"""
|
284 |
+
|
285 |
+
def __init__(
|
286 |
+
self,
|
287 |
+
channels,
|
288 |
+
num_heads=1,
|
289 |
+
num_head_channels=-1,
|
290 |
+
use_checkpoint=False,
|
291 |
+
use_new_attention_order=False,
|
292 |
+
):
|
293 |
+
super().__init__()
|
294 |
+
self.channels = channels
|
295 |
+
if num_head_channels == -1:
|
296 |
+
self.num_heads = num_heads
|
297 |
+
else:
|
298 |
+
assert (
|
299 |
+
channels % num_head_channels == 0
|
300 |
+
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
301 |
+
self.num_heads = channels // num_head_channels
|
302 |
+
self.use_checkpoint = use_checkpoint
|
303 |
+
self.norm = normalization(channels)
|
304 |
+
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
305 |
+
if use_new_attention_order:
|
306 |
+
# split qkv before split heads
|
307 |
+
self.attention = QKVAttention(self.num_heads)
|
308 |
+
else:
|
309 |
+
# split heads before split qkv
|
310 |
+
self.attention = QKVAttentionLegacy(self.num_heads)
|
311 |
+
|
312 |
+
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
313 |
+
|
314 |
+
def forward(self, x):
|
315 |
+
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
316 |
+
#return pt_checkpoint(self._forward, x) # pytorch
|
317 |
+
|
318 |
+
def _forward(self, x):
|
319 |
+
b, c, *spatial = x.shape
|
320 |
+
x = x.reshape(b, c, -1)
|
321 |
+
qkv = self.qkv(self.norm(x))
|
322 |
+
h = self.attention(qkv)
|
323 |
+
h = self.proj_out(h)
|
324 |
+
return (x + h).reshape(b, c, *spatial)
|
325 |
+
|
326 |
+
|
327 |
+
def count_flops_attn(model, _x, y):
|
328 |
+
"""
|
329 |
+
A counter for the `thop` package to count the operations in an
|
330 |
+
attention operation.
|
331 |
+
Meant to be used like:
|
332 |
+
macs, params = thop.profile(
|
333 |
+
model,
|
334 |
+
inputs=(inputs, timestamps),
|
335 |
+
custom_ops={QKVAttention: QKVAttention.count_flops},
|
336 |
+
)
|
337 |
+
"""
|
338 |
+
b, c, *spatial = y[0].shape
|
339 |
+
num_spatial = int(np.prod(spatial))
|
340 |
+
# We perform two matmuls with the same number of ops.
|
341 |
+
# The first computes the weight matrix, the second computes
|
342 |
+
# the combination of the value vectors.
|
343 |
+
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
344 |
+
model.total_ops += th.DoubleTensor([matmul_ops])
|
345 |
+
|
346 |
+
|
347 |
+
class QKVAttentionLegacy(nn.Module):
|
348 |
+
"""
|
349 |
+
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
350 |
+
"""
|
351 |
+
|
352 |
+
def __init__(self, n_heads):
|
353 |
+
super().__init__()
|
354 |
+
self.n_heads = n_heads
|
355 |
+
|
356 |
+
def forward(self, qkv):
|
357 |
+
"""
|
358 |
+
Apply QKV attention.
|
359 |
+
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
360 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
361 |
+
"""
|
362 |
+
bs, width, length = qkv.shape
|
363 |
+
assert width % (3 * self.n_heads) == 0
|
364 |
+
ch = width // (3 * self.n_heads)
|
365 |
+
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
366 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
367 |
+
weight = th.einsum(
|
368 |
+
"bct,bcs->bts", q * scale, k * scale
|
369 |
+
) # More stable with f16 than dividing afterwards
|
370 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
371 |
+
a = th.einsum("bts,bcs->bct", weight, v)
|
372 |
+
return a.reshape(bs, -1, length)
|
373 |
+
|
374 |
+
@staticmethod
|
375 |
+
def count_flops(model, _x, y):
|
376 |
+
return count_flops_attn(model, _x, y)
|
377 |
+
|
378 |
+
|
379 |
+
class QKVAttention(nn.Module):
|
380 |
+
"""
|
381 |
+
A module which performs QKV attention and splits in a different order.
|
382 |
+
"""
|
383 |
+
|
384 |
+
def __init__(self, n_heads):
|
385 |
+
super().__init__()
|
386 |
+
self.n_heads = n_heads
|
387 |
+
|
388 |
+
def forward(self, qkv):
|
389 |
+
"""
|
390 |
+
Apply QKV attention.
|
391 |
+
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
392 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
393 |
+
"""
|
394 |
+
bs, width, length = qkv.shape
|
395 |
+
assert width % (3 * self.n_heads) == 0
|
396 |
+
ch = width // (3 * self.n_heads)
|
397 |
+
q, k, v = qkv.chunk(3, dim=1)
|
398 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
399 |
+
weight = th.einsum(
|
400 |
+
"bct,bcs->bts",
|
401 |
+
(q * scale).view(bs * self.n_heads, ch, length),
|
402 |
+
(k * scale).view(bs * self.n_heads, ch, length),
|
403 |
+
) # More stable with f16 than dividing afterwards
|
404 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
405 |
+
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
406 |
+
return a.reshape(bs, -1, length)
|
407 |
+
|
408 |
+
@staticmethod
|
409 |
+
def count_flops(model, _x, y):
|
410 |
+
return count_flops_attn(model, _x, y)
|
411 |
+
|
412 |
+
|
413 |
+
class UNetModel(nn.Module):
|
414 |
+
"""
|
415 |
+
The full UNet model with attention and timestep embedding.
|
416 |
+
:param in_channels: channels in the input Tensor.
|
417 |
+
:param model_channels: base channel count for the model.
|
418 |
+
:param out_channels: channels in the output Tensor.
|
419 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
420 |
+
:param attention_resolutions: a collection of downsample rates at which
|
421 |
+
attention will take place. May be a set, list, or tuple.
|
422 |
+
For example, if this contains 4, then at 4x downsampling, attention
|
423 |
+
will be used.
|
424 |
+
:param dropout: the dropout probability.
|
425 |
+
:param channel_mult: channel multiplier for each level of the UNet.
|
426 |
+
:param conv_resample: if True, use learned convolutions for upsampling and
|
427 |
+
downsampling.
|
428 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
429 |
+
:param num_classes: if specified (as an int), then this model will be
|
430 |
+
class-conditional with `num_classes` classes.
|
431 |
+
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
432 |
+
:param num_heads: the number of attention heads in each attention layer.
|
433 |
+
:param num_heads_channels: if specified, ignore num_heads and instead use
|
434 |
+
a fixed channel width per attention head.
|
435 |
+
:param num_heads_upsample: works with num_heads to set a different number
|
436 |
+
of heads for upsampling. Deprecated.
|
437 |
+
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
438 |
+
:param resblock_updown: use residual blocks for up/downsampling.
|
439 |
+
:param use_new_attention_order: use a different attention pattern for potentially
|
440 |
+
increased efficiency.
|
441 |
+
"""
|
442 |
+
|
443 |
+
def __init__(
|
444 |
+
self,
|
445 |
+
image_size=32,
|
446 |
+
in_channels=4,
|
447 |
+
model_channels=256,
|
448 |
+
out_channels=4,
|
449 |
+
num_res_blocks=2,
|
450 |
+
attention_resolutions=[4,2,1],
|
451 |
+
dropout=0,
|
452 |
+
channel_mult=(1, 2, 4),
|
453 |
+
conv_resample=True,
|
454 |
+
dims=2,
|
455 |
+
num_classes=None,
|
456 |
+
use_checkpoint=False,
|
457 |
+
use_fp16=False,
|
458 |
+
num_heads=8,
|
459 |
+
num_head_channels=-1,
|
460 |
+
num_heads_upsample=-1,
|
461 |
+
use_scale_shift_norm=False,
|
462 |
+
resblock_updown=False,
|
463 |
+
use_new_attention_order=False,
|
464 |
+
use_spatial_transformer=False, # custom transformer support
|
465 |
+
transformer_depth=1, # custom transformer support
|
466 |
+
context_dim=None, # custom transformer support
|
467 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
468 |
+
legacy=True,
|
469 |
+
**kwargs
|
470 |
+
):
|
471 |
+
super().__init__()
|
472 |
+
if use_spatial_transformer:
|
473 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
474 |
+
|
475 |
+
if context_dim is not None:
|
476 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
477 |
+
# from omegaconf.listconfig import ListConfig
|
478 |
+
# if type(context_dim) == ListConfig:
|
479 |
+
# context_dim = list(context_dim)
|
480 |
+
|
481 |
+
if num_heads_upsample == -1:
|
482 |
+
num_heads_upsample = num_heads
|
483 |
+
|
484 |
+
if num_heads == -1:
|
485 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
486 |
+
|
487 |
+
if num_head_channels == -1:
|
488 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
489 |
+
|
490 |
+
self.image_size = image_size
|
491 |
+
self.in_channels = in_channels
|
492 |
+
self.model_channels = model_channels
|
493 |
+
self.out_channels = out_channels
|
494 |
+
self.num_res_blocks = num_res_blocks
|
495 |
+
self.attention_resolutions = attention_resolutions
|
496 |
+
self.dropout = dropout
|
497 |
+
self.channel_mult = channel_mult
|
498 |
+
self.conv_resample = conv_resample
|
499 |
+
self.num_classes = num_classes
|
500 |
+
self.use_checkpoint = use_checkpoint
|
501 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
502 |
+
self.num_heads = num_heads
|
503 |
+
self.num_head_channels = num_head_channels
|
504 |
+
self.num_heads_upsample = num_heads_upsample
|
505 |
+
self.predict_codebook_ids = n_embed is not None
|
506 |
+
|
507 |
+
time_embed_dim = model_channels * 4
|
508 |
+
self.time_embed = nn.Sequential(
|
509 |
+
linear(model_channels, time_embed_dim),
|
510 |
+
nn.SiLU(),
|
511 |
+
linear(time_embed_dim, time_embed_dim),
|
512 |
+
)
|
513 |
+
|
514 |
+
if self.num_classes is not None:
|
515 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
516 |
+
|
517 |
+
self.input_blocks = nn.ModuleList(
|
518 |
+
[
|
519 |
+
TimestepEmbedSequential(
|
520 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
521 |
+
)
|
522 |
+
]
|
523 |
+
)
|
524 |
+
self._feature_size = model_channels
|
525 |
+
input_block_chans = [model_channels]
|
526 |
+
ch = model_channels
|
527 |
+
ds = 1
|
528 |
+
for level, mult in enumerate(channel_mult):
|
529 |
+
for _ in range(num_res_blocks):
|
530 |
+
layers = [
|
531 |
+
ResBlock(
|
532 |
+
ch,
|
533 |
+
time_embed_dim,
|
534 |
+
dropout,
|
535 |
+
out_channels=mult * model_channels,
|
536 |
+
dims=dims,
|
537 |
+
use_checkpoint=use_checkpoint,
|
538 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
539 |
+
)
|
540 |
+
]
|
541 |
+
ch = mult * model_channels
|
542 |
+
if ds in attention_resolutions:
|
543 |
+
if num_head_channels == -1:
|
544 |
+
dim_head = ch // num_heads
|
545 |
+
else:
|
546 |
+
num_heads = ch // num_head_channels
|
547 |
+
dim_head = num_head_channels
|
548 |
+
if legacy:
|
549 |
+
#num_heads = 1
|
550 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
551 |
+
layers.append(
|
552 |
+
AttentionBlock(
|
553 |
+
ch,
|
554 |
+
use_checkpoint=use_checkpoint,
|
555 |
+
num_heads=num_heads,
|
556 |
+
num_head_channels=dim_head,
|
557 |
+
use_new_attention_order=use_new_attention_order,
|
558 |
+
) if not use_spatial_transformer else SpatialTransformer(
|
559 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
560 |
+
)
|
561 |
+
)
|
562 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
563 |
+
self._feature_size += ch
|
564 |
+
input_block_chans.append(ch)
|
565 |
+
if level != len(channel_mult) - 1:
|
566 |
+
out_ch = ch
|
567 |
+
self.input_blocks.append(
|
568 |
+
TimestepEmbedSequential(
|
569 |
+
ResBlock(
|
570 |
+
ch,
|
571 |
+
time_embed_dim,
|
572 |
+
dropout,
|
573 |
+
out_channels=out_ch,
|
574 |
+
dims=dims,
|
575 |
+
use_checkpoint=use_checkpoint,
|
576 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
577 |
+
down=True,
|
578 |
+
)
|
579 |
+
if resblock_updown
|
580 |
+
else Downsample(
|
581 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
582 |
+
)
|
583 |
+
)
|
584 |
+
)
|
585 |
+
ch = out_ch
|
586 |
+
input_block_chans.append(ch)
|
587 |
+
ds *= 2
|
588 |
+
self._feature_size += ch
|
589 |
+
|
590 |
+
if num_head_channels == -1:
|
591 |
+
dim_head = ch // num_heads
|
592 |
+
else:
|
593 |
+
num_heads = ch // num_head_channels
|
594 |
+
dim_head = num_head_channels
|
595 |
+
if legacy:
|
596 |
+
#num_heads = 1
|
597 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
598 |
+
self.middle_block = TimestepEmbedSequential(
|
599 |
+
ResBlock(
|
600 |
+
ch,
|
601 |
+
time_embed_dim,
|
602 |
+
dropout,
|
603 |
+
dims=dims,
|
604 |
+
use_checkpoint=use_checkpoint,
|
605 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
606 |
+
),
|
607 |
+
AttentionBlock(
|
608 |
+
ch,
|
609 |
+
use_checkpoint=use_checkpoint,
|
610 |
+
num_heads=num_heads,
|
611 |
+
num_head_channels=dim_head,
|
612 |
+
use_new_attention_order=use_new_attention_order,
|
613 |
+
) if not use_spatial_transformer else SpatialTransformer(
|
614 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
615 |
+
),
|
616 |
+
ResBlock(
|
617 |
+
ch,
|
618 |
+
time_embed_dim,
|
619 |
+
dropout,
|
620 |
+
dims=dims,
|
621 |
+
use_checkpoint=use_checkpoint,
|
622 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
623 |
+
),
|
624 |
+
)
|
625 |
+
self._feature_size += ch
|
626 |
+
|
627 |
+
self.output_blocks = nn.ModuleList([])
|
628 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
629 |
+
for i in range(num_res_blocks + 1):
|
630 |
+
ich = input_block_chans.pop()
|
631 |
+
layers = [
|
632 |
+
ResBlock(
|
633 |
+
ch + ich,
|
634 |
+
time_embed_dim,
|
635 |
+
dropout,
|
636 |
+
out_channels=model_channels * mult,
|
637 |
+
dims=dims,
|
638 |
+
use_checkpoint=use_checkpoint,
|
639 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
640 |
+
)
|
641 |
+
]
|
642 |
+
ch = model_channels * mult
|
643 |
+
if ds in attention_resolutions:
|
644 |
+
if num_head_channels == -1:
|
645 |
+
dim_head = ch // num_heads
|
646 |
+
else:
|
647 |
+
num_heads = ch // num_head_channels
|
648 |
+
dim_head = num_head_channels
|
649 |
+
if legacy:
|
650 |
+
#num_heads = 1
|
651 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
652 |
+
layers.append(
|
653 |
+
AttentionBlock(
|
654 |
+
ch,
|
655 |
+
use_checkpoint=use_checkpoint,
|
656 |
+
num_heads=num_heads_upsample,
|
657 |
+
num_head_channels=dim_head,
|
658 |
+
use_new_attention_order=use_new_attention_order,
|
659 |
+
) if not use_spatial_transformer else SpatialTransformer(
|
660 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
661 |
+
)
|
662 |
+
)
|
663 |
+
if level and i == num_res_blocks:
|
664 |
+
out_ch = ch
|
665 |
+
layers.append(
|
666 |
+
ResBlock(
|
667 |
+
ch,
|
668 |
+
time_embed_dim,
|
669 |
+
dropout,
|
670 |
+
out_channels=out_ch,
|
671 |
+
dims=dims,
|
672 |
+
use_checkpoint=use_checkpoint,
|
673 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
674 |
+
up=True,
|
675 |
+
)
|
676 |
+
if resblock_updown
|
677 |
+
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
678 |
+
)
|
679 |
+
ds //= 2
|
680 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
681 |
+
self._feature_size += ch
|
682 |
+
|
683 |
+
self.out = nn.Sequential(
|
684 |
+
normalization(ch),
|
685 |
+
nn.SiLU(),
|
686 |
+
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
687 |
+
)
|
688 |
+
if self.predict_codebook_ids:
|
689 |
+
self.id_predictor = nn.Sequential(
|
690 |
+
normalization(ch),
|
691 |
+
conv_nd(dims, model_channels, n_embed, 1),
|
692 |
+
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
693 |
+
)
|
694 |
+
|
695 |
+
def convert_to_fp16(self):
|
696 |
+
"""
|
697 |
+
Convert the torso of the model to float16.
|
698 |
+
"""
|
699 |
+
self.input_blocks.apply(convert_module_to_f16)
|
700 |
+
self.middle_block.apply(convert_module_to_f16)
|
701 |
+
self.output_blocks.apply(convert_module_to_f16)
|
702 |
+
|
703 |
+
def convert_to_fp32(self):
|
704 |
+
"""
|
705 |
+
Convert the torso of the model to float32.
|
706 |
+
"""
|
707 |
+
self.input_blocks.apply(convert_module_to_f32)
|
708 |
+
self.middle_block.apply(convert_module_to_f32)
|
709 |
+
self.output_blocks.apply(convert_module_to_f32)
|
710 |
+
|
711 |
+
def forward(self, x, t=None, condition=None, context=None, **kwargs):
|
712 |
+
"""
|
713 |
+
Apply the model to an input batch.
|
714 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
715 |
+
:param timesteps: a 1-D batch of timesteps.
|
716 |
+
:param context: conditioning plugged in via crossattn
|
717 |
+
:param y: an [N] Tensor of labels, if class-conditional.
|
718 |
+
:return: an [N x C x ...] Tensor of outputs.
|
719 |
+
"""
|
720 |
+
condition = None # --------------------- WANRING ONLY for Testing ---------------------
|
721 |
+
assert (condition is not None) == (
|
722 |
+
self.num_classes is not None
|
723 |
+
), "must specify y if and only if the model is class-conditional"
|
724 |
+
hs = []
|
725 |
+
t_emb = timestep_embedding(t, self.model_channels, repeat_only=False)
|
726 |
+
emb = self.time_embed(t_emb)
|
727 |
+
|
728 |
+
if self.num_classes is not None:
|
729 |
+
assert condition.shape == (x.shape[0],)
|
730 |
+
emb = emb + self.label_emb(condition)
|
731 |
+
|
732 |
+
h = x.type(self.dtype)
|
733 |
+
for module in self.input_blocks:
|
734 |
+
h = module(h, emb, context)
|
735 |
+
hs.append(h)
|
736 |
+
h = self.middle_block(h, emb, context)
|
737 |
+
for module in self.output_blocks:
|
738 |
+
h = th.cat([h, hs.pop()], dim=1)
|
739 |
+
h = module(h, emb, context)
|
740 |
+
h = h.type(x.dtype)
|
741 |
+
if self.predict_codebook_ids:
|
742 |
+
return self.id_predictor(h)
|
743 |
+
else:
|
744 |
+
return self.out(h), []
|
745 |
+
|
746 |
+
|
747 |
+
class EncoderUNetModel(nn.Module):
|
748 |
+
"""
|
749 |
+
The half UNet model with attention and timestep embedding.
|
750 |
+
For usage, see UNet.
|
751 |
+
"""
|
752 |
+
|
753 |
+
def __init__(
|
754 |
+
self,
|
755 |
+
image_size,
|
756 |
+
in_channels,
|
757 |
+
model_channels,
|
758 |
+
out_channels,
|
759 |
+
num_res_blocks,
|
760 |
+
attention_resolutions,
|
761 |
+
dropout=0,
|
762 |
+
channel_mult=(1, 2, 4, 8),
|
763 |
+
conv_resample=True,
|
764 |
+
dims=2,
|
765 |
+
use_checkpoint=False,
|
766 |
+
use_fp16=False,
|
767 |
+
num_heads=1,
|
768 |
+
num_head_channels=-1,
|
769 |
+
num_heads_upsample=-1,
|
770 |
+
use_scale_shift_norm=False,
|
771 |
+
resblock_updown=False,
|
772 |
+
use_new_attention_order=False,
|
773 |
+
pool="adaptive",
|
774 |
+
*args,
|
775 |
+
**kwargs
|
776 |
+
):
|
777 |
+
super().__init__()
|
778 |
+
|
779 |
+
if num_heads_upsample == -1:
|
780 |
+
num_heads_upsample = num_heads
|
781 |
+
|
782 |
+
self.in_channels = in_channels
|
783 |
+
self.model_channels = model_channels
|
784 |
+
self.out_channels = out_channels
|
785 |
+
self.num_res_blocks = num_res_blocks
|
786 |
+
self.attention_resolutions = attention_resolutions
|
787 |
+
self.dropout = dropout
|
788 |
+
self.channel_mult = channel_mult
|
789 |
+
self.conv_resample = conv_resample
|
790 |
+
self.use_checkpoint = use_checkpoint
|
791 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
792 |
+
self.num_heads = num_heads
|
793 |
+
self.num_head_channels = num_head_channels
|
794 |
+
self.num_heads_upsample = num_heads_upsample
|
795 |
+
|
796 |
+
time_embed_dim = model_channels * 4
|
797 |
+
self.time_embed = nn.Sequential(
|
798 |
+
linear(model_channels, time_embed_dim),
|
799 |
+
nn.SiLU(),
|
800 |
+
linear(time_embed_dim, time_embed_dim),
|
801 |
+
)
|
802 |
+
|
803 |
+
self.input_blocks = nn.ModuleList(
|
804 |
+
[
|
805 |
+
TimestepEmbedSequential(
|
806 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
807 |
+
)
|
808 |
+
]
|
809 |
+
)
|
810 |
+
self._feature_size = model_channels
|
811 |
+
input_block_chans = [model_channels]
|
812 |
+
ch = model_channels
|
813 |
+
ds = 1
|
814 |
+
for level, mult in enumerate(channel_mult):
|
815 |
+
for _ in range(num_res_blocks):
|
816 |
+
layers = [
|
817 |
+
ResBlock(
|
818 |
+
ch,
|
819 |
+
time_embed_dim,
|
820 |
+
dropout,
|
821 |
+
out_channels=mult * model_channels,
|
822 |
+
dims=dims,
|
823 |
+
use_checkpoint=use_checkpoint,
|
824 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
825 |
+
)
|
826 |
+
]
|
827 |
+
ch = mult * model_channels
|
828 |
+
if ds in attention_resolutions:
|
829 |
+
layers.append(
|
830 |
+
AttentionBlock(
|
831 |
+
ch,
|
832 |
+
use_checkpoint=use_checkpoint,
|
833 |
+
num_heads=num_heads,
|
834 |
+
num_head_channels=num_head_channels,
|
835 |
+
use_new_attention_order=use_new_attention_order,
|
836 |
+
)
|
837 |
+
)
|
838 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
839 |
+
self._feature_size += ch
|
840 |
+
input_block_chans.append(ch)
|
841 |
+
if level != len(channel_mult) - 1:
|
842 |
+
out_ch = ch
|
843 |
+
self.input_blocks.append(
|
844 |
+
TimestepEmbedSequential(
|
845 |
+
ResBlock(
|
846 |
+
ch,
|
847 |
+
time_embed_dim,
|
848 |
+
dropout,
|
849 |
+
out_channels=out_ch,
|
850 |
+
dims=dims,
|
851 |
+
use_checkpoint=use_checkpoint,
|
852 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
853 |
+
down=True,
|
854 |
+
)
|
855 |
+
if resblock_updown
|
856 |
+
else Downsample(
|
857 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
858 |
+
)
|
859 |
+
)
|
860 |
+
)
|
861 |
+
ch = out_ch
|
862 |
+
input_block_chans.append(ch)
|
863 |
+
ds *= 2
|
864 |
+
self._feature_size += ch
|
865 |
+
|
866 |
+
self.middle_block = TimestepEmbedSequential(
|
867 |
+
ResBlock(
|
868 |
+
ch,
|
869 |
+
time_embed_dim,
|
870 |
+
dropout,
|
871 |
+
dims=dims,
|
872 |
+
use_checkpoint=use_checkpoint,
|
873 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
874 |
+
),
|
875 |
+
AttentionBlock(
|
876 |
+
ch,
|
877 |
+
use_checkpoint=use_checkpoint,
|
878 |
+
num_heads=num_heads,
|
879 |
+
num_head_channels=num_head_channels,
|
880 |
+
use_new_attention_order=use_new_attention_order,
|
881 |
+
),
|
882 |
+
ResBlock(
|
883 |
+
ch,
|
884 |
+
time_embed_dim,
|
885 |
+
dropout,
|
886 |
+
dims=dims,
|
887 |
+
use_checkpoint=use_checkpoint,
|
888 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
889 |
+
),
|
890 |
+
)
|
891 |
+
self._feature_size += ch
|
892 |
+
self.pool = pool
|
893 |
+
if pool == "adaptive":
|
894 |
+
self.out = nn.Sequential(
|
895 |
+
normalization(ch),
|
896 |
+
nn.SiLU(),
|
897 |
+
nn.AdaptiveAvgPool2d((1, 1)),
|
898 |
+
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
899 |
+
nn.Flatten(),
|
900 |
+
)
|
901 |
+
elif pool == "attention":
|
902 |
+
assert num_head_channels != -1
|
903 |
+
self.out = nn.Sequential(
|
904 |
+
normalization(ch),
|
905 |
+
nn.SiLU(),
|
906 |
+
AttentionPool2d(
|
907 |
+
(image_size // ds), ch, num_head_channels, out_channels
|
908 |
+
),
|
909 |
+
)
|
910 |
+
elif pool == "spatial":
|
911 |
+
self.out = nn.Sequential(
|
912 |
+
nn.Linear(self._feature_size, 2048),
|
913 |
+
nn.ReLU(),
|
914 |
+
nn.Linear(2048, self.out_channels),
|
915 |
+
)
|
916 |
+
elif pool == "spatial_v2":
|
917 |
+
self.out = nn.Sequential(
|
918 |
+
nn.Linear(self._feature_size, 2048),
|
919 |
+
normalization(2048),
|
920 |
+
nn.SiLU(),
|
921 |
+
nn.Linear(2048, self.out_channels),
|
922 |
+
)
|
923 |
+
else:
|
924 |
+
raise NotImplementedError(f"Unexpected {pool} pooling")
|
925 |
+
|
926 |
+
def convert_to_fp16(self):
|
927 |
+
"""
|
928 |
+
Convert the torso of the model to float16.
|
929 |
+
"""
|
930 |
+
self.input_blocks.apply(convert_module_to_f16)
|
931 |
+
self.middle_block.apply(convert_module_to_f16)
|
932 |
+
|
933 |
+
def convert_to_fp32(self):
|
934 |
+
"""
|
935 |
+
Convert the torso of the model to float32.
|
936 |
+
"""
|
937 |
+
self.input_blocks.apply(convert_module_to_f32)
|
938 |
+
self.middle_block.apply(convert_module_to_f32)
|
939 |
+
|
940 |
+
def forward(self, x, timesteps):
|
941 |
+
"""
|
942 |
+
Apply the model to an input batch.
|
943 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
944 |
+
:param timesteps: a 1-D batch of timesteps.
|
945 |
+
:return: an [N x K] Tensor of outputs.
|
946 |
+
"""
|
947 |
+
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
948 |
+
|
949 |
+
results = []
|
950 |
+
h = x.type(self.dtype)
|
951 |
+
for module in self.input_blocks:
|
952 |
+
h = module(h, emb)
|
953 |
+
if self.pool.startswith("spatial"):
|
954 |
+
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
955 |
+
h = self.middle_block(h, emb)
|
956 |
+
if self.pool.startswith("spatial"):
|
957 |
+
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
958 |
+
h = th.cat(results, axis=-1)
|
959 |
+
return self.out(h)
|
960 |
+
else:
|
961 |
+
h = h.type(x.dtype)
|
962 |
+
return self.out(h)
|
medical_diffusion/external/stable_diffusion/util.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adopted from
|
2 |
+
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
3 |
+
# and
|
4 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
5 |
+
# and
|
6 |
+
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
7 |
+
#
|
8 |
+
# thanks!
|
9 |
+
|
10 |
+
|
11 |
+
import os
|
12 |
+
import math
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import numpy as np
|
16 |
+
from einops import repeat
|
17 |
+
|
18 |
+
#--------------- Added ----------------
|
19 |
+
import importlib
|
20 |
+
def get_obj_from_str(string, reload=False):
|
21 |
+
module, cls = string.rsplit(".", 1)
|
22 |
+
if reload:
|
23 |
+
module_imp = importlib.import_module(module)
|
24 |
+
importlib.reload(module_imp)
|
25 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
26 |
+
|
27 |
+
def instantiate_from_config(config):
|
28 |
+
if not "target" in config:
|
29 |
+
if config == '__is_first_stage__':
|
30 |
+
return None
|
31 |
+
elif config == "__is_unconditional__":
|
32 |
+
return None
|
33 |
+
raise KeyError("Expected key `target` to instantiate.")
|
34 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
35 |
+
|
36 |
+
#--------------------------------
|
37 |
+
|
38 |
+
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
39 |
+
if schedule == "linear":
|
40 |
+
betas = (
|
41 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
42 |
+
)
|
43 |
+
|
44 |
+
elif schedule == "cosine":
|
45 |
+
timesteps = (
|
46 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
47 |
+
)
|
48 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
49 |
+
alphas = torch.cos(alphas).pow(2)
|
50 |
+
alphas = alphas / alphas[0]
|
51 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
52 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
53 |
+
|
54 |
+
elif schedule == "sqrt_linear":
|
55 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
56 |
+
elif schedule == "sqrt":
|
57 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
58 |
+
else:
|
59 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
60 |
+
return betas.numpy()
|
61 |
+
|
62 |
+
|
63 |
+
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
64 |
+
if ddim_discr_method == 'uniform':
|
65 |
+
c = num_ddpm_timesteps // num_ddim_timesteps
|
66 |
+
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
67 |
+
elif ddim_discr_method == 'quad':
|
68 |
+
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
69 |
+
else:
|
70 |
+
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
71 |
+
|
72 |
+
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
73 |
+
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
74 |
+
steps_out = ddim_timesteps + 1
|
75 |
+
if verbose:
|
76 |
+
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
77 |
+
return steps_out
|
78 |
+
|
79 |
+
|
80 |
+
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
81 |
+
# select alphas for computing the variance schedule
|
82 |
+
alphas = alphacums[ddim_timesteps]
|
83 |
+
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
84 |
+
|
85 |
+
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
86 |
+
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
87 |
+
if verbose:
|
88 |
+
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
89 |
+
print(f'For the chosen value of eta, which is {eta}, '
|
90 |
+
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
91 |
+
return sigmas, alphas, alphas_prev
|
92 |
+
|
93 |
+
|
94 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
95 |
+
"""
|
96 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
97 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
98 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
99 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
100 |
+
produces the cumulative product of (1-beta) up to that
|
101 |
+
part of the diffusion process.
|
102 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
103 |
+
prevent singularities.
|
104 |
+
"""
|
105 |
+
betas = []
|
106 |
+
for i in range(num_diffusion_timesteps):
|
107 |
+
t1 = i / num_diffusion_timesteps
|
108 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
109 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
110 |
+
return np.array(betas)
|
111 |
+
|
112 |
+
|
113 |
+
def extract_into_tensor(a, t, x_shape):
|
114 |
+
b, *_ = t.shape
|
115 |
+
out = a.gather(-1, t)
|
116 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
117 |
+
|
118 |
+
|
119 |
+
def checkpoint(func, inputs, params, flag):
|
120 |
+
"""
|
121 |
+
Evaluate a function without caching intermediate activations, allowing for
|
122 |
+
reduced memory at the expense of extra compute in the backward pass.
|
123 |
+
:param func: the function to evaluate.
|
124 |
+
:param inputs: the argument sequence to pass to `func`.
|
125 |
+
:param params: a sequence of parameters `func` depends on but does not
|
126 |
+
explicitly take as arguments.
|
127 |
+
:param flag: if False, disable gradient checkpointing.
|
128 |
+
"""
|
129 |
+
if flag:
|
130 |
+
args = tuple(inputs) + tuple(params)
|
131 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
132 |
+
else:
|
133 |
+
return func(*inputs)
|
134 |
+
|
135 |
+
|
136 |
+
class CheckpointFunction(torch.autograd.Function):
|
137 |
+
@staticmethod
|
138 |
+
def forward(ctx, run_function, length, *args):
|
139 |
+
ctx.run_function = run_function
|
140 |
+
ctx.input_tensors = list(args[:length])
|
141 |
+
ctx.input_params = list(args[length:])
|
142 |
+
|
143 |
+
with torch.no_grad():
|
144 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
145 |
+
return output_tensors
|
146 |
+
|
147 |
+
@staticmethod
|
148 |
+
def backward(ctx, *output_grads):
|
149 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
150 |
+
with torch.enable_grad():
|
151 |
+
# Fixes a bug where the first op in run_function modifies the
|
152 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
153 |
+
# Tensors.
|
154 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
155 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
156 |
+
input_grads = torch.autograd.grad(
|
157 |
+
output_tensors,
|
158 |
+
ctx.input_tensors + ctx.input_params,
|
159 |
+
output_grads,
|
160 |
+
allow_unused=True,
|
161 |
+
)
|
162 |
+
del ctx.input_tensors
|
163 |
+
del ctx.input_params
|
164 |
+
del output_tensors
|
165 |
+
return (None, None) + input_grads
|
166 |
+
|
167 |
+
|
168 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
169 |
+
"""
|
170 |
+
Create sinusoidal timestep embeddings.
|
171 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
172 |
+
These may be fractional.
|
173 |
+
:param dim: the dimension of the output.
|
174 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
175 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
176 |
+
"""
|
177 |
+
if not repeat_only:
|
178 |
+
half = dim // 2
|
179 |
+
freqs = torch.exp(
|
180 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
181 |
+
).to(device=timesteps.device)
|
182 |
+
args = timesteps[:, None].float() * freqs[None]
|
183 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
184 |
+
if dim % 2:
|
185 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
186 |
+
else:
|
187 |
+
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
188 |
+
return embedding
|
189 |
+
|
190 |
+
|
191 |
+
def zero_module(module):
|
192 |
+
"""
|
193 |
+
Zero out the parameters of a module and return it.
|
194 |
+
"""
|
195 |
+
for p in module.parameters():
|
196 |
+
p.detach().zero_()
|
197 |
+
return module
|
198 |
+
|
199 |
+
|
200 |
+
def scale_module(module, scale):
|
201 |
+
"""
|
202 |
+
Scale the parameters of a module and return it.
|
203 |
+
"""
|
204 |
+
for p in module.parameters():
|
205 |
+
p.detach().mul_(scale)
|
206 |
+
return module
|
207 |
+
|
208 |
+
|
209 |
+
def mean_flat(tensor):
|
210 |
+
"""
|
211 |
+
Take the mean over all non-batch dimensions.
|
212 |
+
"""
|
213 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
214 |
+
|
215 |
+
|
216 |
+
def normalization(channels):
|
217 |
+
"""
|
218 |
+
Make a standard normalization layer.
|
219 |
+
:param channels: number of input channels.
|
220 |
+
:return: an nn.Module for normalization.
|
221 |
+
"""
|
222 |
+
return GroupNorm32(32, channels)
|
223 |
+
|
224 |
+
|
225 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
226 |
+
class SiLU(nn.Module):
|
227 |
+
def forward(self, x):
|
228 |
+
return x * torch.sigmoid(x)
|
229 |
+
|
230 |
+
|
231 |
+
class GroupNorm32(nn.GroupNorm):
|
232 |
+
def forward(self, x):
|
233 |
+
return super().forward(x.float()).type(x.dtype)
|
234 |
+
|
235 |
+
def conv_nd(dims, *args, **kwargs):
|
236 |
+
"""
|
237 |
+
Create a 1D, 2D, or 3D convolution module.
|
238 |
+
"""
|
239 |
+
if dims == 1:
|
240 |
+
return nn.Conv1d(*args, **kwargs)
|
241 |
+
elif dims == 2:
|
242 |
+
return nn.Conv2d(*args, **kwargs)
|
243 |
+
elif dims == 3:
|
244 |
+
return nn.Conv3d(*args, **kwargs)
|
245 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
246 |
+
|
247 |
+
|
248 |
+
def linear(*args, **kwargs):
|
249 |
+
"""
|
250 |
+
Create a linear module.
|
251 |
+
"""
|
252 |
+
return nn.Linear(*args, **kwargs)
|
253 |
+
|
254 |
+
|
255 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
256 |
+
"""
|
257 |
+
Create a 1D, 2D, or 3D average pooling module.
|
258 |
+
"""
|
259 |
+
if dims == 1:
|
260 |
+
return nn.AvgPool1d(*args, **kwargs)
|
261 |
+
elif dims == 2:
|
262 |
+
return nn.AvgPool2d(*args, **kwargs)
|
263 |
+
elif dims == 3:
|
264 |
+
return nn.AvgPool3d(*args, **kwargs)
|
265 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
266 |
+
|
267 |
+
|
268 |
+
class HybridConditioner(nn.Module):
|
269 |
+
|
270 |
+
def __init__(self, c_concat_config, c_crossattn_config):
|
271 |
+
super().__init__()
|
272 |
+
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
273 |
+
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
274 |
+
|
275 |
+
def forward(self, c_concat, c_crossattn):
|
276 |
+
c_concat = self.concat_conditioner(c_concat)
|
277 |
+
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
278 |
+
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
279 |
+
|
280 |
+
|
281 |
+
def noise_like(shape, device, repeat=False):
|
282 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
283 |
+
noise = lambda: torch.randn(shape, device=device)
|
284 |
+
return repeat_noise() if repeat else noise()
|
medical_diffusion/external/stable_diffusion/util_attention.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
from einops import repeat
|
8 |
+
|
9 |
+
def checkpoint(func, inputs, params, flag):
|
10 |
+
"""
|
11 |
+
Evaluate a function without caching intermediate activations, allowing for
|
12 |
+
reduced memory at the expense of extra compute in the backward pass.
|
13 |
+
:param func: the function to evaluate.
|
14 |
+
:param inputs: the argument sequence to pass to `func`.
|
15 |
+
:param params: a sequence of parameters `func` depends on but does not
|
16 |
+
explicitly take as arguments.
|
17 |
+
:param flag: if False, disable gradient checkpointing.
|
18 |
+
"""
|
19 |
+
if flag:
|
20 |
+
args = tuple(inputs) + tuple(params)
|
21 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
22 |
+
else:
|
23 |
+
return func(*inputs)
|
24 |
+
|
25 |
+
|
26 |
+
class CheckpointFunction(torch.autograd.Function):
|
27 |
+
@staticmethod
|
28 |
+
def forward(ctx, run_function, length, *args):
|
29 |
+
ctx.run_function = run_function
|
30 |
+
ctx.input_tensors = list(args[:length])
|
31 |
+
ctx.input_params = list(args[length:])
|
32 |
+
|
33 |
+
with torch.no_grad():
|
34 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
35 |
+
return output_tensors
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def backward(ctx, *output_grads):
|
39 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
40 |
+
with torch.enable_grad():
|
41 |
+
# Fixes a bug where the first op in run_function modifies the
|
42 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
43 |
+
# Tensors.
|
44 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
45 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
46 |
+
input_grads = torch.autograd.grad(
|
47 |
+
output_tensors,
|
48 |
+
ctx.input_tensors + ctx.input_params,
|
49 |
+
output_grads,
|
50 |
+
allow_unused=True,
|
51 |
+
)
|
52 |
+
del ctx.input_tensors
|
53 |
+
del ctx.input_params
|
54 |
+
del output_tensors
|
55 |
+
return (None, None) + input_grads
|
56 |
+
|
medical_diffusion/external/unet_lucidrains.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn, einsum
|
2 |
+
from einops import rearrange, reduce
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from functools import partial
|
6 |
+
import math
|
7 |
+
|
8 |
+
# -------------------------------- Embeddings ------------------------------------------------------
|
9 |
+
class SinusoidalPosEmb(nn.Module):
|
10 |
+
def __init__(self, dim):
|
11 |
+
super().__init__()
|
12 |
+
self.dim = dim
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
device = x.device
|
16 |
+
half_dim = self.dim // 2
|
17 |
+
emb = math.log(10000) / (half_dim - 1)
|
18 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
19 |
+
emb = x[:, None] * emb[None, :]
|
20 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
21 |
+
return emb
|
22 |
+
|
23 |
+
class LearnedSinusoidalPosEmb(nn.Module):
|
24 |
+
""" following @crowsonkb 's lead with learned sinusoidal pos emb """
|
25 |
+
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
|
26 |
+
|
27 |
+
def __init__(self, dim):
|
28 |
+
super().__init__()
|
29 |
+
assert (dim % 2) == 0
|
30 |
+
half_dim = dim // 2
|
31 |
+
self.weights = nn.Parameter(torch.randn(half_dim))
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
x = rearrange(x, 'b -> b 1')
|
35 |
+
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
|
36 |
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
|
37 |
+
fouriered = torch.cat((x, fouriered), dim = -1)
|
38 |
+
return fouriered
|
39 |
+
|
40 |
+
# -------------------------------------------------------------------
|
41 |
+
|
42 |
+
def exists(x):
|
43 |
+
return x is not None
|
44 |
+
|
45 |
+
def default(val, d):
|
46 |
+
if exists(val):
|
47 |
+
return val
|
48 |
+
return d() if callable(d) else d
|
49 |
+
|
50 |
+
def l2norm(t):
|
51 |
+
return F.normalize(t, dim = -1)
|
52 |
+
|
53 |
+
class Residual(nn.Module):
|
54 |
+
def __init__(self, fn):
|
55 |
+
super().__init__()
|
56 |
+
self.fn = fn
|
57 |
+
|
58 |
+
def forward(self, x, *args, **kwargs):
|
59 |
+
return self.fn(x, *args, **kwargs) + x
|
60 |
+
|
61 |
+
def Upsample(dim, dim_out = None):
|
62 |
+
return nn.Sequential(
|
63 |
+
nn.Upsample(scale_factor = 2, mode = 'nearest'),
|
64 |
+
nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
|
65 |
+
)
|
66 |
+
|
67 |
+
def Downsample(dim, dim_out = None):
|
68 |
+
return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)
|
69 |
+
|
70 |
+
class WeightStandardizedConv2d(nn.Conv2d):
|
71 |
+
"""
|
72 |
+
https://arxiv.org/abs/1903.10520
|
73 |
+
weight standardization purportedly works synergistically with group normalization
|
74 |
+
"""
|
75 |
+
def forward(self, x):
|
76 |
+
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
|
77 |
+
|
78 |
+
weight = self.weight
|
79 |
+
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
|
80 |
+
var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False))
|
81 |
+
normalized_weight = (weight - mean) * (var + eps).rsqrt()
|
82 |
+
|
83 |
+
return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
84 |
+
|
85 |
+
|
86 |
+
class LayerNorm(nn.Module):
|
87 |
+
def __init__(self, dim):
|
88 |
+
super().__init__()
|
89 |
+
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
|
93 |
+
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
94 |
+
mean = torch.mean(x, dim = 1, keepdim = True)
|
95 |
+
return (x - mean) * (var + eps).rsqrt() * self.g
|
96 |
+
|
97 |
+
class PreNorm(nn.Module):
|
98 |
+
def __init__(self, dim, fn):
|
99 |
+
super().__init__()
|
100 |
+
self.fn = fn
|
101 |
+
self.norm = LayerNorm(dim)
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
x = self.norm(x)
|
105 |
+
return self.fn(x)
|
106 |
+
|
107 |
+
class Block(nn.Module):
|
108 |
+
def __init__(self, dim, dim_out, groups = 8):
|
109 |
+
super().__init__()
|
110 |
+
self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
|
111 |
+
self.norm = nn.GroupNorm(groups, dim_out)
|
112 |
+
self.act = nn.SiLU()
|
113 |
+
|
114 |
+
def forward(self, x, scale_shift = None):
|
115 |
+
x = self.proj(x)
|
116 |
+
x = self.norm(x)
|
117 |
+
|
118 |
+
if exists(scale_shift):
|
119 |
+
scale, shift = scale_shift
|
120 |
+
x = x * (scale + 1) + shift
|
121 |
+
|
122 |
+
x = self.act(x)
|
123 |
+
return x
|
124 |
+
|
125 |
+
class ResnetBlock(nn.Module):
|
126 |
+
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
|
127 |
+
super().__init__()
|
128 |
+
self.mlp = nn.Sequential(
|
129 |
+
nn.SiLU(),
|
130 |
+
nn.Linear(time_emb_dim, dim_out * 2)
|
131 |
+
) if exists(time_emb_dim) else None
|
132 |
+
|
133 |
+
self.block1 = Block(dim, dim_out, groups = groups)
|
134 |
+
self.block2 = Block(dim_out, dim_out, groups = groups)
|
135 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
136 |
+
|
137 |
+
def forward(self, x, time_emb = None):
|
138 |
+
|
139 |
+
scale_shift = None
|
140 |
+
if exists(self.mlp) and exists(time_emb):
|
141 |
+
time_emb = self.mlp(time_emb)
|
142 |
+
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
|
143 |
+
scale_shift = time_emb.chunk(2, dim = 1)
|
144 |
+
|
145 |
+
h = self.block1(x, scale_shift = scale_shift)
|
146 |
+
|
147 |
+
h = self.block2(h)
|
148 |
+
|
149 |
+
return h + self.res_conv(x)
|
150 |
+
|
151 |
+
class LinearAttention(nn.Module):
|
152 |
+
def __init__(self, dim, heads = 4, dim_head = 32):
|
153 |
+
super().__init__()
|
154 |
+
self.scale = dim_head ** -0.5
|
155 |
+
self.heads = heads
|
156 |
+
hidden_dim = dim_head * heads
|
157 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
158 |
+
|
159 |
+
self.to_out = nn.Sequential(
|
160 |
+
nn.Conv2d(hidden_dim, dim, 1),
|
161 |
+
LayerNorm(dim)
|
162 |
+
)
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
b, c, h, w = x.shape
|
166 |
+
qkv = self.to_qkv(x).chunk(3, dim = 1)
|
167 |
+
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
|
168 |
+
|
169 |
+
q = q.softmax(dim = -2)
|
170 |
+
k = k.softmax(dim = -1)
|
171 |
+
|
172 |
+
q = q * self.scale
|
173 |
+
v = v / (h * w)
|
174 |
+
|
175 |
+
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
|
176 |
+
|
177 |
+
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
|
178 |
+
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
|
179 |
+
return self.to_out(out)
|
180 |
+
|
181 |
+
class Attention(nn.Module):
|
182 |
+
def __init__(self, dim, heads = 4, dim_head = 32, scale = 10):
|
183 |
+
super().__init__()
|
184 |
+
self.scale = scale
|
185 |
+
self.heads = heads
|
186 |
+
hidden_dim = dim_head * heads
|
187 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
188 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
189 |
+
|
190 |
+
def forward(self, x):
|
191 |
+
b, c, h, w = x.shape
|
192 |
+
qkv = self.to_qkv(x).chunk(3, dim = 1)
|
193 |
+
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
|
194 |
+
|
195 |
+
q, k = map(l2norm, (q, k))
|
196 |
+
|
197 |
+
sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
|
198 |
+
attn = sim.softmax(dim = -1)
|
199 |
+
out = einsum('b h i j, b h d j -> b h i d', attn, v)
|
200 |
+
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
|
201 |
+
return self.to_out(out)
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
class UNet(nn.Module):
|
206 |
+
def __init__(
|
207 |
+
self,
|
208 |
+
dim=32,
|
209 |
+
init_dim = None,
|
210 |
+
out_dim = None,
|
211 |
+
dim_mults=(1, 2, 4, 8),
|
212 |
+
channels = 3,
|
213 |
+
self_condition = False,
|
214 |
+
resnet_block_groups = 8,
|
215 |
+
learned_variance = False,
|
216 |
+
learned_sinusoidal_cond = False,
|
217 |
+
learned_sinusoidal_dim = 16,
|
218 |
+
**kwargs
|
219 |
+
):
|
220 |
+
super().__init__()
|
221 |
+
|
222 |
+
# determine dimensions
|
223 |
+
|
224 |
+
self.channels = channels
|
225 |
+
self.self_condition = self_condition
|
226 |
+
input_channels = channels * (2 if self_condition else 1)
|
227 |
+
|
228 |
+
init_dim = default(init_dim, dim)
|
229 |
+
self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
|
230 |
+
|
231 |
+
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
232 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
233 |
+
|
234 |
+
block_klass = partial(ResnetBlock, groups = resnet_block_groups)
|
235 |
+
|
236 |
+
# time embeddings
|
237 |
+
|
238 |
+
time_dim = dim * 4
|
239 |
+
|
240 |
+
self.learned_sinusoidal_cond = learned_sinusoidal_cond
|
241 |
+
|
242 |
+
if learned_sinusoidal_cond:
|
243 |
+
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
|
244 |
+
fourier_dim = learned_sinusoidal_dim + 1
|
245 |
+
else:
|
246 |
+
sinu_pos_emb = SinusoidalPosEmb(dim)
|
247 |
+
fourier_dim = dim
|
248 |
+
|
249 |
+
self.time_mlp = nn.Sequential(
|
250 |
+
sinu_pos_emb,
|
251 |
+
nn.Linear(fourier_dim, time_dim),
|
252 |
+
nn.GELU(),
|
253 |
+
nn.Linear(time_dim, time_dim)
|
254 |
+
)
|
255 |
+
|
256 |
+
# layers
|
257 |
+
|
258 |
+
self.downs = nn.ModuleList([])
|
259 |
+
self.ups = nn.ModuleList([])
|
260 |
+
num_resolutions = len(in_out)
|
261 |
+
|
262 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
263 |
+
is_last = ind >= (num_resolutions - 1)
|
264 |
+
|
265 |
+
self.downs.append(nn.ModuleList([
|
266 |
+
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
|
267 |
+
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
|
268 |
+
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
|
269 |
+
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
|
270 |
+
]))
|
271 |
+
|
272 |
+
mid_dim = dims[-1]
|
273 |
+
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
|
274 |
+
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
|
275 |
+
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
|
276 |
+
|
277 |
+
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
|
278 |
+
is_last = ind == (len(in_out) - 1)
|
279 |
+
|
280 |
+
self.ups.append(nn.ModuleList([
|
281 |
+
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
|
282 |
+
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
|
283 |
+
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
284 |
+
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
|
285 |
+
]))
|
286 |
+
|
287 |
+
default_out_dim = channels * (1 if not learned_variance else 2)
|
288 |
+
self.out_dim = default(out_dim, default_out_dim)
|
289 |
+
|
290 |
+
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
|
291 |
+
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
|
292 |
+
|
293 |
+
def forward(self, x, time, condition=None, self_cond=None):
|
294 |
+
if self.self_condition:
|
295 |
+
x_self_cond = default(self_cond, lambda: torch.zeros_like(x))
|
296 |
+
x = torch.cat((x_self_cond, x), dim = 1)
|
297 |
+
|
298 |
+
x = self.init_conv(x)
|
299 |
+
r = x.clone()
|
300 |
+
|
301 |
+
t = self.time_mlp(time)
|
302 |
+
|
303 |
+
h = []
|
304 |
+
|
305 |
+
for block1, block2, attn, downsample in self.downs:
|
306 |
+
x = block1(x, t)
|
307 |
+
h.append(x)
|
308 |
+
|
309 |
+
x = block2(x, t)
|
310 |
+
x = attn(x)
|
311 |
+
h.append(x)
|
312 |
+
|
313 |
+
x = downsample(x)
|
314 |
+
|
315 |
+
x = self.mid_block1(x, t)
|
316 |
+
x = self.mid_attn(x)
|
317 |
+
x = self.mid_block2(x, t)
|
318 |
+
|
319 |
+
for block1, block2, attn, upsample in self.ups:
|
320 |
+
x = torch.cat((x, h.pop()), dim = 1)
|
321 |
+
x = block1(x, t)
|
322 |
+
|
323 |
+
x = torch.cat((x, h.pop()), dim = 1)
|
324 |
+
x = block2(x, t)
|
325 |
+
x = attn(x)
|
326 |
+
|
327 |
+
x = upsample(x)
|
328 |
+
|
329 |
+
x = torch.cat((x, r), dim = 1)
|
330 |
+
|
331 |
+
x = self.final_res_block(x, t)
|
332 |
+
return self.final_conv(x), []
|
medical_diffusion/loss/gan_losses.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
def exp_d_loss(logits_real, logits_fake):
|
7 |
+
loss_real = torch.mean(torch.exp(-logits_real))
|
8 |
+
loss_fake = torch.mean(torch.exp(logits_fake))
|
9 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
10 |
+
return d_loss
|
11 |
+
|
12 |
+
def hinge_d_loss(logits_real, logits_fake):
|
13 |
+
loss_real = torch.mean(F.relu(1. - logits_real))
|
14 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
15 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
16 |
+
return d_loss
|
17 |
+
|
18 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
19 |
+
d_loss = 0.5 * (
|
20 |
+
torch.mean(F.softplus(-logits_real)) +
|
21 |
+
torch.mean(F.softplus(logits_fake)))
|
22 |
+
return d_loss
|
medical_diffusion/loss/perceivers.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import lpips
|
4 |
+
import torch
|
5 |
+
|
6 |
+
class LPIPS(torch.nn.Module):
|
7 |
+
"""Learned Perceptual Image Patch Similarity (LPIPS)"""
|
8 |
+
def __init__(self, linear_calibration=False, normalize=False):
|
9 |
+
super().__init__()
|
10 |
+
self.loss_fn = lpips.LPIPS(net='vgg', lpips=linear_calibration) # Note: only 'vgg' valid as loss
|
11 |
+
self.normalize = normalize # If true, normalize [0, 1] to [-1, 1]
|
12 |
+
|
13 |
+
|
14 |
+
def forward(self, pred, target):
|
15 |
+
# No need to do that because ScalingLayer was introduced in version 0.1 which does this indirectly
|
16 |
+
# if pred.shape[1] == 1: # convert 1-channel gray images to 3-channel RGB
|
17 |
+
# pred = torch.concat([pred, pred, pred], dim=1)
|
18 |
+
# if target.shape[1] == 1: # convert 1-channel gray images to 3-channel RGB
|
19 |
+
# target = torch.concat([target, target, target], dim=1)
|
20 |
+
|
21 |
+
if pred.ndim == 5: # 3D Image: Just use 2D model and compute average over slices
|
22 |
+
depth = pred.shape[2]
|
23 |
+
losses = torch.stack([self.loss_fn(pred[:,:,d], target[:,:,d], normalize=self.normalize) for d in range(depth)], dim=2)
|
24 |
+
return torch.mean(losses, dim=2, keepdim=True)
|
25 |
+
else:
|
26 |
+
return self.loss_fn(pred, target, normalize=self.normalize)
|
27 |
+
|
medical_diffusion/metrics/__init__.py
ADDED
File without changes
|
medical_diffusion/metrics/torchmetrics_pr_recall.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor
|
5 |
+
from torchmetrics import Metric
|
6 |
+
import torchvision.models as models
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE
|
12 |
+
|
13 |
+
if _TORCH_FIDELITY_AVAILABLE:
|
14 |
+
from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3
|
15 |
+
else:
|
16 |
+
class FeatureExtractorInceptionV3(Module): # type: ignore
|
17 |
+
pass
|
18 |
+
__doctest_skip__ = ["ImprovedPrecessionRecall", "IPR"]
|
19 |
+
|
20 |
+
class NoTrainInceptionV3(FeatureExtractorInceptionV3):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
name: str,
|
24 |
+
features_list: List[str],
|
25 |
+
feature_extractor_weights_path: Optional[str] = None,
|
26 |
+
) -> None:
|
27 |
+
super().__init__(name, features_list, feature_extractor_weights_path)
|
28 |
+
# put into evaluation mode
|
29 |
+
self.eval()
|
30 |
+
|
31 |
+
def train(self, mode: bool) -> "NoTrainInceptionV3":
|
32 |
+
"""the inception network should not be able to be switched away from evaluation mode."""
|
33 |
+
return super().train(False)
|
34 |
+
|
35 |
+
def forward(self, x: Tensor) -> Tensor:
|
36 |
+
out = super().forward(x)
|
37 |
+
return out[0].reshape(x.shape[0], -1)
|
38 |
+
|
39 |
+
|
40 |
+
# -------------------------- VGG Trans ---------------------------
|
41 |
+
# class Normalize(object):
|
42 |
+
# """Rescale the image from 0-255 (uint8) to [0,1] (float32).
|
43 |
+
# Note, this doesn't ensure that min=0 and max=1 as a min-max scale would do!"""
|
44 |
+
|
45 |
+
# def __call__(self, image):
|
46 |
+
# return image/255
|
47 |
+
|
48 |
+
# # see https://pytorch.org/vision/main/models/generated/torchvision.models.vgg16.html
|
49 |
+
# VGG_Trans = transforms.Compose([
|
50 |
+
# transforms.Resize([224, 224], interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
|
51 |
+
# # transforms.Resize([256, 256], interpolation=InterpolationMode.BILINEAR),
|
52 |
+
# # transforms.CenterCrop(224),
|
53 |
+
# Normalize(), # scale to [0, 1]
|
54 |
+
# transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
|
55 |
+
# ])
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
class ImprovedPrecessionRecall(Metric):
|
60 |
+
is_differentiable: bool = False
|
61 |
+
higher_is_better: bool = True
|
62 |
+
full_state_update: bool = False
|
63 |
+
|
64 |
+
|
65 |
+
def __init__(self, feature=2048, knn=3, splits_real=1, splits_fake=5):
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
|
69 |
+
# ------------------------- Init Feature Extractor (VGG or Inception) ------------------------------
|
70 |
+
# Original VGG: https://github.com/kynkaat/improved-precision-and-recall-metric/blob/b0247eafdead494a5d243bd2efb1b0b124379ae9/utils.py#L40
|
71 |
+
# Compare Inception: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/evaluations/evaluator.py#L574
|
72 |
+
# TODO: Add option to switch between Inception and VGG feature extractor
|
73 |
+
# self.vgg_model = models.vgg16(weights='IMAGENET1K_V1').eval()
|
74 |
+
# self.feature_extractor = transforms.Compose([
|
75 |
+
# VGG_Trans,
|
76 |
+
# self.vgg_model.features,
|
77 |
+
# transforms.Lambda(lambda x: torch.flatten(x, 1)),
|
78 |
+
# self.vgg_model.classifier[:4] # [:4] corresponds to 4096 features
|
79 |
+
# ])
|
80 |
+
|
81 |
+
if isinstance(feature, int):
|
82 |
+
if not _TORCH_FIDELITY_AVAILABLE:
|
83 |
+
raise ModuleNotFoundError(
|
84 |
+
"FrechetInceptionDistance metric requires that `Torch-fidelity` is installed."
|
85 |
+
" Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
|
86 |
+
)
|
87 |
+
valid_int_input = [64, 192, 768, 2048]
|
88 |
+
if feature not in valid_int_input:
|
89 |
+
raise ValueError(
|
90 |
+
f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}."
|
91 |
+
)
|
92 |
+
|
93 |
+
self.feature_extractor = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)])
|
94 |
+
elif isinstance(feature, torch.nn.Module):
|
95 |
+
self.feature_extractor = feature
|
96 |
+
else:
|
97 |
+
raise TypeError("Got unknown input to argument `feature`")
|
98 |
+
|
99 |
+
# --------------------------- End Feature Extractor ---------------------------------------------------------------
|
100 |
+
|
101 |
+
self.knn = knn
|
102 |
+
self.splits_real = splits_real
|
103 |
+
self.splits_fake = splits_fake
|
104 |
+
self.add_state("real_features", [], dist_reduce_fx=None)
|
105 |
+
self.add_state("fake_features", [], dist_reduce_fx=None)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def update(self, imgs: Tensor, real: bool) -> None: # type: ignore
|
110 |
+
"""Update the state with extracted features.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
imgs: tensor with images feed to the feature extractor
|
114 |
+
real: bool indicating if ``imgs`` belong to the real or the fake distribution
|
115 |
+
"""
|
116 |
+
assert torch.is_tensor(imgs) and imgs.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8'
|
117 |
+
|
118 |
+
features = self.feature_extractor(imgs).view(imgs.shape[0], -1)
|
119 |
+
|
120 |
+
if real:
|
121 |
+
self.real_features.append(features)
|
122 |
+
else:
|
123 |
+
self.fake_features.append(features)
|
124 |
+
|
125 |
+
def compute(self):
|
126 |
+
real_features = torch.concat(self.real_features)
|
127 |
+
fake_features = torch.concat(self.fake_features)
|
128 |
+
|
129 |
+
real_distances = _compute_pairwise_distances(real_features, self.splits_real)
|
130 |
+
real_radii = _distances2radii(real_distances, self.knn)
|
131 |
+
|
132 |
+
fake_distances = _compute_pairwise_distances(fake_features, self.splits_fake)
|
133 |
+
fake_radii = _distances2radii(fake_distances, self.knn)
|
134 |
+
|
135 |
+
precision = _compute_metric(real_features, real_radii, self.splits_real, fake_features, self.splits_fake)
|
136 |
+
recall = _compute_metric(fake_features, fake_radii, self.splits_fake, real_features, self.splits_real)
|
137 |
+
|
138 |
+
return precision, recall
|
139 |
+
|
140 |
+
def _compute_metric(ref_features, ref_radii, ref_splits, pred_features, pred_splits):
|
141 |
+
dist = _compute_pairwise_distances(ref_features, ref_splits, pred_features, pred_splits)
|
142 |
+
num_feat = pred_features.shape[0]
|
143 |
+
count = 0
|
144 |
+
for i in range(num_feat):
|
145 |
+
count += (dist[:, i] < ref_radii).any()
|
146 |
+
return count / num_feat
|
147 |
+
|
148 |
+
def _distances2radii(distances, knn):
|
149 |
+
return torch.topk(distances, knn+1, dim=1, largest=False)[0].max(dim=1)[0]
|
150 |
+
|
151 |
+
def _compute_pairwise_distances(X, splits_x, Y=None, splits_y=None):
|
152 |
+
# X = [B, features]
|
153 |
+
# Y = [B', features]
|
154 |
+
Y = X if Y is None else Y
|
155 |
+
# X = X.double()
|
156 |
+
# Y = Y.double()
|
157 |
+
splits_y = splits_x if splits_y is None else splits_y
|
158 |
+
dist = torch.concat([
|
159 |
+
torch.concat([
|
160 |
+
(torch.sum(X_batch**2, dim=1, keepdim=True) +
|
161 |
+
torch.sum(Y_batch**2, dim=1, keepdim=True).t() -
|
162 |
+
2 * torch.einsum("bd,dn->bn", X_batch, Y_batch.t()))
|
163 |
+
for Y_batch in Y.chunk(splits_y, dim=0)], dim=1)
|
164 |
+
for X_batch in X.chunk(splits_x, dim=0)])
|
165 |
+
|
166 |
+
# dist = torch.maximum(dist, torch.zeros_like(dist))
|
167 |
+
dist[dist<0] = 0
|
168 |
+
return torch.sqrt(dist)
|
169 |
+
|
170 |
+
|
medical_diffusion/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model_base import BasicModel
|
medical_diffusion/models/embedders/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .time_embedder import TimeEmbbeding, LearnedSinusoidalPosEmb, SinusoidalPosEmb
|
2 |
+
from .cond_embedders import LabelEmbedder
|
medical_diffusion/models/embedders/cond_embedders.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
from monai.networks.layers.utils import get_act_layer
|
5 |
+
|
6 |
+
class LabelEmbedder(nn.Module):
|
7 |
+
def __init__(self, emb_dim=32, num_classes=2, act_name=("SWISH", {})):
|
8 |
+
super().__init__()
|
9 |
+
self.emb_dim = emb_dim
|
10 |
+
self.embedding = nn.Embedding(num_classes, emb_dim)
|
11 |
+
|
12 |
+
# self.embedding = nn.Embedding(num_classes, emb_dim//4)
|
13 |
+
# self.emb_net = nn.Sequential(
|
14 |
+
# nn.Linear(1, emb_dim),
|
15 |
+
# get_act_layer(act_name),
|
16 |
+
# nn.Linear(emb_dim, emb_dim)
|
17 |
+
# )
|
18 |
+
|
19 |
+
def forward(self, condition):
|
20 |
+
c = self.embedding(condition) #[B,] -> [B, C]
|
21 |
+
# c = self.emb_net(c)
|
22 |
+
# c = self.emb_net(condition[:,None].float())
|
23 |
+
# c = (2*condition-1)[:, None].expand(-1, self.emb_dim).type(torch.float32)
|
24 |
+
return c
|
25 |
+
|
26 |
+
|
27 |
+
|
medical_diffusion/models/embedders/latent_embedders.py
ADDED
@@ -0,0 +1,1065 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torchvision.utils import save_image
|
8 |
+
from monai.networks.blocks import UnetOutBlock
|
9 |
+
|
10 |
+
|
11 |
+
from medical_diffusion.models.utils.conv_blocks import DownBlock, UpBlock, BasicBlock, BasicResBlock, UnetResBlock, UnetBasicBlock
|
12 |
+
from medical_diffusion.loss.gan_losses import hinge_d_loss
|
13 |
+
from medical_diffusion.loss.perceivers import LPIPS
|
14 |
+
from medical_diffusion.models.model_base import BasicModel, VeryBasicModel
|
15 |
+
|
16 |
+
|
17 |
+
from pytorch_msssim import SSIM, ssim
|
18 |
+
|
19 |
+
|
20 |
+
class DiagonalGaussianDistribution(nn.Module):
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
mean, logvar = torch.chunk(x, 2, dim=1)
|
24 |
+
logvar = torch.clamp(logvar, -30.0, 20.0)
|
25 |
+
std = torch.exp(0.5 * logvar)
|
26 |
+
sample = torch.randn(mean.shape, generator=None, device=x.device)
|
27 |
+
z = mean + std * sample
|
28 |
+
|
29 |
+
batch_size = x.shape[0]
|
30 |
+
var = torch.exp(logvar)
|
31 |
+
kl = 0.5 * torch.sum(torch.pow(mean, 2) + var - 1.0 - logvar)/batch_size
|
32 |
+
|
33 |
+
return z, kl
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
class VectorQuantizer(nn.Module):
|
41 |
+
def __init__(self, num_embeddings, emb_channels, beta=0.25):
|
42 |
+
super().__init__()
|
43 |
+
self.num_embeddings = num_embeddings
|
44 |
+
self.emb_channels = emb_channels
|
45 |
+
self.beta = beta
|
46 |
+
|
47 |
+
self.embedder = nn.Embedding(num_embeddings, emb_channels)
|
48 |
+
self.embedder.weight.data.uniform_(-1.0 / self.num_embeddings, 1.0 / self.num_embeddings)
|
49 |
+
|
50 |
+
def forward(self, z):
|
51 |
+
assert z.shape[1] == self.emb_channels, "Channels of z and codebook don't match"
|
52 |
+
z_ch = torch.moveaxis(z, 1, -1) # [B, C, *] -> [B, *, C]
|
53 |
+
z_flattened = z_ch.reshape(-1, self.emb_channels) # [B, *, C] -> [Bx*, C], Note: or use contiguous() and view()
|
54 |
+
|
55 |
+
# distances from z to embeddings e: (z - e)^2 = z^2 + e^2 - 2 e * z
|
56 |
+
dist = ( torch.sum(z_flattened**2, dim=1, keepdim=True)
|
57 |
+
+ torch.sum(self.embedder.weight**2, dim=1)
|
58 |
+
-2* torch.einsum("bd,dn->bn", z_flattened, self.embedder.weight.t())
|
59 |
+
) # [Bx*, num_embeddings]
|
60 |
+
|
61 |
+
min_encoding_indices = torch.argmin(dist, dim=1) # [Bx*]
|
62 |
+
z_q = self.embedder(min_encoding_indices) # [Bx*, C]
|
63 |
+
z_q = z_q.view(z_ch.shape) # [Bx*, C] -> [B, *, C]
|
64 |
+
z_q = torch.moveaxis(z_q, -1, 1) # [B, *, C] -> [B, C, *]
|
65 |
+
|
66 |
+
# Compute Embedding Loss
|
67 |
+
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
|
68 |
+
|
69 |
+
# preserve gradients
|
70 |
+
z_q = z + (z_q - z).detach()
|
71 |
+
|
72 |
+
return z_q, loss
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
class Discriminator(nn.Module):
|
77 |
+
def __init__(self,
|
78 |
+
in_channels=1,
|
79 |
+
spatial_dims = 3,
|
80 |
+
hid_chs = [32, 64, 128, 256, 512],
|
81 |
+
kernel_sizes=[(1,3,3), (1,3,3), (1,3,3), 3, 3],
|
82 |
+
strides = [ 1, (1,2,2), (1,2,2), 2, 2],
|
83 |
+
act_name=("Swish", {}),
|
84 |
+
norm_name = ("GROUP", {'num_groups':32, "affine": True}),
|
85 |
+
dropout=None
|
86 |
+
):
|
87 |
+
super().__init__()
|
88 |
+
|
89 |
+
self.inc = BasicBlock(
|
90 |
+
spatial_dims=spatial_dims,
|
91 |
+
in_channels=in_channels,
|
92 |
+
out_channels=hid_chs[0],
|
93 |
+
kernel_size=kernel_sizes[0], # 2*pad = kernel-stride -> kernel = 2*pad + stride => 1 = 2*0+1, 3, =2*1+1, 2 = 2*0+2, 4 = 2*1+2
|
94 |
+
stride=strides[0],
|
95 |
+
norm_name=norm_name,
|
96 |
+
act_name=act_name,
|
97 |
+
dropout=dropout,
|
98 |
+
)
|
99 |
+
|
100 |
+
self.encoder = nn.Sequential(*[
|
101 |
+
BasicBlock(
|
102 |
+
spatial_dims=spatial_dims,
|
103 |
+
in_channels=hid_chs[i-1],
|
104 |
+
out_channels=hid_chs[i],
|
105 |
+
kernel_size=kernel_sizes[i],
|
106 |
+
stride=strides[i],
|
107 |
+
act_name=act_name,
|
108 |
+
norm_name=norm_name,
|
109 |
+
dropout=dropout)
|
110 |
+
for i in range(1, len(hid_chs))
|
111 |
+
])
|
112 |
+
|
113 |
+
|
114 |
+
self.outc = BasicBlock(
|
115 |
+
spatial_dims=spatial_dims,
|
116 |
+
in_channels=hid_chs[-1],
|
117 |
+
out_channels=1,
|
118 |
+
kernel_size=3,
|
119 |
+
stride=1,
|
120 |
+
act_name=None,
|
121 |
+
norm_name=None,
|
122 |
+
dropout=None,
|
123 |
+
zero_conv=True
|
124 |
+
)
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
x = self.inc(x)
|
130 |
+
x = self.encoder(x)
|
131 |
+
return self.outc(x)
|
132 |
+
|
133 |
+
|
134 |
+
class NLayerDiscriminator(nn.Module):
|
135 |
+
def __init__(self,
|
136 |
+
in_channels=1,
|
137 |
+
spatial_dims = 3,
|
138 |
+
hid_chs = [64, 128, 256, 512, 512],
|
139 |
+
kernel_sizes=[4, 4, 4, 4, 4],
|
140 |
+
strides = [2, 2, 2, 1, 1],
|
141 |
+
act_name=("LeakyReLU", {'negative_slope': 0.2}),
|
142 |
+
norm_name = ("BATCH", {}),
|
143 |
+
dropout=None
|
144 |
+
):
|
145 |
+
super().__init__()
|
146 |
+
|
147 |
+
self.inc = BasicBlock(
|
148 |
+
spatial_dims=spatial_dims,
|
149 |
+
in_channels=in_channels,
|
150 |
+
out_channels=hid_chs[0],
|
151 |
+
kernel_size=kernel_sizes[0],
|
152 |
+
stride=strides[0],
|
153 |
+
norm_name=None,
|
154 |
+
act_name=act_name,
|
155 |
+
dropout=dropout,
|
156 |
+
)
|
157 |
+
|
158 |
+
self.encoder = nn.Sequential(*[
|
159 |
+
BasicBlock(
|
160 |
+
spatial_dims=spatial_dims,
|
161 |
+
in_channels=hid_chs[i-1],
|
162 |
+
out_channels=hid_chs[i],
|
163 |
+
kernel_size=kernel_sizes[i],
|
164 |
+
stride=strides[i],
|
165 |
+
act_name=act_name,
|
166 |
+
norm_name=norm_name,
|
167 |
+
dropout=dropout)
|
168 |
+
for i in range(1, len(strides))
|
169 |
+
])
|
170 |
+
|
171 |
+
|
172 |
+
self.outc = BasicBlock(
|
173 |
+
spatial_dims=spatial_dims,
|
174 |
+
in_channels=hid_chs[-1],
|
175 |
+
out_channels=1,
|
176 |
+
kernel_size=4,
|
177 |
+
stride=1,
|
178 |
+
norm_name=None,
|
179 |
+
act_name=None,
|
180 |
+
dropout=False,
|
181 |
+
)
|
182 |
+
|
183 |
+
def forward(self, x):
|
184 |
+
x = self.inc(x)
|
185 |
+
x = self.encoder(x)
|
186 |
+
return self.outc(x)
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
+
class VQVAE(BasicModel):
|
192 |
+
def __init__(
|
193 |
+
self,
|
194 |
+
in_channels=3,
|
195 |
+
out_channels=3,
|
196 |
+
spatial_dims = 2,
|
197 |
+
emb_channels = 4,
|
198 |
+
num_embeddings = 8192,
|
199 |
+
hid_chs = [32, 64, 128, 256],
|
200 |
+
kernel_sizes=[ 3, 3, 3, 3],
|
201 |
+
strides = [ 1, 2, 2, 2],
|
202 |
+
norm_name = ("GROUP", {'num_groups':32, "affine": True}),
|
203 |
+
act_name=("Swish", {}),
|
204 |
+
dropout=0.0,
|
205 |
+
use_res_block=True,
|
206 |
+
deep_supervision=False,
|
207 |
+
learnable_interpolation=True,
|
208 |
+
use_attention='none',
|
209 |
+
beta = 0.25,
|
210 |
+
embedding_loss_weight=1.0,
|
211 |
+
perceiver = LPIPS,
|
212 |
+
perceiver_kwargs = {},
|
213 |
+
perceptual_loss_weight = 1.0,
|
214 |
+
|
215 |
+
|
216 |
+
optimizer=torch.optim.Adam,
|
217 |
+
optimizer_kwargs={'lr':1e-4},
|
218 |
+
lr_scheduler= None,
|
219 |
+
lr_scheduler_kwargs={},
|
220 |
+
loss = torch.nn.L1Loss,
|
221 |
+
loss_kwargs={'reduction': 'none'},
|
222 |
+
|
223 |
+
sample_every_n_steps = 1000
|
224 |
+
|
225 |
+
):
|
226 |
+
super().__init__(
|
227 |
+
optimizer=optimizer,
|
228 |
+
optimizer_kwargs=optimizer_kwargs,
|
229 |
+
lr_scheduler=lr_scheduler,
|
230 |
+
lr_scheduler_kwargs=lr_scheduler_kwargs
|
231 |
+
)
|
232 |
+
self.sample_every_n_steps=sample_every_n_steps
|
233 |
+
self.loss_fct = loss(**loss_kwargs)
|
234 |
+
self.embedding_loss_weight = embedding_loss_weight
|
235 |
+
self.perceiver = perceiver(**perceiver_kwargs).eval() if perceiver is not None else None
|
236 |
+
self.perceptual_loss_weight = perceptual_loss_weight
|
237 |
+
use_attention = use_attention if isinstance(use_attention, list) else [use_attention]*len(strides)
|
238 |
+
self.depth = len(strides)
|
239 |
+
self.deep_supervision = deep_supervision
|
240 |
+
|
241 |
+
# ----------- In-Convolution ------------
|
242 |
+
ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock
|
243 |
+
self.inc = ConvBlock(spatial_dims, in_channels, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0],
|
244 |
+
act_name=act_name, norm_name=norm_name)
|
245 |
+
|
246 |
+
# ----------- Encoder ----------------
|
247 |
+
self.encoders = nn.ModuleList([
|
248 |
+
DownBlock(
|
249 |
+
spatial_dims,
|
250 |
+
hid_chs[i-1],
|
251 |
+
hid_chs[i],
|
252 |
+
kernel_sizes[i],
|
253 |
+
strides[i],
|
254 |
+
kernel_sizes[i],
|
255 |
+
norm_name,
|
256 |
+
act_name,
|
257 |
+
dropout,
|
258 |
+
use_res_block,
|
259 |
+
learnable_interpolation,
|
260 |
+
use_attention[i])
|
261 |
+
for i in range(1, self.depth)
|
262 |
+
])
|
263 |
+
|
264 |
+
# ----------- Out-Encoder ------------
|
265 |
+
self.out_enc = BasicBlock(spatial_dims, hid_chs[-1], emb_channels, 1)
|
266 |
+
|
267 |
+
|
268 |
+
# ----------- Quantizer --------------
|
269 |
+
self.quantizer = VectorQuantizer(
|
270 |
+
num_embeddings=num_embeddings,
|
271 |
+
emb_channels=emb_channels,
|
272 |
+
beta=beta
|
273 |
+
)
|
274 |
+
|
275 |
+
# ----------- In-Decoder ------------
|
276 |
+
self.inc_dec = ConvBlock(spatial_dims, emb_channels, hid_chs[-1], 3, act_name=act_name, norm_name=norm_name)
|
277 |
+
|
278 |
+
# ------------ Decoder ----------
|
279 |
+
self.decoders = nn.ModuleList([
|
280 |
+
UpBlock(
|
281 |
+
spatial_dims,
|
282 |
+
hid_chs[i+1],
|
283 |
+
hid_chs[i],
|
284 |
+
kernel_size=kernel_sizes[i+1],
|
285 |
+
stride=strides[i+1],
|
286 |
+
upsample_kernel_size=strides[i+1],
|
287 |
+
norm_name=norm_name,
|
288 |
+
act_name=act_name,
|
289 |
+
dropout=dropout,
|
290 |
+
use_res_block=use_res_block,
|
291 |
+
learnable_interpolation=learnable_interpolation,
|
292 |
+
use_attention=use_attention[i],
|
293 |
+
skip_channels=0)
|
294 |
+
for i in range(self.depth-1)
|
295 |
+
])
|
296 |
+
|
297 |
+
# --------------- Out-Convolution ----------------
|
298 |
+
self.outc = BasicBlock(spatial_dims, hid_chs[0], out_channels, 1, zero_conv=True)
|
299 |
+
if isinstance(deep_supervision, bool):
|
300 |
+
deep_supervision = self.depth-1 if deep_supervision else 0
|
301 |
+
self.outc_ver = nn.ModuleList([
|
302 |
+
BasicBlock(spatial_dims, hid_chs[i], out_channels, 1, zero_conv=True)
|
303 |
+
for i in range(1, deep_supervision+1)
|
304 |
+
])
|
305 |
+
|
306 |
+
|
307 |
+
def encode(self, x):
|
308 |
+
h = self.inc(x)
|
309 |
+
for i in range(len(self.encoders)):
|
310 |
+
h = self.encoders[i](h)
|
311 |
+
z = self.out_enc(h)
|
312 |
+
return z
|
313 |
+
|
314 |
+
def decode(self, z):
|
315 |
+
z, _ = self.quantizer(z)
|
316 |
+
h = self.inc_dec(z)
|
317 |
+
for i in range(len(self.decoders), 0, -1):
|
318 |
+
h = self.decoders[i-1](h)
|
319 |
+
x = self.outc(h)
|
320 |
+
return x
|
321 |
+
|
322 |
+
def forward(self, x_in):
|
323 |
+
# --------- Encoder --------------
|
324 |
+
h = self.inc(x_in)
|
325 |
+
for i in range(len(self.encoders)):
|
326 |
+
h = self.encoders[i](h)
|
327 |
+
z = self.out_enc(h)
|
328 |
+
|
329 |
+
# --------- Quantizer --------------
|
330 |
+
z_q, emb_loss = self.quantizer(z)
|
331 |
+
|
332 |
+
# -------- Decoder -----------
|
333 |
+
out_hor = []
|
334 |
+
h = self.inc_dec(z_q)
|
335 |
+
for i in range(len(self.decoders)-1, -1, -1):
|
336 |
+
out_hor.append(self.outc_ver[i](h)) if i < len(self.outc_ver) else None
|
337 |
+
h = self.decoders[i](h)
|
338 |
+
out = self.outc(h)
|
339 |
+
|
340 |
+
return out, out_hor[::-1], emb_loss
|
341 |
+
|
342 |
+
def perception_loss(self, pred, target, depth=0):
|
343 |
+
if (self.perceiver is not None) and (depth<2):
|
344 |
+
self.perceiver.eval()
|
345 |
+
return self.perceiver(pred, target)*self.perceptual_loss_weight
|
346 |
+
else:
|
347 |
+
return 0
|
348 |
+
|
349 |
+
def ssim_loss(self, pred, target):
|
350 |
+
return 1-ssim(((pred+1)/2).clamp(0,1), (target.type(pred.dtype)+1)/2, data_range=1, size_average=False,
|
351 |
+
nonnegative_ssim=True).reshape(-1, *[1]*(pred.ndim-1))
|
352 |
+
|
353 |
+
|
354 |
+
def rec_loss(self, pred, pred_vertical, target):
|
355 |
+
interpolation_mode = 'nearest-exact'
|
356 |
+
weights = [1/2**i for i in range(1+len(pred_vertical))] # horizontal (equal) + vertical (reducing with every step down)
|
357 |
+
tot_weight = sum(weights)
|
358 |
+
weights = [w/tot_weight for w in weights]
|
359 |
+
|
360 |
+
# Loss
|
361 |
+
loss = 0
|
362 |
+
loss += torch.mean(self.loss_fct(pred, target)+self.perception_loss(pred, target)+self.ssim_loss(pred, target))*weights[0]
|
363 |
+
|
364 |
+
for i, pred_i in enumerate(pred_vertical):
|
365 |
+
target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
|
366 |
+
loss += torch.mean(self.loss_fct(pred_i, target_i)+self.perception_loss(pred_i, target_i)+self.ssim_loss(pred_i, target_i))*weights[i+1]
|
367 |
+
|
368 |
+
return loss
|
369 |
+
|
370 |
+
def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
|
371 |
+
# ------------------------- Get Source/Target ---------------------------
|
372 |
+
x = batch['source']
|
373 |
+
target = x
|
374 |
+
|
375 |
+
# ------------------------- Run Model ---------------------------
|
376 |
+
pred, pred_vertical, emb_loss = self(x)
|
377 |
+
|
378 |
+
# ------------------------- Compute Loss ---------------------------
|
379 |
+
loss = self.rec_loss(pred, pred_vertical, target)
|
380 |
+
loss += emb_loss*self.embedding_loss_weight
|
381 |
+
|
382 |
+
# --------------------- Compute Metrics -------------------------------
|
383 |
+
with torch.no_grad():
|
384 |
+
logging_dict = {'loss':loss, 'emb_loss': emb_loss}
|
385 |
+
logging_dict['L2'] = torch.nn.functional.mse_loss(pred, target)
|
386 |
+
logging_dict['L1'] = torch.nn.functional.l1_loss(pred, target)
|
387 |
+
logging_dict['ssim'] = ssim((pred+1)/2, (target.type(pred.dtype)+1)/2, data_range=1)
|
388 |
+
|
389 |
+
# ----------------- Log Scalars ----------------------
|
390 |
+
for metric_name, metric_val in logging_dict.items():
|
391 |
+
self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True)
|
392 |
+
|
393 |
+
# ----------------- Save Image ------------------------------
|
394 |
+
if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0:
|
395 |
+
log_step = self.global_step // self.sample_every_n_steps
|
396 |
+
path_out = Path(self.logger.log_dir)/'images'
|
397 |
+
path_out.mkdir(parents=True, exist_ok=True)
|
398 |
+
# for 3D images use depth as batch :[D, C, H, W], never show more than 16+16 =32 images
|
399 |
+
def depth2batch(image):
|
400 |
+
return (image if image.ndim<5 else torch.swapaxes(image[0], 0, 1))
|
401 |
+
images = torch.cat([depth2batch(img)[:16] for img in (x, pred)])
|
402 |
+
save_image(images, path_out/f'sample_{log_step}.png', nrow=x.shape[0], normalize=True)
|
403 |
+
|
404 |
+
return loss
|
405 |
+
|
406 |
+
|
407 |
+
|
408 |
+
class VQGAN(VeryBasicModel):
|
409 |
+
def __init__(
|
410 |
+
self,
|
411 |
+
in_channels=3,
|
412 |
+
out_channels=3,
|
413 |
+
spatial_dims = 2,
|
414 |
+
emb_channels = 4,
|
415 |
+
num_embeddings = 8192,
|
416 |
+
hid_chs = [ 64, 128, 256, 512],
|
417 |
+
kernel_sizes=[ 3, 3, 3, 3],
|
418 |
+
strides = [ 1, 2, 2, 2],
|
419 |
+
norm_name = ("GROUP", {'num_groups':32, "affine": True}),
|
420 |
+
act_name=("Swish", {}),
|
421 |
+
dropout=0.0,
|
422 |
+
use_res_block=True,
|
423 |
+
deep_supervision=False,
|
424 |
+
learnable_interpolation=True,
|
425 |
+
use_attention='none',
|
426 |
+
beta = 0.25,
|
427 |
+
embedding_loss_weight=1.0,
|
428 |
+
perceiver = LPIPS,
|
429 |
+
perceiver_kwargs = {},
|
430 |
+
perceptual_loss_weight: float = 1.0,
|
431 |
+
|
432 |
+
|
433 |
+
start_gan_train_step = 50000, # NOTE step increase with each optimizer
|
434 |
+
gan_loss_weight: float = 1.0, # = discriminator
|
435 |
+
|
436 |
+
optimizer_vqvae=torch.optim.Adam,
|
437 |
+
optimizer_gan=torch.optim.Adam,
|
438 |
+
optimizer_vqvae_kwargs={'lr':1e-6},
|
439 |
+
optimizer_gan_kwargs={'lr':1e-6},
|
440 |
+
lr_scheduler_vqvae= None,
|
441 |
+
lr_scheduler_vqvae_kwargs={},
|
442 |
+
lr_scheduler_gan= None,
|
443 |
+
lr_scheduler_gan_kwargs={},
|
444 |
+
|
445 |
+
pixel_loss = torch.nn.L1Loss,
|
446 |
+
pixel_loss_kwargs={'reduction':'none'},
|
447 |
+
gan_loss_fct = hinge_d_loss,
|
448 |
+
|
449 |
+
sample_every_n_steps = 1000
|
450 |
+
|
451 |
+
):
|
452 |
+
super().__init__()
|
453 |
+
self.sample_every_n_steps=sample_every_n_steps
|
454 |
+
self.start_gan_train_step = start_gan_train_step
|
455 |
+
self.gan_loss_weight = gan_loss_weight
|
456 |
+
self.embedding_loss_weight = embedding_loss_weight
|
457 |
+
|
458 |
+
self.optimizer_vqvae = optimizer_vqvae
|
459 |
+
self.optimizer_gan = optimizer_gan
|
460 |
+
self.optimizer_vqvae_kwargs = optimizer_vqvae_kwargs
|
461 |
+
self.optimizer_gan_kwargs = optimizer_gan_kwargs
|
462 |
+
self.lr_scheduler_vqvae = lr_scheduler_vqvae
|
463 |
+
self.lr_scheduler_vqvae_kwargs = lr_scheduler_vqvae_kwargs
|
464 |
+
self.lr_scheduler_gan = lr_scheduler_gan
|
465 |
+
self.lr_scheduler_gan_kwargs = lr_scheduler_gan_kwargs
|
466 |
+
|
467 |
+
self.pixel_loss_fct = pixel_loss(**pixel_loss_kwargs)
|
468 |
+
self.gan_loss_fct = gan_loss_fct
|
469 |
+
|
470 |
+
self.vqvae = VQVAE(in_channels, out_channels, spatial_dims, emb_channels, num_embeddings, hid_chs, kernel_sizes,
|
471 |
+
strides, norm_name, act_name, dropout, use_res_block, deep_supervision, learnable_interpolation, use_attention,
|
472 |
+
beta, embedding_loss_weight, perceiver, perceiver_kwargs, perceptual_loss_weight)
|
473 |
+
|
474 |
+
self.discriminator = nn.ModuleList([Discriminator(in_channels, spatial_dims, hid_chs, kernel_sizes, strides,
|
475 |
+
act_name, norm_name, dropout) for i in range(len(self.vqvae.outc_ver)+1)])
|
476 |
+
|
477 |
+
|
478 |
+
# self.discriminator = nn.ModuleList([NLayerDiscriminator(in_channels, spatial_dims)
|
479 |
+
# for _ in range(len(self.vqvae.decoder.outc_ver)+1)])
|
480 |
+
|
481 |
+
|
482 |
+
|
483 |
+
def encode(self, x):
|
484 |
+
return self.vqvae.encode(x)
|
485 |
+
|
486 |
+
def decode(self, z):
|
487 |
+
return self.vqvae.decode(z)
|
488 |
+
|
489 |
+
def forward(self, x):
|
490 |
+
return self.vqvae.forward(x)
|
491 |
+
|
492 |
+
|
493 |
+
def vae_img_loss(self, pred, target, dec_out_layer, step, discriminator, depth=0):
|
494 |
+
# ------ VQVAE -------
|
495 |
+
rec_loss = self.vqvae.rec_loss(pred, [], target)
|
496 |
+
|
497 |
+
# ------- GAN -----
|
498 |
+
if step > self.start_gan_train_step:
|
499 |
+
gan_loss = -torch.mean(discriminator[depth](pred))
|
500 |
+
lambda_weight = self.compute_lambda(rec_loss, gan_loss, dec_out_layer)
|
501 |
+
gan_loss = gan_loss*lambda_weight
|
502 |
+
|
503 |
+
with torch.no_grad():
|
504 |
+
self.log(f"train/gan_loss_{depth}", gan_loss, on_step=True, on_epoch=True)
|
505 |
+
self.log(f"train/lambda_{depth}", lambda_weight, on_step=True, on_epoch=True)
|
506 |
+
else:
|
507 |
+
gan_loss = 0 #torch.tensor([0.0], requires_grad=True, device=target.device)
|
508 |
+
|
509 |
+
return self.gan_loss_weight*gan_loss+rec_loss
|
510 |
+
|
511 |
+
|
512 |
+
def gan_img_loss(self, pred, target, step, discriminators, depth):
|
513 |
+
if (step > self.start_gan_train_step) and (depth<len(discriminators)):
|
514 |
+
logits_real = discriminators[depth](target.detach())
|
515 |
+
logits_fake = discriminators[depth](pred.detach())
|
516 |
+
loss = self.gan_loss_fct(logits_real, logits_fake)
|
517 |
+
else:
|
518 |
+
loss = torch.tensor(0.0, requires_grad=True, device=target.device)
|
519 |
+
|
520 |
+
with torch.no_grad():
|
521 |
+
self.log(f"train/loss_1_{depth}", loss, on_step=True, on_epoch=True)
|
522 |
+
return loss
|
523 |
+
|
524 |
+
def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
|
525 |
+
# ------------------------- Get Source/Target ---------------------------
|
526 |
+
x = batch['source']
|
527 |
+
target = x
|
528 |
+
|
529 |
+
# ------------------------- Run Model ---------------------------
|
530 |
+
pred, pred_vertical, emb_loss = self(x)
|
531 |
+
|
532 |
+
# ------------------------- Compute Loss ---------------------------
|
533 |
+
interpolation_mode = 'area'
|
534 |
+
weights = [1/2**i for i in range(1+len(pred_vertical))] # horizontal + vertical (reducing with every step down)
|
535 |
+
tot_weight = sum(weights)
|
536 |
+
weights = [w/tot_weight for w in weights]
|
537 |
+
logging_dict = {}
|
538 |
+
|
539 |
+
if optimizer_idx == 0:
|
540 |
+
# Horizontal/Top Layer
|
541 |
+
img_loss = self.vae_img_loss(pred, target, self.vqvae.outc.conv, step, self.discriminator, 0)*weights[0]
|
542 |
+
|
543 |
+
# Vertical/Deep Layer
|
544 |
+
for i, pred_i in enumerate(pred_vertical):
|
545 |
+
target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
|
546 |
+
img_loss += self.vae_img_loss(pred_i, target_i, self.vqvae.outc_ver[i].conv, step, self.discriminator, i+1)*weights[i+1]
|
547 |
+
loss = img_loss+self.embedding_loss_weight*emb_loss
|
548 |
+
|
549 |
+
with torch.no_grad():
|
550 |
+
logging_dict[f'img_loss'] = img_loss
|
551 |
+
logging_dict[f'emb_loss'] = emb_loss
|
552 |
+
logging_dict['loss_0'] = loss
|
553 |
+
|
554 |
+
elif optimizer_idx == 1:
|
555 |
+
# Horizontal/Top Layer
|
556 |
+
loss = self.gan_img_loss(pred, target, step, self.discriminator, 0)*weights[0]
|
557 |
+
|
558 |
+
# Vertical/Deep Layer
|
559 |
+
for i, pred_i in enumerate(pred_vertical):
|
560 |
+
target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
|
561 |
+
loss += self.gan_img_loss(pred_i, target_i, step, self.discriminator, i+1)*weights[i+1]
|
562 |
+
|
563 |
+
with torch.no_grad():
|
564 |
+
logging_dict['loss_1'] = loss
|
565 |
+
|
566 |
+
|
567 |
+
# --------------------- Compute Metrics -------------------------------
|
568 |
+
with torch.no_grad():
|
569 |
+
logging_dict['loss'] = loss
|
570 |
+
logging_dict[f'L2'] = torch.nn.functional.mse_loss(pred, x)
|
571 |
+
logging_dict[f'L1'] = torch.nn.functional.l1_loss(pred, x)
|
572 |
+
logging_dict['ssim'] = ssim((pred+1)/2, (target.type(pred.dtype)+1)/2, data_range=1)
|
573 |
+
|
574 |
+
# ----------------- Log Scalars ----------------------
|
575 |
+
for metric_name, metric_val in logging_dict.items():
|
576 |
+
self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True)
|
577 |
+
|
578 |
+
# ----------------- Save Image ------------------------------
|
579 |
+
if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0: # NOTE: step 1 (opt1) , step=2 (opt2), step=3 (opt1), ...
|
580 |
+
|
581 |
+
log_step = self.global_step // self.sample_every_n_steps
|
582 |
+
path_out = Path(self.logger.log_dir)/'images'
|
583 |
+
path_out.mkdir(parents=True, exist_ok=True)
|
584 |
+
# for 3D images use depth as batch :[D, C, H, W], never show more than 16+16 =32 images
|
585 |
+
def depth2batch(image):
|
586 |
+
return (image if image.ndim<5 else torch.swapaxes(image[0], 0, 1))
|
587 |
+
images = torch.cat([depth2batch(img)[:16] for img in (x, pred)])
|
588 |
+
save_image(images, path_out/f'sample_{log_step}.png', nrow=x.shape[0], normalize=True)
|
589 |
+
|
590 |
+
return loss
|
591 |
+
|
592 |
+
def configure_optimizers(self):
|
593 |
+
opt_vqvae = self.optimizer_vqvae(self.vqvae.parameters(), **self.optimizer_vqvae_kwargs)
|
594 |
+
opt_gan = self.optimizer_gan(self.discriminator.parameters(), **self.optimizer_gan_kwargs)
|
595 |
+
schedulers = []
|
596 |
+
if self.lr_scheduler_vqvae is not None:
|
597 |
+
schedulers.append({
|
598 |
+
'scheduler': self.lr_scheduler_vqvae(opt_vqvae, **self.lr_scheduler_vqvae_kwargs),
|
599 |
+
'interval': 'step',
|
600 |
+
'frequency': 1
|
601 |
+
})
|
602 |
+
if self.lr_scheduler_gan is not None:
|
603 |
+
schedulers.append({
|
604 |
+
'scheduler': self.lr_scheduler_gan(opt_gan, **self.lr_scheduler_gan_kwargs),
|
605 |
+
'interval': 'step',
|
606 |
+
'frequency': 1
|
607 |
+
})
|
608 |
+
return [opt_vqvae, opt_gan], schedulers
|
609 |
+
|
610 |
+
def compute_lambda(self, rec_loss, gan_loss, dec_out_layer, eps=1e-4):
|
611 |
+
"""Computes adaptive weight as proposed in eq. 7 of https://arxiv.org/abs/2012.09841"""
|
612 |
+
rec_grads = torch.autograd.grad(rec_loss, dec_out_layer.weight, retain_graph=True)[0]
|
613 |
+
gan_grads = torch.autograd.grad(gan_loss, dec_out_layer.weight, retain_graph=True)[0]
|
614 |
+
d_weight = torch.norm(rec_grads) / (torch.norm(gan_grads) + eps)
|
615 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4)
|
616 |
+
return d_weight.detach()
|
617 |
+
|
618 |
+
|
619 |
+
|
620 |
+
class VAE(BasicModel):
|
621 |
+
def __init__(
|
622 |
+
self,
|
623 |
+
in_channels=3,
|
624 |
+
out_channels=3,
|
625 |
+
spatial_dims = 2,
|
626 |
+
emb_channels = 4,
|
627 |
+
hid_chs = [ 64, 128, 256, 512],
|
628 |
+
kernel_sizes=[ 3, 3, 3, 3],
|
629 |
+
strides = [ 1, 2, 2, 2],
|
630 |
+
norm_name = ("GROUP", {'num_groups':8, "affine": True}),
|
631 |
+
act_name=("Swish", {}),
|
632 |
+
dropout=None,
|
633 |
+
use_res_block=True,
|
634 |
+
deep_supervision=False,
|
635 |
+
learnable_interpolation=True,
|
636 |
+
use_attention='none',
|
637 |
+
embedding_loss_weight=1e-6,
|
638 |
+
perceiver = LPIPS,
|
639 |
+
perceiver_kwargs = {},
|
640 |
+
perceptual_loss_weight = 1.0,
|
641 |
+
|
642 |
+
|
643 |
+
optimizer=torch.optim.Adam,
|
644 |
+
optimizer_kwargs={'lr':1e-4},
|
645 |
+
lr_scheduler= None,
|
646 |
+
lr_scheduler_kwargs={},
|
647 |
+
loss = torch.nn.L1Loss,
|
648 |
+
loss_kwargs={'reduction': 'none'},
|
649 |
+
|
650 |
+
sample_every_n_steps = 1000
|
651 |
+
|
652 |
+
):
|
653 |
+
super().__init__(
|
654 |
+
optimizer=optimizer,
|
655 |
+
optimizer_kwargs=optimizer_kwargs,
|
656 |
+
lr_scheduler=lr_scheduler,
|
657 |
+
lr_scheduler_kwargs=lr_scheduler_kwargs
|
658 |
+
)
|
659 |
+
self.sample_every_n_steps=sample_every_n_steps
|
660 |
+
self.loss_fct = loss(**loss_kwargs)
|
661 |
+
# self.ssim_fct = SSIM(data_range=1, size_average=False, channel=out_channels, spatial_dims=spatial_dims, nonnegative_ssim=True)
|
662 |
+
self.embedding_loss_weight = embedding_loss_weight
|
663 |
+
self.perceiver = perceiver(**perceiver_kwargs).eval() if perceiver is not None else None
|
664 |
+
self.perceptual_loss_weight = perceptual_loss_weight
|
665 |
+
use_attention = use_attention if isinstance(use_attention, list) else [use_attention]*len(strides)
|
666 |
+
self.depth = len(strides)
|
667 |
+
self.deep_supervision = deep_supervision
|
668 |
+
downsample_kernel_sizes = kernel_sizes
|
669 |
+
upsample_kernel_sizes = strides
|
670 |
+
|
671 |
+
# -------- Loss-Reg---------
|
672 |
+
# self.logvar = nn.Parameter(torch.zeros(size=()) )
|
673 |
+
|
674 |
+
# ----------- In-Convolution ------------
|
675 |
+
ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock
|
676 |
+
self.inc = ConvBlock(
|
677 |
+
spatial_dims,
|
678 |
+
in_channels,
|
679 |
+
hid_chs[0],
|
680 |
+
kernel_size=kernel_sizes[0],
|
681 |
+
stride=strides[0],
|
682 |
+
act_name=act_name,
|
683 |
+
norm_name=norm_name,
|
684 |
+
emb_channels=None
|
685 |
+
)
|
686 |
+
|
687 |
+
# ----------- Encoder ----------------
|
688 |
+
self.encoders = nn.ModuleList([
|
689 |
+
DownBlock(
|
690 |
+
spatial_dims = spatial_dims,
|
691 |
+
in_channels = hid_chs[i-1],
|
692 |
+
out_channels = hid_chs[i],
|
693 |
+
kernel_size = kernel_sizes[i],
|
694 |
+
stride = strides[i],
|
695 |
+
downsample_kernel_size = downsample_kernel_sizes[i],
|
696 |
+
norm_name = norm_name,
|
697 |
+
act_name = act_name,
|
698 |
+
dropout = dropout,
|
699 |
+
use_res_block = use_res_block,
|
700 |
+
learnable_interpolation = learnable_interpolation,
|
701 |
+
use_attention = use_attention[i],
|
702 |
+
emb_channels = None
|
703 |
+
)
|
704 |
+
for i in range(1, self.depth)
|
705 |
+
])
|
706 |
+
|
707 |
+
# ----------- Out-Encoder ------------
|
708 |
+
self.out_enc = nn.Sequential(
|
709 |
+
BasicBlock(spatial_dims, hid_chs[-1], 2*emb_channels, 3),
|
710 |
+
BasicBlock(spatial_dims, 2*emb_channels, 2*emb_channels, 1)
|
711 |
+
)
|
712 |
+
|
713 |
+
|
714 |
+
# ----------- Reparameterization --------------
|
715 |
+
self.quantizer = DiagonalGaussianDistribution()
|
716 |
+
|
717 |
+
|
718 |
+
# ----------- In-Decoder ------------
|
719 |
+
self.inc_dec = ConvBlock(spatial_dims, emb_channels, hid_chs[-1], 3, act_name=act_name, norm_name=norm_name)
|
720 |
+
|
721 |
+
# ------------ Decoder ----------
|
722 |
+
self.decoders = nn.ModuleList([
|
723 |
+
UpBlock(
|
724 |
+
spatial_dims = spatial_dims,
|
725 |
+
in_channels = hid_chs[i+1],
|
726 |
+
out_channels = hid_chs[i],
|
727 |
+
kernel_size=kernel_sizes[i+1],
|
728 |
+
stride=strides[i+1],
|
729 |
+
upsample_kernel_size=upsample_kernel_sizes[i+1],
|
730 |
+
norm_name=norm_name,
|
731 |
+
act_name=act_name,
|
732 |
+
dropout=dropout,
|
733 |
+
use_res_block=use_res_block,
|
734 |
+
learnable_interpolation=learnable_interpolation,
|
735 |
+
use_attention=use_attention[i],
|
736 |
+
emb_channels=None,
|
737 |
+
skip_channels=0
|
738 |
+
)
|
739 |
+
for i in range(self.depth-1)
|
740 |
+
])
|
741 |
+
|
742 |
+
# --------------- Out-Convolution ----------------
|
743 |
+
self.outc = BasicBlock(spatial_dims, hid_chs[0], out_channels, 1, zero_conv=True)
|
744 |
+
if isinstance(deep_supervision, bool):
|
745 |
+
deep_supervision = self.depth-1 if deep_supervision else 0
|
746 |
+
self.outc_ver = nn.ModuleList([
|
747 |
+
BasicBlock(spatial_dims, hid_chs[i], out_channels, 1, zero_conv=True)
|
748 |
+
for i in range(1, deep_supervision+1)
|
749 |
+
])
|
750 |
+
# self.logvar_ver = nn.ParameterList([
|
751 |
+
# nn.Parameter(torch.zeros(size=()) )
|
752 |
+
# for _ in range(1, deep_supervision+1)
|
753 |
+
# ])
|
754 |
+
|
755 |
+
|
756 |
+
def encode(self, x):
|
757 |
+
h = self.inc(x)
|
758 |
+
for i in range(len(self.encoders)):
|
759 |
+
h = self.encoders[i](h)
|
760 |
+
z = self.out_enc(h)
|
761 |
+
z, _ = self.quantizer(z)
|
762 |
+
return z
|
763 |
+
|
764 |
+
def decode(self, z):
|
765 |
+
h = self.inc_dec(z)
|
766 |
+
for i in range(len(self.decoders), 0, -1):
|
767 |
+
h = self.decoders[i-1](h)
|
768 |
+
x = self.outc(h)
|
769 |
+
return x
|
770 |
+
|
771 |
+
def forward(self, x_in):
|
772 |
+
# --------- Encoder --------------
|
773 |
+
h = self.inc(x_in)
|
774 |
+
for i in range(len(self.encoders)):
|
775 |
+
h = self.encoders[i](h)
|
776 |
+
z = self.out_enc(h)
|
777 |
+
|
778 |
+
# --------- Quantizer --------------
|
779 |
+
z_q, emb_loss = self.quantizer(z)
|
780 |
+
|
781 |
+
# -------- Decoder -----------
|
782 |
+
out_hor = []
|
783 |
+
h = self.inc_dec(z_q)
|
784 |
+
for i in range(len(self.decoders)-1, -1, -1):
|
785 |
+
out_hor.append(self.outc_ver[i](h)) if i < len(self.outc_ver) else None
|
786 |
+
h = self.decoders[i](h)
|
787 |
+
out = self.outc(h)
|
788 |
+
|
789 |
+
return out, out_hor[::-1], emb_loss
|
790 |
+
|
791 |
+
def perception_loss(self, pred, target, depth=0):
|
792 |
+
if (self.perceiver is not None) and (depth<2):
|
793 |
+
self.perceiver.eval()
|
794 |
+
return self.perceiver(pred, target)*self.perceptual_loss_weight
|
795 |
+
else:
|
796 |
+
return 0
|
797 |
+
|
798 |
+
def ssim_loss(self, pred, target):
|
799 |
+
return 1-ssim(((pred+1)/2).clamp(0,1), (target.type(pred.dtype)+1)/2, data_range=1, size_average=False,
|
800 |
+
nonnegative_ssim=True).reshape(-1, *[1]*(pred.ndim-1))
|
801 |
+
|
802 |
+
def rec_loss(self, pred, pred_vertical, target):
|
803 |
+
interpolation_mode = 'nearest-exact'
|
804 |
+
|
805 |
+
# Loss
|
806 |
+
loss = 0
|
807 |
+
rec_loss = self.loss_fct(pred, target)+self.perception_loss(pred, target)+self.ssim_loss(pred, target)
|
808 |
+
# rec_loss = rec_loss/ torch.exp(self.logvar) + self.logvar # Note this is include in Stable-Diffusion but logvar is not used in optimizer
|
809 |
+
loss += torch.sum(rec_loss)/pred.shape[0]
|
810 |
+
|
811 |
+
|
812 |
+
for i, pred_i in enumerate(pred_vertical):
|
813 |
+
target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
|
814 |
+
rec_loss_i = self.loss_fct(pred_i, target_i)+self.perception_loss(pred_i, target_i)+self.ssim_loss(pred_i, target_i)
|
815 |
+
# rec_loss_i = rec_loss_i/ torch.exp(self.logvar_ver[i]) + self.logvar_ver[i]
|
816 |
+
loss += torch.sum(rec_loss_i)/pred.shape[0]
|
817 |
+
|
818 |
+
return loss
|
819 |
+
|
820 |
+
def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
|
821 |
+
# ------------------------- Get Source/Target ---------------------------
|
822 |
+
x = batch['source']
|
823 |
+
target = x
|
824 |
+
|
825 |
+
# ------------------------- Run Model ---------------------------
|
826 |
+
pred, pred_vertical, emb_loss = self(x)
|
827 |
+
|
828 |
+
# ------------------------- Compute Loss ---------------------------
|
829 |
+
loss = self.rec_loss(pred, pred_vertical, target)
|
830 |
+
loss += emb_loss*self.embedding_loss_weight
|
831 |
+
|
832 |
+
# --------------------- Compute Metrics -------------------------------
|
833 |
+
with torch.no_grad():
|
834 |
+
logging_dict = {'loss':loss, 'emb_loss': emb_loss}
|
835 |
+
logging_dict['L2'] = torch.nn.functional.mse_loss(pred, target)
|
836 |
+
logging_dict['L1'] = torch.nn.functional.l1_loss(pred, target)
|
837 |
+
logging_dict['ssim'] = ssim((pred+1)/2, (target.type(pred.dtype)+1)/2, data_range=1)
|
838 |
+
# logging_dict['logvar'] = self.logvar
|
839 |
+
|
840 |
+
# ----------------- Log Scalars ----------------------
|
841 |
+
for metric_name, metric_val in logging_dict.items():
|
842 |
+
self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True)
|
843 |
+
|
844 |
+
# ----------------- Save Image ------------------------------
|
845 |
+
if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0:
|
846 |
+
log_step = self.global_step // self.sample_every_n_steps
|
847 |
+
path_out = Path(self.logger.log_dir)/'images'
|
848 |
+
path_out.mkdir(parents=True, exist_ok=True)
|
849 |
+
# for 3D images use depth as batch :[D, C, H, W], never show more than 16+16 =32 images
|
850 |
+
def depth2batch(image):
|
851 |
+
return (image if image.ndim<5 else torch.swapaxes(image[0], 0, 1))
|
852 |
+
images = torch.cat([depth2batch(img)[:16] for img in (x, pred)])
|
853 |
+
save_image(images, path_out/f'sample_{log_step}.png', nrow=x.shape[0], normalize=True)
|
854 |
+
|
855 |
+
return loss
|
856 |
+
|
857 |
+
|
858 |
+
|
859 |
+
|
860 |
+
class VAEGAN(VeryBasicModel):
|
861 |
+
def __init__(
|
862 |
+
self,
|
863 |
+
in_channels=3,
|
864 |
+
out_channels=3,
|
865 |
+
spatial_dims = 2,
|
866 |
+
emb_channels = 4,
|
867 |
+
hid_chs = [ 64, 128, 256, 512],
|
868 |
+
kernel_sizes=[ 3, 3, 3, 3],
|
869 |
+
strides = [ 1, 2, 2, 2],
|
870 |
+
norm_name = ("GROUP", {'num_groups':8, "affine": True}),
|
871 |
+
act_name=("Swish", {}),
|
872 |
+
dropout=0.0,
|
873 |
+
use_res_block=True,
|
874 |
+
deep_supervision=False,
|
875 |
+
learnable_interpolation=True,
|
876 |
+
use_attention='none',
|
877 |
+
embedding_loss_weight=1e-6,
|
878 |
+
perceiver = LPIPS,
|
879 |
+
perceiver_kwargs = {},
|
880 |
+
perceptual_loss_weight = 1.0,
|
881 |
+
|
882 |
+
|
883 |
+
start_gan_train_step = 50000, # NOTE step increase with each optimizer
|
884 |
+
gan_loss_weight: float = 1.0, # = discriminator
|
885 |
+
|
886 |
+
optimizer_vqvae=torch.optim.Adam,
|
887 |
+
optimizer_gan=torch.optim.Adam,
|
888 |
+
optimizer_vqvae_kwargs={'lr':1e-6}, # 'weight_decay':1e-2, {'lr':1e-6, 'betas':(0.5, 0.9)}
|
889 |
+
optimizer_gan_kwargs={'lr':1e-6}, # 'weight_decay':1e-2,
|
890 |
+
lr_scheduler_vqvae= None,
|
891 |
+
lr_scheduler_vqvae_kwargs={},
|
892 |
+
lr_scheduler_gan= None,
|
893 |
+
lr_scheduler_gan_kwargs={},
|
894 |
+
|
895 |
+
pixel_loss = torch.nn.L1Loss,
|
896 |
+
pixel_loss_kwargs={'reduction':'none'},
|
897 |
+
gan_loss_fct = hinge_d_loss,
|
898 |
+
|
899 |
+
sample_every_n_steps = 1000
|
900 |
+
|
901 |
+
):
|
902 |
+
super().__init__()
|
903 |
+
self.sample_every_n_steps=sample_every_n_steps
|
904 |
+
self.start_gan_train_step = start_gan_train_step
|
905 |
+
self.gan_loss_weight = gan_loss_weight
|
906 |
+
self.embedding_loss_weight = embedding_loss_weight
|
907 |
+
|
908 |
+
self.optimizer_vqvae = optimizer_vqvae
|
909 |
+
self.optimizer_gan = optimizer_gan
|
910 |
+
self.optimizer_vqvae_kwargs = optimizer_vqvae_kwargs
|
911 |
+
self.optimizer_gan_kwargs = optimizer_gan_kwargs
|
912 |
+
self.lr_scheduler_vqvae = lr_scheduler_vqvae
|
913 |
+
self.lr_scheduler_vqvae_kwargs = lr_scheduler_vqvae_kwargs
|
914 |
+
self.lr_scheduler_gan = lr_scheduler_gan
|
915 |
+
self.lr_scheduler_gan_kwargs = lr_scheduler_gan_kwargs
|
916 |
+
|
917 |
+
self.pixel_loss_fct = pixel_loss(**pixel_loss_kwargs)
|
918 |
+
self.gan_loss_fct = gan_loss_fct
|
919 |
+
|
920 |
+
self.vqvae = VAE(in_channels, out_channels, spatial_dims, emb_channels, hid_chs, kernel_sizes,
|
921 |
+
strides, norm_name, act_name, dropout, use_res_block, deep_supervision, learnable_interpolation, use_attention,
|
922 |
+
embedding_loss_weight, perceiver, perceiver_kwargs, perceptual_loss_weight)
|
923 |
+
|
924 |
+
self.discriminator = nn.ModuleList([Discriminator(in_channels, spatial_dims, hid_chs, kernel_sizes, strides,
|
925 |
+
act_name, norm_name, dropout) for i in range(len(self.vqvae.outc_ver)+1)])
|
926 |
+
|
927 |
+
|
928 |
+
# self.discriminator = nn.ModuleList([NLayerDiscriminator(in_channels, spatial_dims)
|
929 |
+
# for _ in range(len(self.vqvae.outc_ver)+1)])
|
930 |
+
|
931 |
+
|
932 |
+
|
933 |
+
def encode(self, x):
|
934 |
+
return self.vqvae.encode(x)
|
935 |
+
|
936 |
+
def decode(self, z):
|
937 |
+
return self.vqvae.decode(z)
|
938 |
+
|
939 |
+
def forward(self, x):
|
940 |
+
return self.vqvae.forward(x)
|
941 |
+
|
942 |
+
|
943 |
+
def vae_img_loss(self, pred, target, dec_out_layer, step, discriminator, depth=0):
|
944 |
+
# ------ VQVAE -------
|
945 |
+
rec_loss = self.vqvae.rec_loss(pred, [], target)
|
946 |
+
|
947 |
+
# ------- GAN -----
|
948 |
+
if (step > self.start_gan_train_step) and (depth<2):
|
949 |
+
gan_loss = -torch.sum(discriminator[depth](pred)) # clamp(..., None, 0) => only punish areas that were rated as fake (<0) by discriminator => ensures loss >0 and +- don't cannel out in sum
|
950 |
+
lambda_weight = self.compute_lambda(rec_loss, gan_loss, dec_out_layer)
|
951 |
+
gan_loss = gan_loss*lambda_weight
|
952 |
+
|
953 |
+
with torch.no_grad():
|
954 |
+
self.log(f"train/gan_loss_{depth}", gan_loss, on_step=True, on_epoch=True)
|
955 |
+
self.log(f"train/lambda_{depth}", lambda_weight, on_step=True, on_epoch=True)
|
956 |
+
else:
|
957 |
+
gan_loss = 0 #torch.tensor([0.0], requires_grad=True, device=target.device)
|
958 |
+
|
959 |
+
|
960 |
+
|
961 |
+
return self.gan_loss_weight*gan_loss+rec_loss
|
962 |
+
|
963 |
+
def gan_img_loss(self, pred, target, step, discriminators, depth):
|
964 |
+
if (step > self.start_gan_train_step) and (depth<len(discriminators)):
|
965 |
+
logits_real = discriminators[depth](target.detach())
|
966 |
+
logits_fake = discriminators[depth](pred.detach())
|
967 |
+
loss = self.gan_loss_fct(logits_real, logits_fake)
|
968 |
+
else:
|
969 |
+
loss = torch.tensor(0.0, requires_grad=True, device=target.device)
|
970 |
+
|
971 |
+
with torch.no_grad():
|
972 |
+
self.log(f"train/loss_1_{depth}", loss, on_step=True, on_epoch=True)
|
973 |
+
return loss
|
974 |
+
|
975 |
+
def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
|
976 |
+
# ------------------------- Get Source/Target ---------------------------
|
977 |
+
x = batch['source']
|
978 |
+
target = x
|
979 |
+
|
980 |
+
# ------------------------- Run Model ---------------------------
|
981 |
+
pred, pred_vertical, emb_loss = self(x)
|
982 |
+
|
983 |
+
# ------------------------- Compute Loss ---------------------------
|
984 |
+
interpolation_mode = 'area'
|
985 |
+
logging_dict = {}
|
986 |
+
|
987 |
+
if optimizer_idx == 0:
|
988 |
+
# Horizontal/Top Layer
|
989 |
+
img_loss = self.vae_img_loss(pred, target, self.vqvae.outc.conv, step, self.discriminator, 0)
|
990 |
+
|
991 |
+
# Vertical/Deep Layer
|
992 |
+
for i, pred_i in enumerate(pred_vertical):
|
993 |
+
target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
|
994 |
+
img_loss += self.vae_img_loss(pred_i, target_i, self.vqvae.outc_ver[i].conv, step, self.discriminator, i+1)
|
995 |
+
loss = img_loss+self.embedding_loss_weight*emb_loss
|
996 |
+
|
997 |
+
with torch.no_grad():
|
998 |
+
logging_dict[f'img_loss'] = img_loss
|
999 |
+
logging_dict[f'emb_loss'] = emb_loss
|
1000 |
+
logging_dict['loss_0'] = loss
|
1001 |
+
|
1002 |
+
elif optimizer_idx == 1:
|
1003 |
+
# Horizontal/Top Layer
|
1004 |
+
loss = self.gan_img_loss(pred, target, step, self.discriminator, 0)
|
1005 |
+
|
1006 |
+
# Vertical/Deep Layer
|
1007 |
+
for i, pred_i in enumerate(pred_vertical):
|
1008 |
+
target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
|
1009 |
+
loss += self.gan_img_loss(pred_i, target_i, step, self.discriminator, i+1)
|
1010 |
+
|
1011 |
+
with torch.no_grad():
|
1012 |
+
logging_dict['loss_1'] = loss
|
1013 |
+
|
1014 |
+
|
1015 |
+
# --------------------- Compute Metrics -------------------------------
|
1016 |
+
with torch.no_grad():
|
1017 |
+
logging_dict['loss'] = loss
|
1018 |
+
logging_dict[f'L2'] = torch.nn.functional.mse_loss(pred, x)
|
1019 |
+
logging_dict[f'L1'] = torch.nn.functional.l1_loss(pred, x)
|
1020 |
+
logging_dict['ssim'] = ssim((pred+1)/2, (target.type(pred.dtype)+1)/2, data_range=1)
|
1021 |
+
# logging_dict['logvar'] = self.vqvae.logvar
|
1022 |
+
|
1023 |
+
# ----------------- Log Scalars ----------------------
|
1024 |
+
for metric_name, metric_val in logging_dict.items():
|
1025 |
+
self.log(f"{state}/{metric_name}", metric_val, batch_size=x.shape[0], on_step=True, on_epoch=True)
|
1026 |
+
|
1027 |
+
# ----------------- Save Image ------------------------------
|
1028 |
+
if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0: # NOTE: step 1 (opt1) , step=2 (opt2), step=3 (opt1), ...
|
1029 |
+
|
1030 |
+
log_step = self.global_step // self.sample_every_n_steps
|
1031 |
+
path_out = Path(self.logger.log_dir)/'images'
|
1032 |
+
path_out.mkdir(parents=True, exist_ok=True)
|
1033 |
+
# for 3D images use depth as batch :[D, C, H, W], never show more than 16+16 =32 images
|
1034 |
+
def depth2batch(image):
|
1035 |
+
return (image if image.ndim<5 else torch.swapaxes(image[0], 0, 1))
|
1036 |
+
images = torch.cat([depth2batch(img)[:16] for img in (x, pred)])
|
1037 |
+
save_image(images, path_out/f'sample_{log_step}.png', nrow=x.shape[0], normalize=True)
|
1038 |
+
|
1039 |
+
return loss
|
1040 |
+
|
1041 |
+
def configure_optimizers(self):
|
1042 |
+
opt_vqvae = self.optimizer_vqvae(self.vqvae.parameters(), **self.optimizer_vqvae_kwargs)
|
1043 |
+
opt_gan = self.optimizer_gan(self.discriminator.parameters(), **self.optimizer_gan_kwargs)
|
1044 |
+
schedulers = []
|
1045 |
+
if self.lr_scheduler_vqvae is not None:
|
1046 |
+
schedulers.append({
|
1047 |
+
'scheduler': self.lr_scheduler_vqvae(opt_vqvae, **self.lr_scheduler_vqvae_kwargs),
|
1048 |
+
'interval': 'step',
|
1049 |
+
'frequency': 1
|
1050 |
+
})
|
1051 |
+
if self.lr_scheduler_gan is not None:
|
1052 |
+
schedulers.append({
|
1053 |
+
'scheduler': self.lr_scheduler_gan(opt_gan, **self.lr_scheduler_gan_kwargs),
|
1054 |
+
'interval': 'step',
|
1055 |
+
'frequency': 1
|
1056 |
+
})
|
1057 |
+
return [opt_vqvae, opt_gan], schedulers
|
1058 |
+
|
1059 |
+
def compute_lambda(self, rec_loss, gan_loss, dec_out_layer, eps=1e-4):
|
1060 |
+
"""Computes adaptive weight as proposed in eq. 7 of https://arxiv.org/abs/2012.09841"""
|
1061 |
+
rec_grads = torch.autograd.grad(rec_loss, dec_out_layer.weight, retain_graph=True)[0]
|
1062 |
+
gan_grads = torch.autograd.grad(gan_loss, dec_out_layer.weight, retain_graph=True)[0]
|
1063 |
+
d_weight = torch.norm(rec_grads) / (torch.norm(gan_grads) + eps)
|
1064 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4)
|
1065 |
+
return d_weight.detach()
|
medical_diffusion/models/embedders/time_embedder.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from monai.networks.layers.utils import get_act_layer
|
6 |
+
|
7 |
+
class SinusoidalPosEmb(nn.Module):
|
8 |
+
def __init__(self, emb_dim=16, downscale_freq_shift=1, max_period=10000, flip_sin_to_cos=False):
|
9 |
+
super().__init__()
|
10 |
+
self.emb_dim = emb_dim
|
11 |
+
self.downscale_freq_shift = downscale_freq_shift
|
12 |
+
self.max_period = max_period
|
13 |
+
self.flip_sin_to_cos=flip_sin_to_cos
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
device = x.device
|
17 |
+
half_dim = self.emb_dim // 2
|
18 |
+
emb = math.log(self.max_period) / (half_dim - self.downscale_freq_shift)
|
19 |
+
emb = torch.exp(-emb*torch.arange(half_dim, device=device))
|
20 |
+
emb = x[:, None] * emb[None, :]
|
21 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
22 |
+
|
23 |
+
if self.flip_sin_to_cos:
|
24 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
25 |
+
|
26 |
+
if self.emb_dim % 2 == 1:
|
27 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
28 |
+
return emb
|
29 |
+
|
30 |
+
|
31 |
+
class LearnedSinusoidalPosEmb(nn.Module):
|
32 |
+
""" following @crowsonkb 's lead with learned sinusoidal pos emb """
|
33 |
+
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
|
34 |
+
|
35 |
+
def __init__(self, emb_dim):
|
36 |
+
super().__init__()
|
37 |
+
self.emb_dim = emb_dim
|
38 |
+
half_dim = emb_dim // 2
|
39 |
+
self.weights = nn.Parameter(torch.randn(half_dim))
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x = x[:, None]
|
43 |
+
freqs = x * self.weights[None, :] * 2 * math.pi
|
44 |
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
|
45 |
+
fouriered = torch.cat((x, fouriered), dim = -1)
|
46 |
+
if self.emb_dim % 2 == 1:
|
47 |
+
fouriered = torch.nn.functional.pad(fouriered, (0, 1, 0, 0))
|
48 |
+
return fouriered
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
class TimeEmbbeding(nn.Module):
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
emb_dim = 64,
|
56 |
+
pos_embedder = SinusoidalPosEmb,
|
57 |
+
pos_embedder_kwargs = {},
|
58 |
+
act_name=("SWISH", {}) # Swish = SiLU
|
59 |
+
):
|
60 |
+
super().__init__()
|
61 |
+
self.emb_dim = emb_dim
|
62 |
+
self.pos_emb_dim = pos_embedder_kwargs.get('emb_dim', emb_dim//4)
|
63 |
+
pos_embedder_kwargs['emb_dim'] = self.pos_emb_dim
|
64 |
+
self.pos_embedder = pos_embedder(**pos_embedder_kwargs)
|
65 |
+
|
66 |
+
|
67 |
+
self.time_emb = nn.Sequential(
|
68 |
+
self.pos_embedder,
|
69 |
+
nn.Linear(self.pos_emb_dim, self.emb_dim),
|
70 |
+
get_act_layer(act_name),
|
71 |
+
nn.Linear(self.emb_dim, self.emb_dim)
|
72 |
+
)
|
73 |
+
|
74 |
+
def forward(self, time):
|
75 |
+
return self.time_emb(time)
|
medical_diffusion/models/estimators/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .unet2 import UNet
|
medical_diffusion/models/estimators/unet.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from monai.networks.blocks import UnetOutBlock
|
5 |
+
|
6 |
+
from medical_diffusion.models.utils.conv_blocks import BasicBlock, UpBlock, DownBlock, UnetBasicBlock, UnetResBlock, save_add
|
7 |
+
from medical_diffusion.models.embedders import TimeEmbbeding
|
8 |
+
from medical_diffusion.models.utils.attention_blocks import SpatialTransformer, LinearTransformer
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
class UNet(nn.Module):
|
16 |
+
|
17 |
+
def __init__(self,
|
18 |
+
in_ch=1,
|
19 |
+
out_ch=1,
|
20 |
+
spatial_dims = 3,
|
21 |
+
hid_chs = [32, 64, 128, 256],
|
22 |
+
kernel_sizes=[ 1, 3, 3, 3],
|
23 |
+
strides = [ 1, 2, 2, 2],
|
24 |
+
downsample_kernel_sizes = None,
|
25 |
+
upsample_kernel_sizes = None,
|
26 |
+
act_name=("SWISH", {}),
|
27 |
+
norm_name = ("GROUP", {'num_groups':32, "affine": True}),
|
28 |
+
time_embedder=TimeEmbbeding,
|
29 |
+
time_embedder_kwargs={},
|
30 |
+
cond_embedder=None,
|
31 |
+
cond_embedder_kwargs={},
|
32 |
+
deep_supervision=True, # True = all but last layer, 0/False=disable, 1=only first layer, ...
|
33 |
+
use_res_block=True,
|
34 |
+
estimate_variance=False ,
|
35 |
+
use_self_conditioning = False,
|
36 |
+
dropout=0.0,
|
37 |
+
learnable_interpolation=True,
|
38 |
+
use_attention='none',
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
use_attention = use_attention if isinstance(use_attention, list) else [use_attention]*len(strides)
|
42 |
+
self.use_self_conditioning = use_self_conditioning
|
43 |
+
self.use_res_block = use_res_block
|
44 |
+
self.depth = len(strides)
|
45 |
+
if downsample_kernel_sizes is None:
|
46 |
+
downsample_kernel_sizes = kernel_sizes
|
47 |
+
if upsample_kernel_sizes is None:
|
48 |
+
upsample_kernel_sizes = strides
|
49 |
+
|
50 |
+
|
51 |
+
# ------------- Time-Embedder-----------
|
52 |
+
if time_embedder is not None:
|
53 |
+
self.time_embedder=time_embedder(**time_embedder_kwargs)
|
54 |
+
time_emb_dim = self.time_embedder.emb_dim
|
55 |
+
else:
|
56 |
+
self.time_embedder = None
|
57 |
+
|
58 |
+
# ------------- Condition-Embedder-----------
|
59 |
+
if cond_embedder is not None:
|
60 |
+
self.cond_embedder=cond_embedder(**cond_embedder_kwargs)
|
61 |
+
else:
|
62 |
+
self.cond_embedder = None
|
63 |
+
|
64 |
+
# ----------- In-Convolution ------------
|
65 |
+
in_ch = in_ch*2 if self.use_self_conditioning else in_ch
|
66 |
+
ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock
|
67 |
+
self.inc = ConvBlock(
|
68 |
+
spatial_dims = spatial_dims,
|
69 |
+
in_channels = in_ch,
|
70 |
+
out_channels = hid_chs[0],
|
71 |
+
kernel_size=kernel_sizes[0],
|
72 |
+
stride=strides[0],
|
73 |
+
act_name=act_name,
|
74 |
+
norm_name=norm_name,
|
75 |
+
emb_channels=time_emb_dim
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
# ----------- Encoder ----------------
|
80 |
+
self.encoders = nn.ModuleList([
|
81 |
+
DownBlock(
|
82 |
+
spatial_dims = spatial_dims,
|
83 |
+
in_channels = hid_chs[i-1],
|
84 |
+
out_channels = hid_chs[i],
|
85 |
+
kernel_size = kernel_sizes[i],
|
86 |
+
stride = strides[i],
|
87 |
+
downsample_kernel_size = downsample_kernel_sizes[i],
|
88 |
+
norm_name = norm_name,
|
89 |
+
act_name = act_name,
|
90 |
+
dropout = dropout,
|
91 |
+
use_res_block = use_res_block,
|
92 |
+
learnable_interpolation = learnable_interpolation,
|
93 |
+
use_attention = use_attention[i],
|
94 |
+
emb_channels = time_emb_dim
|
95 |
+
)
|
96 |
+
for i in range(1, self.depth)
|
97 |
+
])
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
# ------------ Decoder ----------
|
102 |
+
self.decoders = nn.ModuleList([
|
103 |
+
UpBlock(
|
104 |
+
spatial_dims = spatial_dims,
|
105 |
+
in_channels = hid_chs[i+1],
|
106 |
+
out_channels = hid_chs[i],
|
107 |
+
kernel_size=kernel_sizes[i+1],
|
108 |
+
stride=strides[i+1],
|
109 |
+
upsample_kernel_size=upsample_kernel_sizes[i+1],
|
110 |
+
norm_name=norm_name,
|
111 |
+
act_name=act_name,
|
112 |
+
dropout=dropout,
|
113 |
+
use_res_block=use_res_block,
|
114 |
+
learnable_interpolation=learnable_interpolation,
|
115 |
+
use_attention=use_attention[i],
|
116 |
+
emb_channels=time_emb_dim,
|
117 |
+
skip_channels=hid_chs[i]
|
118 |
+
)
|
119 |
+
for i in range(self.depth-1)
|
120 |
+
])
|
121 |
+
|
122 |
+
|
123 |
+
# --------------- Out-Convolution ----------------
|
124 |
+
out_ch_hor = out_ch*2 if estimate_variance else out_ch
|
125 |
+
self.outc = UnetOutBlock(spatial_dims, hid_chs[0], out_ch_hor, dropout=None)
|
126 |
+
if isinstance(deep_supervision, bool):
|
127 |
+
deep_supervision = self.depth-1 if deep_supervision else 0
|
128 |
+
self.outc_ver = nn.ModuleList([
|
129 |
+
UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None)
|
130 |
+
for i in range(1, deep_supervision+1)
|
131 |
+
])
|
132 |
+
|
133 |
+
|
134 |
+
def forward(self, x_t, t=None, condition=None, self_cond=None):
|
135 |
+
# x_t [B, C, *]
|
136 |
+
# t [B,]
|
137 |
+
# condition [B,]
|
138 |
+
# self_cond [B, C, *]
|
139 |
+
x = [ None for _ in range(len(self.encoders)+1) ]
|
140 |
+
|
141 |
+
# -------- Time Embedding (Global) -----------
|
142 |
+
if t is None:
|
143 |
+
time_emb = None
|
144 |
+
else:
|
145 |
+
time_emb = self.time_embedder(t) # [B, C]
|
146 |
+
|
147 |
+
# -------- Condition Embedding (Global) -----------
|
148 |
+
if (condition is None) or (self.cond_embedder is None):
|
149 |
+
cond_emb = None
|
150 |
+
else:
|
151 |
+
cond_emb = self.cond_embedder(condition) # [B, C]
|
152 |
+
|
153 |
+
# ----------- Embedding Summation --------
|
154 |
+
emb = save_add(time_emb, cond_emb)
|
155 |
+
|
156 |
+
# ---------- Self-conditioning-----------
|
157 |
+
if self.use_self_conditioning:
|
158 |
+
self_cond = torch.zeros_like(x_t) if self_cond is None else x_t
|
159 |
+
x_t = torch.cat([x_t, self_cond], dim=1)
|
160 |
+
|
161 |
+
# -------- In-Convolution --------------
|
162 |
+
x[0] = self.inc(x_t, emb)
|
163 |
+
|
164 |
+
# --------- Encoder --------------
|
165 |
+
for i in range(len(self.encoders)):
|
166 |
+
x[i+1] = self.encoders[i](x[i], emb)
|
167 |
+
|
168 |
+
# -------- Decoder -----------
|
169 |
+
for i in range(len(self.decoders), 0, -1):
|
170 |
+
x[i-1] = self.decoders[i-1](x[i], x[i-1], emb)
|
171 |
+
|
172 |
+
# ---------Out-Convolution ------------
|
173 |
+
y = self.outc(x[0])
|
174 |
+
y_ver = [outc_ver_i(x[i+1]) for i, outc_ver_i in enumerate(self.outc_ver)]
|
175 |
+
|
176 |
+
return y, y_ver
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
if __name__=='__main__':
|
182 |
+
model = UNet(in_ch=3, use_res_block=False, learnable_interpolation=False)
|
183 |
+
input = torch.randn((1,3,16,128,128))
|
184 |
+
time = torch.randn((1,))
|
185 |
+
out_hor, out_ver = model(input, time)
|
186 |
+
print(out_hor[0].shape)
|
medical_diffusion/models/estimators/unet2.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from monai.networks.blocks import UnetOutBlock
|
5 |
+
|
6 |
+
from medical_diffusion.models.utils.conv_blocks import BasicBlock, UpBlock, DownBlock, UnetBasicBlock, UnetResBlock, save_add, BasicDown, BasicUp, SequentialEmb
|
7 |
+
from medical_diffusion.models.embedders import TimeEmbbeding
|
8 |
+
from medical_diffusion.models.utils.attention_blocks import Attention, zero_module
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
class UNet(nn.Module):
|
16 |
+
|
17 |
+
def __init__(self,
|
18 |
+
in_ch=1,
|
19 |
+
out_ch=1,
|
20 |
+
spatial_dims = 3,
|
21 |
+
hid_chs = [256, 256, 512, 1024],
|
22 |
+
kernel_sizes=[ 3, 3, 3, 3],
|
23 |
+
strides = [ 1, 2, 2, 2], # WARNING, last stride is ignored (follows OpenAI)
|
24 |
+
act_name=("SWISH", {}),
|
25 |
+
norm_name = ("GROUP", {'num_groups':32, "affine": True}),
|
26 |
+
time_embedder=TimeEmbbeding,
|
27 |
+
time_embedder_kwargs={},
|
28 |
+
cond_embedder=None,
|
29 |
+
cond_embedder_kwargs={},
|
30 |
+
deep_supervision=True, # True = all but last layer, 0/False=disable, 1=only first layer, ...
|
31 |
+
use_res_block=True,
|
32 |
+
estimate_variance=False ,
|
33 |
+
use_self_conditioning = False,
|
34 |
+
dropout=0.0,
|
35 |
+
learnable_interpolation=True,
|
36 |
+
use_attention='none',
|
37 |
+
num_res_blocks=2,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
use_attention = use_attention if isinstance(use_attention, list) else [use_attention]*len(strides)
|
41 |
+
self.use_self_conditioning = use_self_conditioning
|
42 |
+
self.use_res_block = use_res_block
|
43 |
+
self.depth = len(strides)
|
44 |
+
self.num_res_blocks = num_res_blocks
|
45 |
+
|
46 |
+
# ------------- Time-Embedder-----------
|
47 |
+
if time_embedder is not None:
|
48 |
+
self.time_embedder=time_embedder(**time_embedder_kwargs)
|
49 |
+
time_emb_dim = self.time_embedder.emb_dim
|
50 |
+
else:
|
51 |
+
self.time_embedder = None
|
52 |
+
time_emb_dim = None
|
53 |
+
|
54 |
+
# ------------- Condition-Embedder-----------
|
55 |
+
if cond_embedder is not None:
|
56 |
+
self.cond_embedder=cond_embedder(**cond_embedder_kwargs)
|
57 |
+
cond_emb_dim = self.cond_embedder.emb_dim
|
58 |
+
else:
|
59 |
+
self.cond_embedder = None
|
60 |
+
cond_emb_dim = None
|
61 |
+
|
62 |
+
|
63 |
+
ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock
|
64 |
+
|
65 |
+
# ----------- In-Convolution ------------
|
66 |
+
in_ch = in_ch*2 if self.use_self_conditioning else in_ch
|
67 |
+
self.in_conv = BasicBlock(spatial_dims, in_ch, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0])
|
68 |
+
|
69 |
+
|
70 |
+
# ----------- Encoder ------------
|
71 |
+
in_blocks = []
|
72 |
+
for i in range(1, self.depth):
|
73 |
+
for k in range(num_res_blocks):
|
74 |
+
seq_list = []
|
75 |
+
seq_list.append(
|
76 |
+
ConvBlock(
|
77 |
+
spatial_dims=spatial_dims,
|
78 |
+
in_channels=hid_chs[i-1 if k==0 else i],
|
79 |
+
out_channels=hid_chs[i],
|
80 |
+
kernel_size=kernel_sizes[i],
|
81 |
+
stride=1,
|
82 |
+
norm_name=norm_name,
|
83 |
+
act_name=act_name,
|
84 |
+
dropout=dropout,
|
85 |
+
emb_channels=time_emb_dim
|
86 |
+
)
|
87 |
+
)
|
88 |
+
|
89 |
+
seq_list.append(
|
90 |
+
Attention(
|
91 |
+
spatial_dims=spatial_dims,
|
92 |
+
in_channels=hid_chs[i],
|
93 |
+
out_channels=hid_chs[i],
|
94 |
+
num_heads=8,
|
95 |
+
ch_per_head=hid_chs[i]//8,
|
96 |
+
depth=1,
|
97 |
+
norm_name=norm_name,
|
98 |
+
dropout=dropout,
|
99 |
+
emb_dim=time_emb_dim,
|
100 |
+
attention_type=use_attention[i]
|
101 |
+
)
|
102 |
+
)
|
103 |
+
in_blocks.append(SequentialEmb(*seq_list))
|
104 |
+
|
105 |
+
if i < self.depth-1:
|
106 |
+
in_blocks.append(
|
107 |
+
BasicDown(
|
108 |
+
spatial_dims=spatial_dims,
|
109 |
+
in_channels=hid_chs[i],
|
110 |
+
out_channels=hid_chs[i],
|
111 |
+
kernel_size=kernel_sizes[i],
|
112 |
+
stride=strides[i],
|
113 |
+
learnable_interpolation=learnable_interpolation
|
114 |
+
)
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
self.in_blocks = nn.ModuleList(in_blocks)
|
119 |
+
|
120 |
+
# ----------- Middle ------------
|
121 |
+
self.middle_block = SequentialEmb(
|
122 |
+
ConvBlock(
|
123 |
+
spatial_dims=spatial_dims,
|
124 |
+
in_channels=hid_chs[-1],
|
125 |
+
out_channels=hid_chs[-1],
|
126 |
+
kernel_size=kernel_sizes[-1],
|
127 |
+
stride=1,
|
128 |
+
norm_name=norm_name,
|
129 |
+
act_name=act_name,
|
130 |
+
dropout=dropout,
|
131 |
+
emb_channels=time_emb_dim
|
132 |
+
),
|
133 |
+
Attention(
|
134 |
+
spatial_dims=spatial_dims,
|
135 |
+
in_channels=hid_chs[-1],
|
136 |
+
out_channels=hid_chs[-1],
|
137 |
+
num_heads=8,
|
138 |
+
ch_per_head=hid_chs[-1]//8,
|
139 |
+
depth=1,
|
140 |
+
norm_name=norm_name,
|
141 |
+
dropout=dropout,
|
142 |
+
emb_dim=time_emb_dim,
|
143 |
+
attention_type=use_attention[-1]
|
144 |
+
),
|
145 |
+
ConvBlock(
|
146 |
+
spatial_dims=spatial_dims,
|
147 |
+
in_channels=hid_chs[-1],
|
148 |
+
out_channels=hid_chs[-1],
|
149 |
+
kernel_size=kernel_sizes[-1],
|
150 |
+
stride=1,
|
151 |
+
norm_name=norm_name,
|
152 |
+
act_name=act_name,
|
153 |
+
dropout=dropout,
|
154 |
+
emb_channels=time_emb_dim
|
155 |
+
)
|
156 |
+
)
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
# ------------ Decoder ----------
|
161 |
+
out_blocks = []
|
162 |
+
for i in range(1, self.depth):
|
163 |
+
for k in range(num_res_blocks+1):
|
164 |
+
seq_list = []
|
165 |
+
out_channels=hid_chs[i-1 if k==0 else i]
|
166 |
+
seq_list.append(
|
167 |
+
ConvBlock(
|
168 |
+
spatial_dims=spatial_dims,
|
169 |
+
in_channels=hid_chs[i]+hid_chs[i-1 if k==0 else i],
|
170 |
+
out_channels=out_channels,
|
171 |
+
kernel_size=kernel_sizes[i],
|
172 |
+
stride=1,
|
173 |
+
norm_name=norm_name,
|
174 |
+
act_name=act_name,
|
175 |
+
dropout=dropout,
|
176 |
+
emb_channels=time_emb_dim
|
177 |
+
)
|
178 |
+
)
|
179 |
+
|
180 |
+
seq_list.append(
|
181 |
+
Attention(
|
182 |
+
spatial_dims=spatial_dims,
|
183 |
+
in_channels=out_channels,
|
184 |
+
out_channels=out_channels,
|
185 |
+
num_heads=8,
|
186 |
+
ch_per_head=out_channels//8,
|
187 |
+
depth=1,
|
188 |
+
norm_name=norm_name,
|
189 |
+
dropout=dropout,
|
190 |
+
emb_dim=time_emb_dim,
|
191 |
+
attention_type=use_attention[i]
|
192 |
+
)
|
193 |
+
)
|
194 |
+
|
195 |
+
if (i >1) and k==0:
|
196 |
+
seq_list.append(
|
197 |
+
BasicUp(
|
198 |
+
spatial_dims=spatial_dims,
|
199 |
+
in_channels=out_channels,
|
200 |
+
out_channels=out_channels,
|
201 |
+
kernel_size=strides[i],
|
202 |
+
stride=strides[i],
|
203 |
+
learnable_interpolation=learnable_interpolation
|
204 |
+
)
|
205 |
+
)
|
206 |
+
|
207 |
+
out_blocks.append(SequentialEmb(*seq_list))
|
208 |
+
self.out_blocks = nn.ModuleList(out_blocks)
|
209 |
+
|
210 |
+
|
211 |
+
# --------------- Out-Convolution ----------------
|
212 |
+
out_ch_hor = out_ch*2 if estimate_variance else out_ch
|
213 |
+
self.outc = zero_module(UnetOutBlock(spatial_dims, hid_chs[0], out_ch_hor, dropout=None))
|
214 |
+
if isinstance(deep_supervision, bool):
|
215 |
+
deep_supervision = self.depth-2 if deep_supervision else 0
|
216 |
+
self.outc_ver = nn.ModuleList([
|
217 |
+
zero_module(UnetOutBlock(spatial_dims, hid_chs[i]+hid_chs[i-1], out_ch, dropout=None) )
|
218 |
+
for i in range(2, deep_supervision+2)
|
219 |
+
])
|
220 |
+
|
221 |
+
|
222 |
+
def forward(self, x_t, t=None, condition=None, self_cond=None):
|
223 |
+
# x_t [B, C, *]
|
224 |
+
# t [B,]
|
225 |
+
# condition [B,]
|
226 |
+
# self_cond [B, C, *]
|
227 |
+
|
228 |
+
|
229 |
+
# -------- Time Embedding (Gloabl) -----------
|
230 |
+
if t is None:
|
231 |
+
time_emb = None
|
232 |
+
else:
|
233 |
+
time_emb = self.time_embedder(t) # [B, C]
|
234 |
+
|
235 |
+
# -------- Condition Embedding (Gloabl) -----------
|
236 |
+
if (condition is None) or (self.cond_embedder is None):
|
237 |
+
cond_emb = None
|
238 |
+
else:
|
239 |
+
cond_emb = self.cond_embedder(condition) # [B, C]
|
240 |
+
|
241 |
+
emb = save_add(time_emb, cond_emb)
|
242 |
+
|
243 |
+
# ---------- Self-conditioning-----------
|
244 |
+
if self.use_self_conditioning:
|
245 |
+
self_cond = torch.zeros_like(x_t) if self_cond is None else x_t
|
246 |
+
x_t = torch.cat([x_t, self_cond], dim=1)
|
247 |
+
|
248 |
+
# --------- Encoder --------------
|
249 |
+
x = [self.in_conv(x_t)]
|
250 |
+
for i in range(len(self.in_blocks)):
|
251 |
+
x.append(self.in_blocks[i](x[i], emb))
|
252 |
+
|
253 |
+
# ---------- Middle --------------
|
254 |
+
h = self.middle_block(x[-1], emb)
|
255 |
+
|
256 |
+
# -------- Decoder -----------
|
257 |
+
y_ver = []
|
258 |
+
for i in range(len(self.out_blocks), 0, -1):
|
259 |
+
h = torch.cat([h, x.pop()], dim=1)
|
260 |
+
|
261 |
+
depth, j = i//(self.num_res_blocks+1), i%(self.num_res_blocks+1)-1
|
262 |
+
y_ver.append(self.outc_ver[depth-1](h)) if (len(self.outc_ver)>=depth>0) and (j==0) else None
|
263 |
+
|
264 |
+
h = self.out_blocks[i-1](h, emb)
|
265 |
+
|
266 |
+
# ---------Out-Convolution ------------
|
267 |
+
y = self.outc(h)
|
268 |
+
|
269 |
+
return y, y_ver[::-1]
|
270 |
+
|
271 |
+
|
272 |
+
|
273 |
+
|
274 |
+
if __name__=='__main__':
|
275 |
+
model = UNet(in_ch=3, use_res_block=False, learnable_interpolation=False)
|
276 |
+
input = torch.randn((1,3,16,32,32))
|
277 |
+
time = torch.randn((1,))
|
278 |
+
out_hor, out_ver = model(input, time)
|
279 |
+
print(out_hor[0].shape)
|
medical_diffusion/models/model_base.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from pathlib import Path
|
3 |
+
import json
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
10 |
+
from pytorch_lightning.utilities.migration import pl_legacy_patch
|
11 |
+
|
12 |
+
class VeryBasicModel(pl.LightningModule):
|
13 |
+
def __init__(self):
|
14 |
+
super().__init__()
|
15 |
+
self.save_hyperparameters()
|
16 |
+
self._step_train = 0
|
17 |
+
self._step_val = 0
|
18 |
+
self._step_test = 0
|
19 |
+
|
20 |
+
|
21 |
+
def forward(self, x_in):
|
22 |
+
raise NotImplementedError
|
23 |
+
|
24 |
+
def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
|
25 |
+
raise NotImplementedError
|
26 |
+
|
27 |
+
def training_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0 ):
|
28 |
+
self._step_train += 1 # =self.global_step
|
29 |
+
return self._step(batch, batch_idx, "train", self._step_train, optimizer_idx)
|
30 |
+
|
31 |
+
def validation_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0):
|
32 |
+
self._step_val += 1
|
33 |
+
return self._step(batch, batch_idx, "val", self._step_val, optimizer_idx )
|
34 |
+
|
35 |
+
def test_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0):
|
36 |
+
self._step_test += 1
|
37 |
+
return self._step(batch, batch_idx, "test", self._step_test, optimizer_idx)
|
38 |
+
|
39 |
+
def _epoch_end(self, outputs: list, state: str):
|
40 |
+
return
|
41 |
+
|
42 |
+
def training_epoch_end(self, outputs):
|
43 |
+
self._epoch_end(outputs, "train")
|
44 |
+
|
45 |
+
def validation_epoch_end(self, outputs):
|
46 |
+
self._epoch_end(outputs, "val")
|
47 |
+
|
48 |
+
def test_epoch_end(self, outputs):
|
49 |
+
self._epoch_end(outputs, "test")
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
def save_best_checkpoint(cls, path_checkpoint_dir, best_model_path):
|
53 |
+
with open(Path(path_checkpoint_dir) / 'best_checkpoint.json', 'w') as f:
|
54 |
+
json.dump({'best_model_epoch': Path(best_model_path).name}, f)
|
55 |
+
|
56 |
+
@classmethod
|
57 |
+
def _get_best_checkpoint_path(cls, path_checkpoint_dir, version=0, **kwargs):
|
58 |
+
path_version = 'lightning_logs/version_'+str(version)
|
59 |
+
with open(Path(path_checkpoint_dir) / path_version/ 'best_checkpoint.json', 'r') as f:
|
60 |
+
path_rel_best_checkpoint = Path(json.load(f)['best_model_epoch'])
|
61 |
+
return Path(path_checkpoint_dir)/path_rel_best_checkpoint
|
62 |
+
|
63 |
+
@classmethod
|
64 |
+
def load_best_checkpoint(cls, path_checkpoint_dir, version=0, **kwargs):
|
65 |
+
path_best_checkpoint = cls._get_best_checkpoint_path(path_checkpoint_dir, version)
|
66 |
+
return cls.load_from_checkpoint(path_best_checkpoint, **kwargs)
|
67 |
+
|
68 |
+
def load_pretrained(self, checkpoint_path, map_location=None, **kwargs):
|
69 |
+
if checkpoint_path.is_dir():
|
70 |
+
checkpoint_path = self._get_best_checkpoint_path(checkpoint_path, **kwargs)
|
71 |
+
|
72 |
+
with pl_legacy_patch():
|
73 |
+
if map_location is not None:
|
74 |
+
checkpoint = pl_load(checkpoint_path, map_location=map_location)
|
75 |
+
else:
|
76 |
+
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
|
77 |
+
return self.load_weights(checkpoint["state_dict"], **kwargs)
|
78 |
+
|
79 |
+
def load_weights(self, pretrained_weights, strict=True, **kwargs):
|
80 |
+
filter = kwargs.get('filter', lambda key:key in pretrained_weights)
|
81 |
+
init_weights = self.state_dict()
|
82 |
+
pretrained_weights = {key: value for key, value in pretrained_weights.items() if filter(key)}
|
83 |
+
init_weights.update(pretrained_weights)
|
84 |
+
self.load_state_dict(init_weights, strict=strict)
|
85 |
+
return self
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
class BasicModel(VeryBasicModel):
|
91 |
+
def __init__(self,
|
92 |
+
optimizer=torch.optim.AdamW,
|
93 |
+
optimizer_kwargs={'lr':1e-3, 'weight_decay':1e-2},
|
94 |
+
lr_scheduler= None,
|
95 |
+
lr_scheduler_kwargs={},
|
96 |
+
):
|
97 |
+
super().__init__()
|
98 |
+
self.save_hyperparameters()
|
99 |
+
self.optimizer = optimizer
|
100 |
+
self.optimizer_kwargs = optimizer_kwargs
|
101 |
+
self.lr_scheduler = lr_scheduler
|
102 |
+
self.lr_scheduler_kwargs = lr_scheduler_kwargs
|
103 |
+
|
104 |
+
def configure_optimizers(self):
|
105 |
+
optimizer = self.optimizer(self.parameters(), **self.optimizer_kwargs)
|
106 |
+
if self.lr_scheduler is not None:
|
107 |
+
lr_scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs)
|
108 |
+
return [optimizer], [lr_scheduler]
|
109 |
+
else:
|
110 |
+
return [optimizer]
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
|
medical_diffusion/models/noise_schedulers/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .scheduler_base import BasicNoiseScheduler
|
2 |
+
from .gaussian_scheduler import GaussianNoiseScheduler
|
medical_diffusion/models/noise_schedulers/gaussian_scheduler.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
from medical_diffusion.models.noise_schedulers import BasicNoiseScheduler
|
7 |
+
|
8 |
+
class GaussianNoiseScheduler(BasicNoiseScheduler):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
timesteps=1000,
|
12 |
+
T = None,
|
13 |
+
schedule_strategy='cosine',
|
14 |
+
beta_start = 0.0001, # default 1e-4, stable-diffusion ~ 1e-3
|
15 |
+
beta_end = 0.02,
|
16 |
+
betas = None,
|
17 |
+
):
|
18 |
+
super().__init__(timesteps, T)
|
19 |
+
|
20 |
+
self.schedule_strategy = schedule_strategy
|
21 |
+
|
22 |
+
if betas is not None:
|
23 |
+
betas = torch.as_tensor(betas, dtype = torch.float64)
|
24 |
+
elif schedule_strategy == "linear":
|
25 |
+
betas = torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
|
26 |
+
elif schedule_strategy == "scaled_linear": # proposed as "quadratic" in https://arxiv.org/abs/2006.11239, used in stable-diffusion
|
27 |
+
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64)**2
|
28 |
+
elif schedule_strategy == "cosine":
|
29 |
+
s = 0.008
|
30 |
+
x = torch.linspace(0, timesteps, timesteps + 1, dtype = torch.float64) # [0, T]
|
31 |
+
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
32 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
33 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
34 |
+
betas = torch.clip(betas, 0, 0.999)
|
35 |
+
else:
|
36 |
+
raise NotImplementedError(f"{schedule_strategy} does is not implemented for {self.__class__}")
|
37 |
+
|
38 |
+
|
39 |
+
alphas = 1-betas
|
40 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
41 |
+
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
42 |
+
|
43 |
+
|
44 |
+
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
45 |
+
|
46 |
+
register_buffer('betas', betas) # (0 , 1)
|
47 |
+
|
48 |
+
register_buffer('alphas', alphas) # (1 , 0)
|
49 |
+
register_buffer('alphas_cumprod', alphas_cumprod)
|
50 |
+
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
51 |
+
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
52 |
+
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
53 |
+
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
54 |
+
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
55 |
+
|
56 |
+
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
57 |
+
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
58 |
+
register_buffer('posterior_variance', betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod))
|
59 |
+
|
60 |
+
|
61 |
+
def estimate_x_t(self, x_0, t, x_T=None):
|
62 |
+
# NOTE: t == 0 means diffused for 1 step (https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils.py#L108)
|
63 |
+
# NOTE: t == 0 means not diffused for cold-diffusion (in contradiction to the above comment) https://github.com/arpitbansal297/Cold-Diffusion-Models/blob/c828140b7047ca22f995b99fbcda360bc30fc25d/denoising-diffusion-pytorch/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L361
|
64 |
+
x_T = self.x_final(x_0) if x_T is None else x_T
|
65 |
+
# ndim = x_0.ndim
|
66 |
+
# x_t = (self.extract(self.sqrt_alphas_cumprod, t, ndim)*x_0 +
|
67 |
+
# self.extract(self.sqrt_one_minus_alphas_cumprod, t, ndim)*x_T)
|
68 |
+
def clipper(b):
|
69 |
+
tb = t[b]
|
70 |
+
if tb<0:
|
71 |
+
return x_0[b]
|
72 |
+
elif tb>=self.T:
|
73 |
+
return x_T[b]
|
74 |
+
else:
|
75 |
+
return self.sqrt_alphas_cumprod[tb]*x_0[b]+self.sqrt_one_minus_alphas_cumprod[tb]*x_T[b]
|
76 |
+
x_t = torch.stack([clipper(b) for b in range(t.shape[0])])
|
77 |
+
return x_t
|
78 |
+
|
79 |
+
|
80 |
+
def estimate_x_t_prior_from_x_T(self, x_t, t, x_T, use_log=True, clip_x0=True, var_scale=0, cold_diffusion=False):
|
81 |
+
x_0 = self.estimate_x_0(x_t, x_T, t, clip_x0)
|
82 |
+
return self.estimate_x_t_prior_from_x_0(x_t, t, x_0, use_log, clip_x0, var_scale, cold_diffusion)
|
83 |
+
|
84 |
+
|
85 |
+
def estimate_x_t_prior_from_x_0(self, x_t, t, x_0, use_log=True, clip_x0=True, var_scale=0, cold_diffusion=False):
|
86 |
+
x_0 = self._clip_x_0(x_0) if clip_x0 else x_0
|
87 |
+
|
88 |
+
if cold_diffusion: # see https://arxiv.org/abs/2208.09392
|
89 |
+
x_T_est = self.estimate_x_T(x_t, x_0, t) # or use x_T estimated by UNet if available?
|
90 |
+
x_t_est = self.estimate_x_t(x_0, t, x_T=x_T_est)
|
91 |
+
x_t_prior = self.estimate_x_t(x_0, t-1, x_T=x_T_est)
|
92 |
+
noise_t = x_t_est-x_t_prior
|
93 |
+
x_t_prior = x_t-noise_t
|
94 |
+
else:
|
95 |
+
mean = self.estimate_mean_t(x_t, x_0, t)
|
96 |
+
variance = self.estimate_variance_t(t, x_t.ndim, use_log, var_scale)
|
97 |
+
std = torch.exp(0.5*variance) if use_log else torch.sqrt(variance)
|
98 |
+
std[t==0] = 0.0
|
99 |
+
x_T = self.x_final(x_t)
|
100 |
+
x_t_prior = mean+std*x_T
|
101 |
+
return x_t_prior, x_0
|
102 |
+
|
103 |
+
|
104 |
+
def estimate_mean_t(self, x_t, x_0, t):
|
105 |
+
ndim = x_t.ndim
|
106 |
+
return (self.extract(self.posterior_mean_coef1, t, ndim)*x_0+
|
107 |
+
self.extract(self.posterior_mean_coef2, t, ndim)*x_t)
|
108 |
+
|
109 |
+
|
110 |
+
def estimate_variance_t(self, t, ndim, log=True, var_scale=0, eps=1e-20):
|
111 |
+
min_variance = self.extract(self.posterior_variance, t, ndim)
|
112 |
+
max_variance = self.extract(self.betas, t, ndim)
|
113 |
+
if log:
|
114 |
+
min_variance = torch.log(min_variance.clamp(min=eps))
|
115 |
+
max_variance = torch.log(max_variance.clamp(min=eps))
|
116 |
+
return var_scale * max_variance + (1 - var_scale) * min_variance
|
117 |
+
|
118 |
+
|
119 |
+
def estimate_x_0(self, x_t, x_T, t, clip_x0=True):
|
120 |
+
ndim = x_t.ndim
|
121 |
+
x_0 = (self.extract(self.sqrt_recip_alphas_cumprod, t, ndim)*x_t -
|
122 |
+
self.extract(self.sqrt_recipm1_alphas_cumprod, t, ndim)*x_T)
|
123 |
+
x_0 = self._clip_x_0(x_0) if clip_x0 else x_0
|
124 |
+
return x_0
|
125 |
+
|
126 |
+
|
127 |
+
def estimate_x_T(self, x_t, x_0, t, clip_x0=True):
|
128 |
+
ndim = x_t.ndim
|
129 |
+
x_0 = self._clip_x_0(x_0) if clip_x0 else x_0
|
130 |
+
return ((self.extract(self.sqrt_recip_alphas_cumprod, t, ndim)*x_t - x_0)/
|
131 |
+
self.extract(self.sqrt_recipm1_alphas_cumprod, t, ndim))
|
132 |
+
|
133 |
+
|
134 |
+
@classmethod
|
135 |
+
def x_final(cls, x):
|
136 |
+
return torch.randn_like(x)
|
137 |
+
|
138 |
+
@classmethod
|
139 |
+
def _clip_x_0(cls, x_0):
|
140 |
+
# See "static/dynamic thresholding" in Imagen https://arxiv.org/abs/2205.11487
|
141 |
+
|
142 |
+
# "static thresholding"
|
143 |
+
m = 1 # Set this to about 4*sigma = 4 if latent diffusion is used
|
144 |
+
x_0 = x_0.clamp(-m, m)
|
145 |
+
|
146 |
+
# "dynamic thresholding"
|
147 |
+
# r = torch.stack([torch.quantile(torch.abs(x_0_b), 0.997) for x_0_b in x_0])
|
148 |
+
# r = torch.maximum(r, torch.full_like(r,m))
|
149 |
+
# x_0 = torch.stack([x_0_b.clamp(-r_b, r_b)/r_b*m for x_0_b, r_b in zip(x_0, r) ] )
|
150 |
+
|
151 |
+
return x_0
|
152 |
+
|
153 |
+
|
154 |
+
|
medical_diffusion/models/noise_schedulers/scheduler_base.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class BasicNoiseScheduler(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
timesteps=1000,
|
11 |
+
T=None,
|
12 |
+
):
|
13 |
+
super().__init__()
|
14 |
+
self.timesteps = timesteps
|
15 |
+
self.T = timesteps if T is None else T
|
16 |
+
|
17 |
+
self.register_buffer('timesteps_array', torch.linspace(0, self.T-1, self.timesteps, dtype=torch.long)) # NOTE: End is inclusive therefore use -1 to get [0, T-1]
|
18 |
+
|
19 |
+
def __len__(self):
|
20 |
+
return len(self.timesteps)
|
21 |
+
|
22 |
+
def sample(self, x_0):
|
23 |
+
"""Randomly sample t from [0,T] and return x_t and x_T based on x_0"""
|
24 |
+
t = torch.randint(0, self.T, (x_0.shape[0],), dtype=torch.long, device=x_0.device) # NOTE: High is exclusive, therefore [0, T-1]
|
25 |
+
x_T = self.x_final(x_0)
|
26 |
+
return self.estimate_x_t(x_0, t, x_T), x_T, t
|
27 |
+
|
28 |
+
def estimate_x_t_prior_from_x_T(self, x_T, t, **kwargs):
|
29 |
+
raise NotImplemented
|
30 |
+
|
31 |
+
def estimate_x_t_prior_from_x_0(self, x_0, t, **kwargs):
|
32 |
+
raise NotImplemented
|
33 |
+
|
34 |
+
def estimate_x_t(self, x_0, t, x_T=None, **kwargs):
|
35 |
+
"""Get x_t at time t"""
|
36 |
+
raise NotImplemented
|
37 |
+
|
38 |
+
@classmethod
|
39 |
+
def x_final(cls, x):
|
40 |
+
"""Get noise that should be obtained for t->T """
|
41 |
+
raise NotImplemented
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def extract(x, t, ndim):
|
45 |
+
"""Extract values from x at t and reshape them to n-dim tensor"""
|
46 |
+
return x.gather(0, t).reshape(-1, *((1,)*(ndim-1)))
|
47 |
+
|
48 |
+
|
49 |
+
|
medical_diffusion/models/pipelines/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .diffusion_pipeline import DiffusionPipeline
|
medical_diffusion/models/pipelines/diffusion_pipeline.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torchvision.utils import save_image
|
9 |
+
import streamlit as st
|
10 |
+
|
11 |
+
from medical_diffusion.models import BasicModel
|
12 |
+
from medical_diffusion.utils.train_utils import EMAModel
|
13 |
+
from medical_diffusion.utils.math_utils import kl_gaussians
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
class DiffusionPipeline(BasicModel):
|
21 |
+
def __init__(self,
|
22 |
+
noise_scheduler,
|
23 |
+
noise_estimator,
|
24 |
+
latent_embedder=None,
|
25 |
+
noise_scheduler_kwargs={},
|
26 |
+
noise_estimator_kwargs={},
|
27 |
+
latent_embedder_checkpoint='',
|
28 |
+
estimator_objective = 'x_T', # 'x_T' or 'x_0'
|
29 |
+
estimate_variance=False,
|
30 |
+
use_self_conditioning=False,
|
31 |
+
classifier_free_guidance_dropout=0.5, # Probability to drop condition during training, has only an effect for label-conditioned training
|
32 |
+
num_samples = 4,
|
33 |
+
do_input_centering = True, # Only for training
|
34 |
+
clip_x0=True, # Has only an effect during traing if use_self_conditioning=True, import for inference/sampling
|
35 |
+
use_ema = False,
|
36 |
+
ema_kwargs = {},
|
37 |
+
optimizer=torch.optim.AdamW,
|
38 |
+
optimizer_kwargs={'lr':1e-4}, # stable-diffusion ~ 1e-4
|
39 |
+
lr_scheduler= None, # stable-diffusion - LambdaLR
|
40 |
+
lr_scheduler_kwargs={},
|
41 |
+
loss=torch.nn.L1Loss,
|
42 |
+
loss_kwargs={},
|
43 |
+
sample_every_n_steps = 1000
|
44 |
+
):
|
45 |
+
# self.save_hyperparameters(ignore=['noise_estimator', 'noise_scheduler'])
|
46 |
+
super().__init__(optimizer, optimizer_kwargs, lr_scheduler, lr_scheduler_kwargs)
|
47 |
+
self.loss_fct = loss(**loss_kwargs)
|
48 |
+
self.sample_every_n_steps=sample_every_n_steps
|
49 |
+
|
50 |
+
noise_estimator_kwargs['estimate_variance'] = estimate_variance
|
51 |
+
noise_estimator_kwargs['use_self_conditioning'] = use_self_conditioning
|
52 |
+
|
53 |
+
self.noise_scheduler = noise_scheduler(**noise_scheduler_kwargs)
|
54 |
+
self.noise_estimator = noise_estimator(**noise_estimator_kwargs)
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
if latent_embedder is not None:
|
58 |
+
self.latent_embedder = latent_embedder.load_from_checkpoint(latent_embedder_checkpoint)
|
59 |
+
for param in self.latent_embedder.parameters():
|
60 |
+
param.requires_grad = False
|
61 |
+
else:
|
62 |
+
self.latent_embedder = None
|
63 |
+
|
64 |
+
self.estimator_objective = estimator_objective
|
65 |
+
self.use_self_conditioning = use_self_conditioning
|
66 |
+
self.num_samples = num_samples
|
67 |
+
self.classifier_free_guidance_dropout = classifier_free_guidance_dropout
|
68 |
+
self.do_input_centering = do_input_centering
|
69 |
+
self.estimate_variance = estimate_variance
|
70 |
+
self.clip_x0 = clip_x0
|
71 |
+
|
72 |
+
self.use_ema = use_ema
|
73 |
+
if use_ema:
|
74 |
+
self.ema_model = EMAModel(self.noise_estimator, **ema_kwargs)
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
|
79 |
+
results = {}
|
80 |
+
x_0 = batch['source']
|
81 |
+
condition = batch.get('target', None)
|
82 |
+
|
83 |
+
# Embed into latent space or normalize
|
84 |
+
if self.latent_embedder is not None:
|
85 |
+
self.latent_embedder.eval()
|
86 |
+
with torch.no_grad():
|
87 |
+
x_0 = self.latent_embedder.encode(x_0)
|
88 |
+
|
89 |
+
if self.do_input_centering:
|
90 |
+
x_0 = 2*x_0-1 # [0, 1] -> [-1, 1]
|
91 |
+
|
92 |
+
# if self.clip_x0:
|
93 |
+
# x_0 = torch.clamp(x_0, -1, 1)
|
94 |
+
|
95 |
+
|
96 |
+
# Sample Noise
|
97 |
+
with torch.no_grad():
|
98 |
+
# Randomly selecting t [0,T-1] and compute x_t (noisy version of x_0 at t)
|
99 |
+
x_t, x_T, t = self.noise_scheduler.sample(x_0)
|
100 |
+
|
101 |
+
# Use EMA Model
|
102 |
+
if self.use_ema and (state != 'train'):
|
103 |
+
noise_estimator = self.ema_model.averaged_model
|
104 |
+
else:
|
105 |
+
noise_estimator = self.noise_estimator
|
106 |
+
|
107 |
+
# Re-estimate x_T or x_0, self-conditioned on previous estimate
|
108 |
+
self_cond = None
|
109 |
+
if self.use_self_conditioning:
|
110 |
+
with torch.no_grad():
|
111 |
+
pred, pred_vertical = noise_estimator(x_t, t, condition, None)
|
112 |
+
if self.estimate_variance:
|
113 |
+
pred, _ = pred.chunk(2, dim = 1) # Seperate actual prediction and variance estimation
|
114 |
+
if self.estimator_objective == "x_T": # self condition on x_0
|
115 |
+
self_cond = self.noise_scheduler.estimate_x_0(x_t, pred, t=t, clip_x0=self.clip_x0)
|
116 |
+
elif self.estimator_objective == "x_0": # self condition on x_T
|
117 |
+
self_cond = self.noise_scheduler.estimate_x_T(x_t, pred, t=t, clip_x0=self.clip_x0)
|
118 |
+
else:
|
119 |
+
raise NotImplementedError(f"Option estimator_target={self.estimator_objective} not supported.")
|
120 |
+
|
121 |
+
# Classifier free guidance
|
122 |
+
if torch.rand(1)<self.classifier_free_guidance_dropout:
|
123 |
+
condition = None
|
124 |
+
|
125 |
+
# Run Denoise
|
126 |
+
pred, pred_vertical = noise_estimator(x_t, t, condition, self_cond)
|
127 |
+
|
128 |
+
# Separate variance (scale) if it was learned
|
129 |
+
if self.estimate_variance:
|
130 |
+
pred, pred_var = pred.chunk(2, dim = 1) # Separate actual prediction and variance estimation
|
131 |
+
|
132 |
+
# Specify target
|
133 |
+
if self.estimator_objective == "x_T":
|
134 |
+
target = x_T
|
135 |
+
elif self.estimator_objective == "x_0":
|
136 |
+
target = x_0
|
137 |
+
else:
|
138 |
+
raise NotImplementedError(f"Option estimator_target={self.estimator_objective} not supported.")
|
139 |
+
|
140 |
+
|
141 |
+
# ------------------------- Compute Loss ---------------------------
|
142 |
+
interpolation_mode = 'area'
|
143 |
+
loss = 0
|
144 |
+
weights = [1/2**i for i in range(1+len(pred_vertical))] # horizontal (equal) + vertical (reducing with every step down)
|
145 |
+
tot_weight = sum(weights)
|
146 |
+
weights = [w/tot_weight for w in weights]
|
147 |
+
|
148 |
+
# ----------------- MSE/L1, ... ----------------------
|
149 |
+
loss += self.loss_fct(pred, target)*weights[0]
|
150 |
+
|
151 |
+
# ----------------- Variance Loss --------------
|
152 |
+
if self.estimate_variance:
|
153 |
+
# var_scale = var_scale.clamp(-1, 1) # Should not be necessary
|
154 |
+
var_scale = (pred_var+1)/2 # Assumed to be in [-1, 1] -> [0, 1]
|
155 |
+
pred_logvar = self.noise_scheduler.estimate_variance_t(t, x_t.ndim, log=True, var_scale=var_scale)
|
156 |
+
# pred_logvar = pred_var # If variance is estimated directly
|
157 |
+
|
158 |
+
if self.estimator_objective == 'x_T':
|
159 |
+
pred_x_0 = self.noise_scheduler.estimate_x_0(x_t, x_T, t, clip_x0=self.clip_x0)
|
160 |
+
elif self.estimator_objective == "x_0":
|
161 |
+
pred_x_0 = pred
|
162 |
+
else:
|
163 |
+
raise NotImplementedError()
|
164 |
+
|
165 |
+
with torch.no_grad():
|
166 |
+
pred_mean = self.noise_scheduler.estimate_mean_t(x_t, pred_x_0, t)
|
167 |
+
true_mean = self.noise_scheduler.estimate_mean_t(x_t, x_0, t)
|
168 |
+
true_logvar = self.noise_scheduler.estimate_variance_t(t, x_t.ndim, log=True, var_scale=0)
|
169 |
+
|
170 |
+
kl_loss = torch.mean(kl_gaussians(true_mean, true_logvar, pred_mean, pred_logvar), dim=list(range(1, x_0.ndim)))
|
171 |
+
nnl_loss = torch.mean(F.gaussian_nll_loss(pred_x_0, x_0, torch.exp(pred_logvar), reduction='none'), dim=list(range(1, x_0.ndim)))
|
172 |
+
var_loss = torch.mean(torch.where(t == 0, nnl_loss, kl_loss))
|
173 |
+
loss += var_loss
|
174 |
+
|
175 |
+
results['variance_scale'] = torch.mean(var_scale)
|
176 |
+
results['variance_loss'] = var_loss
|
177 |
+
|
178 |
+
|
179 |
+
# ----------------------------- Deep Supervision -------------------------
|
180 |
+
for i, pred_i in enumerate(pred_vertical):
|
181 |
+
target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
|
182 |
+
loss += self.loss_fct(pred_i, target_i)*weights[i+1]
|
183 |
+
results['loss'] = loss
|
184 |
+
|
185 |
+
|
186 |
+
|
187 |
+
# --------------------- Compute Metrics -------------------------------
|
188 |
+
with torch.no_grad():
|
189 |
+
results['L2'] = F.mse_loss(pred, target)
|
190 |
+
results['L1'] = F.l1_loss(pred, target)
|
191 |
+
# results['SSIM'] = SSIMMetric(data_range=pred.max()-pred.min(), spatial_dims=source.ndim-2)(pred, target)
|
192 |
+
|
193 |
+
# for i, pred_i in enumerate(pred_vertical):
|
194 |
+
# target_i = F.interpolate(target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None)
|
195 |
+
# results[f'L1_{i}'] = F.l1_loss(pred_i, target_i).detach()
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
# ----------------- Log Scalars ----------------------
|
200 |
+
for metric_name, metric_val in results.items():
|
201 |
+
self.log(f"{state}/{metric_name}", metric_val, batch_size=x_0.shape[0], on_step=True, on_epoch=True)
|
202 |
+
|
203 |
+
|
204 |
+
#------------------ Log Image -----------------------
|
205 |
+
if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0:
|
206 |
+
dataformats = 'NHWC' if x_0.ndim == 5 else 'HWC'
|
207 |
+
def norm(x):
|
208 |
+
return (x-x.min())/(x.max()-x.min())
|
209 |
+
|
210 |
+
sample_cond = condition[0:self.num_samples] if condition is not None else None
|
211 |
+
sample_img = self.sample(num_samples=self.num_samples, img_size=x_0.shape[1:], condition=sample_cond).detach()
|
212 |
+
|
213 |
+
log_step = self.global_step // self.sample_every_n_steps
|
214 |
+
# self.logger.experiment.add_images("predict_img", norm(torch.moveaxis(pred[0,-1:], 0,-1)), global_step=self.current_epoch, dataformats=dataformats)
|
215 |
+
# self.logger.experiment.add_images("target_img", norm(torch.moveaxis(target[0,-1:], 0,-1)), global_step=self.current_epoch, dataformats=dataformats)
|
216 |
+
|
217 |
+
# self.logger.experiment.add_images("source_img", norm(torch.moveaxis(x_0[0,-1:], 0,-1)), global_step=log_step, dataformats=dataformats)
|
218 |
+
# self.logger.experiment.add_images("sample_img", norm(torch.moveaxis(sample_img[0,-1:], 0,-1)), global_step=log_step, dataformats=dataformats)
|
219 |
+
|
220 |
+
path_out = Path(self.logger.log_dir)/'images'
|
221 |
+
path_out.mkdir(parents=True, exist_ok=True)
|
222 |
+
# for 3D images use depth as batch :[D, C, H, W], never show more than 32 images
|
223 |
+
def depth2batch(image):
|
224 |
+
return (image if image.ndim<5 else torch.swapaxes(image[0], 0, 1))
|
225 |
+
images = depth2batch(sample_img)[:32]
|
226 |
+
save_image(images, path_out/f'sample_{log_step}.png', normalize=True)
|
227 |
+
|
228 |
+
|
229 |
+
return loss
|
230 |
+
|
231 |
+
|
232 |
+
def forward(self, x_t, t, condition=None, self_cond=None, guidance_scale=1.0, cold_diffusion=False, un_cond=None):
|
233 |
+
# Note: x_t expected to be in range ~ [-1, 1]
|
234 |
+
if self.use_ema:
|
235 |
+
noise_estimator = self.ema_model.averaged_model
|
236 |
+
else:
|
237 |
+
noise_estimator = self.noise_estimator
|
238 |
+
|
239 |
+
# Concatenate inputs for guided and unguided diffusion as proposed by classifier-free-guidance
|
240 |
+
if (condition is not None) and (guidance_scale != 1.0):
|
241 |
+
# Model prediction
|
242 |
+
pred_uncond, _ = noise_estimator(x_t, t, condition=un_cond, self_cond=self_cond)
|
243 |
+
pred_cond, _ = noise_estimator(x_t, t, condition=condition, self_cond=self_cond)
|
244 |
+
pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
|
245 |
+
|
246 |
+
if self.estimate_variance:
|
247 |
+
pred_uncond, pred_var_uncond = pred_uncond.chunk(2, dim = 1)
|
248 |
+
pred_cond, pred_var_cond = pred_cond.chunk(2, dim = 1)
|
249 |
+
pred_var = pred_var_uncond + guidance_scale * (pred_var_cond - pred_var_uncond)
|
250 |
+
else:
|
251 |
+
pred, _ = noise_estimator(x_t, t, condition=condition, self_cond=self_cond)
|
252 |
+
if self.estimate_variance:
|
253 |
+
pred, pred_var = pred.chunk(2, dim = 1)
|
254 |
+
|
255 |
+
if self.estimate_variance:
|
256 |
+
pred_var_scale = pred_var/2+0.5 # [-1, 1] -> [0, 1]
|
257 |
+
pred_var_value = pred_var
|
258 |
+
else:
|
259 |
+
pred_var_scale = 0
|
260 |
+
pred_var_value = None
|
261 |
+
|
262 |
+
# pred_var_scale = pred_var_scale.clamp(0, 1)
|
263 |
+
|
264 |
+
if self.estimator_objective == 'x_0':
|
265 |
+
x_t_prior, x_0 = self.noise_scheduler.estimate_x_t_prior_from_x_0(x_t, t, pred, clip_x0=self.clip_x0, var_scale=pred_var_scale, cold_diffusion=cold_diffusion)
|
266 |
+
x_T = self.noise_scheduler.estimate_x_T(x_t, x_0=pred, t=t, clip_x0=self.clip_x0)
|
267 |
+
self_cond = x_T
|
268 |
+
elif self.estimator_objective == 'x_T':
|
269 |
+
x_t_prior, x_0 = self.noise_scheduler.estimate_x_t_prior_from_x_T(x_t, t, pred, clip_x0=self.clip_x0, var_scale=pred_var_scale, cold_diffusion=cold_diffusion)
|
270 |
+
x_T = pred
|
271 |
+
self_cond = x_0
|
272 |
+
else:
|
273 |
+
raise ValueError("Unknown Objective")
|
274 |
+
|
275 |
+
return x_t_prior, x_0, x_T, self_cond
|
276 |
+
|
277 |
+
|
278 |
+
@torch.no_grad()
|
279 |
+
def denoise(self, x_t, steps=None, condition=None, use_ddim=True, **kwargs):
|
280 |
+
self_cond = None
|
281 |
+
|
282 |
+
# ---------- run denoise loop ---------------
|
283 |
+
if use_ddim:
|
284 |
+
steps = self.noise_scheduler.timesteps if steps is None else steps
|
285 |
+
timesteps_array = torch.linspace(0, self.noise_scheduler.T-1, steps, dtype=torch.long, device=x_t.device) # [0, 1, 2, ..., T-1] if steps = T
|
286 |
+
else:
|
287 |
+
timesteps_array = self.noise_scheduler.timesteps_array[slice(0, steps)] # [0, ...,T-1] (target time not time of x_t)
|
288 |
+
|
289 |
+
st_prog_bar = st.progress(0)
|
290 |
+
for i, t in tqdm(enumerate(reversed(timesteps_array))):
|
291 |
+
st_prog_bar.progress((i+1)/len(timesteps_array))
|
292 |
+
|
293 |
+
# UNet prediction
|
294 |
+
x_t, x_0, x_T, self_cond = self(x_t, t.expand(x_t.shape[0]), condition, self_cond=self_cond, **kwargs)
|
295 |
+
self_cond = self_cond if self.use_self_conditioning else None
|
296 |
+
|
297 |
+
if use_ddim and (steps-i-1>0):
|
298 |
+
t_next = timesteps_array[steps-i-2]
|
299 |
+
alpha = self.noise_scheduler.alphas_cumprod[t]
|
300 |
+
alpha_next = self.noise_scheduler.alphas_cumprod[t_next]
|
301 |
+
sigma = kwargs.get('eta', 1) * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
302 |
+
c = (1 - alpha_next - sigma ** 2).sqrt()
|
303 |
+
noise = torch.randn_like(x_t)
|
304 |
+
x_t = x_0 * alpha_next.sqrt() + c * x_T + sigma * noise
|
305 |
+
|
306 |
+
# ------ Eventually decode from latent space into image space--------
|
307 |
+
if self.latent_embedder is not None:
|
308 |
+
x_t = self.latent_embedder.decode(x_t)
|
309 |
+
|
310 |
+
return x_t # Should be x_0 in final step (t=0)
|
311 |
+
|
312 |
+
@torch.no_grad()
|
313 |
+
def sample(self, num_samples, img_size, condition=None, **kwargs):
|
314 |
+
template = torch.zeros((num_samples, *img_size), device=self.device)
|
315 |
+
x_T = self.noise_scheduler.x_final(template)
|
316 |
+
x_0 = self.denoise(x_T, condition=condition, **kwargs)
|
317 |
+
return x_0
|
318 |
+
|
319 |
+
|
320 |
+
@torch.no_grad()
|
321 |
+
def interpolate(self, img1, img2, i = None, condition=None, lam = 0.5, **kwargs):
|
322 |
+
assert img1.shape == img2.shape, "Image 1 and 2 must have equal shape"
|
323 |
+
|
324 |
+
t = self.noise_scheduler.T-1 if i is None else i
|
325 |
+
t = torch.full(img1.shape[:1], i, device=img1.device)
|
326 |
+
|
327 |
+
img1_t = self.noise_scheduler.estimate_x_t(img1, t=t, clip_x0=self.clip_x0)
|
328 |
+
img2_t = self.noise_scheduler.estimate_x_t(img2, t=t, clip_x0=self.clip_x0)
|
329 |
+
|
330 |
+
img = (1 - lam) * img1_t + lam * img2_t
|
331 |
+
img = self.denoise(img, i, condition, **kwargs)
|
332 |
+
return img
|
333 |
+
|
334 |
+
def on_train_batch_end(self, *args, **kwargs):
|
335 |
+
if self.use_ema:
|
336 |
+
self.ema_model.step(self.noise_estimator)
|
337 |
+
|
338 |
+
def configure_optimizers(self):
|
339 |
+
optimizer = self.optimizer(self.noise_estimator.parameters(), **self.optimizer_kwargs)
|
340 |
+
if self.lr_scheduler is not None:
|
341 |
+
lr_scheduler = {
|
342 |
+
'scheduler': self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs),
|
343 |
+
'interval': 'step',
|
344 |
+
'frequency': 1
|
345 |
+
}
|
346 |
+
return [optimizer], [lr_scheduler]
|
347 |
+
else:
|
348 |
+
return [optimizer]
|
medical_diffusion/models/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .attention_blocks import *
|
2 |
+
from .conv_blocks import *
|
medical_diffusion/models/utils/attention_blocks.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from monai.networks.blocks import TransformerBlock
|
6 |
+
from monai.networks.layers.utils import get_norm_layer, get_dropout_layer
|
7 |
+
from monai.networks.layers.factories import Conv
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
|
11 |
+
class GEGLU(nn.Module):
|
12 |
+
def __init__(self, in_channels, out_channels):
|
13 |
+
super().__init__()
|
14 |
+
self.norm = nn.LayerNorm(in_channels)
|
15 |
+
self.proj = nn.Linear(in_channels, out_channels*2, bias=True)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
# x expected to be [B, C, *]
|
19 |
+
# Workaround as layer norm can't currently be applied on arbitrary dimension: https://github.com/pytorch/pytorch/issues/71465
|
20 |
+
b, c, *spatial = x.shape
|
21 |
+
x = x.reshape(b, c, -1).transpose(1, 2) # -> [B, C, N] -> [B, N, C]
|
22 |
+
x = self.norm(x)
|
23 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
24 |
+
x = x * F.gelu(gate)
|
25 |
+
return x.transpose(1, 2).reshape(b, -1, *spatial) # -> [B, C, N] -> [B, C, *]
|
26 |
+
|
27 |
+
def zero_module(module):
|
28 |
+
"""
|
29 |
+
Zero out the parameters of a module and return it.
|
30 |
+
"""
|
31 |
+
for p in module.parameters():
|
32 |
+
p.detach().zero_()
|
33 |
+
return module
|
34 |
+
|
35 |
+
def compute_attention(q,k,v , num_heads, scale):
|
36 |
+
q, k, v = map(lambda t: rearrange(t, 'b (h d) n -> (b h) d n', h=num_heads), (q, k, v)) # [(BxHeads), Dim_per_head, N]
|
37 |
+
|
38 |
+
attn = (torch.einsum('b d i, b d j -> b i j', q*scale, k*scale)).softmax(dim=-1) # Matrix product = [(BxHeads), Dim_per_head, N] * [(BxHeads), Dim_per_head, N'] =[(BxHeads), N, N']
|
39 |
+
|
40 |
+
out = torch.einsum('b i j, b d j-> b d i', attn, v) # Matrix product: [(BxHeads), N, N'] * [(BxHeads), Dim_per_head, N'] = [(BxHeads), Dim_per_head, N]
|
41 |
+
out = rearrange(out, '(b h) d n-> b (h d) n', h=num_heads) # -> [B, (Heads x Dim_per_head), N]
|
42 |
+
|
43 |
+
return out
|
44 |
+
|
45 |
+
|
46 |
+
class LinearTransformerNd(nn.Module):
|
47 |
+
""" Combines multi-head self-attention and multi-head cross-attention.
|
48 |
+
|
49 |
+
Multi-Head Self-Attention:
|
50 |
+
Similar to multi-head self-attention (https://arxiv.org/abs/1706.03762) without Norm+MLP (compare Monai TransformerBlock)
|
51 |
+
Proposed here: https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
52 |
+
Similar to: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/diffusionmodules/openaimodel.py#L278
|
53 |
+
Similar to: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L80
|
54 |
+
Similar to: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/dfbafee555bdae80b55d63a989073836bbfc257e/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L209
|
55 |
+
Similar to: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py#L150
|
56 |
+
|
57 |
+
CrossAttention:
|
58 |
+
Proposed here: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L152
|
59 |
+
|
60 |
+
"""
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
spatial_dims,
|
64 |
+
in_channels,
|
65 |
+
out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled
|
66 |
+
num_heads=8,
|
67 |
+
ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs)
|
68 |
+
norm_name=("GROUP", {'num_groups':32, "affine": True}), # Or use LayerNorm but be aware of https://github.com/pytorch/pytorch/issues/71465 (=> GroupNorm with num_groups=1)
|
69 |
+
dropout=None,
|
70 |
+
emb_dim=None,
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
hid_channels = num_heads*ch_per_head
|
74 |
+
self.num_heads = num_heads
|
75 |
+
self.scale = ch_per_head**-0.25 # Should be 1/sqrt("queries and keys of dimension"), Note: additional sqrt needed as it follows OpenAI: (q * scale) * (k * scale) instead of (q *k) * scale
|
76 |
+
|
77 |
+
self.norm_x = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels)
|
78 |
+
emb_dim = in_channels if emb_dim is None else emb_dim
|
79 |
+
|
80 |
+
Convolution = Conv["conv", spatial_dims]
|
81 |
+
self.to_q = Convolution(in_channels, hid_channels, 1)
|
82 |
+
self.to_k = Convolution(emb_dim, hid_channels, 1)
|
83 |
+
self.to_v = Convolution(emb_dim, hid_channels, 1)
|
84 |
+
|
85 |
+
self.to_out = nn.Sequential(
|
86 |
+
zero_module(Convolution(hid_channels, out_channels, 1)),
|
87 |
+
nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims)
|
88 |
+
)
|
89 |
+
|
90 |
+
def forward(self, x, embedding=None):
|
91 |
+
# x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *]
|
92 |
+
# if no embedding is given, cross-attention defaults to self-attention
|
93 |
+
|
94 |
+
# Normalize
|
95 |
+
b, c, *spatial = x.shape
|
96 |
+
x_n = self.norm_x(x)
|
97 |
+
|
98 |
+
# Attention: embedding (cross-attention) or x (self-attention)
|
99 |
+
if embedding is None:
|
100 |
+
embedding = x_n # WARNING: This assumes that emb_dim==in_channels
|
101 |
+
else:
|
102 |
+
if embedding.ndim == 2:
|
103 |
+
embedding = embedding.reshape(*embedding.shape[:2], *[1]*(x.ndim-2)) # [B, C*] -> [B, C*, *]
|
104 |
+
# Why no normalization for embedding here?
|
105 |
+
|
106 |
+
# Convolution
|
107 |
+
q = self.to_q(x_n) # -> [B, (Heads x Dim_per_head), *]
|
108 |
+
k = self.to_k(embedding) # -> [B, (Heads x Dim_per_head), *]
|
109 |
+
v = self.to_v(embedding) # -> [B, (Heads x Dim_per_head), *]
|
110 |
+
|
111 |
+
# Flatten
|
112 |
+
q = q.reshape(b, c, -1) # -> [B, (Heads x Dim_per_head), N]
|
113 |
+
k = k.reshape(*embedding.shape[:2], -1) # -> [B, (Heads x Dim_per_head), N']
|
114 |
+
v = v.reshape(*embedding.shape[:2], -1) # -> [B, (Heads x Dim_per_head), N']
|
115 |
+
|
116 |
+
# Apply attention
|
117 |
+
out = compute_attention(q, k, v, self.num_heads, self.scale)
|
118 |
+
|
119 |
+
out = out.reshape(*out.shape[:2], *spatial) # -> [B, (Heads x Dim_per_head), *]
|
120 |
+
out = self.to_out(out) # -> [B, C', *]
|
121 |
+
|
122 |
+
|
123 |
+
if x.shape == out.shape:
|
124 |
+
out = x + out
|
125 |
+
return out # [B, C', *]
|
126 |
+
|
127 |
+
|
128 |
+
class LinearTransformer(nn.Module):
|
129 |
+
""" See LinearTransformer, however this implementation is fixed to Conv1d/Linear"""
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
spatial_dims,
|
133 |
+
in_channels,
|
134 |
+
out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled
|
135 |
+
num_heads,
|
136 |
+
ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs)
|
137 |
+
norm_name=("GROUP", {'num_groups':32, "affine": True}),
|
138 |
+
dropout=None,
|
139 |
+
emb_dim=None
|
140 |
+
):
|
141 |
+
super().__init__()
|
142 |
+
hid_channels = num_heads*ch_per_head
|
143 |
+
self.num_heads = num_heads
|
144 |
+
self.scale = ch_per_head**-0.25 # Should be 1/sqrt("queries and keys of dimension"), Note: additional sqrt needed as it follows OpenAI: (q * scale) * (k * scale) instead of (q *k) * scale
|
145 |
+
|
146 |
+
self.norm_x = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels)
|
147 |
+
emb_dim = in_channels if emb_dim is None else emb_dim
|
148 |
+
|
149 |
+
# Note: Conv1d and Linear are interchangeable but order of input changes [B, C, N] <-> [B, N, C]
|
150 |
+
self.to_q = nn.Conv1d(in_channels, hid_channels, 1)
|
151 |
+
self.to_k = nn.Conv1d(emb_dim, hid_channels, 1)
|
152 |
+
self.to_v = nn.Conv1d(emb_dim, hid_channels, 1)
|
153 |
+
# self.to_qkv = nn.Conv1d(emb_dim, hid_channels*3, 1)
|
154 |
+
|
155 |
+
self.to_out = nn.Sequential(
|
156 |
+
zero_module(nn.Conv1d(hid_channels, out_channels, 1)),
|
157 |
+
nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims)
|
158 |
+
)
|
159 |
+
|
160 |
+
def forward(self, x, embedding=None):
|
161 |
+
# x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *]
|
162 |
+
# if no embedding is given, cross-attention defaults to self-attention
|
163 |
+
|
164 |
+
# Normalize
|
165 |
+
b, c, *spatial = x.shape
|
166 |
+
x_n = self.norm_x(x)
|
167 |
+
|
168 |
+
# Attention: embedding (cross-attention) or x (self-attention)
|
169 |
+
if embedding is None:
|
170 |
+
embedding = x_n # WARNING: This assumes that emb_dim==in_channels
|
171 |
+
else:
|
172 |
+
if embedding.ndim == 2:
|
173 |
+
embedding = embedding.reshape(*embedding.shape[:2], *[1]*(x.ndim-2)) # [B, C*] -> [B, C*, *]
|
174 |
+
# Why no normalization for embedding here?
|
175 |
+
|
176 |
+
# Flatten
|
177 |
+
x_n = x_n.reshape(b, c, -1) # [B, C, *] -> [B, C, N]
|
178 |
+
embedding = embedding.reshape(*embedding.shape[:2], -1) # [B, C*, *] -> [B, C*, N']
|
179 |
+
|
180 |
+
# Convolution
|
181 |
+
q = self.to_q(x_n) # -> [B, (Heads x Dim_per_head), N]
|
182 |
+
k = self.to_k(embedding) # -> [B, (Heads x Dim_per_head), N']
|
183 |
+
v = self.to_v(embedding) # -> [B, (Heads x Dim_per_head), N']
|
184 |
+
# qkv = self.to_qkv(x_n)
|
185 |
+
# q,k,v = qkv.split(qkv.shape[1]//3, dim=1)
|
186 |
+
|
187 |
+
# Apply attention
|
188 |
+
out = compute_attention(q, k, v, self.num_heads, self.scale)
|
189 |
+
|
190 |
+
out = self.to_out(out) # -> [B, C', N]
|
191 |
+
out = out.reshape(*out.shape[:2], *spatial) # -> [B, C', *]
|
192 |
+
|
193 |
+
if x.shape == out.shape:
|
194 |
+
out = x + out
|
195 |
+
return out # [B, C', *]
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
class BasicTransformerBlock(nn.Module):
|
201 |
+
def __init__(
|
202 |
+
self,
|
203 |
+
spatial_dims,
|
204 |
+
in_channels,
|
205 |
+
out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled
|
206 |
+
num_heads,
|
207 |
+
ch_per_head=32,
|
208 |
+
norm_name=("GROUP", {'num_groups':32, "affine": True}),
|
209 |
+
dropout=None,
|
210 |
+
emb_dim=None
|
211 |
+
):
|
212 |
+
super().__init__()
|
213 |
+
self.self_atn = LinearTransformer(spatial_dims, in_channels, in_channels, num_heads, ch_per_head, norm_name, dropout, None)
|
214 |
+
if emb_dim is not None:
|
215 |
+
self.cros_atn = LinearTransformer(spatial_dims, in_channels, in_channels, num_heads, ch_per_head, norm_name, dropout, emb_dim)
|
216 |
+
self.proj_out = nn.Sequential(
|
217 |
+
GEGLU(in_channels, in_channels*4),
|
218 |
+
nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims),
|
219 |
+
Conv["conv", spatial_dims](in_channels*4, out_channels, 1, bias=True)
|
220 |
+
)
|
221 |
+
|
222 |
+
|
223 |
+
def forward(self, x, embedding=None):
|
224 |
+
# x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *]
|
225 |
+
x = self.self_atn(x)
|
226 |
+
if embedding is not None:
|
227 |
+
x = self.cros_atn(x, embedding=embedding)
|
228 |
+
out = self.proj_out(x)
|
229 |
+
if out.shape[1] == x.shape[1]:
|
230 |
+
return out + x
|
231 |
+
return x
|
232 |
+
|
233 |
+
class SpatialTransformer(nn.Module):
|
234 |
+
""" Proposed here: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L218
|
235 |
+
Unrelated to: https://arxiv.org/abs/1506.02025
|
236 |
+
"""
|
237 |
+
def __init__(
|
238 |
+
self,
|
239 |
+
spatial_dims,
|
240 |
+
in_channels,
|
241 |
+
out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled
|
242 |
+
num_heads,
|
243 |
+
ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs)
|
244 |
+
norm_name = ("GROUP", {'num_groups':32, "affine": True}),
|
245 |
+
dropout=None,
|
246 |
+
emb_dim=None,
|
247 |
+
depth=1
|
248 |
+
):
|
249 |
+
super().__init__()
|
250 |
+
self.in_channels = in_channels
|
251 |
+
self.norm = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels)
|
252 |
+
conv_class = Conv["conv", spatial_dims]
|
253 |
+
hid_channels = num_heads*ch_per_head
|
254 |
+
|
255 |
+
self.proj_in = conv_class(
|
256 |
+
in_channels,
|
257 |
+
hid_channels,
|
258 |
+
kernel_size=1,
|
259 |
+
stride=1,
|
260 |
+
padding=0,
|
261 |
+
)
|
262 |
+
|
263 |
+
self.transformer_blocks = nn.ModuleList([
|
264 |
+
BasicTransformerBlock(spatial_dims, hid_channels, hid_channels, num_heads, ch_per_head, norm_name, dropout=dropout, emb_dim=emb_dim)
|
265 |
+
for _ in range(depth)]
|
266 |
+
)
|
267 |
+
|
268 |
+
self.proj_out = conv_class( # Note: zero_module is used in original code
|
269 |
+
hid_channels,
|
270 |
+
out_channels,
|
271 |
+
kernel_size=1,
|
272 |
+
stride=1,
|
273 |
+
padding=0,
|
274 |
+
)
|
275 |
+
|
276 |
+
def forward(self, x, embedding=None):
|
277 |
+
# x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *]
|
278 |
+
# Note: if no embedding is given, cross-attention is disabled
|
279 |
+
h = self.norm(x)
|
280 |
+
h = self.proj_in(h)
|
281 |
+
|
282 |
+
for block in self.transformer_blocks:
|
283 |
+
h = block(h, embedding=embedding)
|
284 |
+
|
285 |
+
h = self.proj_out(h) # -> [B, C'', *]
|
286 |
+
if h.shape == x.shape:
|
287 |
+
return h + x
|
288 |
+
return h
|
289 |
+
|
290 |
+
|
291 |
+
class Attention(nn.Module):
|
292 |
+
def __init__(
|
293 |
+
self,
|
294 |
+
spatial_dims,
|
295 |
+
in_channels,
|
296 |
+
out_channels,
|
297 |
+
num_heads=8,
|
298 |
+
ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs)
|
299 |
+
norm_name = ("GROUP", {'num_groups':32, "affine": True}),
|
300 |
+
dropout=0,
|
301 |
+
emb_dim=None,
|
302 |
+
depth=1,
|
303 |
+
attention_type='linear'
|
304 |
+
) -> None:
|
305 |
+
super().__init__()
|
306 |
+
if attention_type == 'spatial':
|
307 |
+
self.attention = SpatialTransformer(
|
308 |
+
spatial_dims=spatial_dims,
|
309 |
+
in_channels=in_channels,
|
310 |
+
out_channels=out_channels,
|
311 |
+
num_heads=num_heads,
|
312 |
+
ch_per_head=ch_per_head,
|
313 |
+
depth=depth,
|
314 |
+
norm_name=norm_name,
|
315 |
+
dropout=dropout,
|
316 |
+
emb_dim=emb_dim
|
317 |
+
)
|
318 |
+
elif attention_type == 'linear':
|
319 |
+
self.attention = LinearTransformer(
|
320 |
+
spatial_dims=spatial_dims,
|
321 |
+
in_channels=in_channels,
|
322 |
+
out_channels=out_channels,
|
323 |
+
num_heads=num_heads,
|
324 |
+
ch_per_head=ch_per_head,
|
325 |
+
norm_name=norm_name,
|
326 |
+
dropout=dropout,
|
327 |
+
emb_dim=emb_dim
|
328 |
+
)
|
329 |
+
|
330 |
+
|
331 |
+
def forward(self, x, emb=None):
|
332 |
+
if hasattr(self, 'attention'):
|
333 |
+
return self.attention(x, emb)
|
334 |
+
else:
|
335 |
+
return x
|
medical_diffusion/models/utils/conv_blocks.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Sequence, Tuple, Union, Type
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
from monai.networks.blocks.dynunet_block import get_padding, get_output_padding
|
10 |
+
from monai.networks.layers import Pool, Conv
|
11 |
+
from monai.networks.layers.utils import get_act_layer, get_norm_layer, get_dropout_layer
|
12 |
+
from monai.utils.misc import ensure_tuple_rep
|
13 |
+
|
14 |
+
from medical_diffusion.models.utils.attention_blocks import Attention, zero_module
|
15 |
+
|
16 |
+
def save_add(*args):
|
17 |
+
args = [arg for arg in args if arg is not None]
|
18 |
+
return sum(args) if len(args)>0 else None
|
19 |
+
|
20 |
+
|
21 |
+
class SequentialEmb(nn.Sequential):
|
22 |
+
def forward(self, input, emb):
|
23 |
+
for module in self:
|
24 |
+
input = module(input, emb)
|
25 |
+
return input
|
26 |
+
|
27 |
+
|
28 |
+
class BasicDown(nn.Module):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
spatial_dims,
|
32 |
+
in_channels,
|
33 |
+
out_channels,
|
34 |
+
kernel_size=3,
|
35 |
+
stride=2,
|
36 |
+
learnable_interpolation=True,
|
37 |
+
use_res=False
|
38 |
+
) -> None:
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
if learnable_interpolation:
|
42 |
+
Convolution = Conv[Conv.CONV, spatial_dims]
|
43 |
+
self.down_op = Convolution(
|
44 |
+
in_channels,
|
45 |
+
out_channels,
|
46 |
+
kernel_size=kernel_size,
|
47 |
+
stride=stride,
|
48 |
+
padding=get_padding(kernel_size, stride),
|
49 |
+
dilation=1,
|
50 |
+
groups=1,
|
51 |
+
bias=True,
|
52 |
+
)
|
53 |
+
|
54 |
+
if use_res:
|
55 |
+
self.down_skip = nn.PixelUnshuffle(2) # WARNING: Only supports 2D, , out_channels == 4*in_channels
|
56 |
+
|
57 |
+
else:
|
58 |
+
Pooling = Pool['avg', spatial_dims]
|
59 |
+
self.down_op = Pooling(
|
60 |
+
kernel_size=kernel_size,
|
61 |
+
stride=stride,
|
62 |
+
padding=get_padding(kernel_size, stride)
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
def forward(self, x, emb=None):
|
67 |
+
y = self.down_op(x)
|
68 |
+
if hasattr(self, 'down_skip'):
|
69 |
+
y = y+self.down_skip(x)
|
70 |
+
return y
|
71 |
+
|
72 |
+
class BasicUp(nn.Module):
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
spatial_dims,
|
76 |
+
in_channels,
|
77 |
+
out_channels,
|
78 |
+
kernel_size=2,
|
79 |
+
stride=2,
|
80 |
+
learnable_interpolation=True,
|
81 |
+
use_res=False,
|
82 |
+
) -> None:
|
83 |
+
super().__init__()
|
84 |
+
self.learnable_interpolation = learnable_interpolation
|
85 |
+
if learnable_interpolation:
|
86 |
+
# TransConvolution = Conv[Conv.CONVTRANS, spatial_dims]
|
87 |
+
# padding = get_padding(kernel_size, stride)
|
88 |
+
# output_padding = get_output_padding(kernel_size, stride, padding)
|
89 |
+
# self.up_op = TransConvolution(
|
90 |
+
# in_channels,
|
91 |
+
# out_channels,
|
92 |
+
# kernel_size=kernel_size,
|
93 |
+
# stride=stride,
|
94 |
+
# padding=padding,
|
95 |
+
# output_padding=output_padding,
|
96 |
+
# groups=1,
|
97 |
+
# bias=True,
|
98 |
+
# dilation=1
|
99 |
+
# )
|
100 |
+
|
101 |
+
self.calc_shape = lambda x: tuple((np.asarray(x)-1)*np.atleast_1d(stride)+np.atleast_1d(kernel_size)
|
102 |
+
-2*np.atleast_1d(get_padding(kernel_size, stride)))
|
103 |
+
Convolution = Conv[Conv.CONV, spatial_dims]
|
104 |
+
self.up_op = Convolution(
|
105 |
+
in_channels,
|
106 |
+
out_channels,
|
107 |
+
kernel_size=3,
|
108 |
+
stride=1,
|
109 |
+
padding=1,
|
110 |
+
dilation=1,
|
111 |
+
groups=1,
|
112 |
+
bias=True,
|
113 |
+
)
|
114 |
+
|
115 |
+
if use_res:
|
116 |
+
self.up_skip = nn.PixelShuffle(2) # WARNING: Only supports 2D, out_channels == in_channels/4
|
117 |
+
else:
|
118 |
+
self.calc_shape = lambda x: tuple((np.asarray(x)-1)*np.atleast_1d(stride)+np.atleast_1d(kernel_size)
|
119 |
+
-2*np.atleast_1d(get_padding(kernel_size, stride)))
|
120 |
+
|
121 |
+
def forward(self, x, emb=None):
|
122 |
+
if self.learnable_interpolation:
|
123 |
+
new_size = self.calc_shape(x.shape[2:])
|
124 |
+
x_res = F.interpolate(x, size=new_size, mode='nearest-exact')
|
125 |
+
y = self.up_op(x_res)
|
126 |
+
if hasattr(self, 'up_skip'):
|
127 |
+
y = y+self.up_skip(x)
|
128 |
+
return y
|
129 |
+
else:
|
130 |
+
new_size = self.calc_shape(x.shape[2:])
|
131 |
+
return F.interpolate(x, size=new_size, mode='nearest-exact')
|
132 |
+
|
133 |
+
|
134 |
+
class BasicBlock(nn.Module):
|
135 |
+
"""
|
136 |
+
A block that consists of Conv-Norm-Drop-Act, similar to blocks.Convolution.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
spatial_dims: number of spatial dimensions.
|
140 |
+
in_channels: number of input channels.
|
141 |
+
out_channels: number of output channels.
|
142 |
+
kernel_size: convolution kernel size.
|
143 |
+
stride: convolution stride.
|
144 |
+
norm_name: feature normalization type and arguments.
|
145 |
+
act_name: activation layer type and arguments.
|
146 |
+
dropout: dropout probability.
|
147 |
+
zero_conv: zero out the parameters of the convolution.
|
148 |
+
"""
|
149 |
+
|
150 |
+
def __init__(
|
151 |
+
self,
|
152 |
+
spatial_dims: int,
|
153 |
+
in_channels: int,
|
154 |
+
out_channels: int,
|
155 |
+
kernel_size: Union[Sequence[int], int],
|
156 |
+
stride: Union[Sequence[int], int]=1,
|
157 |
+
norm_name: Union[Tuple, str, None]=None,
|
158 |
+
act_name: Union[Tuple, str, None] = None,
|
159 |
+
dropout: Optional[Union[Tuple, str, float]] = None,
|
160 |
+
zero_conv: bool = False,
|
161 |
+
):
|
162 |
+
super().__init__()
|
163 |
+
Convolution = Conv[Conv.CONV, spatial_dims]
|
164 |
+
conv = Convolution(
|
165 |
+
in_channels,
|
166 |
+
out_channels,
|
167 |
+
kernel_size=kernel_size,
|
168 |
+
stride=stride,
|
169 |
+
padding=get_padding(kernel_size, stride),
|
170 |
+
dilation=1,
|
171 |
+
groups=1,
|
172 |
+
bias=True,
|
173 |
+
)
|
174 |
+
self.conv = zero_module(conv) if zero_conv else conv
|
175 |
+
|
176 |
+
if norm_name is not None:
|
177 |
+
self.norm = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
|
178 |
+
if dropout is not None:
|
179 |
+
self.drop = get_dropout_layer(name=dropout, dropout_dim=spatial_dims)
|
180 |
+
if act_name is not None:
|
181 |
+
self.act = get_act_layer(name=act_name)
|
182 |
+
|
183 |
+
|
184 |
+
def forward(self, inp):
|
185 |
+
out = self.conv(inp)
|
186 |
+
if hasattr(self, "norm"):
|
187 |
+
out = self.norm(out)
|
188 |
+
if hasattr(self, 'drop'):
|
189 |
+
out = self.drop(out)
|
190 |
+
if hasattr(self, "act"):
|
191 |
+
out = self.act(out)
|
192 |
+
return out
|
193 |
+
|
194 |
+
class BasicResBlock(nn.Module):
|
195 |
+
"""
|
196 |
+
A block that consists of Conv-Act-Norm + skip.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
spatial_dims: number of spatial dimensions.
|
200 |
+
in_channels: number of input channels.
|
201 |
+
out_channels: number of output channels.
|
202 |
+
kernel_size: convolution kernel size.
|
203 |
+
stride: convolution stride.
|
204 |
+
norm_name: feature normalization type and arguments.
|
205 |
+
act_name: activation layer type and arguments.
|
206 |
+
dropout: dropout probability.
|
207 |
+
zero_conv: zero out the parameters of the convolution.
|
208 |
+
"""
|
209 |
+
def __init__(
|
210 |
+
self,
|
211 |
+
spatial_dims: int,
|
212 |
+
in_channels: int,
|
213 |
+
out_channels: int,
|
214 |
+
kernel_size: Union[Sequence[int], int],
|
215 |
+
stride: Union[Sequence[int], int]=1,
|
216 |
+
norm_name: Union[Tuple, str, None]=None,
|
217 |
+
act_name: Union[Tuple, str, None] = None,
|
218 |
+
dropout: Optional[Union[Tuple, str, float]] = None,
|
219 |
+
zero_conv: bool = False
|
220 |
+
):
|
221 |
+
super().__init__()
|
222 |
+
self.basic_block = BasicBlock(spatial_dims, in_channels, out_channels, kernel_size, stride, norm_name, act_name, dropout, zero_conv)
|
223 |
+
Convolution = Conv[Conv.CONV, spatial_dims]
|
224 |
+
self.conv_res = Convolution(
|
225 |
+
in_channels,
|
226 |
+
out_channels,
|
227 |
+
kernel_size=1,
|
228 |
+
stride=stride,
|
229 |
+
padding=get_padding(1, stride),
|
230 |
+
dilation=1,
|
231 |
+
groups=1,
|
232 |
+
bias=True,
|
233 |
+
) if in_channels != out_channels else nn.Identity()
|
234 |
+
|
235 |
+
|
236 |
+
def forward(self, inp):
|
237 |
+
out = self.basic_block(inp)
|
238 |
+
residual = self.conv_res(inp)
|
239 |
+
out = out+residual
|
240 |
+
return out
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
class UnetBasicBlock(nn.Module):
|
245 |
+
"""
|
246 |
+
A modified version of monai.networks.blocks.UnetBasicBlock with additional embedding
|
247 |
+
|
248 |
+
Args:
|
249 |
+
spatial_dims: number of spatial dimensions.
|
250 |
+
in_channels: number of input channels.
|
251 |
+
out_channels: number of output channels.
|
252 |
+
kernel_size: convolution kernel size.
|
253 |
+
stride: convolution stride.
|
254 |
+
norm_name: feature normalization type and arguments.
|
255 |
+
act_name: activation layer type and arguments.
|
256 |
+
dropout: dropout probability.
|
257 |
+
emb_channels: Number of embedding channels
|
258 |
+
"""
|
259 |
+
|
260 |
+
def __init__(
|
261 |
+
self,
|
262 |
+
spatial_dims: int,
|
263 |
+
in_channels: int,
|
264 |
+
out_channels: int,
|
265 |
+
kernel_size: Union[Sequence[int], int],
|
266 |
+
stride: Union[Sequence[int], int]=1,
|
267 |
+
norm_name: Union[Tuple, str]=None,
|
268 |
+
act_name: Union[Tuple, str]=None,
|
269 |
+
dropout: Optional[Union[Tuple, str, float]] = None,
|
270 |
+
emb_channels: int = None,
|
271 |
+
blocks = 2
|
272 |
+
):
|
273 |
+
super().__init__()
|
274 |
+
self.block_seq = nn.ModuleList([
|
275 |
+
BasicBlock(spatial_dims, in_channels if i==0 else out_channels, out_channels, kernel_size, stride, norm_name, act_name, dropout, i==blocks-1)
|
276 |
+
for i in range(blocks)
|
277 |
+
])
|
278 |
+
|
279 |
+
if emb_channels is not None:
|
280 |
+
self.local_embedder = nn.Sequential(
|
281 |
+
get_act_layer(name=act_name),
|
282 |
+
nn.Linear(emb_channels, out_channels),
|
283 |
+
)
|
284 |
+
|
285 |
+
def forward(self, x, emb=None):
|
286 |
+
# ------------ Embedding ----------
|
287 |
+
if emb is not None:
|
288 |
+
emb = self.local_embedder(emb)
|
289 |
+
b,c, *_ = emb.shape
|
290 |
+
sp_dim = x.ndim-2
|
291 |
+
emb = emb.reshape(b, c, *((1,)*sp_dim) )
|
292 |
+
# scale, shift = emb.chunk(2, dim = 1)
|
293 |
+
# x = x * (scale + 1) + shift
|
294 |
+
# x = x+emb
|
295 |
+
|
296 |
+
# ----------- Convolution ---------
|
297 |
+
n_blocks = len(self.block_seq)
|
298 |
+
for i, block in enumerate(self.block_seq):
|
299 |
+
x = block(x)
|
300 |
+
if (emb is not None) and i<n_blocks:
|
301 |
+
x += emb
|
302 |
+
return x
|
303 |
+
|
304 |
+
|
305 |
+
class UnetResBlock(nn.Module):
|
306 |
+
"""
|
307 |
+
A modified version of monai.networks.blocks.UnetResBlock with additional skip connection and embedding
|
308 |
+
|
309 |
+
Args:
|
310 |
+
spatial_dims: number of spatial dimensions.
|
311 |
+
in_channels: number of input channels.
|
312 |
+
out_channels: number of output channels.
|
313 |
+
kernel_size: convolution kernel size.
|
314 |
+
stride: convolution stride.
|
315 |
+
norm_name: feature normalization type and arguments.
|
316 |
+
act_name: activation layer type and arguments.
|
317 |
+
dropout: dropout probability.
|
318 |
+
emb_channels: Number of embedding channels
|
319 |
+
"""
|
320 |
+
|
321 |
+
def __init__(
|
322 |
+
self,
|
323 |
+
spatial_dims: int,
|
324 |
+
in_channels: int,
|
325 |
+
out_channels: int,
|
326 |
+
kernel_size: Union[Sequence[int], int],
|
327 |
+
stride: Union[Sequence[int], int]=1,
|
328 |
+
norm_name: Union[Tuple, str]=None,
|
329 |
+
act_name: Union[Tuple, str]=None,
|
330 |
+
dropout: Optional[Union[Tuple, str, float]] = None,
|
331 |
+
emb_channels: int = None,
|
332 |
+
blocks = 2
|
333 |
+
):
|
334 |
+
super().__init__()
|
335 |
+
self.block_seq = nn.ModuleList([
|
336 |
+
BasicResBlock(spatial_dims, in_channels if i==0 else out_channels, out_channels, kernel_size, stride, norm_name, act_name, dropout, i==blocks-1)
|
337 |
+
for i in range(blocks)
|
338 |
+
])
|
339 |
+
|
340 |
+
if emb_channels is not None:
|
341 |
+
self.local_embedder = nn.Sequential(
|
342 |
+
get_act_layer(name=act_name),
|
343 |
+
nn.Linear(emb_channels, out_channels),
|
344 |
+
)
|
345 |
+
|
346 |
+
|
347 |
+
def forward(self, x, emb=None):
|
348 |
+
# ------------ Embedding ----------
|
349 |
+
if emb is not None:
|
350 |
+
emb = self.local_embedder(emb)
|
351 |
+
b,c, *_ = emb.shape
|
352 |
+
sp_dim = x.ndim-2
|
353 |
+
emb = emb.reshape(b, c, *((1,)*sp_dim) )
|
354 |
+
# scale, shift = emb.chunk(2, dim = 1)
|
355 |
+
# x = x * (scale + 1) + shift
|
356 |
+
# x = x+emb
|
357 |
+
|
358 |
+
# ----------- Convolution ---------
|
359 |
+
n_blocks = len(self.block_seq)
|
360 |
+
for i, block in enumerate(self.block_seq):
|
361 |
+
x = block(x)
|
362 |
+
if (emb is not None) and i<n_blocks-1:
|
363 |
+
x += emb
|
364 |
+
return x
|
365 |
+
|
366 |
+
|
367 |
+
|
368 |
+
class DownBlock(nn.Module):
|
369 |
+
def __init__(
|
370 |
+
self,
|
371 |
+
spatial_dims: int,
|
372 |
+
in_channels: int,
|
373 |
+
out_channels: int,
|
374 |
+
kernel_size: Union[Sequence[int], int],
|
375 |
+
stride: Union[Sequence[int], int],
|
376 |
+
downsample_kernel_size: Union[Sequence[int], int],
|
377 |
+
norm_name: Union[Tuple, str],
|
378 |
+
act_name: Union[Tuple, str],
|
379 |
+
dropout: Optional[Union[Tuple, str, float]] = None,
|
380 |
+
use_res_block: bool = False,
|
381 |
+
learnable_interpolation: bool = True,
|
382 |
+
use_attention: str = 'none',
|
383 |
+
emb_channels: int = None
|
384 |
+
):
|
385 |
+
super(DownBlock, self).__init__()
|
386 |
+
enable_down = ensure_tuple_rep(stride, spatial_dims) != ensure_tuple_rep(1, spatial_dims)
|
387 |
+
down_out_channels = out_channels if learnable_interpolation and enable_down else in_channels
|
388 |
+
|
389 |
+
# -------------- Down ----------------------
|
390 |
+
self.down_op = BasicDown(
|
391 |
+
spatial_dims,
|
392 |
+
in_channels,
|
393 |
+
out_channels,
|
394 |
+
kernel_size=downsample_kernel_size,
|
395 |
+
stride=stride,
|
396 |
+
learnable_interpolation=learnable_interpolation,
|
397 |
+
use_res=False
|
398 |
+
) if enable_down else nn.Identity()
|
399 |
+
|
400 |
+
|
401 |
+
# ---------------- Attention -------------
|
402 |
+
self.attention = Attention(
|
403 |
+
spatial_dims=spatial_dims,
|
404 |
+
in_channels=down_out_channels,
|
405 |
+
out_channels=down_out_channels,
|
406 |
+
num_heads=8,
|
407 |
+
ch_per_head=down_out_channels//8,
|
408 |
+
depth=1,
|
409 |
+
norm_name=norm_name,
|
410 |
+
dropout=dropout,
|
411 |
+
emb_dim=emb_channels,
|
412 |
+
attention_type=use_attention
|
413 |
+
)
|
414 |
+
|
415 |
+
# -------------- Convolution ----------------------
|
416 |
+
ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock
|
417 |
+
self.conv_block = ConvBlock(
|
418 |
+
spatial_dims,
|
419 |
+
down_out_channels,
|
420 |
+
out_channels,
|
421 |
+
kernel_size=kernel_size,
|
422 |
+
stride=1,
|
423 |
+
dropout=dropout,
|
424 |
+
norm_name=norm_name,
|
425 |
+
act_name=act_name,
|
426 |
+
emb_channels=emb_channels
|
427 |
+
)
|
428 |
+
|
429 |
+
|
430 |
+
def forward(self, x, emb=None):
|
431 |
+
# ----------- Down ---------
|
432 |
+
x = self.down_op(x)
|
433 |
+
|
434 |
+
# ----------- Attention -------------
|
435 |
+
if self.attention is not None:
|
436 |
+
x = self.attention(x, emb)
|
437 |
+
|
438 |
+
# ------------- Convolution --------------
|
439 |
+
x = self.conv_block(x, emb)
|
440 |
+
|
441 |
+
return x
|
442 |
+
|
443 |
+
|
444 |
+
class UpBlock(nn.Module):
|
445 |
+
def __init__(
|
446 |
+
self,
|
447 |
+
spatial_dims,
|
448 |
+
in_channels: int,
|
449 |
+
out_channels: int,
|
450 |
+
kernel_size: Union[Sequence[int], int],
|
451 |
+
stride: Union[Sequence[int], int],
|
452 |
+
upsample_kernel_size: Union[Sequence[int], int],
|
453 |
+
norm_name: Union[Tuple, str],
|
454 |
+
act_name: Union[Tuple, str],
|
455 |
+
dropout: Optional[Union[Tuple, str, float]] = None,
|
456 |
+
use_res_block: bool = False,
|
457 |
+
learnable_interpolation: bool = True,
|
458 |
+
use_attention: str = 'none',
|
459 |
+
emb_channels: int = None,
|
460 |
+
skip_channels: int = 0
|
461 |
+
):
|
462 |
+
super(UpBlock, self).__init__()
|
463 |
+
enable_up = ensure_tuple_rep(stride, spatial_dims) != ensure_tuple_rep(1, spatial_dims)
|
464 |
+
skip_out_channels = out_channels if learnable_interpolation and enable_up else in_channels+skip_channels
|
465 |
+
self.learnable_interpolation = learnable_interpolation
|
466 |
+
|
467 |
+
|
468 |
+
# -------------- Up ----------------------
|
469 |
+
self.up_op = BasicUp(
|
470 |
+
spatial_dims=spatial_dims,
|
471 |
+
in_channels=in_channels,
|
472 |
+
out_channels=out_channels,
|
473 |
+
kernel_size=upsample_kernel_size,
|
474 |
+
stride=stride,
|
475 |
+
learnable_interpolation=learnable_interpolation,
|
476 |
+
use_res=False
|
477 |
+
) if enable_up else nn.Identity()
|
478 |
+
|
479 |
+
# ---------------- Attention -------------
|
480 |
+
self.attention = Attention(
|
481 |
+
spatial_dims=spatial_dims,
|
482 |
+
in_channels=skip_out_channels,
|
483 |
+
out_channels=skip_out_channels,
|
484 |
+
num_heads=8,
|
485 |
+
ch_per_head=skip_out_channels//8,
|
486 |
+
depth=1,
|
487 |
+
norm_name=norm_name,
|
488 |
+
dropout=dropout,
|
489 |
+
emb_dim=emb_channels,
|
490 |
+
attention_type=use_attention
|
491 |
+
)
|
492 |
+
|
493 |
+
|
494 |
+
# -------------- Convolution ----------------------
|
495 |
+
ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock
|
496 |
+
self.conv_block = ConvBlock(
|
497 |
+
spatial_dims,
|
498 |
+
skip_out_channels,
|
499 |
+
out_channels,
|
500 |
+
kernel_size=kernel_size,
|
501 |
+
stride=1,
|
502 |
+
dropout=dropout,
|
503 |
+
norm_name=norm_name,
|
504 |
+
act_name=act_name,
|
505 |
+
emb_channels=emb_channels
|
506 |
+
)
|
507 |
+
|
508 |
+
|
509 |
+
|
510 |
+
def forward(self, x_enc, x_skip=None, emb=None):
|
511 |
+
# ----------- Up -------------
|
512 |
+
x = self.up_op(x_enc)
|
513 |
+
|
514 |
+
# ----------- Skip Connection ------------
|
515 |
+
if x_skip is not None:
|
516 |
+
if self.learnable_interpolation: # Channel of x_enc and x_skip are equal and summation is possible
|
517 |
+
x = x+x_skip
|
518 |
+
else:
|
519 |
+
x = torch.cat((x, x_skip), dim=1)
|
520 |
+
|
521 |
+
# ----------- Attention -------------
|
522 |
+
if self.attention is not None:
|
523 |
+
x = self.attention(x, emb)
|
524 |
+
|
525 |
+
# ----------- Convolution ------------
|
526 |
+
x = self.conv_block(x, emb)
|
527 |
+
|
528 |
+
return x
|
medical_diffusion/utils/math_utils.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def kl_gaussians(mean1, logvar1, mean2, logvar2):
|
4 |
+
""" Compute the KL divergence between two gaussians."""
|
5 |
+
return 0.5 * (logvar2-logvar1 + torch.exp(logvar1 - logvar2) + torch.pow(mean1 - mean2, 2) * torch.exp(-logvar2)-1.0)
|
6 |
+
|
medical_diffusion/utils/train_utils.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class EMAModel(nn.Module):
|
6 |
+
# See: https://github.com/huggingface/diffusers/blob/3100bc967084964480628ae61210b7eaa7436f1d/src/diffusers/training_utils.py#L42
|
7 |
+
"""
|
8 |
+
Exponential Moving Average of models weights
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
model,
|
14 |
+
update_after_step=0,
|
15 |
+
inv_gamma=1.0,
|
16 |
+
power=2 / 3,
|
17 |
+
min_value=0.0,
|
18 |
+
max_value=0.9999,
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
"""
|
22 |
+
@crowsonkb's notes on EMA Warmup:
|
23 |
+
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
24 |
+
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
25 |
+
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
26 |
+
at 215.4k steps).
|
27 |
+
Args:
|
28 |
+
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
29 |
+
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
30 |
+
min_value (float): The minimum EMA decay rate. Default: 0.
|
31 |
+
"""
|
32 |
+
|
33 |
+
self.averaged_model = copy.deepcopy(model).eval()
|
34 |
+
self.averaged_model.requires_grad_(False)
|
35 |
+
|
36 |
+
self.update_after_step = update_after_step
|
37 |
+
self.inv_gamma = inv_gamma
|
38 |
+
self.power = power
|
39 |
+
self.min_value = min_value
|
40 |
+
self.max_value = max_value
|
41 |
+
|
42 |
+
self.averaged_model = self.averaged_model #.to(device=model.device)
|
43 |
+
|
44 |
+
self.decay = 0.0
|
45 |
+
self.optimization_step = 0
|
46 |
+
|
47 |
+
def get_decay(self, optimization_step):
|
48 |
+
"""
|
49 |
+
Compute the decay factor for the exponential moving average.
|
50 |
+
"""
|
51 |
+
step = max(0, optimization_step - self.update_after_step - 1)
|
52 |
+
value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
53 |
+
|
54 |
+
if step <= 0:
|
55 |
+
return 0.0
|
56 |
+
|
57 |
+
return max(self.min_value, min(value, self.max_value))
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
def step(self, new_model):
|
61 |
+
ema_state_dict = {}
|
62 |
+
ema_params = self.averaged_model.state_dict()
|
63 |
+
|
64 |
+
self.decay = self.get_decay(self.optimization_step)
|
65 |
+
|
66 |
+
for key, param in new_model.named_parameters():
|
67 |
+
if isinstance(param, dict):
|
68 |
+
continue
|
69 |
+
try:
|
70 |
+
ema_param = ema_params[key]
|
71 |
+
except KeyError:
|
72 |
+
ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
|
73 |
+
ema_params[key] = ema_param
|
74 |
+
|
75 |
+
if not param.requires_grad:
|
76 |
+
ema_params[key].copy_(param.to(dtype=ema_param.dtype).data)
|
77 |
+
ema_param = ema_params[key]
|
78 |
+
else:
|
79 |
+
ema_param.mul_(self.decay)
|
80 |
+
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
|
81 |
+
|
82 |
+
ema_state_dict[key] = ema_param
|
83 |
+
|
84 |
+
for key, param in new_model.named_buffers():
|
85 |
+
ema_state_dict[key] = param
|
86 |
+
|
87 |
+
self.averaged_model.load_state_dict(ema_state_dict, strict=False)
|
88 |
+
self.optimization_step += 1
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch # pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
+
numpy
|
3 |
+
sklearn
|
4 |
+
pytorch-lightning
|
5 |
+
pytorch_msssim
|
6 |
+
monai
|
7 |
+
torchmetrics
|
8 |
+
torch-fidelity
|
9 |
+
torchio
|
10 |
+
pillow
|
11 |
+
einops
|
12 |
+
torchvision
|
13 |
+
matplotlib
|
14 |
+
pandas
|
15 |
+
lpips
|
16 |
+
|
17 |
+
streamlit
|
scripts/evaluate_images.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import logging
|
3 |
+
from datetime import datetime
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch.utils.data.dataloader import DataLoader
|
9 |
+
from torchvision.datasets import ImageFolder
|
10 |
+
from torch.utils.data import TensorDataset, Subset
|
11 |
+
from torchmetrics.image.fid import FrechetInceptionDistance as FID
|
12 |
+
from torchmetrics.image.inception import InceptionScore as IS
|
13 |
+
|
14 |
+
from medical_diffusion.metrics.torchmetrics_pr_recall import ImprovedPrecessionRecall
|
15 |
+
|
16 |
+
|
17 |
+
# ----------------Settings --------------
|
18 |
+
batch_size = 100
|
19 |
+
max_samples = None # set to None for all
|
20 |
+
# path_out = Path.cwd()/'results'/'MSIvsMSS_2'/'metrics'
|
21 |
+
# path_out = Path.cwd()/'results'/'AIROGS'/'metrics'
|
22 |
+
path_out = Path.cwd()/'results'/'CheXpert'/'metrics'
|
23 |
+
path_out.mkdir(parents=True, exist_ok=True)
|
24 |
+
|
25 |
+
|
26 |
+
# ----------------- Logging -----------
|
27 |
+
current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S")
|
28 |
+
logger = logging.getLogger()
|
29 |
+
logging.basicConfig(level=logging.INFO)
|
30 |
+
logger.addHandler(logging.FileHandler(path_out/f'metrics_{current_time}.log', 'w'))
|
31 |
+
|
32 |
+
# -------------- Helpers ---------------------
|
33 |
+
pil2torch = lambda x: torch.as_tensor(np.array(x)).moveaxis(-1, 0) # In contrast to ToTensor(), this will not cast 0-255 to 0-1 and destroy uint8 (required later)
|
34 |
+
|
35 |
+
|
36 |
+
# ---------------- Dataset/Dataloader ----------------
|
37 |
+
# ds_real = ImageFolder('/mnt/hdd/datasets/pathology/kather_msi_mss_2/train', transform=pil2torch)
|
38 |
+
# ds_fake = ImageFolder('/mnt/hdd/datasets/pathology/kather_msi_mss_2/synthetic_data/SYNTH-CRC-10K/', transform=pil2torch)
|
39 |
+
# ds_fake = ImageFolder('/mnt/hdd/datasets/pathology/kather_msi_mss_2/synthetic_data/diffusion2_250', transform=pil2torch)
|
40 |
+
|
41 |
+
# ds_real = ImageFolder('/mnt/hdd/datasets/eye/AIROGS/data_256x256_ref/', transform=pil2torch)
|
42 |
+
# ds_fake = ImageFolder('/mnt/hdd/datasets/eye/AIROGS/data_generated_stylegan3/', transform=pil2torch)
|
43 |
+
# ds_fake = ImageFolder('/mnt/hdd/datasets/eye/AIROGS/data_generated_diffusion', transform=pil2torch)
|
44 |
+
|
45 |
+
ds_real = ImageFolder('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/reference/', transform=pil2torch)
|
46 |
+
# ds_fake = ImageFolder('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/generated_progan/', transform=pil2torch)
|
47 |
+
ds_fake = ImageFolder('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/generated_diffusion3_250/', transform=pil2torch)
|
48 |
+
|
49 |
+
ds_real.samples = ds_real.samples[slice(max_samples)]
|
50 |
+
ds_fake.samples = ds_fake.samples[slice(max_samples)]
|
51 |
+
|
52 |
+
|
53 |
+
# --------- Select specific class ------------
|
54 |
+
# target_class = 'MSIH'
|
55 |
+
# ds_real = Subset(ds_real, [i for i in range(len(ds_real)) if ds_real.samples[i][1] == ds_real.class_to_idx[target_class]])
|
56 |
+
# ds_fake = Subset(ds_fake, [i for i in range(len(ds_fake)) if ds_fake.samples[i][1] == ds_fake.class_to_idx[target_class]])
|
57 |
+
|
58 |
+
# Only for testing metrics against OpenAI implementation
|
59 |
+
# ds_real = TensorDataset(torch.from_numpy(np.load('/home/gustav/Documents/code/guided-diffusion/data/VIRTUAL_imagenet64_labeled.npz')['arr_0']).swapaxes(1,-1))
|
60 |
+
# ds_fake = TensorDataset(torch.from_numpy(np.load('/home/gustav/Documents/code/guided-diffusion/data/biggan_deep_imagenet64.npz')['arr_0']).swapaxes(1,-1))
|
61 |
+
|
62 |
+
|
63 |
+
dm_real = DataLoader(ds_real, batch_size=batch_size, num_workers=8, shuffle=False, drop_last=False)
|
64 |
+
dm_fake = DataLoader(ds_fake, batch_size=batch_size, num_workers=8, shuffle=False, drop_last=False)
|
65 |
+
|
66 |
+
logger.info(f"Samples Real: {len(ds_real)}")
|
67 |
+
logger.info(f"Samples Fake: {len(ds_fake)}")
|
68 |
+
|
69 |
+
# ------------- Init Metrics ----------------------
|
70 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
71 |
+
calc_fid = FID().to(device) # requires uint8
|
72 |
+
# calc_is = IS(splits=1).to(device) # requires uint8, features must be 1008 see https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/evaluations/evaluator.py#L603
|
73 |
+
calc_pr = ImprovedPrecessionRecall(splits_real=1, splits_fake=1).to(device)
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
# --------------- Start Calculation -----------------
|
79 |
+
for real_batch in tqdm(dm_real):
|
80 |
+
imgs_real_batch = real_batch[0].to(device)
|
81 |
+
|
82 |
+
# -------------- FID -------------------
|
83 |
+
calc_fid.update(imgs_real_batch, real=True)
|
84 |
+
|
85 |
+
# ------ Improved Precision/Recall--------
|
86 |
+
calc_pr.update(imgs_real_batch, real=True)
|
87 |
+
|
88 |
+
# torch.save(torch.concat(calc_fid.real_features), 'real_fid.pt')
|
89 |
+
# torch.save(torch.concat(calc_pr.real_features), 'real_ipr.pt')
|
90 |
+
|
91 |
+
|
92 |
+
for fake_batch in tqdm(dm_fake):
|
93 |
+
imgs_fake_batch = fake_batch[0].to(device)
|
94 |
+
|
95 |
+
# -------------- FID -------------------
|
96 |
+
calc_fid.update(imgs_fake_batch, real=False)
|
97 |
+
|
98 |
+
# -------------- IS -------------------
|
99 |
+
# calc_is.update(imgs_fake_batch)
|
100 |
+
|
101 |
+
# ---- Improved Precision/Recall--------
|
102 |
+
calc_pr.update(imgs_fake_batch, real=False)
|
103 |
+
|
104 |
+
# torch.save(torch.concat(calc_fid.fake_features), 'fake_fid.pt')
|
105 |
+
# torch.save(torch.concat(calc_pr.fake_features), 'fake_ipr.pt')
|
106 |
+
|
107 |
+
# --------------- Load features --------------
|
108 |
+
# real_fid = torch.as_tensor(torch.load('real_fid.pt'), device=device)
|
109 |
+
# real_ipr = torch.as_tensor(torch.load('real_ipr.pt'), device=device)
|
110 |
+
# fake_fid = torch.as_tensor(torch.load('fake_fid.pt'), device=device)
|
111 |
+
# fake_ipr = torch.as_tensor(torch.load('fake_ipr.pt'), device=device)
|
112 |
+
|
113 |
+
# calc_fid.real_features = real_fid.chunk(batch_size)
|
114 |
+
# calc_pr.real_features = real_ipr.chunk(batch_size)
|
115 |
+
# calc_fid.fake_features = fake_fid.chunk(batch_size)
|
116 |
+
# calc_pr.fake_features = fake_ipr.chunk(batch_size)
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
# -------------- Summary -------------------
|
121 |
+
fid = calc_fid.compute()
|
122 |
+
logger.info(f"FID Score: {fid}")
|
123 |
+
|
124 |
+
# is_mean, is_std = calc_is.compute()
|
125 |
+
# logger.info(f"IS Score: mean {is_mean} std {is_std}")
|
126 |
+
|
127 |
+
precision, recall = calc_pr.compute()
|
128 |
+
logger.info(f"Precision: {precision}, Recall {recall} ")
|
129 |
+
|
scripts/evaluate_latent_embedder.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import logging
|
3 |
+
from datetime import datetime
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms.functional as tF
|
9 |
+
from torch.utils.data.dataloader import DataLoader
|
10 |
+
from torchvision.datasets import ImageFolder
|
11 |
+
from torch.utils.data import TensorDataset, Subset
|
12 |
+
|
13 |
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
|
14 |
+
from torchmetrics.functional import multiscale_structural_similarity_index_measure as mmssim
|
15 |
+
|
16 |
+
from medical_diffusion.models.embedders.latent_embedders import VAE
|
17 |
+
|
18 |
+
|
19 |
+
# ----------------Settings --------------
|
20 |
+
batch_size = 100
|
21 |
+
max_samples = None # set to None for all
|
22 |
+
target_class = None # None for no specific class
|
23 |
+
# path_out = Path.cwd()/'results'/'MSIvsMSS_2'/'metrics'
|
24 |
+
# path_out = Path.cwd()/'results'/'AIROGS'/'metrics'
|
25 |
+
path_out = Path.cwd()/'results'/'CheXpert'/'metrics'
|
26 |
+
path_out.mkdir(parents=True, exist_ok=True)
|
27 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
28 |
+
|
29 |
+
# ----------------- Logging -----------
|
30 |
+
current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S")
|
31 |
+
logger = logging.getLogger()
|
32 |
+
logging.basicConfig(level=logging.INFO)
|
33 |
+
logger.addHandler(logging.FileHandler(path_out/f'metrics_{current_time}.log', 'w'))
|
34 |
+
|
35 |
+
|
36 |
+
# -------------- Helpers ---------------------
|
37 |
+
pil2torch = lambda x: torch.as_tensor(np.array(x)).moveaxis(-1, 0) # In contrast to ToTensor(), this will not cast 0-255 to 0-1 and destroy uint8 (required later)
|
38 |
+
|
39 |
+
# ---------------- Dataset/Dataloader ----------------
|
40 |
+
ds_real = ImageFolder('/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/', transform=pil2torch)
|
41 |
+
# ds_real = ImageFolder('/mnt/hdd/datasets/eye/AIROGS/data_256x256_ref/', transform=pil2torch)
|
42 |
+
# ds_real = ImageFolder('/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/reference_test/', transform=pil2torch)
|
43 |
+
|
44 |
+
# ---------- Limit Sample Size
|
45 |
+
ds_real.samples = ds_real.samples[slice(max_samples)]
|
46 |
+
|
47 |
+
|
48 |
+
# --------- Select specific class ------------
|
49 |
+
if target_class is not None:
|
50 |
+
ds_real = Subset(ds_real, [i for i in range(len(ds_real)) if ds_real.samples[i][1] == ds_real.class_to_idx[target_class]])
|
51 |
+
dm_real = DataLoader(ds_real, batch_size=batch_size, num_workers=8, shuffle=False, drop_last=False)
|
52 |
+
|
53 |
+
logger.info(f"Samples Real: {len(ds_real)}")
|
54 |
+
|
55 |
+
|
56 |
+
# --------------- Load Model ------------------
|
57 |
+
model = VAE.load_from_checkpoint('runs/2022_12_12_133315_chest_vaegan/last_vae.ckpt')
|
58 |
+
model.to(device)
|
59 |
+
|
60 |
+
# from diffusers import StableDiffusionPipeline
|
61 |
+
# with open('auth_token.txt', 'r') as file:
|
62 |
+
# auth_token = file.read()
|
63 |
+
# pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32, use_auth_token=auth_token)
|
64 |
+
# model = pipe.vae
|
65 |
+
# model.to(device)
|
66 |
+
|
67 |
+
|
68 |
+
# ------------- Init Metrics ----------------------
|
69 |
+
calc_lpips = LPIPS().to(device)
|
70 |
+
|
71 |
+
|
72 |
+
# --------------- Start Calculation -----------------
|
73 |
+
mmssim_list, mse_list = [], []
|
74 |
+
for real_batch in tqdm(dm_real):
|
75 |
+
imgs_real_batch = real_batch[0].to(device)
|
76 |
+
|
77 |
+
imgs_real_batch = tF.normalize(imgs_real_batch/255, 0.5, 0.5) # [0, 255] -> [-1, 1]
|
78 |
+
with torch.no_grad():
|
79 |
+
imgs_fake_batch = model(imgs_real_batch)[0].clamp(-1, 1)
|
80 |
+
|
81 |
+
# -------------- LPIP -------------------
|
82 |
+
calc_lpips.update(imgs_real_batch, imgs_fake_batch) # expect input to be [-1, 1]
|
83 |
+
|
84 |
+
# -------------- MS-SSIM + MSE -------------------
|
85 |
+
for img_real, img_fake in zip(imgs_real_batch, imgs_fake_batch):
|
86 |
+
img_real, img_fake = (img_real+1)/2, (img_fake+1)/2 # [-1, 1] -> [0, 1]
|
87 |
+
mmssim_list.append(mmssim(img_real[None], img_fake[None], normalize='relu'))
|
88 |
+
mse_list.append(torch.mean(torch.square(img_real-img_fake)))
|
89 |
+
|
90 |
+
|
91 |
+
# -------------- Summary -------------------
|
92 |
+
mmssim_list = torch.stack(mmssim_list)
|
93 |
+
mse_list = torch.stack(mse_list)
|
94 |
+
|
95 |
+
lpips = 1-calc_lpips.compute()
|
96 |
+
logger.info(f"LPIPS Score: {lpips}")
|
97 |
+
logger.info(f"MS-SSIM: {torch.mean(mmssim_list)} ± {torch.std(mmssim_list)}")
|
98 |
+
logger.info(f"MSE: {torch.mean(mse_list)} ± {torch.std(mse_list)}")
|
scripts/helpers/dump_discrimnator.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import torch
|
3 |
+
from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN, VAE, VAEGAN
|
4 |
+
from pytorch_lightning.trainer import Trainer
|
5 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
6 |
+
|
7 |
+
path_root = Path('runs/2022_12_01_210017_patho_vaegan')
|
8 |
+
|
9 |
+
# Load model
|
10 |
+
model = VAEGAN.load_from_checkpoint(path_root/'last.ckpt')
|
11 |
+
# model = torch.load(path_root/'last.ckpt')
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
# Save model-part
|
16 |
+
# torch.save(model.vqvae, path_root/'last_vae.ckpt') # Not working
|
17 |
+
# ------ Ugly workaround ----------
|
18 |
+
checkpointing = ModelCheckpoint()
|
19 |
+
trainer = Trainer(callbacks=[checkpointing])
|
20 |
+
trainer.strategy._lightning_module = model.vqvae
|
21 |
+
trainer.model = model.vqvae
|
22 |
+
trainer.save_checkpoint(path_root/'last_vae.ckpt')
|
23 |
+
# -----------------
|
24 |
+
|
25 |
+
model = VAE.load_from_checkpoint(path_root/'last_vae.ckpt')
|
26 |
+
# model = torch.load(path_root/'last_vae.ckpt') # load_state_dict
|