BinLiunls commited on
Commit
3601075
·
1 Parent(s): 1c8621b

add load_image param

Browse files

Signed-off-by: binliu <binliu@nvidia.com>

Files changed (1) hide show
  1. vista3d_pipeline.py +7 -3
vista3d_pipeline.py CHANGED
@@ -52,6 +52,7 @@ class VISTA3DPipeline(Pipeline):
52
  "image_key",
53
  "resample_spacing",
54
  "metadata_path",
 
55
  ]
56
  INFERENCE_EXTRA_ARGS = [
57
  "mode",
@@ -105,6 +106,7 @@ class VISTA3DPipeline(Pipeline):
105
  image_key: str = "image",
106
  resample_spacing: Sequence = (1.5, 1.5, 1.5),
107
  metadata_path: str = os.path.join(FILE_PATH, "metadata.json"),
 
108
  ):
109
  device = self.device
110
  subclass = {
@@ -114,8 +116,7 @@ class VISTA3DPipeline(Pipeline):
114
  }
115
  metadata = json.loads(pathlib.Path(metadata_path).read_text())
116
  labels_dict = metadata["network_data_format"]["outputs"]["pred"]["channel_def"]
117
- preprocessing_transforms = Compose(
118
- [
119
  LoadImaged(keys=image_key, image_only=True),
120
  EnsureChannelFirstd(keys=image_key),
121
  EnsureTyped(keys=image_key, device=device, track_meta=True),
@@ -137,7 +138,10 @@ class VISTA3DPipeline(Pipeline):
137
  Orientationd(keys=image_key, axcodes="RAS"),
138
  CastToTyped(keys=image_key, dtype=torch.float32),
139
  ]
140
- )
 
 
 
141
  return preprocessing_transforms
142
 
143
  def _init_postprocessing_transforms(
 
52
  "image_key",
53
  "resample_spacing",
54
  "metadata_path",
55
+ "load_image",
56
  ]
57
  INFERENCE_EXTRA_ARGS = [
58
  "mode",
 
106
  image_key: str = "image",
107
  resample_spacing: Sequence = (1.5, 1.5, 1.5),
108
  metadata_path: str = os.path.join(FILE_PATH, "metadata.json"),
109
+ load_image: bool = True,
110
  ):
111
  device = self.device
112
  subclass = {
 
116
  }
117
  metadata = json.loads(pathlib.Path(metadata_path).read_text())
118
  labels_dict = metadata["network_data_format"]["outputs"]["pred"]["channel_def"]
119
+ preprocessing_list = [
 
120
  LoadImaged(keys=image_key, image_only=True),
121
  EnsureChannelFirstd(keys=image_key),
122
  EnsureTyped(keys=image_key, device=device, track_meta=True),
 
138
  Orientationd(keys=image_key, axcodes="RAS"),
139
  CastToTyped(keys=image_key, dtype=torch.float32),
140
  ]
141
+ if not load_image:
142
+ preprocessing_list.pop(0)
143
+
144
+ preprocessing_transforms = Compose(preprocessing_list)
145
  return preprocessing_transforms
146
 
147
  def _init_postprocessing_transforms(