001-Introduction-to-Finetuning¶
📥 Download 001-Introduction-to-Finetuning.ipynb and try it out
Introduction¶
This notebook is intended to be an introduction to using the python SDK to fine-tune a new model from a geospatial foundation model backbone using the Geospatial Studio.
For more information about the Geospatial Studio see the docs page: Geospatial Studio Docs
For more information about the Geospatial Studio SDK and all the functions available through it, see the SDK docs page: Geospatial Studio SDK Docs
Prerequisites¶
- Access to a deploy instance of the Geospatial Studio.
- Ability to run and edit a copy of this notebook.
Install SDK:¶
Prepare a python 3.9+ environment, however you normally do that (e.g. conda, pyenv, poetry, etc.) and activate this new environment.
Install Jupyter into that environment:
python -m pip install --upgrade pipthenpip install notebookInstall the SDK with:
python -m pip install geostudio
Install notebook dependecies¶
!pip install seaborn
%load_ext autoreload
%autoreload 2
# first import the required packages
import json
import uuid
import pandas as pd
import wget
import rasterio
import matplotlib.pyplot as plt
from IPython.display import display, HTML
import seaborn as sns
import getpass # For use in Colab as well
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
from geostudio import Client
from geostudio import gswidgets
Connecting to the platform¶
First, we set up the connection to the platform backend. To do this we need the base url for the studio UI and an API key.
To get an API Key:
- Go to the Geospatial Studio UI page and navigate to the Manage your API keys link.
- This should pop-up a window where you can generate, access and delete your api keys. NB: every user is limited to a maximum of two activate api keys at any one time.
Store the API key and geostudio ui base url in a credentials file locally, for example in /User/bob/.geostudio_config_file. You can do this by:
echo "GEOSTUDIO_API_KEY=<paste_api_key_here>" > .geostudio_config_file
echo "BASE_STUDIO_UI_URL=<paste_ui_base_url_here>" >> .geostudio_config_file
Copy and paste the file path to this credentials file in call below.
#############################################################
# Initialize Geostudio client using a geostudio config file
#############################################################
gfm_client = Client(geostudio_config_file=".geostudio_config_file")
Setting up a fine-tuning task¶
Now we are all set to prepare our fine-tuning task. This assumes that the tuning dataset to be used is already present in the platform (if it is not, please see the dataset factory examples and return here once the dataset is onboarded).
In order to run a fine-tuning task, you need to select the following items:
- tuning task type - what type of learning task are you attempting? segmentation, regression etc
- fine-tuning dataset - what dataset will you use to train the model for your particular application?
- base foundation model - which geospatial foundation model will you use as the starting point for your tuning task?
Below we walk you through how to use the Geospatial Studio SDK to see what options are available in the platform for each of these, then once you have made your selection, how we configure our task and submit it.
Tuning task selection¶
The tuning task tells the model what type of task it is (segmentation, regression etc), and exposes a range of optional hyperparameters which the user can set. These all have reasonable defaults, but it gives uses the possibility to configure the model training how they wish. Below, we will check what task templates are available to us, and then update some parameters.
Advanced users can create and upload new task templates to the platform, and instructions are found in the relevant notebook and documentation. The templates are for Terratorch (the backend tuning library), and more details of Terratroch and configuration options can be found here: https://terrastackai.github.io/terratorch/
tasks = gfm_client.list_tune_templates(output="df")
display(tasks[['name','description', 'id','created_by','updated_at']])
# Choose a task from the options above. Copy and paste the id into the variable, task_id, below.
task_id = 'e4791b2c-bb17-4a5e-9f05-1be5411a4fa6'
# Now we can view the full meta-data and details of the selected task
task_meta = gfm_client.get_task(task_id=task_id)
task_meta
If you are happy with your choice, you can decide which (if any) hyperparameters you want to set (otherwise defaults will be used).
Here we can see the available parameters and their associated defaults. To update a parameter you can just set values in the dictionary (as shown below for max_epochs).
task_params = gfm_client.get_task_param_defaults(task_id)
task_params
task_params['runner']['max_epochs'] = 5
task_params['optimizer']['type'] = 'AdamW'
task_params['data']['batch_size'] = 4
Dataset selection¶
Now we have chosen the type of tuning task we wish to carry out, we need to decide on the tuning dataset. There are two options available:
- use a dataset already registered in the Studio
- create a new dataset by uploading or curating a dataset
In this notebook, we use a already existing dataset. For a walkthrough of how to create new datasets see the relevant example and documentation.
datasets = gfm_client.list_datasets(output='df')
display(datasets[['dataset_name','description','id','status','created_by','updated_at']])
# Explore the dataset
gfm_client.get_dataset("geodata-ferctkm2brxpkbqz9apa6z")
# Copy and paste the id of the dataset into the variable below
dataset_id = 'geodata-ferctkm2brxpkbqz9apa6z'
Foundation model selection¶
The final selection we need to make before kicking off our tuning task is to select the backbone/base model we wish to start from. Again, we can first view the available options in the studio, then make our selection.
base = gfm_client.list_base_models(output='df')
display(base[['name','description','id','updated_at']])
# Copy and paste the id of the base model you selected into the variable below
base_model_id = '71c82e28-c0ee-44b8-aba9-7facd94e08ec'
Submitting the tuning task¶
Now we put that information into the payload below and send the request to the cluster. In this case we will use the asynchronous submission (avoids issues with timeouts for large areas and time windows).
tune_payload = {
"name": "test-fine-tuning",
"description": "testing",
"dataset_id": dataset_id,
"base_model_id": base_model_id,
"tune_template_id": task_id,
# "model_parameters": task_params # uncomment this line if you customised task_params in the cells above otherwise, defaults will be used
}
print(json.dumps(tune_payload, indent=2))
submitted = gfm_client.submit_tune(
data = tune_payload,
output = 'json'
)
print(submitted)
Monitor tuning status and progress¶
After submitting the request, we can poll the inference service to check the progress and get the output details once its complete (this could take a few minutes depending on the request size and the current service load).
# If you wish to you can keep polling the tuning task to monitor its progress.
r = gfm_client.poll_finetuning_until_finished(tune_id=submitted['tune_id'])
tune_id = submitted["tune_id"]
tune_info = gfm_client.get_tune(tune_id, output='json')
tune_info
Check the training metrics from the tune¶
The metrics from the model training are logged in a backend MLflow service and can be accessed through the APIs, SDK and UI.
You can get access the training metrics either in full as a json using:
gfm_client.get_tune_metrics(tune_id)
Or directly to a pandas dataframe for ready analysis using the function below get_tune_metrics_df. In addition, the SDK provides functionality to quickly plot some top level metrics for training and validation.
In addition to that, you can simply plot the training and validation loss and multi-class accuracy using the plot_tune_metrics function.
mdf = gfm_client.get_tune_metrics_df(tune_id)
mdf.head()
gswidgets.plot_tune_metrics(client=gfm_client, tune_id=tune_id)
Try out the model for inference¶
Once your model has finished tuning, if you want to run inference as a test you can do by passing either a location (bbox) or a url to a pre-prepared files. The steps to test the model are:
- Define the inference payload
- Try out the tune temporarily
Using an S3 pre-signed link¶
If you have your image locally and would like to pre-sign the image using S3.
Personal buckets¶
Use the create_upload_presigned_url to generate an upload link that you can use to upload the file to the dataset.
This function assumes you have your own storage bucket to upload to.
upload_url = gfm_client.create_upload_presigned_url(
bucket_name="bucket_name", # bucket name
object_key="data/train/austin1_sdk_upload.tiff", # file path to upload in the bucket
endpoint_url="https://s3.us-east.cloud-object-storage.appdomain.cloud", # s3 endpoint url
service_name= "s3", # service to use
region_name="us-east", # cloud region
expiration=3600 # expiration
# Add any other args to pass to the s3 client
)
upload_url
# Push your file to the bucket using the url generated.
!curl -X PUT -T **your_file.zip or your_file.tiff or your_file.tif** "**upload_url**"
Once the image is uploaded to your s3 bucket, create a download link to use in the inference request.
download_url = gfm_client.create_download_presigned_url(
bucket_name="geospatial-studio-example-data", # bucket name
object_key="data/train/austin1_sdk_upload.tiff", # file path to upload in the bucket
endpoint_url="https://s3.us-east.cloud-object-storage.appdomain.cloud", # s3 endpoint url
service_name= "s3", # service to use
region_name="us-east", # cloud region
expiration=7200 # expiration
# Add any other args to pass to the s3 client
)
download_url
Geostudio temporary buckets¶
If you would like to upload to a geostudio temporary bucket, use this function get_fileshare_links function.
# Unique object name to be used in temporary COS for each layer you want to upload
object_name = "austin1_sdk_upload.tiff"
gfm_client.get_fileshare_links(object_name)
# Push your file to the bucket using the url generated.
!curl -X PUT -T **your_file.zip or your_file.tiff or your_file.tif** "**upload_url**"
Submit Inference¶
Now you can create the inference payload using the download link.
# define the inference payload
bbox = [-121.837006,39.826468,-121.641312,40.038655]
download_url_tiff = download_url
# When using a bbox
request_payload_with_bbox = {
"description": "Park Fire 2024 SDK",
"location": "Red Bluff, California, United States",
"spatial_domain": {
"bbox": [bbox], # When using bboxes
"polygons": [],
"tiles": [],
"urls": []
},
"temporal_domain": [
"2024-08-12_2024-08-13"
]
}
# When using a presigned link
request_payload_with_url = {
"description": "Park Fire 2024 SDK",
"location": "Red Bluff, California, United States",
"spatial_domain": {
"bbox": [],
"polygons": [],
"tiles": [],
"urls": [download_url_tiff] # When using url
},
"temporal_domain": [
"2024-08-12_2024-08-13"
]
}
# Now submit the test inference request
# Change the request to the correct one when using urls
inference_response = gfm_client.try_out_tune(tune_id=tune_id, data=request_payload_with_bbox)
inference_response
Downloading the tuned model artefacts¶
If you want to download the model artefacts (e.g. checkpoint and config) in order to run the model locally or elsewhere, you can use the following function to do it.
gfm_client.download_tune(tune_id)