File size: 13,715 Bytes
8fded10
a809e1c
 
 
 
 
 
4e79cd8
8fded10
 
 
 
 
a809e1c
8fded10
abfc282
 
 
 
 
 
 
 
 
a809e1c
abfc282
a809e1c
abfc282
 
a809e1c
abfc282
 
 
 
 
 
 
 
 
 
 
a809e1c
 
 
 
abfc282
 
 
 
 
 
 
 
a809e1c
 
abfc282
 
 
 
 
a809e1c
abfc282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a809e1c
 
 
 
abfc282
 
a809e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abfc282
 
4e26f10
abfc282
4e26f10
abfc282
 
20acb0b
 
 
 
 
abfc282
 
 
 
 
 
a809e1c
abfc282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20acb0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abfc282
 
 
 
 
 
 
20acb0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abfc282
 
 
 
 
20acb0b
abfc282
 
 
 
 
 
 
 
 
 
 
 
 
a809e1c
abfc282
a809e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abfc282
a809e1c
 
 
abfc282
 
a809e1c
abfc282
 
a809e1c
abfc282
 
 
a809e1c
 
abfc282
4e79cd8
20acb0b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
import streamlit as st
import sys
import os

# Add the parent directory to sys.path to allow imports from 'src'
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import h5py
import torch
import numpy as np
import matplotlib.pyplot as plt
import yaml
import os
import io

# Import models
from src.mobilenetv2_model import LandslideModel as MobileNetV2Model
from src.vgg16_model import LandslideModel as VGG16Model
from src.resnet34_model import LandslideModel as ResNet34Model
from src.efficientnetb0_model import LandslideModel as EfficientNetB0Model
from src.mitb1_model import LandslideModel as MiTB1Model
from src.inceptionv4_model import LandslideModel as InceptionV4Model
from src.densenet121_model import LandslideModel as DenseNet121Model
from src.deeplabv3plus_model import LandslideModel as DeepLabV3PlusModel
from src.resnext50_32x4d_model import LandslideModel as ResNeXt50Model
from src.se_resnet50_model import LandslideModel as SEResNet50Model
from src.se_resnext50_32x4d_model import LandslideModel as SEResNeXt50Model
from src.segformer_model import LandslideModel as SegFormerB2Model
from src.inceptionresnetv2_model import LandslideModel as InceptionResNetV2Model
from src.model_downloader import ModelDownloader

# Define available models
AVAILABLE_MODELS = {
    "mobilenetv2": {"name": "MobileNetV2", "type": "mobilenet_v2"},
    "vgg16": {"name": "VGG16", "type": "vgg16"},
    "resnet34": {"name": "ResNet34", "type": "resnet34"},
    "efficientnetb0": {"name": "EfficientNetB0", "type": "efficientnet_b0"},
    "mitb1": {"name": "MiTB1", "type": "mitb1"},
    "inceptionv4": {"name": "InceptionV4", "type": "inception_v4"},
    "densenet121": {"name": "DenseNet121", "type": "densenet121"},
    "deeplabv3plus": {"name": "DeepLabV3Plus", "type": "deeplabv3plus"},
    "resnext50": {"name": "ResNeXt50", "type": "resnext50_32x4d", "downloader_key": "resnext50_32x4d"},
    "seresnet50": {"name": "SEResNet50", "type": "se_resnet50", "downloader_key": "se_resnet50"},
    "seresnext50": {"name": "SEResNeXt50", "type": "se_resnext50_32x4d", "downloader_key": "se_resnext50_32x4d"},
    "segformerb2": {"name": "SegFormerB2", "type": "segformer_b2", "downloader_key": "segformer"},
    "inceptionresnetv2": {"name": "InceptionResNetV2", "type": "inception_resnet_v2"}
}

# Model descriptions with their respective types and descriptions
MODEL_DESCRIPTIONS = {
    model_key: {
        "type": model_info["type"],
        "description": f"{model_info['name']} - A model for landslide detection and segmentation.",
        "name": model_info["name"],
        "downloader_key": model_info.get("downloader_key", model_key)
    }
    for model_key, model_info in AVAILABLE_MODELS.items()
}

# Load the configuration file
config_str = """
model_config:
  model_type: "mobilenet_v2"
  in_channels: 14
  num_classes: 1
  encoder_weights: "imagenet"
  wce_weight: 0.5

dataset_config:
  num_classes: 1
  num_channels: 14
  channels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
  normalize: False

train_config:
  dataset_path: ""
  checkpoint_path: "checkpoints"
  seed: 42
  train_val_split: 0.8
  batch_size: 16
  num_epochs: 100
  lr: 0.001
  device: "cuda:0"
  save_config: True
  experiment_name: "mobilenet_v2"

logging_config:
  wandb_project: "l4s"
  wandb_entity: "Silvamillion"
"""

