truthdotphd commited on
Commit
d03e9df
·
verified ·
1 Parent(s): 024f5b3

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.pbtxt +12 -19
  2. model.py +198 -73
  3. requirements.txt +2 -0
config.pbtxt CHANGED
@@ -1,25 +1,18 @@
1
  backend: "python"
2
- max_batch_size: 0 # Keep batching disabled as per original config
3
 
4
  input [
5
- {
6
- name: "input_jp2_bytes" # New input name for JP2 bytes
7
- data_type: TYPE_STRING # Use TYPE_STRING for bytes
8
- dims: [ 3 ] # Expecting 3 elements: Red, Green, NIR bytes
9
- }
10
  ]
11
 
12
  output [
13
- {
14
- name: "output_mask"
15
- data_type: TYPE_UINT8
16
- dims: [-1, -1] # Variable height, width
17
- }
18
- ]
19
-
20
- # Optional: Specify instance_group if running on GPU
21
- # instance_group [
22
- # {
23
- # kind: KIND_GPU
24
- # }
25
- # ]
 
1
  backend: "python"
2
+ max_batch_size: 0
3
 
4
  input [
5
+ {
6
+ name: "input_jp2_bytes"
7
+ data_type: TYPE_STRING
8
+ dims: [ 3 ]
9
+ }
10
  ]
11
 
12
  output [
13
+ {
14
+ name: "output_mask"
15
+ data_type: TYPE_UINT8
16
+ dims: [-1, -1]
17
+ }
18
+ ]
 
 
 
 
 
 
 
model.py CHANGED
@@ -4,98 +4,223 @@ from omnicloudmask import predict_from_array
4
  import rasterio
5
  from rasterio.io import MemoryFile
6
  from rasterio.enums import Resampling
 
 
 
7
 
8
  class TritonPythonModel:
9
  def initialize(self, args):
10
  """
11
  Initialize the model. This function is called once when the model is loaded.
12
  """
13
- # You can load models or initialize resources here if needed.
14
- # Ensure rasterio is installed in the Python backend environment.
15
- print('Initialized Cloud Detection model with JP2 input')
16
 
17
- def execute(self, requests):
18
  """
19
- Process inference requests.
20
  """
21
- responses = []
22
- # Every request must contain three JP2 byte strings (Red, Green, NIR).
23
- for request in requests:
24
- # Get the input tensor containing the byte arrays
25
- input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes")
26
- # as_numpy() for TYPE_STRING gives an ndarray of Python bytes objects
27
- jp2_bytes_list = input_tensor.as_numpy()
28
-
29
- if len(jp2_bytes_list) != 3:
30
- # Send an error response if the input shape is incorrect
31
- error = pb_utils.TritonError(f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}")
32
- response = pb_utils.InferenceResponse(output_tensors=[], error=error)
33
- responses.append(response)
34
- continue # Skip to the next request
35
-
36
- # Assume order: Red, Green, NIR based on client logic
37
- red_bytes = jp2_bytes_list[0]
38
- green_bytes = jp2_bytes_list[1]
39
- nir_bytes = jp2_bytes_list[2]
40
-
41
  try:
