File size: 4,522 Bytes
36d88d2 024f5b3 36d88d2 024f5b3 0d9b472 024f5b3 d03e9df 0d9b472 d03e9df 0d9b472 d03e9df 0d9b472 d03e9df 0d9b472 d03e9df 0d9b472 d03e9df e47d31c 0d9b472 d03e9df 0d9b472 d03e9df 0d9b472 d03e9df 0d9b472 024f5b3 0d9b472 024f5b3 0d9b472 36d88d2 024f5b3 0d9b472 024f5b3 0d9b472 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import numpy as np
import triton_python_backend_utils as pb_utils
from omnicloudmask import predict_from_array
import rasterio
from rasterio.io import MemoryFile
from rasterio.enums import Resampling
class TritonPythonModel:
def initialize(self, args):
"""
Initialize the model. This function is called once when the model is loaded.
"""
# You can load models or initialize resources here if needed.
# Ensure rasterio is installed in the Python backend environment.
print('Initialized Cloud Detection model with JP2 input')
def execute(self, requests):
"""
Process inference requests.
"""
responses = []
# Every request must contain three JP2 byte strings (Red, Green, NIR).
for request in requests:
# Get the input tensor containing the byte arrays
input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes")
# as_numpy() for TYPE_STRING gives an ndarray of Python bytes objects
jp2_bytes_list = input_tensor.as_numpy()
if len(jp2_bytes_list) != 3:
# Send an error response if the input shape is incorrect
error = pb_utils.TritonError(f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}")
response = pb_utils.InferenceResponse(output_tensors=[], error=error)
responses.append(response)
continue # Skip to the next request
# Assume order: Red, Green, NIR based on client logic
red_bytes = jp2_bytes_list[0]
green_bytes = jp2_bytes_list[1]
nir_bytes = jp2_bytes_list[2]
try:
# Process JP2 bytes using rasterio in memory
with MemoryFile(red_bytes) as memfile_red:
with memfile_red.open() as src_red:
red_data = src_red.read(1).astype(np.float32)
target_height = src_red.height
target_width = src_red.width
with MemoryFile(green_bytes) as memfile_green:
with memfile_green.open() as src_green:
# Ensure green band matches red band dimensions (should if B03)
if src_green.height != target_height or src_green.width != target_width:
# Optional: Resample green if necessary, though B03 usually matches B04
green_data = src_green.read(
1,
out_shape=(1, target_height, target_width),
resampling=Resampling.bilinear
).astype(np.float32)
else:
green_data = src_green.read(1).astype(np.float32)
with MemoryFile(nir_bytes) as memfile_nir:
with memfile_nir.open() as src_nir:
# Resample NIR (B8A) to match Red/Green (B04/B03) resolution
nir_data = src_nir.read(
1, # Read the first band
out_shape=(1, target_height, target_width),
resampling=Resampling.bilinear
).astype(np.float32)
# Stack bands in CHW format (Red, Green, NIR) for the model
# Match the channel order expected by predict_from_array
input_array = np.stack([red_data, green_data, nir_data], axis=0)
# Perform inference using the original function
pred_mask = predict_from_array(input_array)
# Create output tensor
output_tensor = pb_utils.Tensor(
"output_mask",
pred_mask.astype(np.uint8)
)
response = pb_utils.InferenceResponse([output_tensor])
except Exception as e:
# Handle errors during processing (e.g., invalid JP2 data)
error = pb_utils.TritonError(f"Error processing JP2 data: {str(e)}")
response = pb_utils.InferenceResponse(output_tensors=[], error=error)
responses.append(response)
# Return a list of responses
return responses
def finalize(self):
"""
Called when the model is unloaded. Perform any necessary cleanup.
"""
print('Finalizing Cloud Detection model')
|