|
import numpy as np |
|
import triton_python_backend_utils as pb_utils |
|
from omnicloudmask import predict_from_array |
|
import rasterio |
|
from rasterio.io import MemoryFile |
|
from rasterio.enums import Resampling |
|
|
|
class TritonPythonModel: |
|
def initialize(self, args): |
|
""" |
|
Initialize the model. This function is called once when the model is loaded. |
|
""" |
|
|
|
|
|
print('Initialized Cloud Detection model with JP2 input') |
|
|
|
def execute(self, requests): |
|
""" |
|
Process inference requests. |
|
""" |
|
responses = [] |
|
|
|
for request in requests: |
|
|
|
input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes") |
|
|
|
jp2_bytes_list = input_tensor.as_numpy() |
|
|
|
if len(jp2_bytes_list) != 3: |
|
|
|
error = pb_utils.TritonError(f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}") |
|
response = pb_utils.InferenceResponse(output_tensors=[], error=error) |
|
responses.append(response) |
|
continue |
|
|
|
|
|
red_bytes = jp2_bytes_list[0] |
|
green_bytes = jp2_bytes_list[1] |
|
nir_bytes = jp2_bytes_list[2] |
|
|
|
try: |
|
|
|
with MemoryFile(red_bytes) as memfile_red: |
|
with memfile_red.open() as src_red: |
|
red_data = src_red.read(1).astype(np.float32) |
|
target_height = src_red.height |
|
target_width = src_red.width |
|
|
|
with MemoryFile(green_bytes) as memfile_green: |
|
with memfile_green.open() as src_green: |
|
|
|
if src_green.height != target_height or src_green.width != target_width: |
|
|
|
green_data = src_green.read( |
|
1, |
|
out_shape=(1, target_height, target_width), |
|
resampling=Resampling.bilinear |
|
).astype(np.float32) |
|
else: |
|
green_data = src_green.read(1).astype(np.float32) |
|
|
|
|
|
with MemoryFile(nir_bytes) as memfile_nir: |
|
with memfile_nir.open() as src_nir: |
|
|
|
nir_data = src_nir.read( |
|
1, |
|
out_shape=(1, target_height, target_width), |
|
resampling=Resampling.bilinear |
|
).astype(np.float32) |
|
|
|
|
|
|
|
input_array = np.stack([red_data, green_data, nir_data], axis=0) |
|
|
|
|
|
pred_mask = predict_from_array(input_array) |
|
|
|
|
|
output_tensor = pb_utils.Tensor( |
|
"output_mask", |
|
pred_mask.astype(np.uint8) |
|
) |
|
response = pb_utils.InferenceResponse([output_tensor]) |
|
|
|
except Exception as e: |
|
|
|
error = pb_utils.TritonError(f"Error processing JP2 data: {str(e)}") |
|
response = pb_utils.InferenceResponse(output_tensors=[], error=error) |
|
|
|
responses.append(response) |
|
|
|
|
|
return responses |
|
|
|
def finalize(self): |
|
""" |
|
Called when the model is unloaded. Perform any necessary cleanup. |
|
""" |
|
print('Finalizing Cloud Detection model') |
|
|
|
|