Cloud & Scalable Computing for Geospatial AI
This cheatsheet covers cloud computing strategies, distributed training, and scalable deployment for geospatial foundation models.
Cloud Computing Fundamentals
Key Cloud Platforms for Geospatial AI
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import rasterio
from rasterio.windows import Window
import numpy as np
from pathlib import Path
import dask.array as da
from dask.distributed import Client
import xarray as xr
# Cloud platform configurations
cloud_configs = {
'gcp': {
'compute': ['n1-highmem-32', 'n1-standard-96'],
'gpu': ['nvidia-tesla-v100', 'nvidia-tesla-t4', 'nvidia-tesla-a100'],
'storage': 'gs://bucket-name/',
'earth_engine': True
},
'aws': {
'compute': ['m5.24xlarge', 'c5.24xlarge'],
'gpu': ['p3.16xlarge', 'p4d.24xlarge'],
'storage': 's3://bucket-name/',
'sagemaker': True
},
'azure': {
'compute': ['Standard_D64s_v3', 'Standard_M128s'],
'gpu': ['Standard_NC24rs_v3', 'Standard_ND40rs_v2'],
'storage': 'https://account.blob.core.windows.net/',
'machine_learning': True
}
}
print("Cloud Platform Comparison:")
for platform, config in cloud_configs.items():
print(f"{platform.upper()}: {config['compute'][0]} | {config['gpu'][0]}")Distributed Data Loading
class DistributedGeospatialDataset(torch.utils.data.Dataset):
"""Dataset for distributed training with large geospatial tiles"""
def __init__(self, image_paths, rank=0, world_size=1, tile_size=256):
self.image_paths = image_paths
self.rank = rank
self.world_size = world_size
self.tile_size = tile_size
# Distribute files across processes
files_per_rank = len(image_paths) // world_size
start_idx = rank * files_per_rank
end_idx = start_idx + files_per_rank if rank < world_size - 1 else len(image_paths)
self.local_paths = image_paths[start_idx:end_idx]
def __len__(self):
return len(self.local_paths) * 4 # 4 tiles per image
def __getitem__(self, idx):
file_idx = idx // 4
tile_idx = idx % 4
with rasterio.open(self.local_paths[file_idx]) as src:
height, width = src.height, src.width
# Calculate tile position
row = (tile_idx // 2) * (height // 2)
col = (tile_idx % 2) * (width // 2)
window = Window(col, row, self.tile_size, self.tile_size)
tile = src.read(window=window)
return torch.from_numpy(tile.astype(np.float32))
# Usage example
def setup_distributed_training():
"""Initialize distributed training environment"""
if 'RANK' in os.environ:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
torch.distributed.init_process_group('nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(local_rank)
return rank, world_size, local_rank
else:
return 0, 1, 0
rank, world_size, local_rank = setup_distributed_training()
print(f"Process {rank}/{world_size} on device {local_rank}")Dask for Large-Scale Processing
Chunked Array Operations
# Large raster processing with Dask
def process_large_raster_dask(file_path, chunk_size=1024):
"""Process large rasters using Dask arrays"""
with rasterio.open(file_path) as src:
# Create dask array from raster
dask_array = da.from_delayed(
da.delayed(lambda: src.read())(dtype=src.dtypes[0]),
shape=(src.count, src.height, src.width),
dtype=src.dtypes[0]
)
# Rechunk for optimal processing
dask_array = dask_array.rechunk((1, chunk_size, chunk_size))
# Normalize per chunk
normalized = (dask_array - dask_array.mean()) / dask_array.std()
# Compute NDVI if sufficient bands
if src.count >= 4: # Assuming NIR is band 4, Red is band 3
nir = dask_array[3]
red = dask_array[2]
ndvi = (nir - red) / (nir + red)
return ndvi.compute()
return normalized.compute()
# Distributed client setup
def setup_dask_cluster():
"""Setup Dask distributed cluster"""
# Local cluster
from dask.distributed import LocalCluster
cluster = LocalCluster(n_workers=4, threads_per_worker=2, memory_limit='4GB')
client = Client(cluster)
# Or cloud cluster (example for GCP)
# from dask_kubernetes import KubeCluster
# cluster = KubeCluster.from_yaml('dask-worker-spec.yaml')
# cluster.scale(10) # Scale to 10 workers
print(f"Dask dashboard: {client.dashboard_link}")
return client
client = setup_dask_cluster()Parallel Model Inference
def distributed_model_inference(model, data_paths, client):
"""Run model inference across distributed workers"""
def inference_task(path_batch):
"""Single worker inference task"""
import torch
results = []
for path in path_batch:
with rasterio.open(path) as src:
data = src.read()
tensor = torch.from_numpy(data).unsqueeze(0).float()
with torch.no_grad():
output = model(tensor)
results.append(output.numpy())
return results
# Distribute paths across workers
n_workers = len(client.scheduler_info()['workers'])
batch_size = len(data_paths) // n_workers
futures = []
for i in range(0, len(data_paths), batch_size):
batch = data_paths[i:i+batch_size]
future = client.submit(inference_task, batch)
futures.append(future)
# Gather results
results = client.gather(futures)
return np.concatenate([r for batch in results for r in batch])Google Earth Engine Integration
Scalable Earth Engine Processing
import ee
# Initialize Earth Engine
ee.Initialize()
def large_scale_ee_processing():
"""Large-scale processing using Earth Engine"""
# Define region of interest (e.g., entire continent)
region = ee.Geometry.Polygon([
[[-180, -60], [180, -60], [180, 60], [-180, 60]]
])
# Load Sentinel-2 collection
collection = (ee.ImageCollection('COPERNICUS/S2_SR')
.filterDate('2023-01-01', '2023-12-31')
.filterBounds(region)
.filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 20)))
# Create composite
composite = collection.median().clip(region)
# Calculate NDVI
ndvi = composite.normalizedDifference(['B8', 'B4']).rename('NDVI')
# Export to Cloud Storage
task = ee.batch.Export.image.toCloudStorage(
image=ndvi,
description='global_ndvi_2023',
bucket='your-gcs-bucket',
scale=10,
region=region,
maxPixels=1e13,
shardSize=256
)
task.start()
print(f"Task started: {task.status()}")
return task
# Batch processing function
def batch_ee_export(regions, collection_name):
"""Export multiple regions in batch"""
tasks = []
for i, region in enumerate(regions):
collection = (ee.ImageCollection(collection_name)
.filterBounds(region)
.filterDate('2023-01-01', '2023-12-31'))
composite = collection.median().clip(region)
task = ee.batch.Export.image.toCloudStorage(
image=composite,
description=f'region_{i:03d}',
bucket='your-processing-bucket',
scale=10,
region=region,
maxPixels=1e9
)
task.start()
tasks.append(task)
return tasksModel Optimization Strategies
Model Quantization and Pruning
import torch.quantization as quantization
from torch.nn.utils import prune
class OptimizedGeoModel(torch.nn.Module):
"""Optimized model for deployment"""
def __init__(self, base_model):
super().__init__()
self.backbone = base_model.backbone
self.classifier = base_model.classifier
def forward(self, x):
features = self.backbone(x)
return self.classifier(features)
def optimize_model_for_deployment(model, sample_input):
"""Apply optimization techniques for deployment"""
# 1. Pruning (remove 30% of weights)
parameters_to_prune = []
for module in model.modules():
if isinstance(module, torch.nn.Conv2d):
parameters_to_prune.append((module, 'weight'))
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.3
)
# Make pruning permanent
for module, param_name in parameters_to_prune:
prune.remove(module, param_name)
# 2. Quantization
model.eval()
# Post-training quantization
quantized_model = quantization.quantize_dynamic(
model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
)
# 3. TorchScript compilation
traced_model = torch.jit.trace(quantized_model, sample_input)
return traced_model
# Example usage
sample_input = torch.randn(1, 4, 256, 256) # Batch, Channels, Height, Width
optimized_model = optimize_model_for_deployment(model, sample_input)
# Save optimized model
torch.jit.save(optimized_model, 'optimized_geo_model.pt')
print(f"Model size reduced from {model_size_mb:.1f}MB to {optimized_size_mb:.1f}MB")ONNX Export for Cross-Platform Deployment
import onnx
import onnxruntime as ort
def export_to_onnx(model, sample_input, output_path):
"""Export PyTorch model to ONNX format"""
model.eval()
# Export to ONNX
torch.onnx.export(
model,
sample_input,
output_path,
input_names=['satellite_image'],
output_names=['predictions'],
dynamic_axes={
'satellite_image': {0: 'batch_size', 2: 'height', 3: 'width'},
'predictions': {0: 'batch_size'}
},
opset_version=11
)
# Verify ONNX model
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
# Test with ONNX Runtime
ort_session = ort.InferenceSession(output_path)
ort_inputs = {ort_session.get_inputs()[0].name: sample_input.numpy()}
ort_outputs = ort_session.run(None, ort_inputs)
print(f"ONNX model exported successfully to {output_path}")
return ort_session
# Export model
onnx_session = export_to_onnx(model, sample_input, 'geo_model.onnx')Container Deployment
Docker Configuration
# Create Dockerfile programmatically
dockerfile_content = '''
FROM nvidia/cuda:11.8-runtime-ubuntu20.04
# Install system dependencies
RUN apt-get update && apt-get install -y \\
python3 python3-pip \\
gdal-bin libgdal-dev \\
&& rm -rf /var/lib/apt/lists/*
# Set GDAL environment variables
ENV GDAL_DATA=/usr/share/gdal
ENV PROJ_LIB=/usr/share/proj
# Install Python dependencies
COPY requirements.txt .
RUN pip3 install --no-cache-dir -r requirements.txt
# Copy model and inference code
COPY geo_model.onnx /app/
COPY inference_api.py /app/
WORKDIR /app
# Expose port
EXPOSE 8000
# Run inference API
CMD ["python3", "inference_api.py"]
'''
# Requirements file
requirements_content = '''
torch>=1.12.0
torchvision>=0.13.0
rasterio>=1.3.0
numpy>=1.21.0
fastapi>=0.68.0
uvicorn>=0.15.0
onnxruntime-gpu>=1.12.0
pillow>=8.3.0
'''
# Save files
with open('Dockerfile', 'w') as f:
f.write(dockerfile_content)
with open('requirements.txt', 'w') as f:
f.write(requirements_content)
print("Docker configuration files created")Kubernetes Deployment
# Kubernetes deployment YAML
k8s_deployment = '''
apiVersion: apps/v1
kind: Deployment
metadata:
name: geo-ai-inference
spec:
replicas: 3
selector:
matchLabels:
app: geo-ai-inference
template:
metadata:
labels:
app: geo-ai-inference
spec:
containers:
- name: geo-ai-inference
image: your-registry/geo-ai:latest
ports:
- containerPort: 8000
resources:
requests:
memory: "4Gi"
cpu: "2"
nvidia.com/gpu: 1
limits:
memory: "8Gi"
cpu: "4"
nvidia.com/gpu: 1
env:
- name: MODEL_PATH
value: "/app/geo_model.onnx"
---
apiVersion: v1
kind: Service
metadata:
name: geo-ai-service
spec:
selector:
app: geo-ai-inference
ports:
- port: 80
targetPort: 8000
type: LoadBalancer
'''
with open('k8s-deployment.yaml', 'w') as f:
f.write(k8s_deployment)
print("Kubernetes deployment configuration created")Performance Monitoring
Resource Monitoring
import psutil
import nvidia_ml_py3 as nvml
import time
import logging
class ResourceMonitor:
"""Monitor system resources during inference"""
def __init__(self):
self.logger = logging.getLogger('resource_monitor')
try:
nvml.nvmlInit()
self.gpu_available = True
self.device_count = nvml.nvmlDeviceGetCount()
except:
self.gpu_available = False
self.device_count = 0
def get_system_stats(self):
"""Get current system resource usage"""
stats = {
'cpu_percent': psutil.cpu_percent(interval=1),
'memory_percent': psutil.virtual_memory().percent,
'memory_gb': psutil.virtual_memory().used / (1024**3),
'disk_io': psutil.disk_io_counters(),
'network_io': psutil.net_io_counters()
}
if self.gpu_available:
gpu_stats = []
for i in range(self.device_count):
handle = nvml.nvmlDeviceGetHandleByIndex(i)
# GPU utilization
util = nvml.nvmlDeviceGetUtilizationRates(handle)
# Memory info
mem_info = nvml.nvmlDeviceGetMemoryInfo(handle)
gpu_stats.append({
'gpu_id': i,
'gpu_util_percent': util.gpu,
'memory_util_percent': util.memory,
'memory_used_gb': mem_info.used / (1024**3),
'memory_total_gb': mem_info.total / (1024**3)
})
stats['gpu'] = gpu_stats
return stats
def log_performance_metrics(self, inference_time, batch_size):
"""Log performance metrics"""
stats = self.get_system_stats()
throughput = batch_size / inference_time
self.logger.info(f"Inference Time: {inference_time:.3f}s")
self.logger.info(f"Throughput: {throughput:.1f} images/sec")
self.logger.info(f"CPU Usage: {stats['cpu_percent']:.1f}%")
self.logger.info(f"Memory Usage: {stats['memory_percent']:.1f}%")
if 'gpu' in stats:
for gpu in stats['gpu']:
self.logger.info(f"GPU {gpu['gpu_id']} Util: {gpu['gpu_util_percent']:.1f}%")
# Usage
monitor = ResourceMonitor()
start_time = time.time()
# Run inference here...
# predictions = model(batch)
inference_time = time.time() - start_time
monitor.log_performance_metrics(inference_time, batch_size=32)Best Practices Summary
Scalability Checklist
scalability_checklist = {
'Data Management': [
'✓ Use chunked/tiled data formats (COG, Zarr)',
'✓ Implement distributed data loading',
'✓ Cache frequently accessed data',
'✓ Use cloud-native data formats'
],
'Model Optimization': [
'✓ Apply quantization for deployment',
'✓ Use model pruning to reduce size',
'✓ Convert to ONNX for cross-platform deployment',
'✓ Implement batch inference'
],
'Infrastructure': [
'✓ Use auto-scaling compute resources',
'✓ Implement load balancing',
'✓ Monitor resource utilization',
'✓ Use container orchestration (Kubernetes)'
],
'Cost Optimization': [
'✓ Use spot/preemptible instances',
'✓ Implement lifecycle policies for storage',
'✓ Monitor and alert on costs',
'✓ Use appropriate instance types for workload'
]
}
for category, items in scalability_checklist.items():
print(f"\n{category}:")
for item in items:
print(f" {item}")Key Takeaways
- Distributed Training: Use DDP for multi-GPU training across nodes
- Data Parallelism: Distribute large datasets using Dask and cloud storage
- Model Optimization: Apply quantization, pruning, and ONNX export for deployment
- Container Deployment: Use Docker and Kubernetes for scalable inference
- Resource Monitoring: Track CPU, GPU, memory usage for optimization
- Cloud Integration: Leverage Earth Engine, cloud storage, and managed services
- Cost Management: Use spot instances and lifecycle policies for cost control
These techniques enable processing continent-scale geospatial data and deploying models to serve millions of inference requests efficiently.