File size: 4,522 Bytes
36d88d2
 
 
024f5b3
 
 
36d88d2
 
 
024f5b3
 
 
0d9b472
 
 
024f5b3
d03e9df
 
0d9b472
d03e9df
 
0d9b472
d03e9df
0d9b472
 
 
 
 
 
 
 
 
 
 
d03e9df
0d9b472
 
 
 
d03e9df
0d9b472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d03e9df
e47d31c
0d9b472
 
 
 
 
 
 
 
d03e9df
0d9b472
 
 
d03e9df
0d9b472
 
d03e9df
0d9b472
 
 
 
 
 
024f5b3
 
0d9b472
 
024f5b3
 
0d9b472
 
 
36d88d2
 
 
024f5b3
0d9b472
024f5b3
0d9b472
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import triton_python_backend_utils as pb_utils
from omnicloudmask import predict_from_array
import rasterio
from rasterio.io import MemoryFile
from rasterio.enums import Resampling

class TritonPythonModel:
    def initialize(self, args):
        """
        Initialize the model. This function is called once when the model is loaded.
        """
        # You can load models or initialize resources here if needed.
        # Ensure rasterio is installed in the Python backend environment.
        print('Initialized Cloud Detection model with JP2 input')

    def execute(self, requests):
        """
        Process inference requests.
        """
        responses = []
        # Every request must contain three JP2 byte strings (Red, Green, NIR).
        for request in requests:
            # Get the input tensor containing the byte arrays
            input_tensor = pb_utils.get_input_tensor_by_name(request, "input_jp2_bytes")
            # as_numpy() for TYPE_STRING gives an ndarray of Python bytes objects
            jp2_bytes_list = input_tensor.as_numpy()

            if len(jp2_bytes_list) != 3:
                # Send an error response if the input shape is incorrect
                error = pb_utils.TritonError(f"Expected 3 JP2 byte strings, received {len(jp2_bytes_list)}")
                response = pb_utils.InferenceResponse(output_tensors=[], error=error)
                responses.append(response)
                continue # Skip to the next request

            # Assume order: Red, Green, NIR based on client logic
            red_bytes = jp2_bytes_list[0]
            green_bytes = jp2_bytes_list[1]
            nir_bytes = jp2_bytes_list[2]

            try:
                # Process JP2 bytes using rasterio in memory
                with MemoryFile(red_bytes) as memfile_red:
                    with memfile_red.open() as src_red:
                        red_data = src_red.read(1).astype(np.float32)
                        target_height = src_red.height
                        target_width = src_red.width

                with MemoryFile(green_bytes) as memfile_green:
                    with memfile_green.open() as src_green:
                        # Ensure green band matches red band dimensions (should if B03)
                        if src_green.height != target_height or src_green.width != target_width:
                             # Optional: Resample green if necessary, though B03 usually matches B04
                             green_data = src_green.read(
                                 1,
                                 out_shape=(1, target_height, target_width),
                                 resampling=Resampling.bilinear
                             ).astype(np.float32)
                        else:
                             green_data = src_green.read(1).astype(np.float32)


                with MemoryFile(nir_bytes) as memfile_nir:
                    with memfile_nir.open() as src_nir:
                        # Resample NIR (B8A) to match Red/Green (B04/B03) resolution
                        nir_data = src_nir.read(
                            1, # Read the first band
                            out_shape=(1, target_height, target_width),
                            resampling=Resampling.bilinear
                        ).astype(np.float32)

                # Stack bands in CHW format (Red, Green, NIR) for the model
                # Match the channel order expected by predict_from_array
                input_array = np.stack([red_data, green_data, nir_data], axis=0)

                # Perform inference using the original function
                pred_mask = predict_from_array(input_array)

                # Create output tensor
                output_tensor = pb_utils.Tensor(
                    "output_mask",
                    pred_mask.astype(np.uint8)
                )
                response = pb_utils.InferenceResponse([output_tensor])

            except Exception as e:
                # Handle errors during processing (e.g., invalid JP2 data)
                error = pb_utils.TritonError(f"Error processing JP2 data: {str(e)}")
                response = pb_utils.InferenceResponse(output_tensors=[], error=error)

            responses.append(response)

        # Return a list of responses
        return responses

    def finalize(self):
        """
        Called when the model is unloaded. Perform any necessary cleanup.
        """
        print('Finalizing Cloud Detection model')