HarshShinde0 commited on
Commit
a809e1c
·
1 Parent(s): abfc282

Prepare for HF Spaces: Fix paths, add .gitignore, update app logic

Browse files
.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Models
2
+ models/
3
+ *.pth
4
+ *.h5
5
+ *.npy
6
+
7
+ # Python
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+
12
+ # Environment
13
+ .env
14
+ .venv
15
+ venv/
16
+ ENV/
17
+
18
+ # IDE
19
+ .vscode/
20
+ .idea/
21
+
22
+ # OS
23
+ .DS_Store
24
+ Thumbs.db
src/app.py DELETED
@@ -1,206 +0,0 @@
1
- import streamlit as st
2
- import h5py
3
- import torch
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- import yaml
7
- import os
8
-
9
- # Import models
10
- from mobilenetv2_model import LandslideModel as MobileNetV2Model
11
- from vgg16_model import LandslideModel as VGG16Model
12
- from resnet34_model import LandslideModel as ResNet34Model
13
- from efficientnetb0_model import LandslideModel as EfficientNetB0Model
14
- from mitb1_model import LandslideModel as MiTB1Model
15
- from inceptionv4_model import LandslideModel as InceptionV4Model
16
- from densenet121_model import LandslideModel as DenseNet121Model
17
- from deeplabv3plus_model import LandslideModel as DeepLabV3PlusModel
18
- from resnext50_32x4d_model import LandslideModel as ResNeXt50_32X4DModel
19
- from se_resnet50_model import LandslideModel as SEResNet50Model
20
- from se_resnext50_32x4d_model import LandslideModel as SEResNeXt50_32X4DModel
21
- from segformer_model import LandslideModel as SegFormerB2Model
22
- from inceptionresnetv2_model import LandslideModel as InceptionResNetV2Model
23
-
24
- # Load the configuration file
25
- config = """
26
- model_config:
27
- model_type: "mobilenet_v2"
28
- in_channels: 14
29
- num_classes: 1
30
- encoder_weights: "imagenet"
31
- wce_weight: 0.5
32
-
33
- dataset_config:
34
- num_classes: 1
35
- num_channels: 14
36
- channels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
37
- normalize: False
38
-
39
- train_config:
40
- dataset_path: ""
41
- checkpoint_path: "checkpoints"
42
- seed: 42
43
- train_val_split: 0.8
44
- batch_size: 16
45
- num_epochs: 100
46
- lr: 0.001
47
- device: "cuda:0"
48
- save_config: True
49
- experiment_name: "mobilenet_v2"
50
-
51
- logging_config:
52
- wandb_project: "l4s"
53
- wandb_entity: "Silvamillion"
54
- """
55
-
56
- config = yaml.safe_load(config)
57
-
58
- # Model descriptions
59
- model_descriptions = {
60
- "MobileNetV2": {"path": "mobilenetv2.pth", "type": "mobilenet_v2", "description": "MobileNetV2 is a lightweight deep learning model for image classification and segmentation."},
61
- "VGG16": {"path": "vgg16.pth", "type": "vgg16", "description": "VGG16 is a popular deep learning model known for its simplicity and depth."},
62
- "ResNet34": {"path": "resnet34.pth", "type": "resnet34", "description": "ResNet34 is a deep residual network that helps in training very deep networks."},
63
- "EfficientNetB0": {"path": "effucientnetb0.pth", "type": "efficientnet_b0", "description": "EfficientNetB0 is part of the EfficientNet family, known for its efficiency and performance."},
64
- "MiT-B1": {"path": "mitb1.pth", "type": "mit_b1", "description": "MiT-B1 is a transformer-based model designed for segmentation tasks."},
65
- "InceptionV4": {"path": "inceptionv4.pth", "type": "inceptionv4", "description": "InceptionV4 is a convolutional neural network known for its inception modules."},
66
- "DeepLabV3+": {"path": "deeplabv3.pth", "type": "deeplabv3+", "description": "DeepLabV3+ is an advanced model for semantic image segmentation."},
67
- "DenseNet121": {"path": "densenet121.pth", "type": "densenet121", "description": "DenseNet121 is a densely connected convolutional network for image classification and segmentation."},
68
- "ResNeXt50_32X4D": {"path": "resnext50-32x4d.pth", "type": "resnext50_32x4d", "description": "ResNeXt50_32X4D is a highly modularized network aimed at improving accuracy."},
69
- "SEResNet50": {"path": "se_resnet50.pth", "type": "se_resnet50", "description": "SEResNet50 is a ResNet model with squeeze-and-excitation blocks for better feature recalibration."},
70
- "SEResNeXt50_32X4D": {"path": "se_resnext50_32x4d.pth", "type": "se_resnext50_32x4d", "description": "SEResNeXt50_32X4D combines ResNeXt and SE blocks for improved performance."},
71
- "SegFormerB2": {"path": "segformer.pth", "type": "segformer_b2", "description": "SegFormerB2 is a transformer-based model for semantic segmentation."},
72
- "InceptionResNetV2": {"path": "inceptionresnetv2.pth", "type": "inceptionresnetv2", "description": "InceptionResNetV2 is a hybrid model combining Inception and ResNet architectures."},
73
- }
74
-
75
- # Streamlit app
76
- st.set_page_config(page_title="Landslide Detection", layout="wide")
77
-
78
- st.title("Landslide Detection")
79
- st.markdown("""
80
- ## Instructions
81
- 1. Select a model from the sidebar or choose to run all models.
82
- 2. Upload one or more `.h5` files.
83
- 3. The app will process the files and display the input image, prediction, and overlay.
84
- 4. You can download the prediction results.
85
- """)
86
-
87
- # Sidebar for model selection
88
- st.sidebar.title("Model Selection")
89
- model_option = st.sidebar.radio("Choose an option", ["Select a single model", "Run all models"])
90
- if model_option == "Select a single model":
91
- model_type = st.sidebar.selectbox("Select Model", list(model_descriptions.keys()))
92
- config['model_config']['model_type'] = model_descriptions[model_type]['type']
93
- if model_type == "DeepLabV3+":
94
- model_class = DeepLabV3PlusModel
95
- else:
96
- model_class = locals()[model_type.replace("-", "") + "Model"]
97
- model_path = model_descriptions[model_type]['path']
98
-
99
- # Display model details in the sidebar
100
- st.sidebar.markdown(f"**Model Type:** {model_descriptions[model_type]['type']}")
101
- st.sidebar.markdown(f"**Model Path:** {model_descriptions[model_type]['path']}")
102
- st.sidebar.markdown(f"**Description:** {model_descriptions[model_type]['description']}")
103
-
104
- # Main content
105
- st.header("Upload Data")
106
- uploaded_files = st.file_uploader("Choose .h5 files...", type="h5", accept_multiple_files=True)
107
- if uploaded_files:
108
- for uploaded_file in uploaded_files:
109
- st.write(f"Processing file: {uploaded_file.name}")
110
- with st.spinner('Classifying...'):
111
- with h5py.File(uploaded_file, 'r') as hdf:
112
- data = np.array(hdf.get('img'))
113
- data[np.isnan(data)] = 0.000001
114
- channels = config["dataset_config"]["channels"]
115
- image = np.zeros((128, 128, len(channels)))
116
- for i, channel in enumerate(channels):
117
- image[:, :, i] = data[:, :, channel-1]
118
-
119
- # Transform the image to the required format
120
- image = image.transpose((2, 0, 1)) # (H, W, C) to (C, H, W)
121
- image = torch.from_numpy(image).float().unsqueeze(0) # Add batch dimension
122
-
123
- if model_option == "Select a single model":
124
- # Process the image with the selected model
125
- st.write(f"Using model: {model_type}")
126
-
127
- # Load the model
128
- model = model_class(config)
129
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
130
- model.eval()
131
-
132
- # Make prediction
133
- with torch.no_grad():
134
- prediction = model(image)
135
- prediction = torch.sigmoid(prediction).cpu().numpy()
136
-
137
- # Display prediction
138
- st.header(f"Prediction Results - {model_type}")
139
- fig, ax = plt.subplots(1, 3, figsize=(15, 5))
140
- img = image.squeeze().permute(1, 2, 0).numpy()
141
- img = (img - img.min()) / (img.max() - img.min()) # Normalize the image to [0, 1] range for display
142
- ax[0].imshow(img[:, :, 1:4]) # Display first three channels as RGB
143
- ax[0].set_title("Input Image")
144
- ax[1].imshow(prediction.squeeze() > 0.5, cmap='plasma') # Apply threshold
145
- ax[1].set_title("Prediction")
146
- ax[2].imshow(img[:, :, :3]) # Display first three channels as RGB
147
- ax[2].imshow(prediction.squeeze() > 0.5, cmap='plasma', alpha=0.3) # Overlay prediction
148
- ax[2].set_title("Overlay")
149
- st.pyplot(fig)
150
-
151
- # Option to download the prediction
152
- st.write(f"Download the prediction as a .npy file for {model_type}:")
153
- npy_data = prediction.squeeze()
154
- st.download_button(
155
- label=f"Download Prediction - {model_type}",
156
- data=npy_data.tobytes(),
157
- file_name=f"{uploaded_file.name.split('.')[0]}_{model_type}_prediction.npy",
158
- mime="application/octet-stream"
159
- )
160
-
161
- else:
162
- # Process the image with each model
163
- for model_name, model_info in model_descriptions.items():
164
- st.write(f"Using model: {model_name}")
165
- if model_name == "DeepLabV3+":
166
- model_class = DeepLabV3PlusModel
167
- else:
168
- model_class = locals()[model_name.replace("-", "") + "Model"]
169
- model_path = model_info['path']
170
- config['model_config']['model_type'] = model_info['type']
171
-
172
- # Load the model
173
- model = model_class(config)
174
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
175
- model.eval()
176
-
177
- # Make prediction
178
- with torch.no_grad():
179
- prediction = model(image)
180
- prediction = torch.sigmoid(prediction).cpu().numpy()
181
-
182
- # Display prediction
183
- st.header(f"Prediction Results - {model_name}")
184
- fig, ax = plt.subplots(1, 3, figsize=(15, 5))
185
- img = image.squeeze().permute(1, 2, 0).numpy()
186
- img = (img - img.min()) / (img.max() - img.min()) # Normalize the image to [0, 1] range for display
187
- ax[0].imshow(img[:, :, :3]) # Display first three channels as RGB
188
- ax[0].set_title("Input Image")
189
- ax[1].imshow(prediction.squeeze() > 0.5, cmap='plasma') # Apply threshold
190
- ax[1].set_title("Prediction")
191
- ax[2].imshow(img[:, :, :3]) # Display first three channels as RGB
192
- ax[2].imshow(prediction.squeeze() > 0.5, cmap='plasma', alpha=0.3) # Overlay prediction
193
- ax[2].set_title("Overlay")
194
- st.pyplot(fig)
195
-
196
- # Option to download the prediction
197
- st.write(f"Download the prediction as a .npy file for {model_name}:")
198
- npy_data = prediction.squeeze()
199
- st.download_button(
200
- label=f"Download Prediction - {model_name}",
201
- data=npy_data.tobytes(),
202
- file_name=f"{uploaded_file.name.split('.')[0]}_{model_name}_prediction.npy",
203
- mime="application/octet-stream"
204
- )
205
-
206
- st.success('Done!')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/deeplabv3plus_model.py CHANGED
@@ -10,7 +10,7 @@ from torch.optim.lr_scheduler import StepLR
10
  class smp_model(nn.Module):
