osbm commited on
Commit
c989bc3
·
1 Parent(s): d5dfe32

update app

Browse files
Files changed (1) hide show
  1. app.py +97 -24
app.py CHANGED
@@ -2,9 +2,15 @@ import gradio as gr
2
 
3
  from fastmri.data.subsample import create_mask_for_mask_type
4
  from fastmri.data.transforms import apply_mask, to_tensor, center_crop
5
- from pytorch_msssim import ssim
 
6
 
 
 
 
7
 
 
 
8
 
9
  # st.title('FastMRI Kspace Reconstruction Masks')
10
  # st.write('This app allows you to visualize the masks and their effects on the kspace data.')
@@ -15,41 +21,108 @@ def main_func(
15
  mask_center_fractions: int,
16
  accelerations: int,
17
  seed: int,
18
- input_image: str,
 
19
  ):
20
 
21
- file_dict = {
22
- "knee 1": "knee_singlecoil_train/file1000002.h5",
23
- "knee 2": "knee_singlecoil_train/file1000003.h5",
24
- "brain 1": "brain_axial_train/file1000002.h5",
25
- "prostate 1": "prostate_t1_tse_train/file1000002.h5",
26
- "prostate 2": "prostate_t2_tse_train/file1000002.h5",
27
- }
28
- input_file = file_dict[input_image]
 
29
 
 
 
 
30
  mask_func = create_mask_for_mask_type(
31
  mask_name, center_fractions=[mask_center_fractions], accelerations=[accelerations]
32
  )
33
- mask =
34
- masked_kspace, mask = mask(input_image, return_mask=True)
35
- return masked_kspace, mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  demo = gr.Interface(
38
  fn=main_func,
39
  inputs=[
40
- gr.inputs.Radio(['random', 'equispaced'], label="Mask Type"),
41
- gr.inputs.Slider(minimum=0.04, maximum=0.4, default=0.08, label="Center Fraction"),
42
- gr.inputs.Number(default=4, label="Acceleration"),
43
- gr.inputs.Number(default=0, label="Seed"),
44
- gr.inputs.Radio(["knee 1", "knee 2", "brain 1", "prostate 1", "prostate 2"], label="Input Image")
 
45
  ],
46
  outputs=[
47
- gr.outputs.Image(type="mask", label="Mask"),
48
- gr.outputs.Image(type="kspace", label="Masked Kspace"),
49
- gr.outputs.Image(type="kspace", label="Reconstructed Image"),
50
- gr.outputs.Image(type="kspace", label="Original Image"),
51
-
52
- gr.outputs.Dataframe()
53
  ],
54
  title="FastMRI Kspace Reconstruction Masks",
55
  description="This app allows you to visualize the masks and their effects on the kspace data."
 
2
 
3
  from fastmri.data.subsample import create_mask_for_mask_type
4
  from fastmri.data.transforms import apply_mask, to_tensor, center_crop
5
+ import skimage
6
+ import fastmri
7
 
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
 
12
+ import matplotlib.pyplot as plt
13
+ import uuid
14
 
15
  # st.title('FastMRI Kspace Reconstruction Masks')
16
  # st.write('This app allows you to visualize the masks and their effects on the kspace data.')
 
21
  mask_center_fractions: int,
22
  accelerations: int,
23
  seed: int,
24
+ slice_index: int,
25
+ # input_image: str,
26
  ):
27
 
28
+ # file_dict = {
29
+ # "knee singlecoil": "data/knee1_kspace.npy",
30
+ # "knee multicoil": "data/knee2_kspace.npy",
31
+ # "brain multicoil 1": "data/brain1_kspace.npy",
32
+ # "brain multicoil 2": "data/brain2_kspace.npy",
33
+ # "prostate multicoil 1": "data/prostate1_kspace.npy",
34
+ # "prostate multicoil 2": "data/prostate2_kspace.npy",
35
+ # }
36
+ # input_file_path = file_dict[input_image]
37
 
38
+ # kspace = np.load(input_file_path)
39
+ kspace = np.load("data/prostate1_kspace.npy")
40
+ kspace = to_tensor(kspace)
41
  mask_func = create_mask_for_mask_type(
42
  mask_name, center_fractions=[mask_center_fractions], accelerations=[accelerations]
43
  )
44
+ subsampled_kspace, mask, num_low_frequencies = apply_mask(
45
+ kspace,
46
+ mask_func,
47
+ seed=seed,
48
+ )
49
+
50
+ print(mask.shape)
51
+ print(subsampled_kspace.shape)
52
+ print(kspace.shape)
53
+
54
+ mask = mask.squeeze() # 451
55
+ mask = mask.unsqueeze(0) # 1, 451
56
+ mask = mask.repeat(subsampled_kspace.shape[-3], 1).cpu().numpy()
57
+
58
+ print(mask.shape)
59
+ print()
60
+
61
+ subsampled_kspace = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(subsampled_kspace)), dim=1)
62
+ kspace = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace)), dim=1)
63
+
64
+ print(subsampled_kspace.shape)
65
+ print(kspace.shape)
66
+
67
+ subsampled_kspace = subsampled_kspace[slice_index]
68
+ kspace = kspace[slice_index]
69
+
70
+ print(subsampled_kspace.shape)
71
+ print(kspace.shape)
72
+
73
+
74
+ subsampled_kspace = center_crop(subsampled_kspace, (320, 320))
75
+ kspace = center_crop(kspace, (320, 320))
76
+
77
+ # now that we have the reconstructions, we can calculate the SSIM and psnr
78
+ kspace = kspace.cpu().numpy()
79
+ subsampled_kspace = subsampled_kspace.cpu().numpy()
80
+
81
+
82
+
83
+ ssim = skimage.metrics.structural_similarity(subsampled_kspace, kspace, data_range=kspace.max() - kspace.min())
84
+ psnr = skimage.metrics.peak_signal_noise_ratio(subsampled_kspace, kspace, data_range=kspace.max() - kspace.min())
85
+
86
+ df = pd.DataFrame({"SSIM": [ssim], "PSNR": [psnr], "Num Low Frequencies": [num_low_frequencies]})
87
+ print(df)
88
+
89
+ # create a plot
90
+ fig, ax = plt.subplots(1, 3, figsize=(15, 5))
91
+ ax[0].imshow(mask, cmap="gray")
92
+ ax[0].set_title("Mask")
93
+ ax[0].axis("off")
94
+
95
+ ax[1].imshow(subsampled_kspace, cmap="gray")
96
+ ax[1].set_title("Reconstructed Image")
97
+ ax[1].axis("off")
98
+
99
+ ax[2].imshow(kspace, cmap="gray")
100
+ ax[2].set_title("Original Image")
101
+ ax[2].axis("off")
102
+
103
+ plt.tight_layout()
104
+ plot_filename = f"data/{uuid.uuid4()}.png"
105
+ plt.savefig(plot_filename)
106
+
107
+ return df, plot_filename
108
+
109
 
110
  demo = gr.Interface(
111
  fn=main_func,
112
  inputs=[
113
+ gr.Radio(['random', 'equispaced', "equispaced_fraction", "magic", "magic_fraction"], label="Mask Type", value="equispaced"),
114
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Center Fraction"),
115
+ gr.Number(value=4, label="Acceleration"),
116
+ gr.Number(value=42, label="Seed"),
117
+ gr.Number(value=15, label="Slice Index"),
118
+ # gr.Radio(["knee singlecoil", "knee multicoil", "brain multicoil 1", "brain multicoil 2", "prostate multicoil 1", "prostate multicoil 2"], label="Input Image")
119
  ],
120
  outputs=[
121
+ gr.Dataframe(headers=["SSIM", "PSNR", "Num Low Frequencies"]),
122
+ gr.Image(type="filepath", label="Plot"),
123
+ # gr.Image(type="numpy", image_mode="L", label="Mask",),
124
+ # gr.Image(type="numpy", image_mode="L", label="Reconstructed Image", height=320, width=320),
125
+ # gr.Image(type="numpy", image_mode="L", label="Original Image", height=320, width=320),
 
126
  ],
127
  title="FastMRI Kspace Reconstruction Masks",
128
  description="This app allows you to visualize the masks and their effects on the kspace data."