import streamlit as st import sys import os # Add the parent directory to sys.path to allow imports from 'src' sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import h5py import torch import numpy as np import matplotlib.pyplot as plt import yaml import os import io # Import models from src.mobilenetv2_model import LandslideModel as MobileNetV2Model from src.vgg16_model import LandslideModel as VGG16Model from src.resnet34_model import LandslideModel as ResNet34Model from src.efficientnetb0_model import LandslideModel as EfficientNetB0Model from src.mitb1_model import LandslideModel as MiTB1Model from src.inceptionv4_model import LandslideModel as InceptionV4Model from src.densenet121_model import LandslideModel as DenseNet121Model from src.deeplabv3plus_model import LandslideModel as DeepLabV3PlusModel from src.resnext50_32x4d_model import LandslideModel as ResNeXt50Model from src.se_resnet50_model import LandslideModel as SEResNet50Model from src.se_resnext50_32x4d_model import LandslideModel as SEResNeXt50Model from src.segformer_model import LandslideModel as SegFormerB2Model from src.inceptionresnetv2_model import LandslideModel as InceptionResNetV2Model from src.model_downloader import ModelDownloader # Define available models AVAILABLE_MODELS = { "mobilenetv2": {"name": "MobileNetV2", "type": "mobilenet_v2"}, "vgg16": {"name": "VGG16", "type": "vgg16"}, "resnet34": {"name": "ResNet34", "type": "resnet34"}, "efficientnetb0": {"name": "EfficientNetB0", "type": "efficientnet_b0"}, "mitb1": {"name": "MiTB1", "type": "mitb1"}, "inceptionv4": {"name": "InceptionV4", "type": "inception_v4"}, "densenet121": {"name": "DenseNet121", "type": "densenet121"}, "deeplabv3plus": {"name": "DeepLabV3Plus", "type": "deeplabv3plus"}, "resnext50": {"name": "ResNeXt50", "type": "resnext50_32x4d", "downloader_key": "resnext50_32x4d"}, "seresnet50": {"name": "SEResNet50", "type": "se_resnet50", "downloader_key": "se_resnet50"}, "seresnext50": {"name": "SEResNeXt50", "type": "se_resnext50_32x4d", "downloader_key": "se_resnext50_32x4d"}, "segformerb2": {"name": "SegFormerB2", "type": "segformer_b2", "downloader_key": "segformer"}, "inceptionresnetv2": {"name": "InceptionResNetV2", "type": "inception_resnet_v2"} } # Model descriptions with their respective types and descriptions MODEL_DESCRIPTIONS = { model_key: { "type": model_info["type"], "description": f"{model_info['name']} - A model for landslide detection and segmentation.", "name": model_info["name"], "downloader_key": model_info.get("downloader_key", model_key) } for model_key, model_info in AVAILABLE_MODELS.items() } # Load the configuration file config_str = """ model_config: model_type: "mobilenet_v2" in_channels: 14 num_classes: 1 encoder_weights: "imagenet" wce_weight: 0.5 dataset_config: num_classes: 1 num_channels: 14 channels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] normalize: False train_config: dataset_path: "" checkpoint_path: "checkpoints" seed: 42 train_val_split: 0.8 batch_size: 16 num_epochs: 100 lr: 0.001 device: "cuda:0" save_config: True experiment_name: "mobilenet_v2" logging_config: wandb_project: "l4s" wandb_entity: "Silvamillion" """ config = yaml.safe_load(config_str) def process_and_visualize(model_key, model_info, image_tensor, original_image, uploaded_file_name): """ Process the image with the selected model and visualize results. """ try: st.write(f"Using model: {model_info['name']}") # Update config for the specific model current_config = config.copy() current_config['model_config']['model_type'] = model_info['type'] # Get the model class model_class_name = AVAILABLE_MODELS[model_key]['name'].replace('-', '') + 'Model' if model_class_name not in globals(): # Fallback for naming inconsistencies if any # Try to find it in globals pass model_class = globals()[model_class_name] # Initialize model downloader downloader = ModelDownloader() # Download/get model path download_key = model_info.get('downloader_key', model_key) model_path = downloader.download_model(download_key) st.info(f"Using model from: {model_path}") # Load the model model = model_class(current_config) model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False) model.eval() # Make prediction with torch.no_grad(): prediction = model(image_tensor) prediction = torch.sigmoid(prediction).cpu().numpy() # Display prediction st.header(f"Prediction Results - {model_info['name']}") fig, ax = plt.subplots(1, 3, figsize=(15, 5)) # Normalize image for display img_display = original_image.transpose(1, 2, 0) # (C, H, W) -> (H, W, C) img_display = (img_display - img_display.min()) / (img_display.max() - img_display.min()) ax[0].imshow(img_display[:, :, :3]) # Display first three channels as RGB ax[0].set_title("Input Image") ax[0].axis('off') ax[1].imshow(prediction.squeeze(), cmap='plasma') # Raw prediction map ax[1].set_title("Prediction Probability") ax[1].axis('off') ax[2].imshow(img_display[:, :, :3]) ax[2].imshow(prediction.squeeze() > 0.5, cmap='plasma', alpha=0.4) # Overlay ax[2].set_title("Overlay (Threshold > 0.5)") ax[2].axis('off') st.pyplot(fig) plt.close(fig) # Download button st.write(f"Download the prediction as a .npy file for {model_info['name']}:") npy_data = prediction.squeeze() st.download_button( label=f"Download Prediction - {model_info['name']}", data=npy_data.tobytes(), file_name=f"{uploaded_file_name.split('.')[0]}_{model_key}_prediction.npy", mime="application/octet-stream" ) except Exception as e: st.error(f"Error with model {model_info['name']}: {str(e)}") import traceback st.error(traceback.format_exc()) # Streamlit app st.set_page_config(page_title="DeepSlide: Landslide Detection", layout="wide") st.title("DeepSlide: Landslide Detection") st.markdown(""" ## Instructions 1. **Model Selection**: Choose a single model from the sidebar or select "Run all models". 2. **Data Input**: - Try an example image from the dropdown, or - Upload your own .h5 files 3. **Results**: View predictions and download results as .npy files. """) # Sidebar for model selection st.sidebar.title("Model Selection") model_option = st.sidebar.radio("Choose an option", ["Select a single model", "Run all models"]) selected_model_key = None if model_option == "Select a single model": selected_model_key = st.sidebar.selectbox("Select Model", list(MODEL_DESCRIPTIONS.keys())) selected_model_info = MODEL_DESCRIPTIONS[selected_model_key] # Display model details in the sidebar st.sidebar.markdown("### Model Details") st.sidebar.markdown(f"**Model Name:** {selected_model_info['name']}") st.sidebar.markdown(f"**Model Type:** {selected_model_info['type']}") st.sidebar.markdown(f"**Description:** {selected_model_info['description']}") # Main content st.header("Upload Data") # Initialize session state for error tracking if not exists if 'upload_errors' not in st.session_state: st.session_state.upload_errors = [] # Example images selection st.subheader("Try Example Images") examples_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "examples") example_files = [] try: if os.path.exists(examples_dir): example_files = [f for f in os.listdir(examples_dir) if f.endswith('.h5')] example_files.sort() except: pass if example_files: selected_example = st.selectbox( "Select an example image to test:", options=["None"] + example_files, help="Choose an example .h5 file to quickly test the models" ) else: st.info("No example files found") selected_example = "None" # File upload section st.subheader("Upload Your Own Files") uploaded_files = st.file_uploader( "Choose .h5 files...", type="h5", accept_multiple_files=True, help="Upload your .h5 files here. Maximum file size is 200MB." ) def process_h5_file(file_path, file_name): """Process a single h5 file""" try: with h5py.File(file_path, 'r') as hdf: if 'img' not in hdf: st.error(f"Error: 'img' dataset not found in {file_name}") return data = np.array(hdf.get('img')) data[np.isnan(data)] = 0.000001 channels = config["dataset_config"]["channels"] image = np.zeros((128, 128, len(channels))) if data.ndim == 3: if data.shape[0] == 14: # (C, H, W) for i, band in enumerate(channels): image[:, :, i] = data[band-1, :, :] elif data.shape[2] == 14: # (H, W, C) for i, band in enumerate(channels): image[:, :, i] = data[:, :, band-1] else: st.warning(f"Unexpected data shape: {data.shape}. Assuming (C, H, W).") for i, band in enumerate(channels): if band-1 < data.shape[0]: image[:, :, i] = data[band-1, :, :] else: st.error(f"Data has {data.ndim} dimensions, expected 3.") return # Prepare for model (Batch, Channel, Height, Width) image_display = image.transpose(2, 0, 1) # (C, H, W) image_tensor = torch.from_numpy(image_display).unsqueeze(0).float() # (1, C, H, W) if model_option == "Select a single model": process_and_visualize(selected_model_key, selected_model_info, image_tensor, image_display, file_name) else: for model_key, model_info in MODEL_DESCRIPTIONS.items(): process_and_visualize(model_key, model_info, image_tensor, image_display, file_name) except Exception as e: st.error(f"Error processing file {file_name}: {str(e)}") # Process example file if selected if selected_example != "None": st.write(f"Processing example: {selected_example}") example_path = os.path.join(examples_dir, selected_example) with st.spinner(f'Processing {selected_example}...'): process_h5_file(example_path, selected_example) # Process uploaded files if uploaded_files: for uploaded_file in uploaded_files: st.write(f"Processing file: {uploaded_file.name}") st.write(f"File size: {uploaded_file.size} bytes") with st.spinner('Processing...'): try: # Read the file directly using BytesIO bytes_data = uploaded_file.getvalue() bytes_io = io.BytesIO(bytes_data) with h5py.File(bytes_io, 'r') as hdf: if 'img' not in hdf: st.error(f"Error: 'img' dataset not found in {uploaded_file.name}") continue data = np.array(hdf.get('img')) data[np.isnan(data)] = 0.000001 channels = config["dataset_config"]["channels"] image = np.zeros((128, 128, len(channels))) if data.ndim == 3: if data.shape[0] == 14: # (C, H, W) for i, band in enumerate(channels): image[:, :, i] = data[band-1, :, :] elif data.shape[2] == 14: # (H, W, C) for i, band in enumerate(channels): image[:, :, i] = data[:, :, band-1] else: st.warning(f"Unexpected data shape: {data.shape}. Assuming (C, H, W).") for i, band in enumerate(channels): if band-1 < data.shape[0]: image[:, :, i] = data[band-1, :, :] else: st.error(f"Data has {data.ndim} dimensions, expected 3.") continue # Prepare for model (Batch, Channel, Height, Width) image_display = image.transpose(2, 0, 1) # (C, H, W) image_tensor = torch.from_numpy(image_display).unsqueeze(0).float() # (1, C, H, W) if model_option == "Select a single model": process_and_visualize(selected_model_key, selected_model_info, image_tensor, image_display, uploaded_file.name) else: for model_key, model_info in MODEL_DESCRIPTIONS.items(): process_and_visualize(model_key, model_info, image_tensor, image_display, uploaded_file.name) except Exception as e: st.error(f"Error processing file {uploaded_file.name}: {str(e)}") import traceback st.error(traceback.format_exc()) continue if selected_example != "None" or uploaded_files: st.success('✅ Processing completed!')