FineTuning: HPO With Terratorch Iterate¶
📥 Download 003-Terratorch-Iterate.ipynb and try it out
Introduction¶
This notebook demonstrates how to use the FineTuning SDK to submit an HPO (Hyperparameter Optimization) job to the FineTuning service.
Prerequisites¶
Before proceeding with this notebook, ensure you have:
- Active GeoStudio Service Access: Valid credentials and permissions for the GeoStudio inference service
- SDK Installation: The GeoStudio SDK installed in your environment
- Authentication Setup: API keys configured (either via environment variables or key files)
- TerraTorch Iterate Config File: A prepared configuration file (.yaml) for running fine-tuning.
Note: This workflow assumes you have already prepared a configuration file for TerraTorch Iterate. If you need guidance on fine-tuning TerraTorch models with HPO enabled, refer to the TerraTorch-Iterate documentation first.
Imports & Setup¶
%load_ext autoreload
%autoreload 2
from geostudio import Client
Connecting to Geospatial Studio¶
First, we set up the connection to the platform backend. To do this we need the base url for the studio UI and an API key.
To get an API Key:
- Go to the Geospatial Studio UI page and navigate to the Manage your API keys link.
- This should pop-up a window where you can generate, access and delete your api keys. NB: every user is limited to a maximum of two activate api keys at any one time.
Store the API key and geostudio ui base url in a credentials file locally, for example in /User/bob/.geostudio_config_file. You can do this by:
echo "GEOSTUDIO_API_KEY=<paste_api_key_here>" > .geostudio_config_file
echo "BASE_STUDIO_UI_URL=<paste_ui_base_url_here>" >> .geostudio_config_file
Copy and paste the file path to this credentials file in call below.
#############################################################
# 3. Initialize clients using the key?
#############################################################
geostudio_client = Client(geostudio_config_file=".geostudio_config_file")
Preparing and Onboarding Data Required for FineTuning¶
In order to onboard your dataset to the Geospatial Studio, you need to have a direct download URL pointing to a zip file of the dataset.
Review the notebooks on data ../dataset-onboarding to onboard new datasets into the studio.
If you already have some onboarded datasets, look them up and select the one you need for finetuning.
geostudio_client.api_url
datasets = geostudio_client.list_datasets(output="df")
display(datasets[['id','dataset_name', 'purpose', 'status','size','description', 'created_by']])
Submitting the tune¶
Once the data is onboarded and you have a valid terratorch-iterate config yaml, you are ready to kick off your hpo tuning task. In order to run a fine-tuning task, you need to select the following items:
tune_metadata: Identifying info about your tune such as:name: A name to identify your tunedescription: Some detailed description about your experiment.dataset_id: A dataset id for a dataset that should have been onboarded in the Studio.
Prepare the fine-tuning payload and submit your tuning job.
tunes = geostudio_client.list_tunes(output="df")
# tunes
display(tunes[['id','active', 'created_by', 'name', 'description', 'dataset_id', 'status', 'metrics']])
tune = geostudio_client.submit_hpo_tune(
data={
"tune_metadata": {
"name": "fire-scars-hpo-tune-016",
"description": "Fine-tuned TerraTorch model for fire scar detection",
"dataset_id": "geodata-gdctf3vb3znbbtgptqvuku",
},
"config_file": "../sample_files/burnscars-iterate-hpo.yaml"
}
)
display(tune)
Fetch tune results¶
tune_resp = geostudio_client.get_tune(tune["tune_id"])
tune_resp
MLFlow Experiment Visualization¶
import mlflow
import matplotlib.pyplot as plt
import pandas as pd
mlflow_tracking_uri = "<studio-mlflow-tracking-uri>"
mlflow.set_tracking_uri(mlflow_tracking_uri)
experiment_name = "geotune-qokd9fqyuhxbgyyiuurpxu"
client = mlflow.tracking.MlflowClient()
try:
experiment = client.get_experiment_by_name(experiment_name)
if not experiment:
raise ValueError(f"Experiment '{experiment_name}' not found.")
experiment_id = experiment.experiment_id
except Exception as e:
print(f"Error: {e}")
exit()
runs = client.search_runs(experiment_ids=[experiment_id])
if not runs:
print("No runs found for this experiment.")
exit()
# Create a DataFrame from the runs for easier data manipulation
def create_runs_dataframe(runs):
run_data = []
for run in runs:
# Extract desired metrics and parameters
metrics = run.data.metrics
params = run.data.params
# You may want to flatten the run data into a single dictionary
row = {**metrics, **params, 'run_id': run.info.run_id}
run_data.append(row)
return pd.DataFrame(run_data)
# Create the dataframe and fill any missing values with NaN
runs_df = create_runs_dataframe(runs)
runs_df = runs_df.fillna(value=pd.NA)
# Print a summary of the runs to check the data
display("Summary of MLflow Runs:")
display(runs_df.head())
import matplotlib.pyplot as plt
import numpy as np
metrics_to_plot = ["train/loss", "epoch"]
step_interval = 30 # only keep points every 100 steps
# run = client.get_run(runs[0].info.run_id)
# metrics_keys = run.data.metrics.keys()
# metrics_to_plot = [metric for metric in metrics_keys if not metric.startswith("System")]
for metric_name in metrics_to_plot:
plt.figure(figsize=(6, 4))
plt.title(f"{metric_name.replace('_', ' ').title()} Over Steps", fontsize=10)
plt.xlabel("Step", fontsize=8)
plt.ylabel(metric_name.replace('_', ' ').title(), fontsize=8)
plt.grid(True)
for run in runs:
metric_history = client.get_metric_history(run.info.run_id, metric_name)
if metric_history:
steps = [m.step for m in metric_history]
values = [m.value for m in metric_history]
# Downsample: group by step_interval and take average
smoothed_steps = []
smoothed_values = []
for i in range(0, len(steps), step_interval):
chunk_steps = steps[i:i+step_interval]
chunk_values = values[i:i+step_interval]
smoothed_steps.append(np.mean(chunk_steps))
smoothed_values.append(np.mean(chunk_values))
run_params = run.data.params
legend_label = f"Run {run.info.run_id[:8]}"
if "lr" in run_params:
legend_label += f" (lr={float(run_params['lr']):.18f})"
if "batch_size" in run_params:
legend_label += f" (bs={run_params['batch_size']})"
# print(metric_name, run_params)
plt.plot(smoothed_steps, smoothed_values, label=legend_label)
else:
print(f"No history for metric '{metric_name}' in run {run.info.run_id}")
# plt.legend(loc="lower center", bbox_to_anchor=(0.5, -0.5), ncol=2)
plt.legend(
loc="upper center",
bbox_to_anchor=(0.5, -0.20),
ncol=1, # number of columns (adjust based on #runs)
frameon=False,
fontsize=8,
)
plt.tight_layout()
plt.show()