Lab 3¶
Upload Model Checkpoints and Run Inference
⏱️ Estimated Duration: 30 minutes
📊 Difficulty Level: Intermediate
📥 Getting the Lab Materials
Getting the Lab Materials: Clone the repository:
git clone https://github.com/terrastackai/geospatial-studio.git
cd geospatial-studio/workshop/docs/notebooks
jupyter notebook lab3-running-inference.ipynb
⚠️ Note: This lab requires JSON configuration files. Cloning the repository ensures you have all necessary files.
🎯 Learning Objectives¶
By the end of this lab, you will be able to:
- Understand what model checkpoints are and why they're important
- Upload fine-tuned model checkpoints to Geospatial Studio
- Configure task templates for different model types
- Run inference using uploaded models
- Customize inference payloads for different scenarios
- Bring your own model checkpoints to the platform
📖 What are Model Checkpoints?¶
A model checkpoint is a saved state of a trained machine learning model. It contains:
- Model weights - The learned parameters from training
- Configuration - Architecture and hyperparameter settings
- Metadata - Information about training data and performance
Why Upload Checkpoints?¶
- Reuse trained models - Use models trained elsewhere
- Share models - Collaborate with your team
- Version control - Track different model versions
- Production deployment - Deploy models for inference
🌊 About This Lab's Example¶
We'll upload a flood detection model checkpoint and run inference on it. This model:
- Task: Segmentation (pixel-wise classification)
- Input: Sentinel-2 satellite imagery (6 bands + SCL)
- Output: Flood/No-flood/Permanent water classification
- Use case: Disaster response and flood monitoring
⚠️ Important: Compute Requirements¶
GPU vs CPU Inference¶
Recommended: This lab is best run with GPU acceleration for optimal performance.
CPU Inference:
- ✅ Can run on CPU - Inference will work without a GPU
- ⚠️ Significantly slower - Expect 5-10x longer processing times
- ⚠️ High CPU usage - May consume substantial CPU resources
- ⚠️ Memory intensive - Requires sufficient RAM (8GB+ recommended)
- ⚠️ Process management - May need to stop other resource-intensive processes
Resource Considerations¶
When running inference on CPU:
- Close unnecessary applications to free up resources
- Monitor system resources during inference
- Be patient - CPU inference takes longer but will complete
- Consider smaller spatial extents for testing
For Cluster Deployments¶
If running on a cluster with limited GPUs:
- You may need to scale down the
terratorch-inferencedeployment to release GPU resources - Monitor pod status:
kubectl get pods -n <namespace> - Check logs if pods remain pending:
kubectl logs <pod-name>
import json
from geostudio import Client
print("✅ Imports successful")
Connect to Geospatial Studio¶
Use the same configuration file from Lab 1:
# Initialize the client
client = Client(geostudio_config_file=".geostudio_config_file")
print("✅ Connected to Geospatial Studio")
📋 Step 1: Onboard a Task Template¶
What is a Task Template?¶
A task template defines:
- The type of ML task (segmentation, regression, classification)
- Model architecture options
- Training hyperparameters
- Default configurations
Why Do We Need It?¶
When you upload a model checkpoint, Studio needs to know:
- What task the model was trained for
- How to configure the model for inference
- What inputs the model expects
Available Task Types¶
- Segmentation - Pixel-wise classification (floods, burn scars, land cover)
- Regression - Continuous value prediction (biomass, temperature)
- Classification - Image-level labels (crop type, building detection)
Let's onboard the segmentation template:
# Load the segmentation task template from file
with open('../../../populate-studio/payloads/templates/template-seg.json', 'r') as f:
segmentation_template = json.load(f)
print("📋 Segmentation Task Template:")
print(json.dumps(segmentation_template, indent=2))
# Create the task template
print("⏳ Creating task template...")
template_response = client.create_task(segmentation_template)
print("\n✅ Task template created successfully")
print(f"📋 Template ID: {template_response['id']}")
📤 Step 2: Upload a Model Checkpoint¶
What You Need to Upload¶
To upload a model checkpoint, you need:
- Model checkpoint file (
.ckpt) - The trained weights - Configuration file (
.yaml) - Model architecture and settings - Metadata - Information about the model
Checkpoint Structure¶
The checkpoint payload includes:
- URLs - Where to download the checkpoint and config
- Model input spec - What bands/data the model expects
- Visualization config - How to display results
- Post-processing - Optional masking and filtering
Example: Flood Detection Model¶
Let's upload a pre-trained flood detection model:
# Load the flood detection model checkpoint from file
with open('../../../populate-studio/payloads/tunes/tune-prithvi-eo-flood.json', 'r') as f:
flood_checkpoint = json.load(f)
print("📋 Flood Detection Model Checkpoint:")
print(json.dumps(flood_checkpoint, indent=2))
# Upload the checkpoint
print("⏳ Uploading model checkpoint...")
print(" This may take 1-2 minutes to download and register the model.")
tune_response = client.upload_completed_tunes(flood_checkpoint)
print("\n✅ Model checkpoint uploaded successfully")
print(f"🆔 Tune ID: {tune_response['tune_id']}")
Monitor Upload Progress¶
The upload process involves:
- Downloading checkpoint and config files
- Validating the model structure
- Registering in the database
Let's wait for it to complete:
# Poll until upload is complete
tune_id = tune_response['tune_id']
print(f"⏳ Monitoring upload progress for tune ID: {tune_id}")
client.poll_finetuning_until_finished(tune_id)
print("\n✅ Model is ready for inference!")
print("💡 You can now view this model in the Studio UI under 'Models & Tunes'")
🚀 Step 3: Run Inference¶
Understanding Inference Payloads¶
An inference payload defines:
- Spatial domain - Where to run inference (bbox, tiles, or URLs)
- Temporal domain - When (date or date range)
- Model configuration - Which model to use
- Processing options - Masking, filtering, etc.
Example: Flood Detection in Assam, India¶
Let's run inference on a flood event:
# Define the inference payload
inference_payload = {
"model_display_name": "flood-detection-demo",
"fine_tuning_id": tune_id,
"location": "Dakhin Petbaha, Raha, Nagaon, Assam, India",
"description": "Flood detection in Assam using Sentinel-2",
# Spatial domain: bounding box
"spatial_domain": {
"bbox": [
[92.703396, 26.247896, 92.748087, 26.267903]
],
"urls": [],
"tiles": [],
"polygons": []
},
# Temporal domain: date range
"temporal_domain": [
"2024-07-25_2024-07-28"
]
}
# Submit the inference request
print("🚀 Submitting inference request...")
inference_response = client.try_out_tune(tune_id=tune_id, data=inference_payload)
print("\n✅ Inference submitted successfully")
print(f"🆔 Inference ID: {inference_response.get('inference_id', 'N/A')}")
print("\n💡 View results in the Studio UI:")
print(" 1. Go to 'Models & Tunes'")
print(" 2. Click on 'flood-detection-demo'")
print(" 3. Check the 'History' tab for your inference")
🎨 Step 4: Customize Inference Configuration¶
Advanced Payload Options¶
You can customize many aspects of inference:
- Pipeline steps - Control the processing workflow
- Post-processing - Apply masks and filters
- Visualization - Customize how results are displayed
- Band configuration - Specify input bands and scaling
Example: Full Configuration¶
Here's a more detailed inference payload:
# Advanced inference payload with full configuration
advanced_payload = {
"model_display_name": "flood-detection-advanced",
"fine_tuning_id": tune_id,
"location": "Custom Location",
"description": "Advanced inference with custom configuration",
# Spatial domain with URL
"spatial_domain": {
"bbox": [],
"urls": [
"https://your-data-url.com/imagery.tif"
],
"tiles": [],
"polygons": []
},
# Temporal domain
"temporal_domain": [
"2024-08-01"
],
# Pipeline steps (processing workflow)
"pipeline_steps": [
{
"status": "READY",
"process_id": "url-connector",
"step_number": 0
},
{
"status": "WAITING",
"process_id": "terratorch-inference",
"step_number": 1
},
{
"status": "WAITING",
"process_id": "postprocess-generic",
"step_number": 2
},
{
"status": "WAITING",
"process_id": "push-to-geoserver",
"step_number": 3
}
],
# Post-processing options
"post_processing": {
"cloud_masking": "True",
"ocean_masking": "True",
"snow_ice_masking": "True",
"permanent_water_masking": "False"
},
# Model input data specification
"model_input_data_spec": [
{
"bands": [
{"index": "0", "RGB_band": "B", "band_name": "Blue", "scaling_factor": "0.0001"},
{"index": "1", "RGB_band": "G", "band_name": "Green", "scaling_factor": "0.0001"},
{"index": "2", "RGB_band": "R", "band_name": "Red", "scaling_factor": "0.0001"},
{"index": "3", "band_name": "NIR_Narrow", "scaling_factor": "0.0001"},
{"index": "4", "band_name": "SWIR1", "scaling_factor": "0.0001"},
{"index": "5", "band_name": "SWIR2", "scaling_factor": "0.0001"}
],
"connector": "sentinelhub",
"collection": "hls_l30",
"file_suffix": "_merged.tif",
"modality_tag": "HLS_L30"
}
],
# GeoServer visualization configuration
"geoserver_push": [
{
"z_index": 0,
"workspace": "geofm",
"layer_name": "input_rgb",
"file_suffix": "",
"display_name": "Input image (RGB)",
"filepath_key": "model_input_original_image_rgb",
"geoserver_style": {
"rgb": [
{"label": "RedChannel", "channel": 1, "maxValue": 255, "minValue": 0},
{"label": "GreenChannel", "channel": 2, "maxValue": 255, "minValue": 0},
{"label": "BlueChannel", "channel": 3, "maxValue": 255, "minValue": 0}
]
},
"visible_by_default": "True"
},
{
"z_index": 1,
"workspace": "geofm",
"layer_name": "pred",
"file_suffix": "",
"display_name": "Model prediction",
"filepath_key": "model_output_image",
"geoserver_style": {
"segmentation": [
{"color": "#000000", "label": "No data", "opacity": 0, "quantity": "0"},
{"color": "#FA4D56", "label": "Flood", "opacity": 1, "quantity": "1"}
]
},
"visible_by_default": "True"
}
]
}
print("📋 Advanced payload configuration:")
print(json.dumps(advanced_payload, indent=2))
🎓 Bring Your Own Model (BYOM)¶
Requirements for Your Own Checkpoints¶
To upload your own model checkpoint, you need:
1. Model Files¶
- Checkpoint file (
.ckpt) - PyTorch Lightning checkpoint - Config file (
.yaml) - TerraTorch configuration - Both files must be accessible via HTTPS URLs
2. Task Template¶
- Know which task type your model was trained for:
- Segmentation
- Regression
- Classification
- The task template must exist in Studio (create it first if needed)
3. Model Metadata¶
- Input bands - Which spectral bands your model expects
- Band order - The sequence of bands
- Scaling factors - How to normalize input data
- Output classes - For segmentation/classification
4. Visualization Configuration¶
- Color schemes - For segmentation classes
- Value ranges - For regression outputs
- Layer names - For GeoServer display
Example Template for Your Model¶
Here's a template you can adapt:
# Template for your own model checkpoint
your_model_checkpoint = {
"name": "your-model-name",
"description": "Description of your model",
# URLs to your model files (must be publicly accessible)
"tune_config_url": "https://your-storage.com/path/to/config.yaml",
"tune_checkpoint_url": "https://your-storage.com/path/to/checkpoint.ckpt",
# Specify the input bands your model expects
"model_input_data_spec": [
{
"bands": [
# List each band with its index, name, and scaling
{"index": 0, "band_name": "Band1", "scaling_factor": 1},
{"index": 1, "band_name": "Band2", "scaling_factor": 1},
# Add more bands as needed
],
"connector": "sentinel_aws", # or "sentinelhub", "url-connector"
"collection": "sentinel-2-l2a", # or your data source
"file_suffix": "_suffix"
}
],
# Configure how results are visualized
"geoserver_push": [
{
"z_index": 0,
"layer_name": "input_rgb",
"display_name": "Input Image",
"filepath_key": "model_input_original_image_rgb",
"geoserver_style": {
"rgb": [
{"label": "Red", "channel": 1, "maxValue": 255, "minValue": 0},
{"label": "Green", "channel": 2, "maxValue": 255, "minValue": 0},
{"label": "Blue", "channel": 3, "maxValue": 255, "minValue": 0}
]
}
},
{
"z_index": 1,
"layer_name": "prediction",
"display_name": "Model Output",
"filepath_key": "model_output_image",
"geoserver_style": {
# For segmentation:
"segmentation": [
{"color": "#FF0000", "label": "Class 1", "opacity": 1, "quantity": "1"},
{"color": "#00FF00", "label": "Class 2", "opacity": 1, "quantity": "2"}
]
# For regression, use "regression" instead
}
}
],
# Post-processing options
"post_processing": {
"cloud_masking": "False",
"ocean_masking": "False",
"snow_ice_masking": None,
"permanent_water_masking": "False"
}
}
print("📝 Template for your own model:")
print("\n⚠️ Remember to:")
print(" 1. Replace URLs with your actual model files")
print(" 2. Update band configuration to match your model")
print(" 3. Ensure the task template exists in Studio")
print(" 4. Configure visualization for your output classes")
print("\n💡 To upload: client.upload_completed_tunes(your_model_checkpoint)")
💡 Tips and Best Practices¶
Model Upload¶
- ✅ Host checkpoint files on reliable cloud storage (S3, GCS, Azure Blob)
- ✅ Use presigned URLs if files are private
- ✅ Ensure URLs are accessible from your Studio deployment
- ✅ Test URLs in a browser before uploading
Task Templates¶
- ✅ Create task template before uploading checkpoints
- ✅ Use the correct task type (segmentation/regression/classification)
- ✅ Match the template to your model's training configuration
Inference Configuration¶
- ✅ Start with simple payloads, then add complexity
- ✅ Test with small spatial extents first
- ✅ Use appropriate temporal ranges for your data source
- ✅ Enable post-processing masks as needed
Troubleshooting¶
- ❌ If upload fails, check pod logs:
kubectl logs geofm-gateway-xxx - ❌ If inference fails, verify band configuration matches input data
- ❌ If visualization is wrong, check GeoServer style configuration
🎉 Summary¶
In this lab, you learned how to:
✅ Create and register task templates
✅ Upload pre-trained model checkpoints
✅ Run inference with uploaded models
✅ Customize inference payloads
✅ Prepare your own models for upload
Next Steps¶
- Lab 4: Complete end-to-end workflow with training and inference
- Explore: Try different models and configurations
- Experiment: Upload your own model checkpoints