{ "cells": [ { "cell_type": "markdown", "id": "f541ffd4", "metadata": {}, "source": [ "# Synthetic High-Resolution DEM Generation for Marrakech, Morocco\n", "# Using Only McKinley Dataset for Training\n", "\n", "This notebook implements the full pipeline, training only on the McKinley dataset to generate a model for super-resolving 30m SRTM to 10m DEMs fused with Sentinel-2 imagery for Marrakech, Morocco.\n", "\n", "**Key Assumptions:**\n", "- Training on McKinley Mine NM high-res LiDAR DEM.\n", "- Inference on Marrakech mountain area.\n", "- Adapted DeepDEM model with 7 input channels.\n", "\n", "Run cells sequentially." ] }, { "cell_type": "code", "execution_count": null, "id": "b7aa9465", "metadata": {}, "outputs": [], "source": [ "# Cell 1: Install Dependencies\n", "!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n", "!pip install pytorch-lightning torchgeo segmentation-models-pytorch rasterio geopandas albumentations scipy gdown earthengine-api\n", "!apt-get install -y libspatialindex-dev libgdal-dev\n", "!pip install gdal==$(gdal-config --version)" ] }, { "cell_type": "code", "execution_count": null, "id": "c4f399ac", "metadata": {}, "outputs": [], "source": [ "# Cell 2: Mount Google Drive and Set Up Directories\n", "from google.colab import drive\n", "drive.mount('/content/drive')\n", "%cd /content/drive/MyDrive/DEM_Project\n", "!mkdir -p Training_Data/McKinley Inference_Data/Marrackech Models" ] }, { "cell_type": "code", "execution_count": null, "id": "ed508a36", "metadata": {}, "outputs": [], "source": [ "# Cell 3: Clone DeepDEM Repo and Adapt for Our Use Case\n", "!git clone https://github.com/uw-cryo/DeepDEM.git\n", "%cd DeepDEM\n", "\n", "# Adapt model for our inputs: Modify task_module.py to accept ['dsm', 'ortho_r', 'ortho_g', 'ortho_b', 'ortho_nir', 'ndvi', 'nodata_mask'] (7 channels)\n", "# Set model in_channels=7, out_channels=1 (residuals)\n", "# For simplicity, assume manual edit or duplicate code here.\n", "\n", "import os\n", "os.environ['PYTHONPATH'] += ':/content/drive/MyDrive/DEM_Project/DeepDEM'" ] }, { "cell_type": "code", "execution_count": null, "id": "d7bb1f40", "metadata": {}, "outputs": [], "source": [ "# Cell 4: Authenticate and Initialize Earth Engine\n", "from google.colab import auth\n", "import ee\n", "\n", "# 1. Authenticate your Google user\n", "auth.authenticate_user()\n", "\n", "# 2. Initialize Earth Engine with your Google Cloud Project ID\n", "# REPLACE 'your-gcp-project-id' with the actual ID of your project\n", "try:\n", " ee.Initialize(project='dem-collab')\n", " print(\"Earth Engine initialized successfully!\")\n", "except ee.EEException as e:\n", " print(f\"Error during initialization: {e}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "091b3f03", "metadata": {}, "outputs": [], "source": [ "# Cell 5: Data Acquisition Function (SRTM + Sentinel-2 for McKinley and Marrakech)\n", "import os\n", "\n", "def fetch_gee_data(bbox, output_dir, dataset_name):\n", " os.makedirs(output_dir, exist_ok=True)\n", "\n", " geom = ee.Geometry.BBox(*bbox) # Define geometry for region\n", "\n", " # SRTM (30m)\n", " srtm = ee.Image('CGIAR/SRTM90_V4').clip(geom).rename('dsm')\n", " task_srtm = ee.batch.Export.image.toDrive(\n", " image=srtm,\n", " description=f'{dataset_name}_srtm',\n", " folder=output_dir.split('/')[-1],\n", " scale=30,\n", " fileFormat='GeoTIFF',\n", " region=geom\n", " )\n", " task_srtm.start()\n", "\n", " # Sentinel-2 (10m, cloud-free median, RGB + NIR)\n", " # Sharper image: sort by cloud cover and take the best one from a good season\n", " collection = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') \\\n", " .filterBounds(geom) \\\n", " .filterDate('2023-06-01', '2023-10-31') \\\n", " .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 10)) \\\n", " .sort('CLOUDY_PIXEL_PERCENTAGE')\n", "\n", " num_images = collection.size().getInfo()\n", " print(f'For {dataset_name} bbox: {num_images} Sentinel-2 images found.')\n", "\n", " # Get the best image\n", " best_image = collection.first()\n", "\n", " # Export the raw 4-band image for the model\n", " sentinel_raw = best_image.select(['B4','B3','B2','B8'])\n", " task_s2_raw = ee.batch.Export.image.toDrive(\n", " image=sentinel_raw,\n", " description=f'{dataset_name}_sentinel', # Keep original name for downstream tasks\n", " folder=output_dir.split('/')[-1],\n", " scale=10,\n", " fileFormat='GeoTIFF',\n", " region=geom\n", " )\n", " task_s2_raw.start()\n", "\n", " # Export a separate, color-corrected visual version for inspection\n", " sentinel_viz = best_image.visualize(min=0, max=3000, bands=['B4', 'B3', 'B2'])\n", " task_s2_viz = ee.batch.Export.image.toDrive(\n", " image=sentinel_viz,\n", " description=f'{dataset_name}_sentinel_viz',\n", " folder=output_dir.split('/')[-1],\n", " scale=10,\n", " fileFormat='GeoTIFF',\n", " region=geom\n", " )\n", " task_s2_viz.start()\n", "\n", "\n", "# Bounding boxes\n", "bbox_mckinley = [-109.03892074228675, 35.58282920746211, -108.87077846472735, 35.736434167381475]\n", "fetch_gee_data(bbox_mckinley, '/content/drive/MyDrive/DEM_Project/Training_Data/McKinley', 'mckinley')\n", "\n", "bbox_marrakech = [-8.1, 31.5, -7.9, 31.7]\n", "fetch_gee_data(bbox_marrakech, '/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech', 'marrakech')" ] }, { "cell_type": "code", "execution_count": null, "id": "5f9b0934", "metadata": {}, "outputs": [], "source": [ "# Cell 6: Download High-Resolution DEM Tiles and Merge (Only for McKinley)\n", "!pip install boto3 gdal retry\n", "!apt install gdal-bin\n", "import boto3\n", "import os\n", "import shutil\n", "from botocore import UNSIGNED\n", "from botocore.client import Config\n", "from retry import retry\n", "\n", "# S3 Configuration\n", "endpoint_url = 'https://opentopography.s3.sdsc.edu'\n", "client = boto3.client('s3', endpoint_url=endpoint_url, config=Config(signature_version=UNSIGNED))\n", "\n", "# Temp and data dirs\n", "temp_base = '/content/hr_temp'\n", "data_base = '/content/drive/MyDrive/DEM_Project/Training_Data'\n", "os.makedirs(temp_base, exist_ok=True)\n", "os.makedirs(data_base, exist_ok=True)\n", "\n", "# Only McKinley\n", "datasets = {'mckinley': 'NM23_McKinley'}\n", "folder_names = {'mckinley': 'McKinley'}\n", "\n", "@retry(tries=3, delay=2, backoff=2)\n", "def download_file_with_retry(bucket, key, filename):\n", " client.download_file(Bucket=bucket, Key=key, Filename=filename)\n", "\n", "!df -h /content\n", "for local_name, s3_dir in datasets.items():\n", " temp_dir = os.path.join(temp_base, local_name)\n", " dataset_dir = os.path.join(data_base, folder_names[local_name])\n", " os.makedirs(temp_dir, exist_ok=True)\n", " os.makedirs(dataset_dir, exist_ok=True)\n", "\n", " output_tif = os.path.join(dataset_dir, f'{local_name}_hr_dem.tif')\n", " if os.path.exists(output_tif):\n", " print(f'Merged DEM already exists for {local_name}: {output_tif}, skipping download and merge.')\n", " continue\n", "\n", " paginator = client.get_paginator('list_objects_v2')\n", " prefix = f'{s3_dir}/{s3_dir}_be/'\n", " downloaded_files = []\n", " try:\n", " for page in paginator.paginate(Bucket='raster', Prefix=prefix):\n", " for obj in page.get('Contents', []):\n", " key = obj['Key']\n", " if key.endswith('.tif'):\n", " file_path = os.path.join(temp_dir, os.path.basename(key))\n", " # Check if the tile file already exists\n", " if os.path.exists(file_path):\n", " print(f'Tile {os.path.basename(key)} already exists for {local_name}, skipping download.')\n", " downloaded_files.append(file_path)\n", " continue\n", " try:\n", " download_file_with_retry('raster', key, file_path)\n", " downloaded_files.append(file_path)\n", " print(f'Downloaded {os.path.basename(key)} for {local_name}')\n", " except Exception as e:\n", " print(f'Error downloading {key} for {local_name}: {e}')\n", " except Exception as e:\n", " print(f'Error listing tiles for {local_name}: {e}')\n", "\n", " if not downloaded_files:\n", " print(f'No tiles downloaded for {local_name}; skipping merge.')\n", " continue\n", "\n", " try:\n", " # Using gdalbuildvrt and gdal_translate for better performance\n", " !gdalbuildvrt merged.vrt {\" \".join(downloaded_files)}\n", " !gdal_translate -of GTiff merged.vrt \"{output_tif}\" -co TILED=YES -co COMPRESS=DEFLATE -co NUM_THREADS=ALL_CPUS\n", " print(f'Merged tiles to TIFF for {local_name}: {output_tif}')\n", " except Exception as e:\n", " print(f'Error merging tiles for {local_name}: {e}')\n", " continue\n", " # Removed the shutil.rmtree(temp_dir) line as requested\n", "\n", "!df -h /content\n", "print('Download complete for McKinley!')" ] }, { "cell_type": "code", "execution_count": null, "id": "7f82f16c", "metadata": {}, "outputs": [], "source": [ "# Cell 7: Data Preprocessing (Only for McKinley and Marrakech)\n", "import rasterio\n", "from rasterio.enums import Resampling\n", "import numpy as np\n", "from scipy.ndimage import gaussian_filter\n", "import os\n", "\n", "folder_names = {'mckinley': 'McKinley', 'marrakech': 'Marrakech'}\n", "\n", "def preprocess_dataset(dataset_name, is_training=True, custom_base=None):\n", " if custom_base:\n", " base_dir = custom_base\n", " else:\n", " base_dir = f'/content/drive/MyDrive/DEM_Project/Training_Data/{folder_names[dataset_name]}'\n", " \n", " srtm_path = os.path.join(base_dir, f'{dataset_name}_srtm.tif')\n", " s2_path = os.path.join(base_dir, f'{dataset_name}_sentinel.tif')\n", " hr_path = os.path.join(base_dir, f'{dataset_name}_hr_dem.tif') if is_training else None\n", " output_dir = base_dir\n", "\n", " with rasterio.open(srtm_path) as srtm_src, rasterio.open(s2_path) as s2_src:\n", " target_shape = (s2_src.height, s2_src.width)\n", " srtm = srtm_src.read(1, out_shape=target_shape, resampling=Resampling.cubic)\n", "\n", " s2 = s2_src.read()\n", " r, g, b, nir = s2\n", "\n", " ndvi = (nir - r) / (nir + r + 1e-10)\n", "\n", " mask = np.where(srtm == srtm_src.nodata, 1, 0).astype(np.float32)\n", "\n", " if is_training:\n", " with rasterio.open(hr_path) as hr_src:\n", " hr = hr_src.read(1, out_shape=target_shape, resampling=Resampling.cubic) if hr_src.shape != target_shape else hr_src.read(1)\n", "\n", " trend = gaussian_filter(hr, sigma=5)\n", " residual = hr - trend\n", "\n", " target_profile = s2_src.profile\n", " target_profile['count'] = 1\n", " with rasterio.open(os.path.join(output_dir, 'target.tif'), 'w', **target_profile) as dst:\n", " dst.write(residual, 1)\n", "\n", " input_profile = s2_src.profile\n", " input_profile['count'] = 7\n", " with rasterio.open(os.path.join(output_dir, 'input.tif'), 'w', **input_profile) as dst:\n", " dst.write(srtm, 1)\n", " dst.write(r, 2)\n", " dst.write(g, 3)\n", " dst.write(b, 4)\n", " dst.write(nir, 5)\n", " dst.write(ndvi, 6)\n", " dst.write(mask, 7)\n", "\n", "# Preprocess McKinley\n", "preprocess_dataset('mckinley', is_training=True)\n", "\n", "# Preprocess Marrakech (no HR)\n", "preprocess_dataset('marrakech', is_training=False, custom_base='/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech')\n", "\n", "!df -h /content/drive/MyDrive" ] }, { "cell_type": "code", "execution_count": null, "id": "7933d058", "metadata": {}, "outputs": [], "source": [ "# Cell 8: Custom Dataset Class\n", "import albumentations as A\n", "from albumentations.pytorch import ToTensorV2\n", "from torch.utils.data import Dataset\n", "import rasterio.windows\n", "\n", "class CustomDEMDataset(Dataset):\n", " def __init__(self, data_dirs, tile_size=256, transform=None):\n", " self.pairs = []\n", " for d_dir in data_dirs:\n", " input_path = os.path.join(d_dir, 'input.tif')\n", " target_path = os.path.join(d_dir, 'target.tif')\n", " if os.path.exists(input_path) and os.path.exists(target_path):\n", " self.pairs.append((input_path, target_path))\n", " self.tile_size = tile_size\n", " self.transform = transform or A.Compose([\n", " A.RandomCrop(height=tile_size, width=tile_size),\n", " A.RandomRotate90(),\n", " A.HorizontalFlip(),\n", " A.VerticalFlip(),\n", " A.GaussNoise(var_limit=(0.01, 0.01)),\n", " ToTensorV2()\n", " ])\n", "\n", " def __len__(self):\n", " return len(self.pairs) * 50\n", "\n", " def __getitem__(self, idx):\n", " input_path, target_path = self.pairs[idx % len(self.pairs)]\n", " with rasterio.open(input_path) as inp, rasterio.open(target_path) as tgt:\n", " max_col = inp.width - self.tile_size\n", " max_row = inp.height - self.tile_size\n", " col_off = np.random.randint(0, max_col + 1)\n", " row_off = np.random.randint(0, max_row + 1)\n", " window = rasterio.windows.Window(col_off, row_off, self.tile_size, self.tile_size)\n", " input_data = inp.read(window=window)\n", " target_data = tgt.read(1, window=window)\n", "\n", " data = {'image': input_data.transpose(1,2,0).astype(np.float32), 'target': target_data.astype(np.float32)}\n", " augmented = self.transform(image=data['image'], mask=data['target'])\n", " return augmented['image'], augmented['mask'].unsqueeze(0)\n", "\n", "train_dirs = ['/content/drive/MyDrive/DEM_Project/Training_Data/McKinley']\n", "dataset = CustomDEMDataset(train_dirs)" ] }, { "cell_type": "code", "execution_count": null, "id": "46cf2c6d", "metadata": {}, "outputs": [], "source": [ "# Cell 9: Model and Training (Using Only McKinley)\n", "import pytorch_lightning as pl\n", "import segmentation_models_pytorch as smp\n", "from torch.utils.data import DataLoader\n", "import torch\n", "import torch.nn as nn\n", "\n", "class DeepDEMRefinement(pl.LightningModule):\n", " def __init__(self, lr=1e-4):\n", " super().__init__()\n", " self.model = smp.Unet(encoder_name='resnet34', in_channels=7, classes=1, activation=None)\n", " self.loss_fn = nn.L1Loss()\n", " self.lr = lr\n", "\n", " def forward(self, x):\n", " return self.model(x)\n", "\n", " def training_step(self, batch, batch_idx):\n", " inputs, targets = batch\n", " preds = self(inputs)\n", " loss = self.loss_fn(preds, targets)\n", " self.log('train_loss', loss)\n", " return loss\n", "\n", " def configure_optimizers(self):\n", " return torch.optim.Adam(self.parameters(), lr=self.lr)\n", "\n", "class DEMDataModule(pl.LightningDataModule):\n", " def __init__(self, train_dirs, batch_size=4):\n", " super().__init__()\n", " self.train_dataset = CustomDEMDataset(train_dirs)\n", " self.batch_size = batch_size\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2)\n", "\n", "model = DeepDEMRefinement()\n", "datamodule = DEMDataModule(train_dirs)\n", "trainer = pl.Trainer(max_epochs=5, accelerator='gpu', devices=1) # Training for 5 epochs\n", "trainer.fit(model, datamodule)\n", "\n", "trainer.save_checkpoint('/content/drive/MyDrive/DEM_Project/Models/deepdem_model.ckpt')" ] }, { "cell_type": "code", "execution_count": null, "id": "30872060", "metadata": {}, "outputs": [], "source": [ "# Cell 10: Inference for Marrakech\n", "model = DeepDEMRefinement.load_from_checkpoint('/content/drive/MyDrive/DEM_Project/Models/deepdem_model.ckpt')\n", "model.eval()\n", "model.to('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "input_path = '/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech/input.tif'\n", "with rasterio.open(input_path) as src:\n", " input_data = src.read().astype(np.float32)\n", " trend = gaussian_filter(input_data[0], sigma=5)\n", "\n", " input_tensor = torch.from_numpy(input_data).unsqueeze(0).to(model.device)\n", "\n", " with torch.no_grad():\n", " residual_pred = model(input_tensor)\n", "\n", " synth_dem = residual_pred.squeeze().cpu().numpy() + trend\n", "\n", " profile = src.profile\n", " profile['count'] = 1\n", " with rasterio.open('/content/drive/MyDrive/DEM_Project/synth_dem_marrakech.tif', 'w', **profile) as dst:\n", " dst.write(synth_dem, 1)\n", "\n", "print('Synthetic DEM generated for Marrakech!')" ] }, { "cell_type": "markdown", "id": "8943e470", "metadata": {}, "source": [ "# Quick correctness checks\n", "\n", "This section runs a few sanity checks on the trained model and data:\n", "\n", "- Validate shapes, CRS, and basic channel statistics of `input.tif` and `target.tif`\n", "- Compute masked MAE/RMSE on random training crops (McKinley) to gauge training fit\n", "- Flag obvious issues (e.g., all-zeros bands, nodata dominance)\n", "\n", "Run cells in order after training has completed." ] }, { "cell_type": "code", "execution_count": null, "id": "3c35167d", "metadata": {}, "outputs": [], "source": [ "# Check 1: Inspect input/target rasters (McKinley)\n", "import rasterio\n", "import numpy as np\n", "from pathlib import Path\n", "\n", "train_dir = Path('/content/drive/MyDrive/DEM_Project/Training_Data/McKinley')\n", "input_path = train_dir / 'input.tif'\n", "target_path = train_dir / 'target.tif'\n", "\n", "issues = []\n", "\n", "with rasterio.open(input_path) as src:\n", " print('INPUT:')\n", " print({'count': src.count, 'width': src.width, 'height': src.height, 'crs': str(src.crs), 'dtype': src.dtypes})\n", " data = src.read(out_dtype='float32')\n", " nodata = src.nodata\n", " band_stats = []\n", " for i in range(src.count):\n", " b = data[i]\n", " if nodata is not None:\n", " mask = b == nodata\n", " valid = np.where(mask, np.nan, b)\n", " else:\n", " valid = b\n", " mask = np.zeros_like(b, dtype=bool)\n", " s = {\n", " 'band': i+1,\n", " 'nan_frac': float(np.mean(np.isnan(valid))),\n", " 'nodata_frac': float(np.mean(mask)),\n", " 'min': float(np.nanmin(valid)),\n", " 'max': float(np.nanmax(valid)),\n", " 'mean': float(np.nanmean(valid)),\n", " 'std': float(np.nanstd(valid)),\n", " }\n", " band_stats.append(s)\n", " print('Input band stats (1:dsm, 2:R, 3:G, 4:B, 5:NIR, 6:NDVI, 7:mask):')\n", " for s in band_stats:\n", " print(s)\n", " # Basic checks\n", " if band_stats[5]['min'] < -1.01 or band_stats[5]['max'] > 1.01:\n", " issues.append('NDVI out of expected [-1,1] range; check scaling and bands (R,NIR indices).')\n", " if band_stats[6]['mean'] < 0.01 and band_stats[6]['max'] < 0.5:\n", " issues.append('Mask band appears mostly zeros; ensure mask=1 at nodata pixels, 0 elsewhere.')\n", "\n", "with rasterio.open(target_path) as src:\n", " print('\\nTARGET:')\n", " print({'count': src.count, 'width': src.width, 'height': src.height, 'crs': str(src.crs), 'dtype': src.dtypes})\n", " t = src.read(1, out_dtype='float32')\n", " print({'min': float(np.nanmin(t)), 'max': float(np.nanmax(t)), 'mean': float(np.nanmean(t)), 'std': float(np.nanstd(t))})\n", " if np.allclose(t, 0):\n", " issues.append('Target residual is all zeros; check HR DEM loading and detrending step.')\n", "\n", "print('\\nPotential issues:')\n", "print(issues if issues else 'None detected')" ] }, { "cell_type": "code", "execution_count": null, "id": "ad78e542", "metadata": {}, "outputs": [], "source": [ "# Check 2: Compute quick masked MAE/RMSE on random training crops\n", "import torch\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "import numpy as np\n", "\n", "# Reuse dataset and model classes already defined earlier\n", "try:\n", " _ = CustomDEMDataset\n", "except NameError:\n", " raise RuntimeError('CustomDEMDataset not defined; run earlier cells first.')\n", "\n", "try:\n", " _ = DeepDEMRefinement\n", "except NameError:\n", " raise RuntimeError('DeepDEMRefinement not defined; run training cells first.')\n", "\n", "# Load model\n", "ckpt = '/content/drive/MyDrive/DEM_Project/Models/deepdem_model.ckpt'\n", "model = DeepDEMRefinement.load_from_checkpoint(ckpt)\n", "model.eval()\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "model.to(device)\n", "\n", "# Small eval dataset with deterministic crops\n", "np.random.seed(42)\n", "transform = A.Compose([\n", " A.RandomCrop(height=256, width=256),\n", " ToTensorV2()\n", "])\n", "\n", "val_ds = CustomDEMDataset([str(train_dir)], tile_size=256, transform=transform)\n", "val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=2)\n", "\n", "maes, rmses = [], []\n", "with torch.no_grad():\n", " for i, (x, y) in enumerate(val_loader):\n", " if i >= 10: # ~40 tiles\n", " break\n", " x = x.to(device)\n", " y = y.to(device)\n", " pred = model(x)\n", " # If mask channel included, optionally down-weight masked pixels\n", " mask = x[:, 6:7] # channel 7\n", " valid = (mask < 0.5).float()\n", " diff = (pred - y) * valid\n", " denom = valid.sum().clamp_min(1.0)\n", " mae = diff.abs().sum() / denom\n", " rmse = torch.sqrt((diff.pow(2).sum() / denom))\n", " maes.append(mae.item())\n", " rmses.append(rmse.item())\n", "\n", "print({'MAE_mean': float(np.mean(maes)), 'MAE_std': float(np.std(maes)), 'RMSE_mean': float(np.mean(rmses)), 'RMSE_std': float(np.std(rmses)), 'tiles': len(maes)*val_loader.batch_size})\n", "\n", "if np.mean(rmses) > 8.0:\n", " print('Warning: High RMSE for residuals. Training may be underfit or target scaling may be off.')\n", "else:\n", " print('Residual error looks reasonable for the training run.')" ] }, { "cell_type": "code", "execution_count": null, "id": "def70a98", "metadata": {}, "outputs": [], "source": [ "# Check 3: Sanity-check inference output alignment vs input for Marrakech\n", "from scipy.ndimage import gaussian_filter\n", "\n", "marrakech_input = '/content/drive/MyDrive/DEM_Project/Inference_Data/Marrakech/input.tif'\n", "marrakech_out = '/content/drive/MyDrive/DEM_Project/synth_dem_marrakech.tif'\n", "\n", "with rasterio.open(marrakech_input) as src_in, rasterio.open(marrakech_out) as src_out:\n", " print('INFERENCE INPUT:', {'shape': (src_in.count, src_in.height, src_in.width), 'crs': str(src_in.crs), 'transform': tuple(src_in.transform)})\n", " print('SYNTH OUTPUT:', {'shape': (src_out.count, src_out.height, src_out.width), 'crs': str(src_out.crs), 'transform': tuple(src_out.transform)})\n", " if src_in.crs != src_out.crs:\n", " print('Warning: CRS mismatch between input and output!')\n", " if (src_in.height != src_out.height) or (src_in.width != src_out.width):\n", " print('Warning: Dimension mismatch between input and output!')\n", "\n", " out_dem = src_out.read(1).astype('float32')\n", " # Simple terrain sanity: residual-added trend should correlate with SRTM trend\n", " srtm = src_in.read(1).astype('float32')\n", " trend = gaussian_filter(srtm, sigma=5)\n", " corr = np.corrcoef(trend.flatten(), out_dem.flatten())[0,1]\n", " print('Correlation between SRTM trend and synthetic DEM:', float(corr))\n", " if corr < 0.5:\n", " print('Low correlation; output may be noisy or misaligned.')" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }