cloud-detection / model.py
erfaneshrati's picture
jp2 as input type
024f5b3
raw
history blame contribute delete
4.52 kB
import numpy as np
import triton_python_backend_utils as pb_utils
from omnicloudmask import predict_from_array
import rasterio
from rasterio.io import MemoryFile
from rasterio.enums import Resampling
class TritonPythonModel:
def initialize(self, args):
"""
Initialize the model. This function is called once when the model is loaded.
"""
# You can load models or initialize resources here if needed.
# Ensure rasterio is installed in the Python backend environment.
print('Initialized Cloud Detection model with JP2 input')
def execute(self, requests):
"""
Process inference requests.
"""
responses = []
# Every request must contain three JP2 byte strings (Red, Green, NIR).
for request in requests:
# Get the input tensor containing the byte arrays
input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes")
# as_numpy() for TYPE_STRING gives an ndarray of Python bytes objects
jp2_bytes_list = input_tensor.as_numpy()
if len(jp2_bytes_list) != 3:
# Send an error response if the input shape is incorrect
error = pb_utils.TritonError(f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}")
response = pb_utils.InferenceResponse(output_tensors=[], error=error)
responses.append(response)
continue # Skip to the next request
# Assume order: Red, Green, NIR based on client logic
red_bytes = jp2_bytes_list[0]
green_bytes = jp2_bytes_list[1]
nir_bytes = jp2_bytes_list[2]
try:
# Process JP2 bytes using rasterio in memory
with MemoryFile(red_bytes) as memfile_red:
with memfile_red.open() as src_red:
red_data = src_red.read(1).astype(np.float32)
target_height = src_red.height
target_width = src_red.width
with MemoryFile(green_bytes) as memfile_green:
with memfile_green.open() as src_green:
# Ensure green band matches red band dimensions (should if B03)
if src_green.height != target_height or src_green.width != target_width:
# Optional: Resample green if necessary, though B03 usually matches B04
green_data = src_green.read(
1,
out_shape=(1, target_height, target_width),
resampling=Resampling.bilinear
).astype(np.float32)
else:
green_data = src_green.read(1).astype(np.float32)
with MemoryFile(nir_bytes) as memfile_nir:
with memfile_nir.open() as src_nir:
# Resample NIR (B8A) to match Red/Green (B04/B03) resolution
nir_data = src_nir.read(
1, # Read the first band
out_shape=(1, target_height, target_width),
resampling=Resampling.bilinear
).astype(np.float32)
# Stack bands in CHW format (Red, Green, NIR) for the model
# Match the channel order expected by predict_from_array
input_array = np.stack([red_data, green_data, nir_data], axis=0)
# Perform inference using the original function
pred_mask = predict_from_array(input_array)
# Create output tensor
output_tensor = pb_utils.Tensor(
"output_mask",
pred_mask.astype(np.uint8)
)
response = pb_utils.InferenceResponse([output_tensor])
except Exception as e:
# Handle errors during processing (e.g., invalid JP2 data)
error = pb_utils.TritonError(f"Error processing JP2 data: {str(e)}")
response = pb_utils.InferenceResponse(output_tensors=[], error=error)
responses.append(response)
# Return a list of responses
return responses
def finalize(self):
"""
Called when the model is unloaded. Perform any necessary cleanup.
"""
print('Finalizing Cloud Detection model')