config = yaml.safe_load(config_str)

def process_and_visualize(model_key, model_info, image_tensor, original_image, uploaded_file_name):
    """
    Process the image with the selected model and visualize results.
    """
    try:
        st.write(f"Using model: {model_info['name']}")
        
        # Update config for the specific model
        current_config = config.copy()
        current_config['model_config']['model_type'] = model_info['type']
        
        # Get the model class
        model_class_name = AVAILABLE_MODELS[model_key]['name'].replace('-', '') + 'Model'
        if model_class_name not in globals():
             # Fallback for naming inconsistencies if any
             # Try to find it in globals
             pass
        model_class = globals()[model_class_name]

        # Initialize model downloader
        downloader = ModelDownloader()
        
        # Download/get model path
        download_key = model_info.get('downloader_key', model_key)
        model_path = downloader.download_model(download_key)
        st.info(f"Using model from: {model_path}")
        
        # Load the model
        model = model_class(current_config)
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
        model.eval()

        # Make prediction
        with torch.no_grad():
            prediction = model(image_tensor)
            prediction = torch.sigmoid(prediction).cpu().numpy()

        # Display prediction
        st.header(f"Prediction Results - {model_info['name']}")
        fig, ax = plt.subplots(1, 3, figsize=(15, 5))
        
        # Normalize image for display
        img_display = original_image.transpose(1, 2, 0) # (C, H, W) -> (H, W, C)
        img_display = (img_display - img_display.min()) / (img_display.max() - img_display.min())
        
        ax[0].imshow(img_display[:, :, :3])  # Display first three channels as RGB
        ax[0].set_title("Input Image")
        ax[0].axis('off')
        
        ax[1].imshow(prediction.squeeze(), cmap='plasma')  # Raw prediction map
        ax[1].set_title("Prediction Probability")
        ax[1].axis('off')
        
        ax[2].imshow(img_display[:, :, :3])
        ax[2].imshow(prediction.squeeze() > 0.5, cmap='plasma', alpha=0.4)  # Overlay
        ax[2].set_title("Overlay (Threshold > 0.5)")
        ax[2].axis('off')
        
        st.pyplot(fig)
        plt.close(fig)

        # Download button
        st.write(f"Download the prediction as a .npy file for {model_info['name']}:")
        npy_data = prediction.squeeze()
        st.download_button(
            label=f"Download Prediction - {model_info['name']}",
            data=npy_data.tobytes(),
            file_name=f"{uploaded_file_name.split('.')[0]}_{model_key}_prediction.npy",
            mime="application/octet-stream"
        )
        
    except Exception as e:
        st.error(f"Error with model {model_info['name']}: {str(e)}")
        import traceback
        st.error(traceback.format_exc())

# Streamlit app
st.set_page_config(page_title="DeepSlide: Landslide Detection", layout="wide")

st.title("DeepSlide: Landslide Detection")
st.markdown("""
## Instructions
1. **Model Selection**: Choose a single model from the sidebar or select "Run all models".
2. **Data Input**: 
   - Try an example image from the dropdown, or
   - Upload your own .h5 files
3. **Results**: View predictions and download results as .npy files.
""")

# Sidebar for model selection
st.sidebar.title("Model Selection")
model_option = st.sidebar.radio("Choose an option", ["Select a single model", "Run all models"])

selected_model_key = None
if model_option == "Select a single model":
    selected_model_key = st.sidebar.selectbox("Select Model", list(MODEL_DESCRIPTIONS.keys()))
    selected_model_info = MODEL_DESCRIPTIONS[selected_model_key]
    
    # Display model details in the sidebar
    st.sidebar.markdown("### Model Details")
    st.sidebar.markdown(f"**Model Name:** {selected_model_info['name']}")
    st.sidebar.markdown(f"**Model Type:** {selected_model_info['type']}")
    st.sidebar.markdown(f"**Description:** {selected_model_info['description']}")

# Main content
st.header("Upload Data")

# Initialize session state for error tracking if not exists
if 'upload_errors' not in st.session_state:
    st.session_state.upload_errors = []

# Example images selection
st.subheader("Try Example Images")
examples_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "examples")
example_files = []

try:
    if os.path.exists(examples_dir):
        example_files = [f for f in os.listdir(examples_dir) if f.endswith('.h5')]
        example_files.sort()
except:
    pass

if example_files:
    selected_example = st.selectbox(
        "Select an example image to test:",
        options=["None"] + example_files,
        help="Choose an example .h5 file to quickly test the models"
    )