11
  def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
  super(smp_model, self).__init__()
13
- if model_type == "deeplabv3+":
14
  self.model = smp.DeepLabV3Plus(
15
  encoder_name="resnet50", # Change this to a valid encoder
16
  encoder_weights=encoder_weights,
 
10
  class smp_model(nn.Module):
11
  def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
  super(smp_model, self).__init__()
13
+ if model_type == "deeplabv3plus":
14
  self.model = smp.DeepLabV3Plus(
15
  encoder_name="resnet50", # Change this to a valid encoder
16
  encoder_weights=encoder_weights,
src/densenet121_model.py CHANGED
@@ -18,7 +18,7 @@ class smp_model(nn.Module):
18
  )
19
 
20
  def load_pretrained_weights(self):
21
- state_dict = torch.load('/home/hks/MOU/DenseNet121_14C_L4S/densenet121-fbdb23505-trainWeights.pth', map_location='cpu')
22
  conv1_weight = state_dict['features.conv0.weight']
23
  new_conv1_weight = torch.zeros((conv1_weight.shape[0], 14, *conv1_weight.shape[2:]))
24
  new_conv1_weight[:, :3, :, :] = conv1_weight # Copy weights for the first 3 channels
@@ -50,7 +50,7 @@ class LandslideModel(pl.LightningModule):
50
  model_type=model_type,
51
  num_classes=num_classes,
52
  encoder_weights=encoder_weights)
53
- self.model.load_pretrained_weights()
54
 
55
  self.weights = torch.tensor([5], dtype=torch.float32).to(self.device)
56
  self.wce = nn.BCELoss(weight=self.weights)
 
18
  )
19
 
20
  def load_pretrained_weights(self):
21
+ # self.model.load_pretrained_weights()
22
  conv1_weight = state_dict['features.conv0.weight']
23
  new_conv1_weight = torch.zeros((conv1_weight.shape[0], 14, *conv1_weight.shape[2:]))
