osbm commited on
Commit
88176b6
1 Parent(s): c7f853f

how about this

Browse files
Files changed (2) hide show
  1. app.py +0 -39
  2. inference.py +30 -8
app.py DELETED
@@ -1,39 +0,0 @@
1
- import streamlit as st
2
- import huggingface_hub as hf_hub
3
- import monai
4
- import os
5
- import zipfile
6
- import torch
7
-
8
- hf_hub.login(token=st.secrets["HF_TOKEN"])
9
-
10
- with st.spinner("Downloading Dataset"):
11
- data_path = hf_hub.hf_hub_download(repo_id="osbm/prostate158", filename="data.zip", repo_type="dataset")
12
-
13
- st.write(data_path)
14
- with st.spinner("Unzipping..."):
15
- with zipfile.ZipFile(data_path, 'r') as zip_ref:
16
- zip_ref.extractall(".")
17
-
18
-
19
- # st.write(os.listdir(os.getcwd()))
20
- # st.write(os.getcwd())
21
-
22
-
23
- model = monai.networks.nets.UNet(
24
- in_channels=1,
25
- out_channels=3,
26
- spatial_dims=3,
27
- channels=[16, 32, 64, 128, 256, 512],
28
- strides=[2, 2, 2, 2, 2],
29
- num_res_units=4,
30
- act="PRELU",
31
- norm="BATCH",
32
- dropout=0.15,
33
- )
34
-
35
- # load this model using anatomy.pt
36
- model.load_state_dict(torch.load('anatomy.pt', map_location=torch.device('cpu')))
37
-
38
- print(model)
39
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference.py CHANGED
@@ -43,7 +43,32 @@ from monai.transforms import (
43
 
44
  # model.load_state_dict(torch.load("anatomy.pt", map_location=device))
45
 
46
- keys = ("t2", "t2_anatomy_reader1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  transforms = Compose(
48
  [
49
  LoadImaged(keys=keys, image_only=False),
@@ -54,24 +79,21 @@ transforms = Compose(
54
  NormalizeIntensityd(keys=keys),
55
  EnsureTyped(keys=keys),
56
  ConcatItemsd(keys=("t2"), name=CommonKeys.IMAGE, dim=0),
57
- ConcatItemsd(keys=("t2_anatomy_reader1"), name=CommonKeys.LABEL, dim=0),
58
  ],
59
- allow_missing_keys=True,
60
  )
61
 
62
-
63
-
64
-
65
  postprocessing = Compose(
66
  [
67
- EnsureTyped(keys=[CommonKeys.PRED, CommonKeys.LABEL]),
68
  KeepLargestConnectedComponentd(
69
  keys=CommonKeys.PRED,
70
  applied_labels=list(range(1, 3))
71
  ),
72
  ],
73
- allow_missing_keys=True,
74
  )
 
 
 
75
  inferer = monai.inferers.SlidingWindowInferer(
76
  roi_size=(96, 96, 96),
77
  sw_batch_size=4,
 
43
 
44
  # model.load_state_dict(torch.load("anatomy.pt", map_location=device))
45
 
46
+ # keys = ("t2", "t2_anatomy_reader1")
47
+ # transforms = Compose(
48
+ # [
49
+ # LoadImaged(keys=keys, image_only=False),
50
+ # EnsureChannelFirstd(keys=keys),
51
+ # Spacingd(keys=keys, pixdim=[0.5, 0.5, 0.5], mode=("bilinear", "nearest")),
52
+ # Orientationd(keys=keys, axcodes="RAS"),
53
+ # ScaleIntensityd(keys=keys, minv=0, maxv=1),
54
+ # NormalizeIntensityd(keys=keys),
55
+ # EnsureTyped(keys=keys),
56
+ # ConcatItemsd(keys=("t2"), name=CommonKeys.IMAGE, dim=0),
57
+ # ConcatItemsd(keys=("t2_anatomy_reader1"), name=CommonKeys.LABEL, dim=0),
58
+ # ],
59
+ # )
60
+
61
+ # postprocessing = Compose(
62
+ # [
63
+ # EnsureTyped(keys=[CommonKeys.PRED, CommonKeys.LABEL]),
64
+ # KeepLargestConnectedComponentd(
65
+ # keys=CommonKeys.PRED,
66
+ # applied_labels=list(range(1, 3))
67
+ # ),
68
+ # ],
69
+ # )
70
+
71
+ keys = ("t2")
72
  transforms = Compose(
73
  [
74
  LoadImaged(keys=keys, image_only=False),
 
79
  NormalizeIntensityd(keys=keys),
80
  EnsureTyped(keys=keys),
81
  ConcatItemsd(keys=("t2"), name=CommonKeys.IMAGE, dim=0),
 
82
  ],
 
83
  )
84
 
 
 
 
85
  postprocessing = Compose(
86
  [
87
+ EnsureTyped(keys=[CommonKeys.PRED]),
88
  KeepLargestConnectedComponentd(
89
  keys=CommonKeys.PRED,
90
  applied_labels=list(range(1, 3))
91
  ),
92
  ],
 
93
  )
94
+
95
+
96
+
97
  inferer = monai.inferers.SlidingWindowInferer(
98
  roi_size=(96, 96, 96),
99
  sw_batch_size=4,