pidajay commited on
Commit
83e314f
1 Parent(s): 4da0685

Commited model weights and demo code

Browse files
LICENSE.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6
+
7
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8
+
9
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
README.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ library_name: diffusers
6
+ tags:
7
+ - MRI
8
+ - medical-imaging
9
+ - VAE
10
+ - autoencoder
11
+ ---
12
+ # MRI Autoencoder v0.1
13
+
14
+ ## Model
15
+ MRI autoencoder is a Variational Autoencoder (VAE) trained on the fast MRI multi-coil brain and knee datasets. The model is trained from scratch and uses the same architecture as the Stable Diffusion SDXL VAE model.
16
+
17
+ Latent Diffusion Models (LDMs) have been extremely popular in synthesizing images and videos. However, they remain relatively under-explored in the field of medical imaging. One possible reason is the lack of domain specific autoencoders that can encode and decode higher dimensional medical imaging data to their lower dimensional latent representation. MRI images, for example, are different than general domain images in that they are complex valued with magnitude and phase information. To this end, we are publishing an autoencoder that can be used to encode and decode complex valued MRI images to and from their latent representation.
18
+
19
+ ## Use
20
+
21
+ ```
22
+ from diffusers.models import AutoencoderKL
23
+ autoencoder = AutoencoderKL.from_pretrained("microsoft/mri-autoencoder-v0.1")
24
+ ```
25
+
26
+ For more details please refer to the provided autoencoders_demo notebook. For details on how the fastmri data was preprocessed, please refer to data_preprocessing_recipe.py.
27
+
28
+ ## Intended Use
29
+
30
+ The model is intended to be used solely for future research in medical imaging. Stakeholders would benefit by treating this model as a building block towards exploring latent space generative models applied to complex valued MRI images.
31
+
32
+ ## Out-of-Scope Use
33
+
34
+ Any deployed use case of the model, commercial or otherwise, is out of scope. The model weights and code are not intended for clinical use.
35
+
36
+ ## Evaluation
37
+
38
+ The PSNR and SSIM scores on randomly chosen 8000 slices from the fastMRI multicoil validation dataset are as follows:
39
+
40
+ | Autoencoder | Median PSNR | Mean PSNR | PSNR 95% CI | Median SSIM | Mean SSIM | SSIM 95% CI |
41
+ | ----------- | ----------- | --------- | ----------- | ----------- | --------- | ----------- |
42
+ | MRI-AUTOENCODER-v0.1 | 34.31 | 33.98 | (28.55. 37.79) | 0.91 | 0.88 | (0.54, 0.97) |
43
+ | SDXL-VAE | 31.45 | 31.51 | (27.85, 35.63) | 0.89 | 0.86 | (0.58, 0.94) |
44
+
45
+ ## Data
46
+
47
+ This model was trained, with permission, using the NYU fastMRI Dataset (https://fastmri.med.nyu.edu/), which is a deidentified imaging dataset provided by NYU Langone comprised of raw k-space data in several sub-dataset groups.
48
+
49
+ ## Limitations
50
+
51
+ A model trained on this dataset might likely overfit and not generalize well to new data. This model has not been evaluated for clinical use or across a range of scanner types.
autoencoders_demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.24.0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512
9
+ ],
10
+ "down_block_types": [
11
+ "DownEncoderBlock2D",
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D"
14
+ ],
15
+ "force_upcast": true,
16
+ "in_channels": 2,
17
+ "latent_channels": 4,
18
+ "layers_per_block": 2,
19
+ "norm_num_groups": 32,
20
+ "out_channels": 2,
21
+ "sample_size": 256,
22
+ "scaling_factor": 0.18215,
23
+ "up_block_types": [
24
+ "UpDecoderBlock2D",
25
+ "UpDecoderBlock2D",
26
+ "UpDecoderBlock2D"
27
+ ]
28
+ }
data_preprocessing_recipe.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''' This file contains the recipe for data preprocessing used to generate the combined coil images from the fastmri multicoil brain and knee datasets.
2
+ These combined coil images were then used to train the autoencoder. The combined coil images are generated by combining the coil images using
3
+ the sensitivity maps calculated with bart. To run this recipe, the bart toolbox needs to be installed and then follow the steps outlined in
4
+ the preprocess_recipe function.'''
5
+
6
+
7
+ # bart toolbox installation instructions - https://mrirecon.github.io/bart/installation.html
8
+ _BART_TOOLBOX_PATH = ''
9
+
10
+ import numpy as np
11
+ import h5py
12
+ from tqdm import tqdm
13
+ import sys, os
14
+
15
+ os.environ["TOOLBOX_PATH"] = _BART_TOOLBOX_PATH
16
+ sys.path.append(os.path.join(_BART_TOOLBOX_PATH, 'python'))
17
+ from bart import bart
18
+ os.environ["OMP_NUM_THREADS"] = "1"
19
+
20
+ def fftc(input, axes=None, norm='ortho'):
21
+ """
22
+ Perform a Fast Fourier Transform on the input array.
23
+
24
+ Parameters:
25
+ input (numpy.ndarray): The input array to transform.
26
+ axes (tuple, optional): Axes over which to compute the FFT. If not specified, compute over all axes.
27
+ norm (str, optional): Normalization mode. Default is 'ortho' for orthonormal transform.
28
+
29
+ Returns:
30
+ numpy.ndarray: The transformed output array.
31
+ """
32
+ tmp = np.fft.ifftshift(input, axes=axes)
33
+ tmp = np.fft.fftn(tmp, axes=axes, norm=norm)
34
+ output = np.fft.fftshift(tmp, axes=axes)
35
+ return output
36
+
37
+ def ifftc(input, axes=None, norm='ortho'):
38
+ """
39
+ Perform an Inverse Fast Fourier Transform on the input array.
40
+
41
+ Parameters:
42
+ input (numpy.ndarray): The input array to transform.
43
+ axes (tuple, optional): Axes over which to compute the inverse FFT. If not specified, compute over all axes.
44
+ norm (str, optional): Normalization mode. Default is 'ortho' for orthonormal transform.
45
+
46
+ Returns:
47
+ numpy.ndarray: The transformed output array.
48
+ """
49
+ tmp = np.fft.ifftshift(input, axes=axes)
50
+ tmp = np.fft.ifftn(tmp, axes=axes, norm=norm)
51
+ output = np.fft.fftshift(tmp, axes=axes)
52
+ return output
53
+
54
+ def adjoint(ksp, maps, mask):
55
+ """
56
+ Perform the adjoint operation on k-space data with coil sensitivity maps and a mask.
57
+
58
+ Parameters:
59
+ ksp (numpy.ndarray): The input k-space data, shape: [1, C, H, W].
60
+ maps (numpy.ndarray): The coil sensitivity maps, shape: [1, C, H, W].
61
+ mask (numpy.ndarray): The mask to apply on the k-space data, shape: [1, 1, H, W].
62
+
63
+ Returns:
64
+ numpy.ndarray: The output image after applying the adjoint operation, shape: [1, 1, H, W].
65
+ """
66
+ masked_ksp = ksp*mask
67
+ coil_imgs = ifftc(masked_ksp,axes=(-2,-1))
68
+ img_out = np.sum(coil_imgs*np.conj(maps),axis=1)[:,None,...]
69
+ return img_out
70
+
71
+ def _expand_shapes(*shapes):
72
+ """
73
+ Expand the dimensions of the given shapes to match the maximum dimension.
74
+
75
+ This function prepends 1s to the shapes with fewer dimensions to match the maximum number of dimensions.
76
+
77
+ Parameters:
78
+ *shapes (tuple): A variable length tuple containing shapes (as lists or tuples of integers).
79
+
80
+ Returns:
81
+ tuple: A tuple of expanded shapes, where each shape is a list of integers.
82
+ """
83
+
84
+ shapes = [list(shape) for shape in shapes]
85
+ max_ndim = max(len(shape) for shape in shapes)
86
+ shapes_exp = [[1] * (max_ndim - len(shape)) + shape
87
+ for shape in shapes]
88
+
89
+ return tuple(shapes_exp)
90
+
91
+ def resize(input, oshape, ishift=None, oshift=None):
92
+ """
93
+ Resize with zero-padding or cropping.
94
+
95
+ Parameters:
96
+ input (array): Input array.
97
+ oshape (tuple of ints): Output shape.
98
+ ishift (None or tuple of ints): Input shift.
99
+ oshift (None or tuple of ints): Output shift.
100
+
101
+ Returns:
102
+ array: Zero-padded or cropped result.
103
+ """
104
+
105
+ ishape1, oshape1 = _expand_shapes(input.shape, oshape)
106
+
107
+ if ishape1 == oshape1:
108
+ return input.reshape(oshape)
109
+
110
+ if ishift is None:
111
+ ishift = [max(i // 2 - o // 2, 0) for i, o in zip(ishape1, oshape1)]
112
+
113
+ if oshift is None:
114
+ oshift = [max(o // 2 - i // 2, 0) for i, o in zip(ishape1, oshape1)]
115
+
116
+ copy_shape = [min(i - si, o - so)
117
+ for i, si, o, so in zip(ishape1, ishift, oshape1, oshift)]
118
+ islice = tuple([slice(si, si + c) for si, c in zip(ishift, copy_shape)])
119
+ oslice = tuple([slice(so, so + c) for so, c in zip(oshift, copy_shape)])
120
+
121
+ output = np.zeros(oshape1, dtype=input.dtype)
122
+ input = input.reshape(ishape1)
123
+ output[oslice] = input[islice]
124
+
125
+ return output.reshape(oshape)
126
+
127
+ def shape_data(ksp, final_res):
128
+ """
129
+ Reshape coil k-space data to output coil images with isotropic pixels and correct FOV = origional image width and the correct square image size given by "final_res".
130
+
131
+ This function assumes that the k-space data has already been padded to make the corresponding images have isotropic pixels.
132
+
133
+ Parameters:
134
+ ksp (numpy.ndarray): The input coil k-space data, shape: [S, C, H, W].
135
+ final_res (int): The final resolution for the output image.
136
+
137
+ Returns:
138
+ numpy.ndarray: The output image after reshaping, shape: [S, C, final_res, final_res].
139
+ """
140
+ H = ksp.shape[-2]
141
+ W = ksp.shape[-1]
142
+ S = ksp.shape[0]
143
+ C = ksp.shape[1]
144
+ # bring the coil ksp into coil image space
145
+ img1 = ifftc(ksp,axes=(-2,-1))
146
+ img1_cropped = resize(img1, oshape=(S,C,W,W))
147
+ # FOV is now the same in both directions without modifying the resolution
148
+ ksp1 = fftc(img1_cropped,axes=(-2,-1))
149
+ # crop or pad the ksp isotropically in fourier space to the correct image size while mainting the same field of view (in width direction) in the original image
150
+ ksp1_cropped = resize(ksp1, oshape=(S,C,final_res,final_res))
151
+ img_out = ifftc(ksp1_cropped,axes=(-2,-1))
152
+
153
+ return img_out
154
+
155
+ def read_fastmri_data(file_path):
156
+ """
157
+ This function reads k-space data from a .h5 file.
158
+
159
+ Parameters:
160
+ file_path (str): The path to the .h5 file containing FastMRI data.
161
+
162
+ Returns:
163
+ numpy.ndarray: The k-space data as a numpy array.
164
+ """
165
+ hf = h5py.File(file_path, 'r')
166
+ ksp = np.asarray(hf['kspace'])
167
+ return ksp
168
+
169
+ def combine_coils(ksp):
170
+ """
171
+ Combine multi-coil k-space data into a single coil image.
172
+
173
+ This function reshapes the raw multi-coil k-space data, calculates sensitivity maps for the reshaped data using the BART tool's 'ecalib' command, and then uses these maps to create a single coil image via a fully sampled adjoint operation.
174
+
175
+ Parameters:
176
+ ksp (numpy.ndarray): The input multi-coil k-space data, shape: [B, C, H, W].
177
+
178
+ Returns:
179
+ numpy.ndarray: The output single coil image, shape: [B, 1, H, W].
180
+ """
181
+ # reshape raw multi-coil kspace to desired shape (ex [B,C,256,256])
182
+ coil_img_rs = shape_data(ksp, final_res=256)
183
+ coil_ksp_rs = fftc(coil_img_rs, axes=(-2,-1))
184
+
185
+ # calculate sensitivity maps for reshaped coil ksp
186
+ ksp_rs = coil_ksp_rs.transpose((2,3,0,1))
187
+ maps = np.array(ksp_rs)
188
+ #calculate Espirit maps with bart
189
+ for j in tqdm(range(ksp_rs.shape[2])):
190
+ sens = bart(1,'ecalib -m1 -W -c0', ksp_rs[:,:,j,None,:])#requires data of the form (Row,Column,None,Coil)<-output of ecalib too, this should then be saved (slice, coil, rows, columns)
191
+ maps[:,:,j,:] = sens[:,:,0,:]
192
+
193
+ maps_rs = maps.transpose((2,3,0,1))
194
+ # use new maps to create single coil image via fully sampled adjoint operation
195
+ single_coil_rs_img = adjoint(ksp=coil_ksp_rs, maps = maps_rs, mask = np.ones_like(coil_ksp_rs))
196
+ return single_coil_rs_img
197
+
198
+ def preprocess_data_recipe():
199
+ # for each file in the fastMRI dataset
200
+ # call read_fastmri_data to get the kspace data
201
+ # call combine_coils to create the combined coil image
202
+ pass
data_utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def complex_to_two_channel_image(complex_img: np.ndarray) -> np.ndarray:
4
+ """Converts a complex valued image to a 2 channel image (real and imaginary channels)"""
5
+ real, imag = np.real(complex_img), np.imag(complex_img)
6
+ return np.concatenate((real, imag), axis=0)
7
+
8
+ def two_channel_to_complex_image(two_ch_img: np.ndarray) -> np.ndarray:
9
+ """Converts a 2 channel image (real and imaginary channels) to a complex valued image"""
10
+ two_ch_img = two_ch_img[0]
11
+ real = two_ch_img[0]
12
+ imag = two_ch_img[1]
13
+ complex_image = real + 1j*imag
14
+ return complex_image[None,...]
15
+
16
+ def normalize_complex_coil_image(complex_coil_img: np.ndarray) -> np.ndarray:
17
+ """Scales the complex valued coil image """
18
+ max_val = np.percentile(np.abs(complex_coil_img), 99.5)
19
+ return complex_coil_img / max_val
20
+
21
+ def create_three_channel_image(complex_coil_img: np.ndarray) -> np.ndarray:
22
+ """Converts a complex valued coil image to a 3 channel image (magnitude channels repated 3 times)"""
23
+ mag = np.abs(complex_coil_img)
24
+ return np.concatenate((mag, mag, mag), axis=0)
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b07cfaad692d5e60669b5bfe0432de71eb11ad6925bf9e0bd333b69d15c5e62
3
+ size 221317280
example_data/mri_complex_images.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69a6b217303b2147957ac2f1dbcee976b5ed509621acd94ca55110a1c8f02e5c
3
+ size 2022432
inference.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import data_utils as du
3
+
4
+ def run_inference_two_channels(coil_complex_image, autoencoder, device="cuda"):
5
+ coil_complex_image = du.normalize_complex_coil_image(coil_complex_image)
6
+ two_channel_image = du.complex_to_two_channel_image(coil_complex_image)
7
+ two_channel_tensor = torch.from_numpy(two_channel_image)[None,...].type(torch.FloatTensor).to(device)
8
+ autoencoder = autoencoder.to(device)
9
+ with torch.no_grad():
10
+ autoencoder_output = autoencoder.encode(two_channel_tensor)
11
+ latents = autoencoder_output.latent_dist.mean
12
+ decoded_image = autoencoder.decode(latents).sample
13
+ recon = du.two_channel_to_complex_image(decoded_image.detach().cpu().numpy())
14
+ input = coil_complex_image
15
+ return input, recon
16
+
17
+ def run_inference_three_channels(coil_complex_image, autoencoder, device="cuda"):
18
+ coil_complex_image = du.normalize_complex_coil_image(coil_complex_image)
19
+ three_channel_image = du.create_three_channel_image(coil_complex_image)
20
+ three_channel_tensor = torch.from_numpy(three_channel_image)[None,...].type(torch.FloatTensor).to(device)
21
+ autoencoder = autoencoder.to(device)
22
+ with torch.no_grad():
23
+ autoencoder_output = autoencoder.encode(three_channel_tensor)
24
+ latents = autoencoder_output.latent_dist.mean
25
+ decoded_image = autoencoder.decode(latents).sample
26
+ recon = decoded_image[0].detach().cpu().numpy()
27
+ input = three_channel_image
28
+ return input, recon
metrics.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from skimage.metrics import peak_signal_noise_ratio, structural_similarity
4
+ from typing import Optional
5
+
6
+ def ssim(
7
+ gt: np.ndarray, pred: np.ndarray, data_range: Optional[float] = None
8
+ ) -> np.ndarray:
9
+ """Compute Structural Similarity Index Metric (SSIM)"""
10
+ if not gt.ndim == 3:
11
+ raise ValueError("Unexpected number of dimensions in ground truth.")
12
+ if not gt.ndim == pred.ndim:
13
+ raise ValueError("Ground truth dimensions does not match pred.")
14
+
15
+ data_range = gt.max() if data_range is None else data_range
16
+
17
+ ssim = np.array([0])
18
+ for slice_num in range(gt.shape[0]):
19
+ ssim = ssim + structural_similarity(
20
+ gt[slice_num], pred[slice_num], data_range=data_range
21
+ )
22
+
23
+ return ssim / gt.shape[0]
24
+
25
+ def psnr(
26
+ gt: np.ndarray, pred: np.ndarray, data_range: Optional[float] = None
27
+ ) -> np.ndarray:
28
+ """Compute Peak Signal to Noise Ratio metric (PSNR)"""
29
+ data_range = gt.max() if data_range is None else data_range
30
+ return peak_signal_noise_ratio(gt, pred, data_range=data_range)