24
  new_conv1_weight[:, :3, :, :] = conv1_weight # Copy weights for the first 3 channels
 
50
  model_type=model_type,
51
  num_classes=num_classes,
52
  encoder_weights=encoder_weights)
53
+ # self.model.load_pretrained_weights()
54
 
55
  self.weights = torch.tensor([5], dtype=torch.float32).to(self.device)
56
  self.wce = nn.BCELoss(weight=self.weights)
src/download_all_models.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ # Add the parent directory to sys.path to allow imports from 'src'
5
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
6
+
7
+ from src.model_downloader import ModelDownloader
8
+
9
+ def download_all():
10
+ print("Starting download of all models...")
11
+ downloader = ModelDownloader()
12
+ models = downloader.list_available_models()
13
+
14
+ for model_name in models:
15
+ try:
16
+ print(f"Checking/Downloading {model_name}...")
17
+ path = downloader.download_model(model_name)
18
+ print(f"✓ {model_name} is ready at {path}")
19
+ except Exception as e:
20
+ print(f"✗ Failed to download {model_name}: {e}")
21
+
22
+ print("\nAll downloads completed.")
23
+
24
+ if __name__ == "__main__":
25
+ download_all()
src/inceptionresnetv2_model.py CHANGED
@@ -11,7 +11,7 @@ class smp_model(nn.Module):
11
  def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
  super(smp_model, self).__init__()
13
  self.model = smp.Unet(
14
- encoder_name=model_type,
15
  encoder_weights=encoder_weights,
16
  in_channels=in_channels,
17
  classes=num_classes,
 
11
  def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
  super(smp_model, self).__init__()
13
  self.model = smp.Unet(
14
+ encoder_name="inceptionresnetv2",
15
  encoder_weights=encoder_weights,
16
  in_channels=in_channels,
17
  classes=num_classes,
src/model_downloader.py CHANGED
@@ -8,11 +8,11 @@ from tqdm.auto import tqdm
8
  class ModelDownloader:
9
  def __init__(self):
10
  # Create models directory for caching
11
- self.models_dir = Path("/app/models")
12
  self.models_dir.mkdir(exist_ok=True)
13
 
14
  # HuggingFace model repository details
15
- self.hf_model_url = "https://huggingface.co/harshinde/Sims/resolve/main/models/"
16
 
17
  # Model mapping with file names
18
  self.model_files = {
@@ -26,47 +26,47 @@ class ModelDownloader:
26
  },
27
  "efficientnetb0": {
28
  "file": "efficientnetb0.pth",
29
- "url": f"{self.hf_model_url}efficientnetb0.pth"
30
  },
31
  "inceptionresnetv2": {
32
  "file": "inceptionresnetv2.pth",
33
- "id": "inceptionresnetv2"
34
  },
35
  "inceptionv4": {
36
  "file": "inceptionv4.pth",
37
- "id": "inceptionv4"
38
  },
39
  "mitb1": {
40
  "file": "mitb1.pth",
41
- "id": "mitb1"
42
  },
43
  "mobilenetv2": {
44
  "file": "mobilenetv2.pth",
45
- "id": "mobilenetv2"
46
  },
47
  "resnet34": {
48
  "file": "resnet34.pth",
49
- "id": "resnet34"
50
  },
51
  "resnext50_32x4d": {
52
  "file": "resnext50-32x4d.pth",
53
- "id": "resnext50-32x4d"
54
  },
55
  "se_resnet50": {
56
  "file": "se_resnet50.pth",
57
- "id": "se_resnet50"
58
  },
59
  "se_resnext50_32x4d": {
60
  "file": "se_resnext50_32x4d.pth",
61
- "id": "se_resnext50_32x4d"
62
  },
63
  "segformer": {
64
  "file": "segformer.pth",
65
- "id": "segformer"
66
  },
67
  "vgg16": {
68
  "file": "vgg16.pth",
69
- "id": "vgg16"
70
  }
71
  }
72
 
@@ -86,7 +86,15 @@ class ModelDownloader:
86
 
87
  if not model_path.exists():
88
  print(f"Downloading {model_name} model...")
89
- response = requests.get(model_info['url'], stream=True)
 
 
 
 
 
 
 
 
90
  response.raise_for_status()
91
 
92
  total_size = int(response.headers.get('content-length', 0))
@@ -102,55 +110,6 @@ class ModelDownloader:
102
  print(f"Model downloaded successfully to {model_path}")
103
 
104
  return str(model_path)
105
-
106
- # If model already exists, return path
107
- if model_path.exists():
108
- return str(model_path)
109
-
110
- # Construct download URL for the specific model
111
- download_url = f"{self.kaggle_model_url}/{model_info['id']}/1"
112
-
113
- try:
114
- st.info(f"Downloading model {model_name} from Kaggle Models...")
115
- progress_bar = st.progress(0)
116
-
117
- # Download with progress
118
- response = requests.get(download_url, stream=True)
119
- response.raise_for_status()
120
-
121
- total_size = int(response.headers.get('content-length', 0))
122
- block_size = 1024 # 1 Kibibyte
123
-
124
- with open(model_path, 'wb') as f:
125
- for i, data in enumerate(response.iter_content(block_size)):
126
- progress_bar.progress(min(i * block_size / total_size, 1.0))
127
- f.write(data)
128
-
129
- st.success(f"Successfully downloaded {model_name}")
130
- return str(model_path)
131
-
132
- except requests.exceptions.RequestException as e:
133
- raise Exception(f"Failed to download model from Kaggle: {str(e)}")
134
-
135
- def get_model_path(self, model_name):
136
- """
137
- Get the path for a model file, downloading it from Kaggle if necessary
138
- Args:
139
- model_name (str): Name of the model (e.g., 'deeplabv3plus', 'densenet121', etc.)
140
- Returns:
141
- str: Path to the model file
142
- """
143
- if model_name not in self.model_files:
144
- raise ValueError(f"Model {model_name} not found. Available models: {list(self.model_files.keys())}")
145
-
146
- model_info = self.model_files[model_name]
147
- model_path = self.models_dir / model_info['file']
148
-
149
- # If model doesn't exist locally, download it
150
- if not model_path.exists():
151
- return self.download_from_kaggle(model_name)
152
-
153
- return str(model_path)
154
 
155
  def list_available_models(self):
156
  """
 
8
  class ModelDownloader:
9
  def __init__(self):
10
  # Create models directory for caching
11
+ self.models_dir = Path("models").resolve()
12
  self.models_dir.mkdir(exist_ok=True)
13
 
14
  # HuggingFace model repository details
15
+ self.hf_model_url = "https://huggingface.co/harshinde/DeepSlide_Models/resolve/main/"
16
 
17
  # Model mapping with file names
18
  self.model_files = {
 
26
  },
27
  "efficientnetb0": {
28
  "file": "efficientnetb0.pth",
29
+ "url": f"{self.hf_model_url}effucientnetb0.pth"
30
  },
31
  "inceptionresnetv2": {
32
  "file": "inceptionresnetv2.pth",
33
+ "url": f"{self.hf_model_url}inceptionresnetv2.pth"
34
  },
35
  "inceptionv4": {
36
  "file": "inceptionv4.pth",
37
+ "url": f"{self.hf_model_url}inceptionv4.pth"
38
  },
39
  "mitb1": {
40
  "file": "mitb1.pth",
41
+ "url": f"{self.hf_model_url}mitb1.pth"
42
  },
43
  "mobilenetv2": {
44
  "file": "mobilenetv2.pth",
45
+ "url": f"{self.hf_model_url}mobilenetv2.pth"
46
  },
47
  "resnet34": {
48
  "file": "resnet34.pth",
49
+ "url": f"{self.hf_model_url}resnet34.pth"
50
  },
51
  "resnext50_32x4d": {
52
  "file": "resnext50-32x4d.pth",
53
+ "url": f"{self.hf_model_url}resnext50-32x4d.pth"
54
  },
55
  "se_resnet50": {
56
  "file": "se_resnet50.pth",
57
+ "url": f"{self.hf_model_url}se_resnet50.pth"
58
  },
59
  "se_resnext50_32x4d": {
60
  "file": "se_resnext50_32x4d.pth",
61
+ "url": f"{self.hf_model_url}se_resnext50_32x4d.pth"
62
  },
63
  "segformer": {
64
  "file": "segformer.pth",
65
+ "url": f"{self.hf_model_url}segformer.pth"
66
  },
67
  "vgg16": {
68
  "file": "vgg16.pth",
69
+ "url": f"{self.hf_model_url}vgg16.pth"
70
  }
71
  }
72
 
 
86
 
87
  if not model_path.exists():
88
  print(f"Downloading {model_name} model...")
89
+ # Use 'url' if available, otherwise fallback or error (logic simplified for now as per plan)
90
+ if 'url' in model_info:
91
+ url = model_info['url']
92
+ else:
93
+ # Fallback for models without explicit URL in the map (though all seem to have it or use ID)
94
+ # Assuming the pattern from init for others
95
+ url = f"{self.hf_model_url}{model_info['file']}"
96
+
97
+ response = requests.get(url, stream=True)
98
  response.raise_for_status()
99
 
100
  total_size = int(response.headers.get('content-length', 0))
 
110
  print(f"Model downloaded successfully to {model_path}")
111
 
112
  return str(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  def list_available_models(self):
115
  """
src/streamlit_app.py CHANGED
@@ -1,10 +1,17 @@
1
  import streamlit as st
 
 
 
 
 
 
2
  import h5py
3
  import torch
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  import yaml
7
  import os
 
8
 
9
  # Import models
10
  from src.mobilenetv2_model import LandslideModel as MobileNetV2Model
@@ -15,11 +22,12 @@ from src.mitb1_model import LandslideModel as MiTB1Model
15
  from src.inceptionv4_model import LandslideModel as InceptionV4Model
16
  from src.densenet121_model import LandslideModel as DenseNet121Model
17
  from src.deeplabv3plus_model import LandslideModel as DeepLabV3PlusModel
18
- from src.resnext50_32x4d_model import LandslideModel as ResNeXt50_32X4DModel
19
  from src.se_resnet50_model import LandslideModel as SEResNet50Model
20
- from src.se_resnext50_32x4d_model import LandslideModel as SEResNeXt50_32X4DModel
21
  from src.segformer_model import LandslideModel as SegFormerB2Model
22
  from src.inceptionresnetv2_model import LandslideModel as InceptionResNetV2Model
 
23
 
24
  # Define available models
25
  AVAILABLE_MODELS = {
@@ -31,10 +39,10 @@ AVAILABLE_MODELS = {
31
  "inceptionv4": {"name": "InceptionV4", "type": "inception_v4"},
32
  "densenet121": {"name": "DenseNet121", "type": "densenet121"},
33
  "deeplabv3plus": {"name": "DeepLabV3Plus", "type": "deeplabv3plus"},
34
- "resnext50": {"name": "ResNeXt50", "type": "resnext50_32x4d"},
35
- "seresnet50": {"name": "SEResNet50", "type": "se_resnet50"},
36
- "seresnext50": {"name": "SEResNeXt50", "type": "se_resnext50_32x4d"},
37
- "segformerb2": {"name": "SegFormerB2", "type": "segformer_b2"},
38
  "inceptionresnetv2": {"name": "InceptionResNetV2", "type": "inception_resnet_v2"}
39
  }
40
 
@@ -43,13 +51,14 @@ MODEL_DESCRIPTIONS = {
43
  model_key: {
44
  "type": model_info["type"],
45
  "description": f"{model_info['name']} - A model for landslide detection and segmentation.",
46
- "name": model_info["name"]
 
47
  }
48
  for model_key, model_info in AVAILABLE_MODELS.items()
49
  }
50
 
51
  # Load the configuration file
52
- config = """
53
  model_config:
54
  model_type: "mobilenet_v2"
55
  in_channels: 14
@@ -74,9 +83,89 @@ train_config:
74
  device: "cuda:0"
75
  save_config: True
76
  experiment_name: "mobilenet_v2"
 
 
 
 
77
  """
78
 
79
- config = yaml.safe_load(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  # Streamlit app
82
  st.set_page_config(page_title="Landslide Detection", layout="wide")
@@ -94,10 +183,10 @@ st.markdown("""
94
  st.sidebar.title("Model Selection")
95
  model_option = st.sidebar.radio("Choose an option", ["Select a single model", "Run all models"])
96
 
 
97
  if model_option == "Select a single model":
98
  selected_model_key = st.sidebar.selectbox("Select Model", list(MODEL_DESCRIPTIONS.keys()))
99
  selected_model_info = MODEL_DESCRIPTIONS[selected_model_key]
100
- config['model_config']['model_type'] = selected_model_info['type']
101
 
102
  # Display model details in the sidebar
103
  st.sidebar.markdown("### Model Details")
@@ -127,7 +216,6 @@ if uploaded_files:
127
  with st.spinner('Classifying...'):
128
  try:
129
  # Read the file directly using BytesIO
130
- import io
131
  bytes_data = uploaded_file.getvalue()
132
  bytes_io = io.BytesIO(bytes_data)
133
 
@@ -139,413 +227,52 @@ if uploaded_files:
139
  data = np.array(hdf.get('img'))
140
  data[np.isnan(data)] = 0.000001
141
  channels = config["dataset_config"]["channels"]
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  image = np.zeros((128, 128, len(channels)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- for i, band in enumerate(channels):
145
- image[:, :, i] = data[band-1]
146
-
147
- selected_channels = [image[:, :, i] for i in range(3)]
148
- image = np.transpose(image, (2, 0, 1))
149
 
150
  if model_option == "Select a single model":
151
- # Get the model class from AVAILABLE_MODELS
152
- model_class_name = AVAILABLE_MODELS[selected_model_key]['name'].replace('-', '') + 'Model'
153
- model_class = locals()[model_class_name]
154
-
155
- # Initialize model downloader
156
- from model_downloader import ModelDownloader
157
- downloader = ModelDownloader()
158
-
159
- try:
160
- # Download/get model path
161
- model_path = downloader.download_model(selected_model_key)
162
- st.info(f"Using model from: {model_path}")
163
-
164
- # Load the model
165
- model = model_class(config)
166
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
167
- model.eval()
168
-
169
- # Make prediction
170
- with torch.no_grad():
171
- prediction = model(torch.from_numpy(image).unsqueeze(0).float())
172
- prediction = torch.sigmoid(prediction).numpy()
173
-
174
- st.header(f"Prediction Results - {selected_model_info['name']}")
175
-
176
- # Create columns for input image, prediction, and overlay
177
- col1, col2, col3 = st.columns(3)
178
-
179
- # Display input image
180
- with col1:
181
- st.write("Input Image")
182
- plt.figure(figsize=(8, 8))
183
- plt.imshow(selected_channels[0], cmap='viridis')
184
- plt.colorbar()
185
- plt.axis('off')
186
- st.pyplot(plt)
187
-
188
- # Display prediction
189
- with col2:
190
- st.write("Prediction")
191
- plt.figure(figsize=(8, 8))
192
- plt.imshow(prediction.squeeze(), cmap='viridis')
193
- plt.colorbar()
194
- plt.axis('off')
195
- st.pyplot(plt)
196
-
197
- # Display overlay
198
- with col3:
199
- st.write("Overlay")
200
- plt.figure(figsize=(8, 8))
201
- plt.imshow(selected_channels[0], cmap='viridis')
202
- plt.imshow(prediction.squeeze(), cmap='viridis', alpha=0.5)
203
- plt.colorbar()
204
- plt.axis('off')
205
- st.pyplot(plt)
206
-
207
- # Download button for prediction
208
- st.write(f"Download the prediction as a .npy file for {selected_model_info['name']}:")
209
- npy_data = prediction.squeeze()
210
- st.download_button(
211
- label=f"Download Prediction - {selected_model_info['name']}",
212
- data=npy_data.tobytes(),
213
- file_name=f"{uploaded_file.name.split('.')[0]}_{selected_model_key}_prediction.npy",
214
- mime="application/octet-stream"
215
- )
216
-
217
- except Exception as e:
218
- st.error(f"Error with model {selected_model_info['name']}: {str(e)}")
219
  else:
220
- # Process the image with each model
221
  for model_key, model_info in MODEL_DESCRIPTIONS.items():
222
- st.write(f"Using model: {model_info['name']}")
223
- config['model_config']['model_type'] = model_info['type']
224
-
225
- # Get the model class from AVAILABLE_MODELS
226
- model_class_name = AVAILABLE_MODELS[model_key]['name'].replace('-', '') + 'Model'
227
- model_class = locals()[model_class_name]
228
-
229
- # Initialize model downloader
230
- from model_downloader import ModelDownloader
231
- downloader = ModelDownloader()
232
-
233
- try:
234
- # Download/get model path
235
- model_path = downloader.download_model(model_key)
236
- st.info(f"Using model from: {model_path}")
237
-
238
- # Load the model
239
- model = model_class(config)
240
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
241
- model.eval()
242
-
243
- # Make prediction
244
- with torch.no_grad():
245
- prediction = model(torch.from_numpy(image).unsqueeze(0).float())
246
- prediction = torch.sigmoid(prediction).numpy()
247
-
248
- st.header(f"Prediction Results - {model_info['name']}")
249
-
250
- # Create columns for input image, prediction, and overlay
251
- col1, col2, col3 = st.columns(3)
252
-
253
- # Display input image
254
- with col1:
255
- st.write("Input Image")
256
- plt.figure(figsize=(8, 8))
257
- plt.imshow(selected_channels[0], cmap='viridis')
258
- plt.colorbar()
259
- plt.axis('off')
260
- st.pyplot(plt)
261
-
262
- # Display prediction
263
- with col2:
264
- st.write("Prediction")
265
- plt.figure(figsize=(8, 8))
266
- plt.imshow(prediction.squeeze(), cmap='viridis')
267
- plt.colorbar()
268
- plt.axis('off')
269
- st.pyplot(plt)
270
-
271
- # Display overlay
272
- with col3:
273
- st.write("Overlay")
274
- plt.figure(figsize=(8, 8))
275
- plt.imshow(selected_channels[0], cmap='viridis')
276
- plt.imshow(prediction.squeeze(), cmap='viridis', alpha=0.5)
277
- plt.colorbar()
278
- plt.axis('off')
279
- st.pyplot(plt)
280
-
281
- # Download button for prediction
282
- st.write(f"Download the prediction as a .npy file for {model_info['name']}:")
283
- npy_data = prediction.squeeze()
284
- st.download_button(
285
- label=f"Download Prediction - {model_info['name']}",
286
- data=npy_data.tobytes(),
287
- file_name=f"{uploaded_file.name.split('.')[0]}_{model_key}_prediction.npy",
288
- mime="application/octet-stream"
289
- )
290
-
291
- except Exception as e:
292
- st.error(f"Error with model {model_info['name']}: {str(e)}")
293
- continue
294
 
295
  except Exception as e:
296
  st.error(f"Error processing file {uploaded_file.name}: {str(e)}")
 
 
297
  continue
298
- import h5py
299
- import torch
300
- import numpy as np
301
- import matplotlib.pyplot as plt
302
- import yaml
303
- import os
304
-
305
- # Import models
306
- from src.mobilenetv2_model import LandslideModel as MobileNetV2Model
307
- from src.vgg16_model import LandslideModel as VGG16Model
308
- from src.resnet34_model import LandslideModel as ResNet34Model
309
- from src.efficientnetb0_model import LandslideModel as EfficientNetB0Model
310
- from src.mitb1_model import LandslideModel as MiTB1Model
311
- from src.inceptionv4_model import LandslideModel as InceptionV4Model
312
- from src.densenet121_model import LandslideModel as DenseNet121Model
313
- from src.deeplabv3plus_model import LandslideModel as DeepLabV3PlusModel
314
- from src.resnext50_32x4d_model import LandslideModel as ResNeXt50_32X4DModel
315
- from src.se_resnet50_model import LandslideModel as SEResNet50Model
316
- from src.se_resnext50_32x4d_model import LandslideModel as SEResNeXt50_32X4DModel
317
- from segformer_model import LandslideModel as SegFormerB2Model
318
- from inceptionresnetv2_model import LandslideModel as InceptionResNetV2Model
319
-
320
- # Define available models
321
- AVAILABLE_MODELS = {
322
- "mobilenetv2": {"name": "MobileNetV2", "type": "mobilenet_v2"},
323
- "vgg16": {"name": "VGG16", "type": "vgg16"},
324
- "resnet34": {"name": "ResNet34", "type": "resnet34"},
325
- "efficientnetb0": {"name": "EfficientNetB0", "type": "efficientnet_b0"},
326
- "mitb1": {"name": "MiTB1", "type": "mitb1"},
327
- "inceptionv4": {"name": "InceptionV4", "type": "inception_v4"},
328
- "densenet121": {"name": "DenseNet121", "type": "densenet121"},
329
- "deeplabv3plus": {"name": "DeepLabV3Plus", "type": "deeplabv3plus"},
330
- "resnext50": {"name": "ResNeXt50", "type": "resnext50_32x4d"},
331
- "seresnet50": {"name": "SEResNet50", "type": "se_resnet50"},
332
- "seresnext50": {"name": "SEResNeXt50", "type": "se_resnext50_32x4d"},
333
- "segformerb2": {"name": "SegFormerB2", "type": "segformer_b2"},
334
- "inceptionresnetv2": {"name": "InceptionResNetV2", "type": "inception_resnet_v2"}
335
- }
336
-
337
- # Load the configuration file
338
- config = """
339
- model_config:
340
- model_type: "mobilenet_v2"
341
- in_channels: 14
342
- num_classes: 1
343
- encoder_weights: "imagenet"
344
- wce_weight: 0.5
345
-
346
- dataset_config:
347
- num_classes: 1
348
- num_channels: 14
349
- channels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
350
- normalize: False
351
-
352
- train_config:
353
- dataset_path: ""
354
- checkpoint_path: "checkpoints"
355
- seed: 42
356
- train_val_split: 0.8
357
- batch_size: 16
358
- num_epochs: 100
359
- lr: 0.001
360
- device: "cuda:0"
361
- save_config: True
362
- experiment_name: "mobilenet_v2"
363
-
364
- logging_config:
365
- wandb_project: "l4s"
366
- wandb_entity: "Silvamillion"
367
- """
368
-
369
- config = yaml.safe_load(config)
370
-
371
- # Model descriptions with their respective types and descriptions
372
- MODEL_DESCRIPTIONS = {
373
- model_key: {
374
- "type": model_info["type"],
375
- "description": f"{model_info['name']} - A model for landslide detection and segmentation.",
376
- "name": model_info["name"]
377
- }
378
- for model_key, model_info in AVAILABLE_MODELS.items()
379
- }
380
-
381
- # Streamlit app
382
- st.set_page_config(page_title="Landslide Detection", layout="wide")
383
-
384
- st.title("Landslide Detection")
385
- st.markdown("""
386
- ## Instructions
387
- 1. Select a model from the sidebar or choose to run all models.
388
- 2. Upload one or more `.h5` files.
389
- 3. The app will process the files and display the input image, prediction, and overlay.
390
- 4. You can download the prediction results.
391
- """)
392
-
393
- # Sidebar for model selection
394
- st.sidebar.title("Model Selection")
395
- model_option = st.sidebar.radio("Choose an option", ["Select a single model", "Run all models"])
396
- if model_option == "Select a single model":
397
- selected_model = st.sidebar.selectbox("Select Model", list(MODEL_DESCRIPTIONS.keys()))
398
- config['model_config']['model_type'] = MODEL_DESCRIPTIONS[selected_model]['type']
399
-
400
- # Display model details in the sidebar
401
- st.sidebar.markdown(f"**Model Name:** {MODEL_DESCRIPTIONS[selected_model]['name']}")
402
- st.sidebar.markdown(f"**Model Type:** {MODEL_DESCRIPTIONS[selected_model]['type']}")
403
- st.sidebar.markdown(f"**Description:** {MODEL_DESCRIPTIONS[selected_model]['description']}")
404
-
405
- # Main content
406
- st.header("Upload Data")
407
-
408
- # Initialize session state for error tracking if not exists
409
- if 'upload_errors' not in st.session_state:
410
- st.session_state.upload_errors = []
411
-
412
- uploaded_files = st.file_uploader(
413
- "Choose .h5 files...",
414
- type="h5",
415
- accept_multiple_files=True,
416
- help="Upload your .h5 files here. Maximum file size is 200MB."
417
- )
418
-
419
- if uploaded_files:
420
- for uploaded_file in uploaded_files:
421
- st.write(f"Processing file: {uploaded_file.name}")
422
-
423
- # Display file details for debugging
424
- st.write(f"File size: {uploaded_file.size} bytes")
425
-
426
- with st.spinner('Classifying...'):
427
- try:
428
- # Read the file directly using BytesIO
429
- import io
430
- bytes_data = uploaded_file.getvalue()
431
- bytes_io = io.BytesIO(bytes_data)
432
-
433
- with h5py.File(bytes_io, 'r') as hdf:
434
- # Check if 'img' exists in the file
435
- if 'img' not in hdf:
436
- st.error(f"Error: 'img' dataset not found in {uploaded_file.name}")
437
- continue
438
-
439
- data = np.array(hdf.get('img'))
440
- data[np.isnan(data)] = 0.000001
441
- channels = config["dataset_config"]["channels"]
442
- image = np.zeros((128, 128, len(channels)))
443
-
444
- except h5py.Error as he:
445
- st.error(f"H5PY Error processing {uploaded_file.name}: {str(he)}")
446
- continue
447
- except Exception as e:
448
- st.error(f"Error processing {uploaded_file.name}: {str(e)}")
449
- continue
450
- for i, channel in enumerate(channels):
451
- image[:, :, i] = data[:, :, channel-1]
452
-
453
- # Transform the image to the required format
454
- image = image.transpose((2, 0, 1)) # (H, W, C) to (C, H, W)
455
- image = torch.from_numpy(image).float().unsqueeze(0) # Add batch dimension
456
-
457
- if model_option == "Select a single model":
458
- # Process the image with the selected model
459
- st.write(f"Using model: {model_type}")
460
-
461
- # Load the model
462
- model = model_class(config)
463
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
464
- model.eval()
465
-
466
- # Make prediction
467
- with torch.no_grad():
468
- prediction = model(image)
469
- prediction = torch.sigmoid(prediction).cpu().numpy()
470
-
471
- # Display prediction
472
- st.header(f"Prediction Results - {model_type}")
473
- fig, ax = plt.subplots(1, 3, figsize=(15, 5))
474
- img = image.squeeze().permute(1, 2, 0).numpy()
475
- img = (img - img.min()) / (img.max() - img.min()) # Normalize the image to [0, 1] range for display
476
- ax[0].imshow(img[:, :, 1:4]) # Display first three channels as RGB
477
- ax[0].set_title("Input Image")
478
- ax[1].imshow(prediction.squeeze() > 0.5, cmap='plasma') # Apply threshold
479
- ax[1].set_title("Prediction")
480
- ax[2].imshow(img[:, :, :3]) # Display first three channels as RGB
481
- ax[2].imshow(prediction.squeeze() > 0.5, cmap='plasma', alpha=0.3) # Overlay prediction
482
- ax[2].set_title("Overlay")
483
- st.pyplot(fig)
484
-
485
- # Option to download the prediction
486
- st.write(f"Download the prediction as a .npy file for {model_type}:")
487
- npy_data = prediction.squeeze()
488
- st.download_button(
489
- label=f"Download Prediction - {model_type}",
490
- data=npy_data.tobytes(),
491
- file_name=f"{uploaded_file.name.split('.')[0]}_{model_type}_prediction.npy",
492
- mime="application/octet-stream"
493
- )
494
-
495
- else:
496
- # Process the image with each model
497
- for model_key, model_info in MODEL_DESCRIPTIONS.items():
498
- st.write(f"Using model: {model_info['name']}")
499
- config['model_config']['model_type'] = model_info['type']
500
-
501
- # Get the model class from AVAILABLE_MODELS
502
- model_class_name = AVAILABLE_MODELS[model_key]['name'].replace('-', '') + 'Model'
503
- model_class = locals()[model_class_name]
504
-
505
- # Initialize model downloader
506
- from model_downloader import ModelDownloader
507
- downloader = ModelDownloader()
508
-
509
- try:
510
- # Download/get model path
511
- model_path = downloader.download_model(model_name.lower())
512
- st.info(f"Using model from: {model_path}")
513
-
514
- # Load the model
515
- model = model_class(config)
516
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
517
- model.eval()
518
- except Exception as e:
519
- st.error(f"Error loading model {model_name}: {str(e)}")
520
- continue
521
-
522
- # Make prediction
523
- with torch.no_grad():
524
- prediction = model(image)
525
- prediction = torch.sigmoid(prediction).cpu().numpy()
526
-
527
- # Display prediction
528
- st.header(f"Prediction Results - {model_name}")
529
- fig, ax = plt.subplots(1, 3, figsize=(15, 5))
530
- img = image.squeeze().permute(1, 2, 0).numpy()
531
- img = (img - img.min()) / (img.max() - img.min()) # Normalize the image to [0, 1] range for display
532
- ax[0].imshow(img[:, :, :3]) # Display first three channels as RGB
533
- ax[0].set_title("Input Image")
534
- ax[1].imshow(prediction.squeeze() > 0.5, cmap='plasma') # Apply threshold
535
- ax[1].set_title("Prediction")
536
- ax[2].imshow(img[:, :, :3]) # Display first three channels as RGB
537
- ax[2].imshow(prediction.squeeze() > 0.5, cmap='plasma', alpha=0.3) # Overlay prediction
538
- ax[2].set_title("Overlay")
539
- st.pyplot(fig)
540
-
541
- # Option to download the prediction
542
- st.write(f"Download the prediction as a .npy file for {model_name}:")
543
- npy_data = prediction.squeeze()
544
- st.download_button(
545
- label=f"Download Prediction - {model_name}",
546
- data=npy_data.tobytes(),
547
- file_name=f"{uploaded_file.name.split('.')[0]}_{model_name}_prediction.npy",
548
- mime="application/octet-stream"
549
- )
550
 
551
  st.success('Done!')
 
1
  import streamlit as st
2
+ import sys
3
+ import os
4
+
5
+ # Add the parent directory to sys.path to allow imports from 'src'
6
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7
+
8
  import h5py
9
  import torch
10
  import numpy as np
11
  import matplotlib.pyplot as plt
12
  import yaml
13
  import os
14
+ import io
15
 
16
  # Import models
17
  from src.mobilenetv2_model import LandslideModel as MobileNetV2Model
 
22
  from src.inceptionv4_model import LandslideModel as InceptionV4Model
23
  from src.densenet121_model import LandslideModel as DenseNet121Model
24
  from src.deeplabv3plus_model import LandslideModel as DeepLabV3PlusModel
25
+ from src.resnext50_32x4d_model import LandslideModel as ResNeXt50Model
26
  from src.se_resnet50_model import LandslideModel as SEResNet50Model
27
+ from src.se_resnext50_32x4d_model import LandslideModel as SEResNeXt50Model
28
  from src.segformer_model import LandslideModel as SegFormerB2Model
29
  from src.inceptionresnetv2_model import LandslideModel as InceptionResNetV2Model
30
+ from src.model_downloader import ModelDownloader
31
 
32
  # Define available models
33
  AVAILABLE_MODELS = {
 
39
  "inceptionv4": {"name": "InceptionV4", "type": "inception_v4"},
40
  "densenet121": {"name": "DenseNet121", "type": "densenet121"},
41
  "deeplabv3plus": {"name": "DeepLabV3Plus", "type": "deeplabv3plus"},
42
+ "resnext50": {"name": "ResNeXt50", "type": "resnext50_32x4d", "downloader_key": "resnext50_32x4d"},
43
+ "seresnet50": {"name": "SEResNet50", "type": "se_resnet50", "downloader_key": "se_resnet50"},
44
+ "seresnext50": {"name": "SEResNeXt50", "type": "se_resnext50_32x4d", "downloader_key": "se_resnext50_32x4d"},
45
+ "segformerb2": {"name": "SegFormerB2", "type": "segformer_b2", "downloader_key": "segformer"},
46
  "inceptionresnetv2": {"name": "InceptionResNetV2", "type": "inception_resnet_v2"}
47
  }
48
 
 
51
  model_key: {
52
  "type": model_info["type"],
53
  "description": f"{model_info['name']} - A model for landslide detection and segmentation.",
54
+ "name": model_info["name"],
55
+ "downloader_key": model_info.get("downloader_key", model_key)
56
  }
57
  for model_key, model_info in AVAILABLE_MODELS.items()
58
  }
59
 
60
  # Load the configuration file
61
+ config_str = """
62
  model_config:
63
  model_type: "mobilenet_v2"
64
  in_channels: 14
 
83
  device: "cuda:0"
84
  save_config: True
85
  experiment_name: "mobilenet_v2"
86
+
87
+ logging_config:
88
+ wandb_project: "l4s"
89
+ wandb_entity: "Silvamillion"
90
  """
91
 
92
+ config = yaml.safe_load(config_str)
93
+
94
+ def process_and_visualize(model_key, model_info, image_tensor, original_image, uploaded_file_name):
95
+ """
96
+ Process the image with the selected model and visualize results.
97
+ """
98
+ try:
99
+ st.write(f"Using model: {model_info['name']}")
100
+
101
+ # Update config for the specific model
102
+ current_config = config.copy()
103
+ current_config['model_config']['model_type'] = model_info['type']
104
+
105
+ # Get the model class
106
+ model_class_name = AVAILABLE_MODELS[model_key]['name'].replace('-', '') + 'Model'
107
+ if model_class_name not in globals():
108
+ # Fallback for naming inconsistencies if any
109
+ # Try to find it in globals
110
+ pass
111
+ model_class = globals()[model_class_name]
112
+
113
+ # Initialize model downloader
114
+ downloader = ModelDownloader()
115
+
116
+ # Download/get model path
117
+ download_key = model_info.get('downloader_key', model_key)
118
+ model_path = downloader.download_model(download_key)
119
+ st.info(f"Using model from: {model_path}")
120
+
121
+ # Load the model
122
+ model = model_class(current_config)
123
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
124
+ model.eval()
125
+
126
+ # Make prediction
127
+ with torch.no_grad():
128
+ prediction = model(image_tensor)
129
+ prediction = torch.sigmoid(prediction).cpu().numpy()
130
+
131
+ # Display prediction
132
+ st.header(f"Prediction Results - {model_info['name']}")
133
+ fig, ax = plt.subplots(1, 3, figsize=(15, 5))
134
+
135
+ # Normalize image for display
136
+ img_display = original_image.transpose(1, 2, 0) # (C, H, W) -> (H, W, C)
137
+ img_display = (img_display - img_display.min()) / (img_display.max() - img_display.min())
138
+
139
+ ax[0].imshow(img_display[:, :, :3]) # Display first three channels as RGB
140
+ ax[0].set_title("Input Image")
141
+ ax[0].axis('off')
142
+
143
+ ax[1].imshow(prediction.squeeze(), cmap='plasma') # Raw prediction map
144
+ ax[1].set_title("Prediction Probability")
145
+ ax[1].axis('off')
146
+
147
+ ax[2].imshow(img_display[:, :, :3])
148
+ ax[2].imshow(prediction.squeeze() > 0.5, cmap='plasma', alpha=0.4) # Overlay
149
+ ax[2].set_title("Overlay (Threshold > 0.5)")
150
+ ax[2].axis('off')
151
+
152
+ st.pyplot(fig)
153
+ plt.close(fig)
154
+
155
+ # Download button
156
+ st.write(f"Download the prediction as a .npy file for {model_info['name']}:")
157
+ npy_data = prediction.squeeze()
158
+ st.download_button(
159
+ label=f"Download Prediction - {model_info['name']}",
160
+ data=npy_data.tobytes(),
161
+ file_name=f"{uploaded_file_name.split('.')[0]}_{model_key}_prediction.npy",
162
+ mime="application/octet-stream"
163
+ )
164
+
165
+ except Exception as e:
166
+ st.error(f"Error with model {model_info['name']}: {str(e)}")
167
+ import traceback
168
+ st.error(traceback.format_exc())
169
 
170
  # Streamlit app
171
  st.set_page_config(page_title="Landslide Detection", layout="wide")
 
183
  st.sidebar.title("Model Selection")
184
  model_option = st.sidebar.radio("Choose an option", ["Select a single model", "Run all models"])
185
 
186
+ selected_model_key = None
187
  if model_option == "Select a single model":
188
  selected_model_key = st.sidebar.selectbox("Select Model", list(MODEL_DESCRIPTIONS.keys()))
189
  selected_model_info = MODEL_DESCRIPTIONS[selected_model_key]
 
190
 
191
  # Display model details in the sidebar
192
  st.sidebar.markdown("### Model Details")
 
216
  with st.spinner('Classifying...'):
217
  try:
218
  # Read the file directly using BytesIO
 
219
  bytes_data = uploaded_file.getvalue()
220
  bytes_io = io.BytesIO(bytes_data)
221
 
 
227
  data = np.array(hdf.get('img'))
228
  data[np.isnan(data)] = 0.000001
229
  channels = config["dataset_config"]["channels"]
230
+
231
+ # Prepare image
232
+ # Assuming data shape is (14, 128, 128) based on typical satellite data or (128, 128, 14)
233
+ # The original code did: image[:, :, i] = data[band-1] implying data is (14, 128, 128) if accessed by index
234
+ # But later it did data[:, :, channel-1] in the else block?
235
+ # Let's check the original code logic again.
236
+ # Original code had two different logic blocks for data loading!
237
+ # Block 1 (single model): image[:, :, i] = data[band-1] -> implies data is (C, H, W)
238
+ # Block 2 (all models): image[:, :, i] = data[:, :, channel-1] -> implies data is (H, W, C)
239
+
240
+ # I will assume (C, H, W) is more standard for HDF5 'img' usually, but let's try to be robust or pick one.
241
+ # Given the inconsistency, I'll check data shape.
242
+
243
  image = np.zeros((128, 128, len(channels)))
244
+
245
+ if data.ndim == 3:
246
+ if data.shape[0] == 14: # (C, H, W)
247
+ for i, band in enumerate(channels):
248
+ image[:, :, i] = data[band-1, :, :]
249
+ elif data.shape[2] == 14: # (H, W, C)
250
+ for i, band in enumerate(channels):
251
+ image[:, :, i] = data[:, :, band-1]
252
+ else:
253
+ st.warning(f"Unexpected data shape: {data.shape}. Assuming (C, H, W).")
254
+ for i, band in enumerate(channels):
255
+ if band-1 < data.shape[0]:
256
+ image[:, :, i] = data[band-1, :, :]
257
+ else:
258
+ st.error(f"Data has {data.ndim} dimensions, expected 3.")
259
+ continue
260
 
261
+ # Prepare for model (Batch, Channel, Height, Width)
262
+ # image is currently (H, W, C)
263
+ image_display = image.transpose(2, 0, 1) # (C, H, W)
264
+ image_tensor = torch.from_numpy(image_display).unsqueeze(0).float() # (1, C, H, W)
 
265
 
266
  if model_option == "Select a single model":
267
+ process_and_visualize(selected_model_key, selected_model_info, image_tensor, image_display, uploaded_file.name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  else:
 
269
  for model_key, model_info in MODEL_DESCRIPTIONS.items():
270
+ process_and_visualize(model_key, model_info, image_tensor, image_display, uploaded_file.name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  except Exception as e:
273
  st.error(f"Error processing file {uploaded_file.name}: {str(e)}")
274
+ import traceback
275
+ st.error(traceback.format_exc())
276
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  st.success('Done!')