42
- # Process JP2 bytes using rasterio in memory
43
- with MemoryFile(red_bytes) as memfile_red:
44
- with memfile_red.open() as src_red:
45
- red_data = src_red.read(1).astype(np.float32)
46
- target_height = src_red.height
47
- target_width = src_red.width
48
-
49
- with MemoryFile(green_bytes) as memfile_green:
50
- with memfile_green.open() as src_green:
51
- # Ensure green band matches red band dimensions (should if B03)
52
- if src_green.height != target_height or src_green.width != target_width:
53
- # Optional: Resample green if necessary, though B03 usually matches B04
54
- green_data = src_green.read(
55
- 1,
56
- out_shape=(1, target_height, target_width),
57
- resampling=Resampling.bilinear
58
- ).astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  else:
60
- green_data = src_green.read(1).astype(np.float32)
61
-
62
-
63
- with MemoryFile(nir_bytes) as memfile_nir:
64
- with memfile_nir.open() as src_nir:
65
- # Resample NIR (B8A) to match Red/Green (B04/B03) resolution
66
- nir_data = src_nir.read(
67
- 1, # Read the first band
68
- out_shape=(1, target_height, target_width),
69
- resampling=Resampling.bilinear
70
- ).astype(np.float32)
71
-
72
- # Stack bands in CHW format (Red, Green, NIR) for the model
73
- # Match the channel order expected by predict_from_array
74
- input_array = np.stack([red_data, green_data, nir_data], axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- # Perform inference using the original function
77
- pred_mask = predict_from_array(input_array)
 
 
 
78
 
79
- # Create output tensor
80
- output_tensor = pb_utils.Tensor(
81
- "output_mask",
82
- pred_mask.astype(np.uint8)
83
- )
84
- response = pb_utils.InferenceResponse([output_tensor])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  except Exception as e:
87
- # Handle errors during processing (e.g., invalid JP2 data)
88
- error = pb_utils.TritonError(f"Error processing JP2 data: {str(e)}")
 
 
89
  response = pb_utils.InferenceResponse(output_tensors=[], error=error)
 
90
 
91
- responses.append(response)
92
-
93
- # Return a list of responses
94
  return responses
95
 
96
  def finalize(self):
97
  """
98
- Called when the model is unloaded. Perform any necessary cleanup.
99
  """
100
- print('Finalizing Cloud Detection model')
101
-
 
4
  import rasterio
5
  from rasterio.io import MemoryFile
6
  from rasterio.enums import Resampling
7
+ import tempfile
8
+ import os
9
+ from io import BytesIO
10
 
11
  class TritonPythonModel:
12
  def initialize(self, args):
13
  """
14
  Initialize the model. This function is called once when the model is loaded.
15
  """
16
+ print('Initialized Cloud Detection model with JP2 input and robust GDAL handling')
 
 
17
 
18
+ def safe_read_jp2_bytes(self, jp2_bytes):
19
  """
20
+ Safely read JP2 bytes with multiple fallback methods
21
  """
22
+ try:
23
+ # Method 1: Try direct MemoryFile approach (works if GDAL drivers are properly configured)
24
+ with MemoryFile(jp2_bytes) as memfile:
25
+ with memfile.open() as src:
26
+ data = src.read(1).astype(np.float32)
27
+ height, width = src.height, src.width
28
+ profile = src.profile
29
+ return data, height, width, profile
30
+
31
+ except Exception as e1:
32
+ print(f"Method 1 (MemoryFile) failed: {e1}")
 
 
 
 
 
 
 
 
 
33
  try:
34
+ # Method 2: Write to temporary file and read from disk
35
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.jp2') as tmp_file:
36
+ tmp_file.write(jp2_bytes)
37
+ tmp_file.flush()
38
+
39
+ with rasterio.open(tmp_file.name) as src:
40
+ data = src.read(1).astype(np.float32)
41
+ height, width = src.height, src.width
42
+ profile = src.profile
43
+
44
+ # Clean up temporary file
45
+ os.unlink(tmp_file.name)
46
+ return data, height, width, profile
47
+
48
+ except Exception as e2:
49
+ print(f"Method 2 (temporary file) failed: {e2}")
50
+ try:
51
+ # Method 3: Try with different suffix and basic profile
52
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.tiff') as tmp_file:
53
+ tmp_file.write(jp2_bytes)
54
+ tmp_file.flush()
55
+
56
+ with rasterio.open(tmp_file.name) as src:
57
+ data = src.read(1).astype(np.float32)
58
+ height, width = src.height, src.width
59
+ profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'}
60
+
61
+ os.unlink(tmp_file.name)
62
+ return data, height, width, profile
63
+
64
+ except Exception as e3:
65
+ print(f"Method 3 (tiff fallback) failed: {e3}")
66
+ # Method 4: Final fallback - try to interpret as raw numpy array
67
+ try:
68
+ # This assumes the client is sending raw numpy bytes as fallback
69
+ data_array = np.frombuffer(jp2_bytes, dtype=np.float32)
70
+
71
+ # Try to guess square dimensions
72
+ side_length = int(np.sqrt(len(data_array)))
73
+ if side_length * side_length == len(data_array):
74
+ data = data_array.reshape(side_length, side_length)
75
+ height, width = side_length, side_length
76
+ profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'}
77
+ return data, height, width, profile
78
  else:
79
+ # Try common satellite image dimensions
80
+ common_dims = [(10980, 10980), (5490, 5490), (1024, 1024), (512, 512)]
81
+ for h, w in common_dims:
82
+ if h * w == len(data_array):
83
+ data = data_array.reshape(h, w)
84
+ height, width = h, w
85
+ profile = {'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'float32'}
86
+ return data, height, width, profile
87
+
88
+ raise ValueError(f"Cannot interpret data array of length {len(data_array)} as image")
89
+
90
+ except Exception as e4:
91
+ raise Exception(f"All fallback methods failed: MemoryFile({e1}), TempFile({e2}), TiffFallback({e3}), RawBytes({e4})")
92
+
93
+ def safe_resample_data(self, data, current_height, current_width, target_height, target_width, profile):
94
+ """
95
+ Safely resample data to target dimensions with fallback methods
96
+ """
97
+ if current_height == target_height and current_width == target_width:
98
+ return data
99
+
100
+ try:
101
+ # Method 1: Use rasterio resampling
102
+ temp_profile = profile.copy()
103
+ temp_profile.update({
104
+ 'height': current_height,
105
+ 'width': current_width,
106
+ 'count': 1,
107
+ 'dtype': 'float32'
108
+ })
109
+
110
+ with MemoryFile() as memfile:
111
+ with memfile.open(**temp_profile) as temp_dataset:
112
+ temp_dataset.write(data, 1)
113
+
114
+ resampled = temp_dataset.read(
115
+ out_shape=(1, target_height, target_width),
116
+ resampling=Resampling.bilinear
117
+ )[0].astype(np.float32)
118
+
119
+ return resampled
120
+
121
+ except Exception as e1:
122
+ print(f"Rasterio resampling failed: {e1}")
123
+ try:
124
+ # Method 2: Use scipy if available
125
+ from scipy import ndimage
126
+ zoom_factors = (target_height / current_height, target_width / current_width)
127
+ resampled = ndimage.zoom(data, zoom_factors, order=1)
128
+ return resampled.astype(np.float32)
129
+
130
+ except ImportError:
131
+ print("Scipy not available for resampling")
132
+ # Method 3: Simple nearest-neighbor resampling
133
+ h_indices = np.round(np.linspace(0, current_height - 1, target_height)).astype(int)
134
+ w_indices = np.round(np.linspace(0, current_width - 1, target_width)).astype(int)
135
+
136
+ resampled = data[np.ix_(h_indices, w_indices)]
137
+ return resampled.astype(np.float32)
138
+
139
+ except Exception as e2:
140
+ print(f"Scipy resampling failed: {e2}")
141
+ # Method 3: Simple nearest-neighbor resampling
142
+ h_indices = np.round(np.linspace(0, current_height - 1, target_height)).astype(int)
143
+ w_indices = np.round(np.linspace(0, current_width - 1, target_width)).astype(int)
144
+
145
+ resampled = data[np.ix_(h_indices, w_indices)]
146
+ return resampled.astype(np.float32)
147
 
148
+ def execute(self, requests):
149
+ """
150
+ Process inference requests with robust error handling.
151
+ """
152
+ responses = []
153
 
154
+ for request in requests:
155
+ try:
156
+ input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes")
157
+ jp2_bytes_list = input_tensor.as_numpy()
158
+
159
+ if len(jp2_bytes_list) != 3:
160
+ error_msg = f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}"
161
+ error = pb_utils.TritonError(error_msg)
162
+ response = pb_utils.InferenceResponse(output_tensors=[], error=error)
163
+ responses.append(response)
164
+ continue
165
+
166
+ red_bytes = jp2_bytes_list[0]
167
+ green_bytes = jp2_bytes_list[1]
168
+ nir_bytes = jp2_bytes_list[2]
169
+
170
+ print(f"Processing JP2 data - sizes: Red={len(red_bytes)}, Green={len(green_bytes)}, NIR={len(nir_bytes)}")
171
+
172
+ # Read red band data (use as reference for dimensions)
173
+ red_data, target_height, target_width, red_profile = self.safe_read_jp2_bytes(red_bytes)
174
+ print(f"Red band: {red_data.shape}, target dimensions: {target_height}x{target_width}")
175
+
176
+ # Read and resample green band
177
+ green_data, green_height, green_width, green_profile = self.safe_read_jp2_bytes(green_bytes)
178
+ green_data = self.safe_resample_data(green_data, green_height, green_width, target_height, target_width, green_profile)
179
+ print(f"Green band after resampling: {green_data.shape}")
180
+
181
+ # Read and resample NIR band
182
+ nir_data, nir_height, nir_width, nir_profile = self.safe_read_jp2_bytes(nir_bytes)
183
+ nir_data = self.safe_resample_data(nir_data, nir_height, nir_width, target_height, target_width, nir_profile)
184
+ print(f"NIR band after resampling: {nir_data.shape}")
185
+
186
+ # Verify all bands have the same shape
187
+ if not (red_data.shape == green_data.shape == nir_data.shape):
188
+ shapes = [red_data.shape, green_data.shape, nir_data.shape]
189
+ error_msg = f"Band shape mismatch after resampling: {shapes}"
190
+ error = pb_utils.TritonError(error_msg)
191
+ response = pb_utils.InferenceResponse(output_tensors=[], error=error)
192
+ responses.append(response)
193
+ continue
194
+
195
+ # Stack bands in CHW format for prediction (channels, height, width)
196
+ prediction_array = np.stack([red_data, green_data, nir_data], axis=0)
197
+ print(f"Final prediction array shape: {prediction_array.shape}")
198
+
199
+ # Run cloud detection prediction
200
+ cloud_mask = predict_from_array(prediction_array)
201
+ print(f"Cloud mask shape: {cloud_mask.shape}")
202
+
203
+ # Flatten the mask for output
204
+ if cloud_mask.ndim > 1:
205
+ cloud_mask = cloud_mask.flatten()
206
+
207
+ # Create output tensor (config expects TYPE_UINT8)
208
+ output_tensor = pb_utils.Tensor("output_mask", cloud_mask.astype(np.uint8))
209
+ response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
210
+ responses.append(response)
211
 
212
  except Exception as e:
213
+ # Enhanced error reporting
214
+ error_msg = f"Error processing JP2 data: {str(e)}"
215
+ print(f"Model execution error: {error_msg}")
216
+ error = pb_utils.TritonError(error_msg)
217
  response = pb_utils.InferenceResponse(output_tensors=[], error=error)
218
+ responses.append(response)
219
 
 
 
 
220
  return responses
221
 
222
  def finalize(self):
223
  """
224
+ Clean up when the model is unloaded.
225
  """
226
+ print('Cloud Detection model finalized')
 
requirements.txt CHANGED
@@ -5,3 +5,5 @@ timm>=0.9
5
  tqdm>=4.0
6
  gdown>=5.1.0
7
  torch>=2.2
 
 
 
5
  tqdm>=4.0
6
  gdown>=5.1.0
7
  torch>=2.2
8
+ scipy>=1.9.0
9
+ numpy>=1.21.0