try:if data_ready and datamodule isnotNone: _ = trainer.test(model=task, datamodule=datamodule)else:print("Skipping test (data not prepared).")exceptExceptionas e:print("Test skipped due to error:", e)
Skipping test (data not prepared).
Optional: visualize a prediction
This cell runs a forward pass on the single test image and displays the predicted mask (argmax over classes). It does not modify model state.
Code
import numpy as nptry:if data_ready and datamodule isnotNone: dm = datamodule dm.setup("test") test_loader = dm.test_dataloader() batch =next(iter(test_loader)) images = batch[0]with torch.no_grad(): preds = task(images).output pred_mask = preds.argmax(dim=1)[0].cpu().numpy()print("Prediction mask unique labels:", np.unique(pred_mask).tolist())else:print("Skipping prediction visualization (data not prepared).")exceptExceptionas e:print("Prediction skipped due to error:", e)
Skipping prediction visualization (data not prepared).
Why this matters (reflection)
You can compose a full segmentation workflow by combining a pretrained backbone, a decoder, and a lightweight datamodule.
With a tiny toy dataset, you can validate I/O, augmentation, and training loops before scaling to larger data.
Swapping backbones or decoders becomes a configuration change instead of a rewrite.
Source Code
---title: "Semantic Segmentation with TerraTorch (Toy Burn-Scars)"subtitle: "From backbone to datamodule to 1-epoch fit"jupyter: geoaiformat: html: toc: true toc-depth: 3 code-fold: showeditor_options: chunk_output_type: console---## OverviewIntent: walk through a minimal end-to-end semantic segmentation workflow using TerraTorch on a tiny burn-scar example.- Build a segmentation model from a pretrained backbone- Prepare a tiny toy dataset (1 image + 1 mask)- Visualize inputs and labels- Train for 1 epoch and evaluate---## Quick environment checkUse this cell to confirm your runtime and whether `terratorch` is available. If it is not installed, see the optional install cell below.```{python}# | echo: trueimport sys, platformprint(f"Python: {sys.version.split()[0]}")print(f"Platform: {platform.platform()}")try:import torchprint(f"PyTorch: {torch.__version__}; cuda={torch.cuda.is_available()}")exceptExceptionas e:print("PyTorch not available:", e)try:import terratorchfrom terratorch import BACKBONE_REGISTRYprint("TerraTorch is installed.")exceptExceptionas e:print("TerraTorch not available:", e)```Optional: install missing packages (run only if needed).```{bash}#| echo: true#| eval: false# If needed, install basics for this tutorialpip install --upgrade terratorch lightning rioxarray matplotlib```---## Inspect available Prithvi backbones```{python}# | echo: truetry:from terratorch import BACKBONE_REGISTRY prithvi_models = [name for name in BACKBONE_REGISTRY if"terratorch_prithvi"in name]print("Available Prithvi models:", prithvi_models)exceptExceptionas e:print("Skipping registry check (TerraTorch not available):", e)```---## Build a segmentation model and sanity-check forward passWe construct a small segmentation model with an encoder-decoder factory and confirm output shape on a dummy tensor.```{python}# | echo: trueimport torchtry:from terratorch.datasets import HLSBandsfrom terratorch.models import EncoderDecoderFactory factory = EncoderDecoderFactory() model = factory.build_model( task="segmentation", backbone="prithvi_eo_v1_100", decoder="FCNDecoder", backbone_bands=[ HLSBands.BLUE, HLSBands.GREEN, HLSBands.RED, HLSBands.NIR_NARROW, HLSBands.SWIR_1, HLSBands.SWIR_2, ], num_classes=2, backbone_pretrained=True, backbone_num_frames=1, decoder_channels=128, head_dropout=0.2, ) trial = torch.zeros(1, 6, 224, 224) out = model(trial)print("Sanity forward pass → output shape:", out.output.shape)exceptExceptionas e:print("Skipping model build (dependency missing or no GPU):", e)```What to notice:- The output has shape `[batch, num_classes, height, width]`.- You can swap `decoder` or `num_classes` without changing the backbone.---## Prepare a tiny burn-scar dataset (1 image + 1 mask)This tutorial uses a single sample from the burn-scar demo to create a minimal train/val/test split. Run the download cell if files are missing.```{python}# | echo: true# | eval: false# Download a single image and its mask from the burn-scars demo (Python-based)import osfrom urllib.request import urlretrieveIMG_URL ="https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-Burn-scars-demo/resolve/main/subsetted_512x512_HLS.S30.T10TGS.2018285.v1.4_merged.tif"MSK_URL ="https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-Burn-scars-demo/resolve/main/subsetted_512x512_HLS.S30.T10TGS.2018285.v1.4.mask.tif"input_file_name ="subsetted_512x512_HLS.S30.T10TGS.2018285.v1.4_merged.tif"label_file_name ="subsetted_512x512_HLS.S30.T10TGS.2018285.v1.4.mask.tif"def download_if_missing(url: str, path: str) ->None:if os.path.exists(path):print(f"Exists: {path}")returntry:print(f"Downloading {url} → {path}") urlretrieve(url, path) size_mb = os.path.getsize(path) /1e6print(f"Saved: {path} ({size_mb:.2f} MB)")exceptExceptionas e:print("Download failed:", e)download_if_missing(IMG_URL, input_file_name)download_if_missing(MSK_URL, label_file_name)``````{python}# | echo: trueimport os, shutil# Reuse variables from the previous cell if set; otherwise, fall back to defaultstry: input_file_name label_file_nameexceptNameError: input_file_name ="subsetted_512x512_HLS.S30.T10TGS.2018285.v1.4_merged.tif" label_file_name ="subsetted_512x512_HLS.S30.T10TGS.2018285.v1.4.mask.tif"data_ready = os.path.exists(input_file_name) and os.path.exists(label_file_name)if data_ready: root ="burn_scar_segmentation_toy"ifnot os.path.isdir(root): os.mkdir(root)for img_dir in ["train_images", "val_images", "test_images"]: os.mkdir(os.path.join(root, img_dir)) shutil.copy(input_file_name, os.path.join(root, img_dir, input_file_name))for lbl_dir in ["train_labels", "val_labels", "test_labels"]: os.mkdir(os.path.join(root, lbl_dir)) shutil.copy(label_file_name, os.path.join(root, lbl_dir, label_file_name))print("Toy dataset directory ready:", root)else:print("Toy files not found. Run the previous download cell (set eval: true) and re-run this cell.")```---## Visualize the image and mask```{python}# | echo: trueif data_ready:import matplotlib.pyplot as pltimport rioxarray as rio fig, ax = plt.subplots(ncols=2, figsize=(10, 5)) ax[0].imshow( rio.open_rasterio(input_file_name).sel(band=[3, 2, 1]).transpose("y", "x", "band").to_numpy() ) ax[0].set_title("RGB composite") ax[1].imshow(rio.open_rasterio(label_file_name).to_numpy()[0]) ax[1].set_title("Mask (burn vs. non-burn)") plt.show()else:print("Skipping visualization (data not present).")```What to notice:- The RGB composite uses bands `[RED, GREEN, BLUE]` from the HLS image.- The mask is a single-channel label image with two classes.---## Create a datamodule for segmentation```{python}# | echo: truetry:from terratorch.datasets import HLSBandsfrom terratorch.datamodules import GenericNonGeoSegmentationDataModule means = [0.033349706741586264,0.05701185520536176,0.05889748132001316,0.2323245113436119,0.1972854853760658,0.11944914225186566, ] stds = [0.02269135568823774,0.026807560223070237,0.04004109844362779,0.07791732423672691,0.08708738838140137,0.07241979477437814, ]if data_ready: datamodule = GenericNonGeoSegmentationDataModule( batch_size=1, num_workers=0, train_data_root="burn_scar_segmentation_toy/train_images", val_data_root="burn_scar_segmentation_toy/val_images", test_data_root="burn_scar_segmentation_toy/test_images", image_glob="*_merged.tif", label_glob="*.mask.tif", mean=means, std=stds, num_classes=2, train_label_data_root="burn_scar_segmentation_toy/train_labels", val_label_data_root="burn_scar_segmentation_toy/val_labels", test_label_data_root="burn_scar_segmentation_toy/test_labels", dataset_bands=[ HLSBands.BLUE, HLSBands.GREEN, HLSBands.RED, HLSBands.NIR_NARROW, HLSBands.SWIR_1, HLSBands.SWIR_2, ], output_bands=[ HLSBands.BLUE, HLSBands.GREEN, HLSBands.RED, HLSBands.NIR_NARROW, HLSBands.SWIR_1, HLSBands.SWIR_2, ], no_data_replace=0, no_label_replace=-1, ) datamodule.setup("fit")print("Datamodule ready. Train/Val/Test prepared.")else: datamodule =Noneprint("Skipping datamodule (data not present).")exceptExceptionas e: datamodule =Noneprint("Skipping datamodule setup (dependency missing):", e)```---## Define the training task and fit for 1 epoch```{python}# | echo: truetry:if data_ready and datamodule isnotNone:from terratorch.datasets import HLSBandsfrom terratorch.tasks import SemanticSegmentationTaskfrom lightning.pytorch import Trainerfrom lightning.pytorch.callbacks import ( EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar, )from lightning.pytorch.loggers import TensorBoardLogger model_args = {"backbone": "prithvi_vit_100","decoder": "FCNDecoder","num_classes": 2,"backbone_bands": [ HLSBands.RED, HLSBands.GREEN, HLSBands.BLUE, HLSBands.NIR_NARROW, HLSBands.SWIR_1, HLSBands.SWIR_2, ],"backbone_pretrained": True,"backbone_num_frames": 1,"decoder_channels": 128,"head_dropout": 0.2,"necks": [ {"name": "SelectIndices", "indices": [-1]}, {"name": "ReshapeTokensToImage"}, ], } task = SemanticSegmentationTask( model_args, model_factory_name="EncoderDecoderFactory", loss="ce", aux_loss={"fcn_aux_head": 0.4}, lr=1e-3, ignore_index=-1, optimizer="AdamW", optimizer_hparams={"weight_decay": 0.05}, ) accelerator ="auto" logger = TensorBoardLogger(save_dir="tutorial_experiments", name="seg_toy") trainer = Trainer( accelerator=accelerator, callbacks=[ RichProgressBar(), ModelCheckpoint(monitor=task.monitor, save_top_k=1, save_last=True), LearningRateMonitor(logging_interval="epoch"), ], logger=logger, max_epochs=1, log_every_n_steps=1, check_val_every_n_epoch=200, default_root_dir="tutorial_experiments/seg_toy", ) trainer.fit(model=task, datamodule=datamodule)else:print("Skipping training (data not prepared).")exceptExceptionas e:print("Training skipped due to error:", e)```---## Evaluate on the test split```{python}# | echo: truetry:if data_ready and datamodule isnotNone: _ = trainer.test(model=task, datamodule=datamodule)else:print("Skipping test (data not prepared).")exceptExceptionas e:print("Test skipped due to error:", e)```---## Optional: visualize a predictionThis cell runs a forward pass on the single test image and displays the predicted mask (argmax over classes). It does not modify model state.```{python}# | echo: trueimport numpy as nptry:if data_ready and datamodule isnotNone: dm = datamodule dm.setup("test") test_loader = dm.test_dataloader() batch =next(iter(test_loader)) images = batch[0]with torch.no_grad(): preds = task(images).output pred_mask = preds.argmax(dim=1)[0].cpu().numpy()print("Prediction mask unique labels:", np.unique(pred_mask).tolist())else:print("Skipping prediction visualization (data not prepared).")exceptExceptionas e:print("Prediction skipped due to error:", e)```---## Why this matters (reflection)- You can compose a full segmentation workflow by combining a pretrained backbone, a decoder, and a lightweight datamodule.- With a tiny toy dataset, you can validate I/O, augmentation, and training loops before scaling to larger data.- Swapping backbones or decoders becomes a configuration change instead of a rewrite.