# Load DIOR from Hugging Face (COCO format with annotations!)
# Using HichTala/dior which has proper bounding box annotations
from datasets import load_dataset
print("Loading DIOR from Hugging Face (with annotations)...")
hf_train = load_dataset("HichTala/dior", split="train")
print(f"โ Training samples: {len(hf_train)}")
print(f"โ All images are 800ร800 pixels (perfect for ViT!)")
# Examine structure
sample = hf_train[0]
print(f"\nโ Sample keys: {list(sample.keys())}")
print(f" image_id: {sample['image_id']}")
print(f" width: {sample['width']}, height: {sample['height']}")
print(f" num objects: {len(sample['objects']['category'])}")
# Show first object
if len(sample['objects']['category']) > 0:
print(f"\nโ First object:")
print(f" category: {sample['objects']['category'][0]}")
print(f" bbox: {sample['objects']['bbox'][0]}") # [x, y, w, h] COCO format
print(f" area: {sample['objects']['area'][0]}")Overview
This tutorial demonstrates object detection using the Prithvi geospatial foundation model on the DIOR aerial imagery dataset. Youโll learn how to adapt a Vision Transformer (ViT) backbone for object detection, solving the architectural challenges that arise when integrating ViTs with traditional detection frameworks.
Dataset & Pipeline: - Load DIOR (23,463 aerial images, 20 classes) from Hugging Face - Build complete detection pipeline: Prithvi ViT โ Faster R-CNN (single-scale) - Create custom data collation for ViT-friendly 512ร512 images
ViT Detection Challenges & Solutions: - Challenge: Faster R-CNN expects multi-scale CNN features (FPN), ViT gives single-scale sequences - Solution 1: Skip FPN, work directly on ViTโs 32ร32 feature map - Solution 2: Custom single-scale RPN with 15,360 anchors tuned for aerial imagery - Solution 3: Replace MultiScaleRoIAlign with single-scale RoIAlign - Solution 4: Custom forward pass to reshape ViT sequences โ spatial features
Training & Evaluation: - Train object detection model with automatic GPU/CPU support - Monitor RPN and ROI losses separately - Debug predictions with confidence score analysis - Visualize detections vs ground truth
By the end, youโll understand how to integrate ANY ViT-based GFM with object detection frameworks!
The VHR10 Problem: - Variable image sizes (586ร716 to 958ร1024) - ViTs need consistent input sizes - Required complex padding/resizing logic
The DIOR Solution: - โ Fixed 800ร800 pixel images - Every single image! - โ 23,463 images - Much larger dataset - โ 20 object classes - More diversity - โ Simple resize to 512ร512 or 224ร224 - No padding needed
This makes DIOR the ideal dataset for testing Prithvi-based object detection!
Background
Object Detection vs. Segmentation: - Object Detection: Localizes and classifies objects in an image - Segmentation: Assigns a label to each pixel independently
Load the DIOR Dataset
We are using the HichTala/dior dataset from Hugging Face, which has proper bounding box annotations.
Citation: @article{Li_2020, title={Object detection in optical remote sensing images: A survey and a new benchmark}, volume={159}, ISSN={0924-2716}, url={http://dx.doi.org/10.1016/j.isprsjprs.2019.11.023}, DOI={10.1016/j.isprsjprs.2019.11.023}, journal={ISPRS Journal of Photogrammetry and Remote Sensing}, publisher={Elsevier BV}, author={Li, Ke and Wan, Gang and Cheng, Gong and Meng, Liqiu and Han, Junwei}, year={2020}, month=jan, pages={296โ307}}
The DIOR (Dataset for Object deteCtIon in aerial Images) is an excellent dataset for object detection in aerial imagery, hosted on Hugging Face.
Key Features: - 23,463 total images (19,000 train, 4,463 test) - All images are 800ร800 pixels (consistent size!) - 20 object classes common in aerial imagery - Standard train/test splits
Classes (20 total): airplane, airport, baseball field, basketball court, bridge, chimney, dam, expressway service area, expressway toll station, golf course, ground track field, harbor, overpass, ship, stadium, storage tank, tennis court, train station, vehicle, wind mill
Objects are annotated with bounding boxes in [xmin, ymin, xmax, ymax] format.
Now letโs create a PyTorch dataset wrapper and visualize DIOR samples:
import torch
import numpy as np
from PIL import Image
# DIOR class names (20 classes)
DIOR_CLASSES = [
'airplane', 'airport', 'baseball field', 'basketball court', 'bridge',
'chimney', 'dam', 'expressway service area', 'expressway toll station',
'golf course', 'ground track field', 'harbor', 'overpass', 'ship',
'stadium', 'storage tank', 'tennis court', 'train station', 'vehicle', 'wind mill'
]
# First, inspect what keys HuggingFace DIOR has
sample = hf_train[0]
print("DIOR sample structure:")
for key in sample.keys():
val = sample[key]
print(f" {key}: {type(val)}")
if isinstance(val, (list, dict)):
print(f" Content: {val}")Now letโs create a proper PyTorch dataset wrapper:
class DIORPyTorchDataset(torch.utils.data.Dataset):
"""
Convert Hugging Face DIOR (COCO format) to PyTorch detection format.
Input (COCO format):
- bbox: [x, y, width, height]
- category: class index
Output (torchvision format):
- image: Tensor [3, 800, 800]
- bbox_xyxy: Tensor [N, 4] in [xmin, ymin, xmax, ymax] format
- label: Tensor [N]
"""
def __init__(self, hf_dataset):
self.hf_dataset = hf_dataset
def __len__(self):
return len(self.hf_dataset)
def __getitem__(self, idx):
sample = self.hf_dataset[idx]
# Convert PIL image to tensor [C, H, W]
image = sample['image']
if isinstance(image, Image.Image):
image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
# Extract bounding boxes (COCO format: [x, y, w, h])
bboxes_coco = sample['objects']['bbox'] # List of [x, y, w, h]
labels = sample['objects']['category'] # List of class indices
# Convert COCO [x, y, w, h] to xyxy [xmin, ymin, xmax, ymax]
bboxes_xyxy = []
for bbox in bboxes_coco:
x, y, w, h = bbox
xmin, ymin = x, y
xmax, ymax = x + w, y + h
bboxes_xyxy.append([xmin, ymin, xmax, ymax])
# Convert to tensors
boxes = torch.tensor(bboxes_xyxy, dtype=torch.float32) if bboxes_xyxy else torch.zeros((0, 4), dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.int64) if labels else torch.zeros((0,), dtype=torch.int64)
return {
'image': image,
'bbox_xyxy': boxes,
'label': labels
}
# Create dataset
dataset_train = DIORPyTorchDataset(hf_train)
# Test it
test_sample = dataset_train[0]
print(f"\nโ PyTorch dataset created")
print(f" Image shape: {test_sample['image'].shape}")
print(f" Bounding boxes: {test_sample['bbox_xyxy'].shape}")
print(f" Labels: {test_sample['label'].shape}")
print(f" Expected: [3, 800, 800] images with variable number of objects")Now letโs visualize some DIOR samples with their annotations:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# DIOR class names
DIOR_CLASSES = [
'Airplane', 'Airport', 'Baseball field', 'Basketball court', 'Bridge',
'Chimney', 'Dam', 'Expressway service area', 'Expressway toll station',
'Golf course', 'Ground track field', 'Harbor', 'Overpass', 'Ship',
'Stadium', 'Storage tank', 'Tennis court', 'Train station', 'Vehicle', 'Wind mill'
]
# Visualize 4 samples
fig, axes = plt.subplots(2, 2, figsize=(16, 16))
axes = axes.flatten()
for idx in range(4):
sample = dataset_train[idx]
# Get image
image = sample['image'].permute(1, 2, 0).numpy() # [C,H,W] โ [H,W,C]
# Get boxes and labels
boxes = sample['bbox_xyxy']
labels = sample['label']
# Display
ax = axes[idx]
ax.imshow(image)
# Draw bounding boxes
for box, label in zip(boxes, labels):
xmin, ymin, xmax, ymax = box.numpy()
width = xmax - xmin
height = ymax - ymin
# Rectangle
rect = patches.Rectangle(
(xmin, ymin), width, height,
linewidth=2, edgecolor='red', facecolor='none'
)
ax.add_patch(rect)
# Label
class_name = DIOR_CLASSES[int(label)]
ax.text(
xmin, ymin - 5,
class_name,
color='white',
fontsize=9,
weight='bold',
bbox=dict(boxstyle='round', facecolor='red', alpha=0.8)
)
ax.axis('off')
ax.set_title(f"Sample {idx}: {len(boxes)} object(s)", fontsize=12, weight='bold')
plt.suptitle("DIOR Dataset Examples (800ร800 pixels)", fontsize=16, weight='bold')
plt.tight_layout()
plt.show()Understanding Object Detection Architecture
Object detection requires more than just a backbone model - we need a complete detection pipeline:
- Backbone: Extract features from the image (e.g., Prithvi ViT)
- Neck (optional): Process features at multiple scales (e.g., FPN)
- Detection Head: Predict bounding boxes and classes (e.g., Faster R-CNN, RetinaNet)
Unlike classification or segmentation, object detection must: - Predict variable numbers of objects per image - Localize objects with bounding boxes - Handle objects at different scales - Deal with overlapping predictions (Non-Maximum Suppression)
This complexity means we typically need specialized frameworks like Detectron2, MMDetection, or TerraTorch that provide pre-built detection pipelines.
Setting Up Object Detection with TerraTorch
Use the TerraTorch Object Detection Model Factory to create a model for object detection. The complete pipeline has three components:
- Backbone (Prithvi): Extracts features from input images
- Neck (Optional - FPN): Can transform features into multi-scale pyramid
- Head (Faster R-CNN): Predicts bounding boxes and classes from features
For ViT, we skip the neck and use single-scale detection directly.
Traditional CNN detectors use FPN to create multi-scale features. However:
- Prithvi ViT: Outputs single-scale features (32ร32)
- FPN Complexity: Feature map filtering and multi-scale expectations
- DIOR Dataset: Consistent image sizes and object scales
Our approach: Skip FPN entirely and use single-scale detection directly on ViT features. This is simpler and avoids compatibility issues.
TerraTorchโs ObjectDetectionModel is a wrapper around torchvisionโs detection models. While the wrapper is useful for model construction, it has some quirks during training:
The Issue: The wrapperโs forward() method doesnโt properly pass targets to the underlying model during training mode.
The Solution: Extract the underlying torchvision model using model.torchvision_model and train that directly.
This is a common pattern when working with model wrappers - use them for construction, but train the underlying model directly.
from terratorch.models.object_detection_model_factory import ObjectDetectionModelFactory
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.rpn import RegionProposalNetwork, RPNHead
from collections import OrderedDict
import types
import math
model_factory = ObjectDetectionModelFactory()
model = model_factory.build_model(
task="object_detection",
backbone="prithvi_eo_v1_100",
backbone_bands=[0, 1, 2],
backbone_num_frames=1,
backbone_pretrained=True,
neck=None, # No FPN - simpler single-scale detection
num_classes=20,
framework="faster-rcnn"
)
if hasattr(model, 'torchvision_model'):
detection_model = model.torchvision_model
else:
detection_model = model
# Fix image transform to keep 512ร512 size (don't resize)
detection_model.transform.min_size = (512,)
detection_model.transform.max_size = 512
print("Model created with Prithvi backbone")
print("Image transform configured: 512ร512 (no resizing)")Now we need to configure custom anchors for ViTโs single-scale features:
# --- Anchor Generator and Custom Region Proposal Network for ViT Single-Scale Detection ---
def create_vit_anchor_generator():
"""
Create an AnchorGenerator specifically for Prithvi ViT's single-scale feature map.
In the classic Faster R-CNN pipeline, anchors are generated over multiple feature map scales (FPN).
Since our ViT backbone outputs just a single feature map (no FPN), we use a single set of anchor sizes and ratios:
For 512x512 input with patch_size=16:
- ViT outputs a single [32x32] feature map (stride 16).
- Each feature location has anchors of sizes: 32, 64, 128, 256, 512 pixels.
- Each anchor size is combined with aspect ratios: 0.5 (tall), 1.0 (square), 2.0 (wide).
This ensures the RPN can propose boxes covering a wide range of object sizes/shapes
despite only a single feature resolution.
"""
# Anchor sizes tuned for objects in 512ร512 images (DIOR resized from 800ร800)
# Cover range: small vehicles (~20px) to large buildings/airports (~200px)
anchor_sizes = ((16, 32, 64, 128, 256),)
aspect_ratios = ((0.5, 1.0, 2.0),)
return AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
def replace_rpn(model):
"""
Replace the model's RPN (Region Proposal Network) with a version
compatible with ViT's single-scale features.
- in_channels: Set to Prithvi's hidden dimension (768 for ViT-base)
because that's the channel count of the feature map.
- anchor_generator: Single-scale anchor generator from above.
- RPNHead: Computes objectness and box regression deltas per anchor.
- RegionProposalNetwork: Uses anchors and predictions to generate
object proposal regions. All RPN parameters
are tuned to standard values but operate only
on the single ViT feature map (not FPN).
This customization is required because the default RPN expects FPN-style input,
but our model has just one feature map level (from ViT).
"""
in_channels = 768 # Prithvi ViT-Base hidden dimension
anchor_generator = create_vit_anchor_generator()
# RPNHead = conv layers predicting objectness & bbox offsets for each anchor at each location.
rpn_head = RPNHead(
in_channels=in_channels,
num_anchors=anchor_generator.num_anchors_per_location()[0]
)
# The actual RPN module, configured for single-scale input.
rpn = RegionProposalNetwork(
anchor_generator=anchor_generator,
head=rpn_head,
fg_iou_thresh=0.7, # IoU threshold for matching anchors to GT boxes (foreground)
bg_iou_thresh=0.3, # IoU threshold for matching anchors as background
batch_size_per_image=256,
positive_fraction=0.5,
pre_nms_top_n=dict(training=2000, testing=1000),
post_nms_top_n=dict(training=2000, testing=1000),
nms_thresh=0.7 # Non-max suppression threshold for RPN proposals
)
# Override the model's RPN with our custom single-scale RPN.
model.rpn = rpn
return model
# Insert our custom single-scale RPN into the model pipeline.
detection_model = replace_rpn(detection_model)
print("Custom RPN configured for ViT features")Now replace the ROI pooler to work with single-scale features:
import torch.nn as nn
from torchvision.ops import RoIAlign
def replace_roi_pooler(model):
"""
Replace the default multi-scale ROI pooler in Faster R-CNN with a
custom single-scale ROIAlign module suitable for ViT-style (single feature map) backbones.
Why:
Faster R-CNN defaults to MultiScaleRoIAlign, which expects a dict of feature maps at
multiple pyramid levels (as with FPN). Vision Transformers output a single feature map,
so we need a simpler pooling module that operates on just one map.
What this does:
- Instantiates a single-scale RoIAlign set to a 7x7 output, which is standard for the detection head.
- Uses a spatial scale of 1/16, as Prithvi/ViT feature stride is 16 pixels (512 input -> 32x32 features).
- Defines a custom nn.Module that handles both dict and tensor feature inputs, always pulling out
the only feature map available, and ignoring FPN-style logic.
- Replaces model.roi_heads.box_roi_pool with this new single-scale implementation.
Returns:
The modified model, ready for use with a ViT backbone.
"""
# Create single-scale RoIAlign; output size 7x7 is default for detection heads
roi_align = RoIAlign(
output_size=7,
spatial_scale=1.0/16, # ViT stride=16 (e.g., input 512 -> 32x32)
sampling_ratio=2
)
class SingleScaleRoIPool(nn.Module):
"""
A wrapper for single-scale RoIAlign to mimic expected interface of MultiScaleRoIAlign,
allowing seamless plug-in to torchvision's Faster R-CNN code.
"""
def __init__(self, roi_align):
super().__init__()
self.roi_align = roi_align
def forward(self, features, proposals, image_shapes):
"""
Apply RoIAlign to a single feature map.
- features: Either a dict (assumed to contain only one entry) or a Tensor.
- proposals: List of proposal boxes per image (as required by RoIAlign).
- image_shapes: (Ignored) present for API compatibility only.
Returns: Tensor of pooled RoI features.
"""
# If features is a dict (from model), extract its sole value; otherwise, use it as-is
if isinstance(features, dict):
x = list(features.values())[0]
else:
x = features
# image_shapes argument is not used, as RoIAlign doesn't require it for single-scale
return self.roi_align(x, proposals)
# Swap out multi-scale RoI pooler for single-scale version
model.roi_heads.box_roi_pool = SingleScaleRoIPool(roi_align)
return model
# Patch the model to use the new single-scale ROI pooler
detection_model = replace_roi_pooler(detection_model)
print("Custom ROI pooler configured for single-scale features")To make Prithvi ViT work with Faster R-CNN, we made four critical changes:
- Skip FPN (neck=None): Avoid multi-scale complexity, work directly on single 32ร32 feature map
- Custom RPN: Single-scale anchor generator with 15,360 anchors tuned for 512ร512 images
- Custom ROI Pooler: Replace MultiScaleRoIAlign with single-scale RoIAlign at 1/16 spatial scale
- Custom Forward Pass: Reshape ViT sequence [B,1024,768] โ spatial [B,768,32,32] with MPS compatibility
These adaptations allow single-scale detection without FPN complexity while maintaining full compatibility with Faster R-CNNโs architecture.
Apply the compatibility patch for ViT features:
def fixed_fasterrcnn_forward(self, images, targets=None):
"""
Custom forward pass to enable Faster R-CNN with Prithvi ViT backbone.
This replaces the original forward to handle single-scale ViT features, which differ
from typical multi-scale CNN backbones (like FPN).
Steps performed:
1. Converts ViT sequence output from shape [B, N, C] into a spatial feature map [B, C, H, W],
where:
B=batch size,
N=number of tokens (patches),
C=channels, and
H and W are the height and width of the feature map (typically recovered from N as H ร W = N).
- Handles [CLS] token (N==1025) by removing the first token.
- Reshapes flat sequence into square map (ViT patch output).
2. Wraps features in an OrderedDict, preserving required interface.
- Ensures tensor data is contiguous for MPS (Apple Silicon) and general backend compatibility.
3. Proceeds with proposal, ROI, and loss computation as usual.
4. Returns post-processed detections or losses depending on training mode.
Args:
images: List of input images, already as Tensors.
targets: List of ground truth targets (boxes/labels). Required in training mode.
Returns:
Loss dict (training mode) or list of detection outputs (eval mode).
"""
# Training mode sanity checkโtargets required for loss computation
if self.training and targets is None:
torch._assert(False, "targets should not be none when in training mode")
# Record original sizes before transforms (needed for inverse mapping in postprocess)
original_image_sizes = []
for img in images:
val = img.shape[-2:]
original_image_sizes.append((val[0], val[1]))
# Apply input transforms (e.g. resizing, normalization)
images, targets = self.transform(images, targets)
# Forward through the ViT backbone
features = self.backbone(images.tensors) # Could be Tensor, list, or OrderedDict
# If backbone outputs a list, use the last element (common for some backbones)
if isinstance(features, list):
features = features[-1]
# If features are [B, N, C] (ViT style: sequence length), reshape to [B, C, H, W]
if isinstance(features, torch.Tensor) and features.ndim == 3:
B, N, C = features.shape
# Some ViT models output CLS token: strip if N==1025 ([B, 1025, C])
if N == 1025:
features = features[:, 1:, :]
N = 1024 # Now we have only patch tokens
# Recover H and W; for ViT this is usually 32x32
H = W = int(math.sqrt(N))
# Rearrange to [B, C, H, W] for convolutional-style heads
features = features.permute(0, 2, 1).reshape(B, C, H, W)
# Convert to OrderedDict if not already, as expected by torchvision detection components.
# Always call .contiguous() for MPS (Metal) backend compatibility and for safety.
if not isinstance(features, OrderedDict):
# Only one feature map ('0' key)
features = OrderedDict([("0", features.contiguous())])
else:
# Make sure all tensors in the dict are contiguous
features = OrderedDict([(k, v.contiguous()) for k, v in features.items()])
# Generate proposals with Region Proposal Network (RPN), then detect boxes/labels/sores
proposals, proposal_losses = self.rpn(images, features, targets)
detections, detector_losses = self.roi_heads(
features, proposals, images.image_sizes, targets
)
detections = self.transform.postprocess(
detections, images.image_sizes, original_image_sizes
)
# Aggregate all computed losses for training, or return detections in eval mode
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
if self.training:
return losses
return detections
# Overwrite the model's forward method with our ViT-compatible implementation
detection_model.forward = types.MethodType(fixed_fasterrcnn_forward, detection_model)
print("Compatibility patch applied for ViT backbone (single-scale features)")
print("Model ready for training or inference")The Challenge: - Default Faster R-CNN expects multi-scale CNN features (FPN pyramid) - Prithvi ViT produces single-scale features (32ร32 at stride 16) - Multi-scale components (RPN, ROI pooler) fail with single-scale input
Our Three-Part Solution:
- Skip FPN (
neck=None)- Avoid multi-scale complexity entirely
- Direct single-scale detection
- Works directly on ViTโs 32ร32 feature map
- Custom RPN
- Single anchor tuple for 32ร32 feature map
- Anchor sizes: 16, 32, 64, 128, 256 pixels (tuned for 512ร512 input)
- Aspect ratios: 0.5 (tall), 1.0 (square), 2.0 (wide)
- Total: 5 sizes ร 3 ratios ร 32ร32 locations = 15,360 anchors per image
- Custom ROI Pooler
- Replace
MultiScaleRoIAlignwith simpleRoIAlign - Spatial scale: 1/16 (matches ViT stride)
- Output: 7ร7 features per proposal (standard for detection head)
- Bypass FPN-style feature map filtering
- Replace
- Custom Forward Pass
- Reshape ViT sequence [B, 1024, 768] โ spatial [B, 768, 32, 32]
- Handle [CLS] token removal if present
- Ensure tensor contiguity for MPS (Apple Silicon) compatibility
- Wrap features in OrderedDict for torchvision compatibility
Why This Works for DIOR: - Fixed 800ร800 images โ resize to 512ร512 โ consistent 32ร32 features - Objects range 20-200 pixels after resize (anchors cover this range) - Single-scale sufficient for consistent aerial imagery - Simpler architecture, fewer failure points
Prepare Data for Training
Letโs create train and validation splits from the DIOR dataset. Weโll use a subset for faster training in this demonstration.
from torch.utils.data import DataLoader, Subset
# Custom collate function for DIOR with Prithvi
# DIOR has fixed 800ร800 images - perfect for ViT!
# We'll resize to 512ร512 (divisible by 16 for patch_size)
import torch.nn.functional as F
def detection_collate_fn(batch):
"""
Custom collate function for DIOR with Prithvi ViT.
Since DIOR images are all 800ร800:
- Simple resize to 512ร512 (no padding needed!)
- Scale bounding boxes proportionally
- ViT gets consistent 1024 patches (32ร32 grid with patch_size=16)
"""
images = []
targets = []
target_size = 512 # Resize to 512ร512 for ViT
for sample in batch:
# Extract image and convert to tensor if needed
img = sample['image']
if not isinstance(img, torch.Tensor):
# Convert PIL to tensor
import numpy as np
img = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
_, orig_h, orig_w = img.shape # Should be 800ร800
# Resize to target size
img_resized = F.interpolate(
img.unsqueeze(0),
size=(target_size, target_size),
mode='bilinear',
align_corners=False
).squeeze(0)
# Scale bounding boxes (800ร800 โ 512ร512)
scale = target_size / orig_w # 512 / 800 = 0.64
boxes = sample['bbox_xyxy'].clone().float()
boxes *= scale
images.append(img_resized)
target = {
'boxes': boxes,
'labels': sample['label']
}
targets.append(target)
return images, targets
print("โ Custom collate function defined for DIOR")
print(" - Resizes 800ร800 โ 512ร512 (ViT friendly!)")
print(" - Scales bounding boxes proportionally")
print(" - No padding needed (square images)")Patch Size = 16 pixels: - 512รท16 = 32 patches per dimension - Total: 32ร32 = 1,024 patches - Perfect square โ
Alternatives: - 224ร224: 14ร14 = 196 patches (smaller, faster) - 448ร448: 28ร28 = 784 patches (medium) - 800ร800: 50ร50 = 2,500 patches (original size, slower)
We chose 512ร512 as a good balance between resolution and speed!
# Load validation split from Hugging Face
print("Loading validation split...")
hf_val = load_dataset("HichTala/dior", split="validation")
val_dataset = DIORPyTorchDataset(hf_val)
# For this demo, use subset of training data to speed things up
train_dataset = Subset(dataset_train, range(1000)) # Use first 1000 images
val_dataset = Subset(val_dataset, range(200)) # Use first 200 val images
print(f"โ Training samples: {len(train_dataset)}")
print(f"โ Validation samples: {len(val_dataset)}")
# Create data loaders with custom collate function
train_loader = DataLoader(
train_dataset,
batch_size=2, # Small batch size for detection
shuffle=True,
num_workers=0, # Set to 0 for compatibility
collate_fn=detection_collate_fn # Use custom collate function
)
val_loader = DataLoader(
val_dataset,
batch_size=2,
shuffle=False,
num_workers=0,
collate_fn=detection_collate_fn
)
print(f"โ Training batches: {len(train_loader)}")
print(f"โ Validation batches: {len(val_loader)}")
# Test the data loader
sample_batch = next(iter(train_loader))
images, targets = sample_batch
print(f"\nโ Batch structure:")
print(f" - Images: list of {len(images)} tensors")
print(f" - Image 0 shape: {images[0].shape}")
print(f" - Image 1 shape: {images[1].shape}")
print(f" - Targets: list of {len(targets)} dicts")
print(f" - Target 0 boxes: {targets[0]['boxes'].shape}")
print(f" - Target 0 labels: {targets[0]['labels'].shape}")Train the Object Detection Model
Now letโs train the model! With all our adaptations in place, the model is ready for training.
What weโve configured: - โ Prithvi ViT backbone with single-scale 32ร32 features - โ Custom RPN with 15,360 anchors tuned for 512ร512 images - โ Custom single-scale ROI pooler (spatial scale 1/16) - โ Custom forward pass handling ViT sequence โ spatial reshaping - โ Fixed image transform (512ร512, no resizing) - โ Data pipeline: 800ร800 DIOR โ 512ร512 with bbox scaling
Object detection training differs from classification: - Model computes losses internally (RPN objectness/bbox + ROI classification/bbox) - Must pass both images AND targets during training - Returns dict of losses in train mode, predictions in eval mode - Training typically requires 20-50 epochs for convergence (weโll start with 2 for demo)
The training function automatically detects and uses available hardware acceleration: - Apple Silicon (M1/M2/M3): Uses MPS (Metal Performance Shaders) - NVIDIA GPUs: Uses CUDA - CPU: Falls back to CPU if no GPU available
MPS Compatibility: Some operations may not be fully supported on MPS. The training loop includes automatic fallback to CPU if MPS errors occur, ensuring training completes successfully.
This can significantly speed up training on compatible hardware!
import torch
def train_detection_model(model, train_loader, val_loader, num_epochs=2, lr=0.0001):
"""Train an object detection model with MPS support."""
# Try MPS (Apple Silicon) first, then CUDA, then CPU
if torch.backends.mps.is_available():
device = torch.device("mps")
print("Training on: Apple Silicon (MPS)")
print("Note: If MPS errors occur, model will fall back to CPU")
elif torch.cuda.is_available():
device = torch.device("cuda")
print("Training on: CUDA")
else:
device = torch.device("cpu")
print("Training on: CPU")
print()
try:
model = model.to(device)
except Exception as e:
print(f"Failed to move model to {device}: {e}")
print("Falling back to CPU")
device = torch.device("cpu")
model = model.to(device)
model.train()
# Use SGD optimizer (standard for object detection)
optimizer = torch.optim.SGD(
model.parameters(),
lr=lr,
momentum=0.9,
weight_decay=0.0005
)
for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
epoch_loss = 0.0
for batch_idx, (images, targets) in enumerate(train_loader):
try:
# Move data to device
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# Forward pass - model returns loss dict
loss_dict = model(images, targets)
# Sum all losses
losses = sum(loss for loss in loss_dict.values())
# Backward pass
optimizer.zero_grad()
losses.backward()
optimizer.step()
epoch_loss += losses.item()
# Print progress every 50 batches
if batch_idx % 50 == 0:
print(f" Batch {batch_idx}/{len(train_loader)}, Loss: {losses.item():.4f}")
print(f" RPN: cls={loss_dict.get('loss_objectness', 0):.3f}, "
f"bbox={loss_dict.get('loss_rpn_box_reg', 0):.3f}, "
f"ROI: cls={loss_dict.get('loss_classifier', 0):.3f}, "
f"bbox={loss_dict.get('loss_box_reg', 0):.3f}")
except RuntimeError as e:
if "MPS" in str(e) or "mps" in str(e):
print(f"\nMPS compatibility issue detected: {e}")
print("Switching to CPU for remaining training...")
device = torch.device("cpu")
model = model.to(device)
# Retry this batch on CPU
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
optimizer.zero_grad()
losses.backward()
optimizer.step()
epoch_loss += losses.item()
else:
raise
avg_loss = epoch_loss / len(train_loader)
print(f"โ Epoch {epoch+1} completed - Avg Loss: {avg_loss:.4f}\n")
return model# Train the model
# Note: Start with 2 epochs for quick demo. For production:
# - Use 20-50 epochs for convergence
# - Monitor validation loss to avoid overfitting
# - Consider learning rate scheduling (e.g., reduce on plateau)
trained_model = train_detection_model(
detection_model,
train_loader,
val_loader,
num_epochs=2, # Increase to 20-50 for better results
lr=0.0001
)
# Ensure transform is configured correctly for inference
trained_model.transform.min_size = (512,)
trained_model.transform.max_size = 512
# Lower score threshold for inference (default 0.05 is too high for early training)
# Note: Some models store these in box_predictor or have different structures
if hasattr(trained_model.roi_heads, 'score_thresh'):
trained_model.roi_heads.score_thresh = 0.001
if hasattr(trained_model.roi_heads, 'nms_thresh'):
trained_model.roi_heads.nms_thresh = 0.5
if hasattr(trained_model.roi_heads, 'detections_per_img'):
trained_model.roi_heads.detections_per_img = 100
print(f"ROI heads attributes: {[attr for attr in dir(trained_model.roi_heads) if 'thresh' in attr.lower() or 'detections' in attr.lower()]}")
print(f"Configured thresholds: score={getattr(trained_model.roi_heads, 'score_thresh', 'N/A')}, nms={getattr(trained_model.roi_heads, 'nms_thresh', 'N/A')}")##Evaluate the Trained Model
Letโs see how the model performs on validation data:
import torch
def evaluate_detection_model(model, val_loader, num_samples=5):
"""Evaluate object detection model on validation set."""
# Use same device detection as training
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
try:
model = model.to(device)
except Exception as e:
print(f"Failed to move model to {device}: {e}")
print("Falling back to CPU for evaluation")
device = torch.device("cpu")
model = model.to(device)
model.eval()
print("Running inference on validation samples...\n")
print(f"Using device: {device}")
sample_images, sample_targets = next(iter(val_loader))
try:
images = [img.to(device) for img in sample_images]
with torch.no_grad():
predictions = model(images)
except RuntimeError as e:
if "MPS" in str(e) or "mps" in str(e):
print(f"\nMPS compatibility issue during inference: {e}")
print("Switching to CPU for evaluation...")
device = torch.device("cpu")
model = model.to(device)
images = [img.to(device) for img in sample_images]
with torch.no_grad():
predictions = model(images)
else:
raise
# Display results
for i in range(min(len(predictions), num_samples)):
pred = predictions[i]
target = sample_targets[i]
print(f"Sample {i+1}:")
print(f" Ground truth: {len(target['boxes'])} objects")
print(f" Predicted: {len(pred['boxes'])} detections")
# Show top 3 predictions
if len(pred['boxes']) > 0:
top_indices = pred['scores'].argsort(descending=True)[:3]
print(f" Top 3 detections:")
for idx in top_indices:
score = pred['scores'][idx].item()
label = pred['labels'][idx].item()
print(f" - Class {label}, confidence: {score:.3f}")
print()
evaluate_detection_model(trained_model, val_loader, num_samples=3)Debug Model Predictions
Before visualizing, letโs check if the model is generating any predictions at all:
import torch
def debug_model_predictions(model, dataset, sample_idx=0):
"""Debug what the model is predicting with detailed introspection."""
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
try:
model = model.to(device)
except Exception as e:
device = torch.device("cpu")
model = model.to(device)
model.eval()
# Get sample
sample = dataset[sample_idx]
batch = detection_collate_fn([sample])
images, targets = batch
img = images[0].to(device)
print(f"Sample {sample_idx} Debug Info:")
print(f" Ground truth: {len(targets[0]['boxes'])} objects")
# Check current threshold settings
print(f"\n Model threshold settings:")
print(f" score_thresh: {getattr(model.roi_heads, 'score_thresh', 'N/A')}")
print(f" nms_thresh: {getattr(model.roi_heads, 'nms_thresh', 'N/A')}")
print(f" detections_per_img: {getattr(model.roi_heads, 'detections_per_img', 'N/A')}")
# Try with extremely low threshold
original_score_thresh = model.roi_heads.score_thresh
model.roi_heads.score_thresh = 0.0001 # Extremely low
with torch.no_grad():
predictions = model([img])
pred = predictions[0]
print(f"\n Model output (with score_thresh=0.0001):")
print(f" Total predictions after NMS: {len(pred['boxes'])}")
if len(pred['boxes']) == 0:
print(" โ ๏ธ STILL no predictions! This suggests:")
print(" 1. ROI head is not producing any boxes with positive scores")
print(" 2. All boxes might be getting filtered before score threshold")
print(" 3. Check if box_predictor is outputting valid scores")
model.roi_heads.score_thresh = original_score_thresh
return
print(f" Score range: [{pred['scores'].min():.6f}, {pred['scores'].max():.6f}]")
# Show score distribution
thresholds = [0.0001, 0.001, 0.01, 0.05, 0.1, 0.3, 0.5]
print(f"\n Predictions by confidence threshold:")
for thresh in thresholds:
count = (pred['scores'] > thresh).sum().item()
print(f" >{thresh}: {count} detections")
print(f"\n Top 10 predictions:")
top_indices = pred['scores'].argsort(descending=True)[:10]
for i, idx in enumerate(top_indices):
score = pred['scores'][idx].item()
label = pred['labels'][idx].item()
box = pred['boxes'][idx].cpu().numpy()
class_name = DIOR_CLASSES[label] if label < len(DIOR_CLASSES) else f"Class{label}"
print(f" {i+1}. {class_name}, score={score:.6f}")
# Restore original threshold
model.roi_heads.score_thresh = original_score_threshdebug_model_predictions(trained_model, dataset_train, sample_idx=0)Visualize Predictions vs Ground Truth
Letโs compare the modelโs predictions against the actual labels for some of the images we explored earlier:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
def visualize_predictions_vs_truth(model, dataset, indices=[0, 1], confidence_threshold=0.5):
"""
Compare model predictions with ground truth annotations.
"""
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
try:
model = model.to(device)
except Exception as e:
device = torch.device("cpu")
model = model.to(device)
model.eval()
fig, axes = plt.subplots(len(indices), 2, figsize=(16, 8 * len(indices)))
if len(indices) == 1:
axes = axes.reshape(1, -1)
for plot_idx, data_idx in enumerate(indices):
sample = dataset[data_idx]
# Use collate function to process image (same as training)
batch = detection_collate_fn([sample])
images, targets = batch
img_resized = images[0]
boxes_scaled = targets[0]['boxes']
# Run inference
try:
with torch.no_grad():
predictions = model([img_resized.to(device)])
except RuntimeError as e:
if "MPS" in str(e) or "mps" in str(e):
device = torch.device("cpu")
model = model.to(device)
with torch.no_grad():
predictions = model([img_resized.to(device)])
else:
raise
pred = predictions[0]
# Filter by confidence
keep = pred['scores'] > confidence_threshold
pred_boxes = pred['boxes'][keep].cpu()
pred_labels = pred['labels'][keep].cpu()
pred_scores = pred['scores'][keep].cpu()
# Convert image back to 512x512 for visualization
img_display = img_resized.permute(1, 2, 0).numpy()
# Plot ground truth
ax = axes[plot_idx, 0]
ax.imshow(img_display)
for box, label in zip(boxes_scaled, sample['label']):
xmin, ymin, xmax, ymax = box.numpy()
width = xmax - xmin
height = ymax - ymin
rect = patches.Rectangle(
(xmin, ymin), width, height,
linewidth=2, edgecolor='lime', facecolor='none'
)
ax.add_patch(rect)
class_name = DIOR_CLASSES[int(label)]
ax.text(
xmin, ymin - 5, class_name,
color='white', fontsize=8, weight='bold',
bbox=dict(boxstyle='round', facecolor='lime', alpha=0.8)
)
ax.axis('off')
ax.set_title(f"Ground Truth (Sample {data_idx}): {len(sample['label'])} objects",
fontsize=12, weight='bold')
# Plot predictions
ax = axes[plot_idx, 1]
ax.imshow(img_display)
for box, label, score in zip(pred_boxes, pred_labels, pred_scores):
xmin, ymin, xmax, ymax = box.numpy()
width = xmax - xmin
height = ymax - ymin
rect = patches.Rectangle(
(xmin, ymin), width, height,
linewidth=2, edgecolor='red', facecolor='none'
)
ax.add_patch(rect)
class_name = DIOR_CLASSES[int(label)]
ax.text(
xmin, ymin - 5, f"{class_name} ({score:.2f})",
color='white', fontsize=8, weight='bold',
bbox=dict(boxstyle='round', facecolor='red', alpha=0.8)
)
ax.axis('off')
ax.set_title(f"Model Predictions: {len(pred_boxes)} detections (conf>{confidence_threshold})",
fontsize=12, weight='bold')
plt.suptitle("Object Detection: Ground Truth vs Predictions (512ร512)",
fontsize=16, weight='bold')
plt.tight_layout()
plt.show()
visualize_predictions_vs_truth(trained_model, dataset_train, indices=[0, 1], confidence_threshold=0.001)Left column (Green boxes): Ground truth annotations from the DIOR dataset
Right column (Red boxes): Model predictions with confidence scores
Training Status After 10 Epochs: - Model is generating predictions but with low confidence (~0.004-0.005) - Currently showing detections with confidence > 0.001 (very low threshold) - Typical object detection requires 20-50+ epochs to converge - Expected behavior: confidence scores should increase to 0.5+ with more training
Why Low Confidence? - Limited training data (1000 samples) and epochs (10) - Single-scale detection (no FPN) is challenging - ViT backbone requires careful tuning for detection tasks
To Improve Results: - Train for 30-50 epochs with full dataset - Consider using learning rate scheduling - Experiment with different anchor sizes for your specific object scales
Note that images are resized to 512ร512 for visualization, matching the modelโs input size.
Summary and Key Takeaways
Congratulations! Youโve successfully built and trained an object detection model. Hereโs what you accomplished:
โ Complete Pipeline Built
- Dataset Loading & Exploration
- Loaded DIOR object detection dataset from Hugging Face
- Converted from COCO format to PyTorch/torchvision format
- Explored bounding box format (COCO
[x,y,w,h]โ[xmin,ymin,xmax,ymax]) - Visualized 800ร800 aerial images with bounding boxes
- Model Architecture
- Built complete detection pipeline: Prithvi ViT โ Faster R-CNN (single-scale)
- Used actual Geospatial Foundation Model (HLS satellite pretrained!)
- Fixed ViTโdetection compatibility issues
- Understood the role of each component (backbone, neck, head)
- Data Pipeline
- Created train/val splits with
Subset - Implemented custom
collate_fnthat resizes 800ร800 โ 512ร512 - Scaled bounding boxes proportionally (factor: 0.64)
- Formatted targets as list of dicts (torchvision standard)
- Created train/val splits with
- Training & Evaluation
- Successfully trained detection model with decreasing losses
- Understood detection-specific training (internal loss computation)
- Monitored RPN and ROI head losses separately
- Learned about convergence time (20-50 epochs typical for object detection)
- Debugged prediction confidence scores and thresholds
- Visualized predictions vs ground truth with adjustable confidence thresholds
๐ Key Lessons
How We Made Prithvi ViT Work for Object Detection:
The Core Challenge: - Faster R-CNN expects multi-scale CNN features (FPN pyramid: multiple resolutions) - Prithvi ViT produces single-scale sequence features (1024 tokens of 768 dimensions) - Multi-scale components (RPN, ROI pooler) fail or filter incorrectly with single-scale input
Our Four-Part Solution:
- Skip FPN (neck=None)
- No feature pyramid network
- Work directly on ViTโs single 32ร32 feature map
- Avoids multi-scale expectations entirely
- Custom RPN with Single-Scale Anchors
- AnchorGenerator: 5 sizes ร 3 ratios = 15 anchors per location
- Anchor sizes: 16, 32, 64, 128, 256 pixels (tuned for 512ร512 DIOR images after resize)
- 32ร32 feature map ร 15 anchors = 15,360 anchor boxes per image
- RPN configured for single feature map (not FPN pyramid)
- Custom Single-Scale ROI Pooler
- Replace MultiScaleRoIAlign (expects FPN dict) with single RoIAlign
- Spatial scale: 1/16 (ViT stride: 512 input โ 32ร32 features)
- Output: 7ร7 pooled features per proposal (standard for detection head)
- Bypasses multi-scale feature map filtering that caused empty results
- Custom Forward Pass for ViT Features
- Reshape sequence [B, 1024, 768] โ spatial [B, 768, 32, 32]
- Handle [CLS] token removal if present (1025 โ 1024 patches)
- Ensure tensor contiguity for MPS (Apple Silicon GPU) compatibility
- Wrap in OrderedDict for torchvision Faster R-CNN interface
- Fixed Image Transform
- Set min_size=512, max_size=512 to prevent resizing
- Collate function handles 800ร800 โ 512ร512 resize with bbox scaling
- Model transform preserves 512ร512 (doesnโt resize again)
- Result: Consistent 32ร32 ViT features throughout pipeline
Why DIOR Made It Possible: - Fixed 800ร800 images (no variable size complexity) - Square images (no aspect ratio issues)
- Consistent patch count (always 1,024) - Large dataset (18,000 train, 3,463 test)
Key Insight: - Prithvi (HLS pretrained) is a true Geospatial Foundation Model - With proper dataset (fixed-size) and architectural adaptation, ViT GFMs work for detection - The techniques you learned apply to integrating any new backbone architecture!
Object Detection vs Classification: - Detection models compute losses internally - Must pass both images AND targets during training - Returns loss dict (train) or predictions (eval) - Variable-sized images require custom collate functions
๐ Next Steps
To Use Other Geospatial Foundation Models:
- Try other ViT-based GFMs - SatMAE, Clay (apply same adaptation techniques)
- Use larger Prithvi - Prithvi v2 300M or 600M for better performance
- Use DETR/Deformable DETR - Native transformer detection (no CNN conversion needed)
- Explore MMDetection - More flexible ViT integration than torchvision
To Improve This Model:
- Train Longer: 20-50 epochs with learning rate scheduling
- Current 2-10 epochs gives confidence ~0.004-0.01
- Target 30+ epochs for confidence >0.5
- Use ReduceLROnPlateau or CosineAnnealingLR
- Use Full Dataset: Currently using 1,000 samples subset
- Full DIOR: 19,000 training images
- More data = better convergence
- Add Data Augmentation: Carefully with bbox transforms
- Horizontal flips (easy)
- Random crops (requires bbox adjustment)
- Color jittering (safe - doesnโt affect boxes)
- Tune Detection Thresholds:
- NMS threshold (currently 0.5)
- Score threshold (currently 0.001 for early training)
- Increase score_thresh as confidence improves
- Evaluate with Proper Metrics:
- mAP@0.5 and mAP@[0.5:0.95] (COCO metrics)
- Per-class AP to find difficult classes
- Confusion matrix for misclassifications
- Experiment with Anchor Sizes:
- Current: 16, 32, 64, 128, 256
- Analyze DIOR object size distribution
- Adjust if objects are consistently smaller/larger
Applying What You Learned to GFMs:
When GFM+detection integration improves, youโll use the exact same: - Dataset loading and exploration patterns - Custom collate functions - Training loop structure
- Evaluation approaches
The skills transfer completely - only the backbone changes!
๐ Further Reading
- RetinaNet Paper - Focal loss for dense object detection
- Feature Pyramid Networks - Multi-scale features
- MMDetection - Production detection framework
- DETR - Detection transformers for end-to-end detection
You now have a working foundation for object detection that you can build upon! ๐ฏ