Spaces:
Runtime error
Runtime error
update app
Browse files
app.py
CHANGED
@@ -2,9 +2,15 @@ import gradio as gr
|
|
2 |
|
3 |
from fastmri.data.subsample import create_mask_for_mask_type
|
4 |
from fastmri.data.transforms import apply_mask, to_tensor, center_crop
|
5 |
-
|
|
|
6 |
|
|
|
|
|
|
|
7 |
|
|
|
|
|
8 |
|
9 |
# st.title('FastMRI Kspace Reconstruction Masks')
|
10 |
# st.write('This app allows you to visualize the masks and their effects on the kspace data.')
|
@@ -15,41 +21,108 @@ def main_func(
|
|
15 |
mask_center_fractions: int,
|
16 |
accelerations: int,
|
17 |
seed: int,
|
18 |
-
|
|
|
19 |
):
|
20 |
|
21 |
-
file_dict = {
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
29 |
|
|
|
|
|
|
|
30 |
mask_func = create_mask_for_mask_type(
|
31 |
mask_name, center_fractions=[mask_center_fractions], accelerations=[accelerations]
|
32 |
)
|
33 |
-
mask =
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
demo = gr.Interface(
|
38 |
fn=main_func,
|
39 |
inputs=[
|
40 |
-
gr.
|
41 |
-
gr.
|
42 |
-
gr.
|
43 |
-
gr.
|
44 |
-
gr.
|
|
|
45 |
],
|
46 |
outputs=[
|
47 |
-
gr.
|
48 |
-
gr.
|
49 |
-
gr.
|
50 |
-
gr.
|
51 |
-
|
52 |
-
gr.outputs.Dataframe()
|
53 |
],
|
54 |
title="FastMRI Kspace Reconstruction Masks",
|
55 |
description="This app allows you to visualize the masks and their effects on the kspace data."
|
|
|
2 |
|
3 |
from fastmri.data.subsample import create_mask_for_mask_type
|
4 |
from fastmri.data.transforms import apply_mask, to_tensor, center_crop
|
5 |
+
import skimage
|
6 |
+
import fastmri
|
7 |
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import torch
|
11 |
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import uuid
|
14 |
|
15 |
# st.title('FastMRI Kspace Reconstruction Masks')
|
16 |
# st.write('This app allows you to visualize the masks and their effects on the kspace data.')
|
|
|
21 |
mask_center_fractions: int,
|
22 |
accelerations: int,
|
23 |
seed: int,
|
24 |
+
slice_index: int,
|
25 |
+
# input_image: str,
|
26 |
):
|
27 |
|
28 |
+
# file_dict = {
|
29 |
+
# "knee singlecoil": "data/knee1_kspace.npy",
|
30 |
+
# "knee multicoil": "data/knee2_kspace.npy",
|
31 |
+
# "brain multicoil 1": "data/brain1_kspace.npy",
|
32 |
+
# "brain multicoil 2": "data/brain2_kspace.npy",
|
33 |
+
# "prostate multicoil 1": "data/prostate1_kspace.npy",
|
34 |
+
# "prostate multicoil 2": "data/prostate2_kspace.npy",
|
35 |
+
# }
|
36 |
+
# input_file_path = file_dict[input_image]
|
37 |
|
38 |
+
# kspace = np.load(input_file_path)
|
39 |
+
kspace = np.load("data/prostate1_kspace.npy")
|
40 |
+
kspace = to_tensor(kspace)
|
41 |
mask_func = create_mask_for_mask_type(
|
42 |
mask_name, center_fractions=[mask_center_fractions], accelerations=[accelerations]
|
43 |
)
|
44 |
+
subsampled_kspace, mask, num_low_frequencies = apply_mask(
|
45 |
+
kspace,
|
46 |
+
mask_func,
|
47 |
+
seed=seed,
|
48 |
+
)
|
49 |
+
|
50 |
+
print(mask.shape)
|
51 |
+
print(subsampled_kspace.shape)
|
52 |
+
print(kspace.shape)
|
53 |
+
|
54 |
+
mask = mask.squeeze() # 451
|
55 |
+
mask = mask.unsqueeze(0) # 1, 451
|
56 |
+
mask = mask.repeat(subsampled_kspace.shape[-3], 1).cpu().numpy()
|
57 |
+
|
58 |
+
print(mask.shape)
|
59 |
+
print()
|
60 |
+
|
61 |
+
subsampled_kspace = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(subsampled_kspace)), dim=1)
|
62 |
+
kspace = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace)), dim=1)
|
63 |
+
|
64 |
+
print(subsampled_kspace.shape)
|
65 |
+
print(kspace.shape)
|
66 |
+
|
67 |
+
subsampled_kspace = subsampled_kspace[slice_index]
|
68 |
+
kspace = kspace[slice_index]
|
69 |
+
|
70 |
+
print(subsampled_kspace.shape)
|
71 |
+
print(kspace.shape)
|
72 |
+
|
73 |
+
|
74 |
+
subsampled_kspace = center_crop(subsampled_kspace, (320, 320))
|
75 |
+
kspace = center_crop(kspace, (320, 320))
|
76 |
+
|
77 |
+
# now that we have the reconstructions, we can calculate the SSIM and psnr
|
78 |
+
kspace = kspace.cpu().numpy()
|
79 |
+
subsampled_kspace = subsampled_kspace.cpu().numpy()
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
ssim = skimage.metrics.structural_similarity(subsampled_kspace, kspace, data_range=kspace.max() - kspace.min())
|
84 |
+
psnr = skimage.metrics.peak_signal_noise_ratio(subsampled_kspace, kspace, data_range=kspace.max() - kspace.min())
|
85 |
+
|
86 |
+
df = pd.DataFrame({"SSIM": [ssim], "PSNR": [psnr], "Num Low Frequencies": [num_low_frequencies]})
|
87 |
+
print(df)
|
88 |
+
|
89 |
+
# create a plot
|
90 |
+
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
|
91 |
+
ax[0].imshow(mask, cmap="gray")
|
92 |
+
ax[0].set_title("Mask")
|
93 |
+
ax[0].axis("off")
|
94 |
+
|
95 |
+
ax[1].imshow(subsampled_kspace, cmap="gray")
|
96 |
+
ax[1].set_title("Reconstructed Image")
|
97 |
+
ax[1].axis("off")
|
98 |
+
|
99 |
+
ax[2].imshow(kspace, cmap="gray")
|
100 |
+
ax[2].set_title("Original Image")
|
101 |
+
ax[2].axis("off")
|
102 |
+
|
103 |
+
plt.tight_layout()
|
104 |
+
plot_filename = f"data/{uuid.uuid4()}.png"
|
105 |
+
plt.savefig(plot_filename)
|
106 |
+
|
107 |
+
return df, plot_filename
|
108 |
+
|
109 |
|
110 |
demo = gr.Interface(
|
111 |
fn=main_func,
|
112 |
inputs=[
|
113 |
+
gr.Radio(['random', 'equispaced', "equispaced_fraction", "magic", "magic_fraction"], label="Mask Type", value="equispaced"),
|
114 |
+
gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Center Fraction"),
|
115 |
+
gr.Number(value=4, label="Acceleration"),
|
116 |
+
gr.Number(value=42, label="Seed"),
|
117 |
+
gr.Number(value=15, label="Slice Index"),
|
118 |
+
# gr.Radio(["knee singlecoil", "knee multicoil", "brain multicoil 1", "brain multicoil 2", "prostate multicoil 1", "prostate multicoil 2"], label="Input Image")
|
119 |
],
|
120 |
outputs=[
|
121 |
+
gr.Dataframe(headers=["SSIM", "PSNR", "Num Low Frequencies"]),
|
122 |
+
gr.Image(type="filepath", label="Plot"),
|
123 |
+
# gr.Image(type="numpy", image_mode="L", label="Mask",),
|
124 |
+
# gr.Image(type="numpy", image_mode="L", label="Reconstructed Image", height=320, width=320),
|
125 |
+
# gr.Image(type="numpy", image_mode="L", label="Original Image", height=320, width=320),
|
|
|
126 |
],
|
127 |
title="FastMRI Kspace Reconstruction Masks",
|
128 |
description="This app allows you to visualize the masks and their effects on the kspace data."
|