else:
    st.info("No example files found")
    selected_example = "None"

# File upload section
st.subheader("Upload Your Own Files")
uploaded_files = st.file_uploader(
    "Choose .h5 files...", 
    type="h5", 
    accept_multiple_files=True,
    help="Upload your .h5 files here. Maximum file size is 200MB."
)

def process_h5_file(file_path, file_name):
    """Process a single h5 file"""
    try:
        with h5py.File(file_path, 'r') as hdf:
            if 'img' not in hdf:
                st.error(f"Error: 'img' dataset not found in {file_name}")
                return
                
            data = np.array(hdf.get('img'))
            data[np.isnan(data)] = 0.000001
            channels = config["dataset_config"]["channels"]
            
            image = np.zeros((128, 128, len(channels)))
            
            if data.ndim == 3:
                if data.shape[0] == 14: # (C, H, W)
                    for i, band in enumerate(channels):
                        image[:, :, i] = data[band-1, :, :]
                elif data.shape[2] == 14: # (H, W, C)
                    for i, band in enumerate(channels):
                        image[:, :, i] = data[:, :, band-1]
                else:
                    st.warning(f"Unexpected data shape: {data.shape}. Assuming (C, H, W).")
                    for i, band in enumerate(channels):
                        if band-1 < data.shape[0]:
                            image[:, :, i] = data[band-1, :, :]
            else:
                st.error(f"Data has {data.ndim} dimensions, expected 3.")
                return

            # Prepare for model (Batch, Channel, Height, Width)
            image_display = image.transpose(2, 0, 1) # (C, H, W)
            image_tensor = torch.from_numpy(image_display).unsqueeze(0).float() # (1, C, H, W)

            if model_option == "Select a single model":
                process_and_visualize(selected_model_key, selected_model_info, image_tensor, image_display, file_name)
            else:
                for model_key, model_info in MODEL_DESCRIPTIONS.items():
                    process_and_visualize(model_key, model_info, image_tensor, image_display, file_name)

    except Exception as e:
        st.error(f"Error processing file {file_name}: {str(e)}")

# Process example file if selected
if selected_example != "None":
    st.write(f"Processing example: {selected_example}")
    example_path = os.path.join(examples_dir, selected_example)
    with st.spinner(f'Processing {selected_example}...'):
        process_h5_file(example_path, selected_example)

# Process uploaded files
if uploaded_files:
    for uploaded_file in uploaded_files:
        st.write(f"Processing file: {uploaded_file.name}")
        st.write(f"File size: {uploaded_file.size} bytes")
        
        with st.spinner('Processing...'):
            try:
                # Read the file directly using BytesIO
                bytes_data = uploaded_file.getvalue()
                bytes_io = io.BytesIO(bytes_data)
                
                with h5py.File(bytes_io, 'r') as hdf:
                    if 'img' not in hdf:
                        st.error(f"Error: 'img' dataset not found in {uploaded_file.name}")
                        continue
                        
                    data = np.array(hdf.get('img'))
                    data[np.isnan(data)] = 0.000001
                    channels = config["dataset_config"]["channels"]
                    
                    image = np.zeros((128, 128, len(channels)))
                    
                    if data.ndim == 3:
                        if data.shape[0] == 14: # (C, H, W)
                             for i, band in enumerate(channels):
                                image[:, :, i] = data[band-1, :, :]
                        elif data.shape[2] == 14: # (H, W, C)
                             for i, band in enumerate(channels):
                                image[:, :, i] = data[:, :, band-1]
                        else:
                            st.warning(f"Unexpected data shape: {data.shape}. Assuming (C, H, W).")
                            for i, band in enumerate(channels):
                                if band-1 < data.shape[0]:
                                     image[:, :, i] = data[band-1, :, :]
                    else:
                        st.error(f"Data has {data.ndim} dimensions, expected 3.")
                        continue

                    # Prepare for model (Batch, Channel, Height, Width)
                    image_display = image.transpose(2, 0, 1) # (C, H, W)
                    image_tensor = torch.from_numpy(image_display).unsqueeze(0).float() # (1, C, H, W)

                    if model_option == "Select a single model":
                        process_and_visualize(selected_model_key, selected_model_info, image_tensor, image_display, uploaded_file.name)
                    else:
                        for model_key, model_info in MODEL_DESCRIPTIONS.items():
                            process_and_visualize(model_key, model_info, image_tensor, image_display, uploaded_file.name)

            except Exception as e:
                st.error(f"Error processing file {uploaded_file.name}: {str(e)}")
                import traceback
                st.error(traceback.format_exc())
                continue

if selected_example != "None" or uploaded_files:
    st.success('✅ Processing completed!')