Skip to content

Fine-tuning

geostudio.backends.v2.gtune.client

Client

Client(
    api_config: GeoFmSettings = None,
    session: Session = None,
    api_token: str = None,
    api_key: str = None,
    api_key_file: str = None,
    geostudio_config_file: str = None,
    *args,
    **kwargs
)

Bases: BaseClient

Source code in geostudio/backends/base_client.py
def __init__(
    self,
    api_config: GeoFmSettings = None,
    session: requests.Session = None,
    api_token: str = None,
    api_key: str = None,
    api_key_file: str = None,
    geostudio_config_file: str = None,
    *args,
    **kwargs,
):
    """
    Initializes the GeoFmClient with the provided configuration.

    Args:
        api_config (GeoFmSettings, optional): The configuration settings for the GeoFm API. Defaults to None.
        session (requests.Session, optional): A pre-configured requests session. Defaults to None.
        api_token (str, optional): The API token for authentication. Defaults to None.
        api_key (str, optional): The API key for authentication. Defaults to None.
        api_key_file (str, optional): The path to the file containing the API key. Defaults to None.
        geostudio_config_file (str): The file path to the geostudio config path containing api_key + base_urls.
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Raises:
        GeoFMException: If no API token, API key, or API key file is provided.

    Attributes:
        api_config (GeoFmSettings): The configuration settings for the GeoFm API.
        session (requests.Session): A pre-configured requests session.
        logger (logging.Logger): The logger instance for logging messages.
    """
    self.api_config = api_config or GeoFmSettings()

    if api_token:
        print("Using api_token")
        api_token = api_token or GeoFmSettings.GEOFM_API_TOKEN
        self.session = gfm_session(access_token=api_token)
    elif api_key:
        print("Using api_key from sdk command")
        self.session = gfm_session(api_key=api_key)
    elif api_key_file:
        if not os.path.isfile(api_key_file):
            raise GeoFMException("Config file does not exist, Please provide a valid config file.")
        print("Using api_key from file")
        self.session = gfm_session(api_key_file=api_key_file)
    elif geostudio_config_file:
        if not os.path.isfile(geostudio_config_file):
            raise GeoFMException("Config file does not exist, Please provide a valid config file.")
        print("Using api key and base urls from geostudio config file")
        geostudio_config_file_values = dotenv_values(geostudio_config_file)
        settings.BASE_GATEWAY_API_URL = geostudio_config_file_values.get("BASE_GATEWAY_API_URL", "")
        settings.BASE_STUDIO_UI_URL = geostudio_config_file_values.get("BASE_STUDIO_UI_URL", "")
        settings.GEOSTUDIO_API_KEY = geostudio_config_file_values.get("GEOSTUDIO_API_KEY", None)
        self.session = gfm_session(api_key=settings.GEOSTUDIO_API_KEY)
    else:
        raise GeoFMException("Missing APIToken. Add `GEOFM_API_TOKEN` to env variables.")

    # else:
    #     self.session = session or gfm_session(
    #         client_id=self.api_config.ISV_CLIENT_ID,
    #         client_secret=self.api_config.ISV_CLIENT_SECRET,
    #         well_known_url=self.api_config.ISV_WELL_KNOWN,
    #         userinfo_endpoint=self.api_config.ISV_USER_ENDPOINT,
    #     )
    self.logger = logging.getLogger()

list_tunes

list_tunes(output: str = 'json')

Lists all fine tuning jobs in the studio.

Parameters:

Name Type Description Default
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary containing the list of tunes found.

Source code in geostudio/backends/v2/gtune/client.py
def list_tunes(self, output: str = "json"):
    """
    Lists all fine tuning jobs in the studio.

    Args:
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: A dictionary containing the list of tunes found.
    """
    response = self.http_get(f"{self.api_version}/tunes", output=output, data_field="results")
    return response

get_tune

get_tune(tune_id: str, output: str = 'json')

Retrieves a tune by ID. If the tune's status is Failed, a pre-signed url for the logs is generated.

Parameters:

Name Type Description Default
tune_id str

The unique identifier of the tune to retrieve.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The tune's status and information

Source code in geostudio/backends/v2/gtune/client.py
def get_tune(self, tune_id: str, output: str = "json"):
    """
    Retrieves a tune by ID. If the tune's status is Failed, a pre-signed url for the logs is generated.

    Parameters:
        tune_id (str): The unique identifier of the tune to retrieve.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The tune's status and information
    """
    response = self.http_get(f"{self.api_version}/tunes/{tune_id}", output=output)
    return response

update_tune

update_tune(
    tune_id: str, data: TuneUpdateIn, output: str = "json"
)

Update a tune in the database

Parameters:

Name Type Description Default
tune_id str

The unique identifier of the tune to be updated.

required
data TuneUpdateIn

A dictionary containing the data to update for the tune.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary of the updated tune.

Source code in geostudio/backends/v2/gtune/client.py
def update_tune(self, tune_id: str, data: TuneUpdateIn, output: str = "json"):
    """
    Update a tune in the database

    Args:
        tune_id (str): The unique identifier of the tune to be updated.
        data (TuneUpdateIn): A dictionary containing the data to update for the tune.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: A dictionary of the updated tune.
    """
    payload = json.loads(TuneUpdateIn(**data).model_dump_json())
    response = self.http_patch(f"{self.api_version}/tunes/{tune_id}", data=payload, output=output)
    return response

delete_tune

delete_tune(tune_id, output: str = 'json')

Deletes a specified tune using its ID.

Parameters:

Name Type Description Default
tune_id str

The ID of the tune to be deleted.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

Message of successfully deleted tune

Source code in geostudio/backends/v2/gtune/client.py
def delete_tune(self, tune_id, output: str = "json"):
    """
    Deletes a specified tune using its ID.

    Args:
        tune_id (str): The ID of the tune to be deleted.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: Message of successfully deleted tune
    """
    response = self.http_delete(f"{self.api_version}/tunes/{tune_id}", output=output)
    return response

submit_tune

submit_tune(data: TuneSubmitIn, output: str = 'json')

Submit a fine-tuning job to the Geospatial studio platform

Parameters:

Name Type Description Default
data TuneSubmitIn

Parameters for the tuning job.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The server's response containing the submitted tune info.

Source code in geostudio/backends/v2/gtune/client.py
def submit_tune(self, data: TuneSubmitIn, output: str = "json"):
    """
    Submit a fine-tuning job to the Geospatial studio platform

    Args:
        data (TuneSubmitIn): Parameters for the tuning job.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The server's response containing the submitted tune info.
    """
    data["name"] = data["name"].lower().replace(" ", "-").replace("_", "-")

    payload = json.loads(TuneSubmitIn(**data).model_dump_json())
    response = self.http_post(f"{self.api_version}/submit-tune", data=payload, output=output)
    return response

submit_hpo_tune

submit_hpo_tune(
    data: HpoTuneSubmitIn, output: str = "json"
)

Submit a fine-tuning job with terratorch-iterate enabled.

Parameters:

Name Type Description Default
data HpoTuneSubmitIn

Parameters for the tuning job

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The server's response containing the submitted tune info.

Source code in geostudio/backends/v2/gtune/client.py
def submit_hpo_tune(self, data: HpoTuneSubmitIn, output: str = "json"):
    """Submit a fine-tuning job with terratorch-iterate enabled.

    Args:
        data (HpoTuneSubmitIn): Parameters for the tuning job
        output ( str, optional):  The desired output format. Defaults to "json".

    Returns:
        dict: The server's response containing the submitted tune info.
    """
    if isinstance(data, dict):
        data = HpoTuneSubmitIn(**data)

    if not os.path.isfile(data.config_file):
        raise ValueError(f"Config file not found: {data.config_file}")
    if os.path.getsize(data.config_file) == 0:
        raise ValueError(f"Config file is empty: {data.config_file}")

    filename = os.path.basename(data.config_file)
    with open(data.config_file, "rb") as fobj:
        config_content = fobj.read()

    files = {"config_file": (filename, config_content, "application/x-yaml")}
    payload = {"tune_metadata": data.tune_metadata.model_dump_json()}
    response = self.http_post(
        f"{self.api_version}/submit-hpo-tune",
        data=payload,
        files=files,
        output=output,
    )
    return response

upload_completed_tunes

upload_completed_tunes(data: UploadTuneInput)

Upload a completed fine-tuning job to the Geostudio platform

Parameters:

Name Type Description Default
data UploadTuneInput

Parameters to update the tune with

required

Returns:

Name Type Description
dict

Message of successfully uploaded tune

Source code in geostudio/backends/v2/gtune/client.py
def upload_completed_tunes(self, data: UploadTuneInput):
    """
    Upload a completed fine-tuning job to the Geostudio platform

    Args:
        data (UploadTuneInput): Parameters to update the tune with

    Returns:
        dict: Message of successfully uploaded tune
    """
    payload = json.loads(UploadTuneInput(**data).model_dump_json())
    response = self.http_post(f"{self.api_version}/upload-completed-tunes", data=payload, output="json")
    return response

try_out_tune

try_out_tune(tune_id: str, data: TryOutTuneInput)

Try-out inference on a tune without deploying the model.

Parameters:

Name Type Description Default
tune_id str

The unique identifier of the tune experiment.

required
data TryOutTuneInput

The inference configurations to try the tune on

required

Returns:

Name Type Description
dict

Dictionary containing the details of the inference submitted.

Source code in geostudio/backends/v2/gtune/client.py
def try_out_tune(self, tune_id: str, data: TryOutTuneInput):
    """Try-out inference on a tune without deploying the model.

    Args:
        tune_id (str): The unique identifier of the tune experiment.
        data (TryOutTuneInput): The inference configurations to try the tune on

    Returns:
        dict: Dictionary containing the details of the inference submitted.
    """
    payload = json.loads(TryOutTuneInput(**data).model_dump_json())
    response = self.http_post(f"{self.api_version}/tunes/{tune_id}/try-out", data=payload, output="json")
    return response

download_tune

download_tune(tune_id: str, output: str = 'json')

Downloads a tuned model from the server.

Parameters:

Name Type Description Default
tune_id str

The unique identifier of the tuned model to download.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

Dictionary with tune details including presigned urls to download the artifacts.

Source code in geostudio/backends/v2/gtune/client.py
def download_tune(self, tune_id: str, output: str = "json"):
    """
    Downloads a tuned model from the server.

    Args:
        tune_id (str): The unique identifier of the tuned model to download.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: Dictionary with tune details including presigned urls to download the artifacts.
    """
    response = self.http_get(f"{self.api_version}/tunes/{tune_id}/download", output=output)
    return response

get_mlflow_metrics

get_mlflow_metrics(tune_id: str, output: str = 'json')

Retrieves the MLflow URLs for the training and testing metrics of a given Tune experiment.

Parameters:

Name Type Description Default
tune_id str

The ID of the Tune experiment.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary containing the MLflow URLs for the training and testing metrics. The dictionary will have the keys "train_mlflow_url" and "test_mlflow_url". If no metrics are found, the value for "train_mlflow_url" will be None.

Source code in geostudio/backends/v2/gtune/client.py
def get_mlflow_metrics(self, tune_id: str, output: str = "json"):
    """
    Retrieves the MLflow URLs for the training and testing metrics of a given Tune experiment.

    Args:
        tune_id (str): The ID of the Tune experiment.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: A dictionary containing the MLflow URLs for the training and testing metrics.
            The dictionary will have the keys "train_mlflow_url" and "test_mlflow_url".
            If no metrics are found, the value for "train_mlflow_url" will be None.
    """
    response = self.get_tune(tune_id)

    ui_url_path = f"{settings.BASE_STUDIO_UI_URL}mlflow/#"
    # Sample output [{'Train': '/experiments/exp_id/runs/run_id'}, {'Test': '/experiments/exp_id/runs/run_id'}]
    train_path = None
    test_path = None
    try:
        if response["metrics"]:
            merged = {k: v for d in response["metrics"] for k, v in d.items()}
            train_path = merged.get("Train")
            test_path = merged.get("Test")

            train_path = f"{ui_url_path}{train_path}"
            if test_path:
                test_path = f"{ui_url_path}{test_path}"

            return {"train_mlflow_url": train_path, "test_mlflow_url": test_path}
        else:
            print(f"No mlflow url found for {tune_id}")

    except Exception as e:
        print(f"Error getting metrics urls: {e} ")

get_tune_metrics

get_tune_metrics(tune_id: str, output: str = 'json')

Retrieves the MLflow metrics for a specific tune.

Parameters:

Name Type Description Default
tune_id str

The unique identifier of the tune.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The metrics of the tune in the specified format.

Source code in geostudio/backends/v2/gtune/client.py
def get_tune_metrics(self, tune_id: str, output: str = "json"):
    """
    Retrieves the MLflow metrics for a specific tune.

    Args:
        tune_id (str): The unique identifier of the tune.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: The metrics of the tune in the specified format.
    """
    response = self.http_get(f"{self.api_version}/tunes/{tune_id}/metrics", output=output)
    return response

get_tune_metrics_df

get_tune_metrics_df(tune_id: str, run_name: str = 'Train')

Retrieves the MLflow metrics for a specific tune and displays them in a pandas DataFrame

Parameters:

Name Type Description Default
tune_id str

The unique identifier of the tune.

required

Returns:

Type Description

pd.DataFrame: A pandas DataFrame containing the tuning metrics.

Source code in geostudio/backends/v2/gtune/client.py
def get_tune_metrics_df(self, tune_id: str, run_name: str = "Train"):
    """
    Retrieves the MLflow metrics for a specific tune and displays them in a pandas DataFrame

    Args:
        tune_id (str): The unique identifier of the tune.

    Returns:
        pd.DataFrame: A pandas DataFrame containing the tuning metrics.

    """
    m = self.get_tune_metrics(tune_id)
    if not m.get("runs"):
        return pd.DataFrame()
    run = next((run for run in m.get("runs") if run.get("name") == run_name), {})
    if not run.get("metrics"):
        return pd.DataFrame()
    mdf = pd.DataFrame.from_records(run["metrics"][0])

    for i in range(2, len(run["metrics"])):
        mdf_tmp = pd.DataFrame.from_records(run["metrics"][i]).drop(["epoch"], axis=1)
        mdf = pd.concat([mdf, mdf_tmp], axis=1)

    mdf.sort_values(["epoch"], inplace=True)

    return mdf

list_tuning_artefacts

list_tuning_artefacts(tune_id: str)

Resolve the MLflow training run referenced by a tune and list artefact paths.

This function: - Calls gfm_client.get_tune(tune_id) to obtain tune metadata that contains a reference to the MLflow training run (expected under a metric named 'Train'). - Queries the MLflow server's artifacts list endpoint for that run to obtain available artifact file paths.

Parameters

tune_id : str Identifier of the tune (used to lookup metrics that contain the MLflow train run).

Returns

tuple[list[str], str] A tuple (art_files, train_run_id) where: - art_files : list[str] — list of artifact paths returned by MLflow (from the 'files' array -> each element's 'path'). - train_run_id : str — the resolved MLflow run id extracted from the tune metadata.

Notes
  • The function expects the tune metadata (gfm_client.get_tune) to include a metric mapping containing a 'Train' entry whose value includes the MLflow run UUID (the run id is taken as the last path segment after splitting on '/').
  • The MLflow artifacts list response is assumed to include a JSON 'files' array where each item has a 'path' key.
  • Example: art_files, run_id = list_tuning_artefacts('geotune-xxxxx', 'https://my-mlflow')
Source code in geostudio/backends/v2/gtune/client.py
def list_tuning_artefacts(self, tune_id: str):
    """
    Resolve the MLflow training run referenced by a tune and list artefact paths.

    This function:
    - Calls gfm_client.get_tune(tune_id) to obtain tune metadata that contains a
        reference to the MLflow training run (expected under a metric named 'Train').
    - Queries the MLflow server's artifacts list endpoint for that run to obtain
        available artifact file paths.

    Parameters
    ----------
    tune_id : str
        Identifier of the tune (used to lookup metrics that contain the MLflow train run).

    Returns
    -------
    tuple[list[str], str]
        A tuple (art_files, train_run_id) where:
        - art_files : list[str] — list of artifact paths returned by MLflow (from the
            'files' array -> each element's 'path').
        - train_run_id : str — the resolved MLflow run id extracted from the tune metadata.

    Notes
    -----
    - The function expects the tune metadata (gfm_client.get_tune) to include a metric
    mapping containing a 'Train' entry whose value includes the MLflow run UUID (the
    run id is taken as the last path segment after splitting on '/').
    - The MLflow artifacts list response is assumed to include a JSON 'files' array where
    each item has a 'path' key.
    - Example:
        art_files, run_id = list_tuning_artefacts('geotune-xxxxx', 'https://my-mlflow')
    """

    mlflow_url = f"{'/'.join(self.api_url[:-1].split('/')[:-1])}/mlflow"

    tune_info = self.get_tune(tune_id)
    train_run_id = {k: v for d in tune_info["metrics"] for k, v in d.items()}["Train"].split("/")[-1]

    # req = requests.get(f"{mlflow_url}/api/2.0/mlflow/artifacts/list?run_id={train_run_id}")
    print(f"{mlflow_url}/api/2.0/mlflow/artifacts/list?run_id={train_run_id}")
    req = self.http_get(f"{mlflow_url}/api/2.0/mlflow/artifacts/list?run_id={train_run_id}", output="json")
    art_list = req["files"]
    art_files = [X["path"] for X in art_list]
    print(f"Found {len(art_files)} artefacts")
    return art_files, train_run_id

get_tuning_artefacts

get_tuning_artefacts(
    tune_id: str,
    epochs: list = None,
    image_numbers: list = None,
)

Download fine‑tuning artefact images from an MLflow run referenced by a tune.

This function: - Resolves the MLflow training run id for the given tune via gfm_client.get_tune(...) - Lists artefacts for that run from the MLflow server - Optionally filters artefact filenames by epoch and/or image number - Downloads matching artefacts in parallel and returns a list of records

Parameters

tune_id : str Identifier of the tune (used to lookup metrics that contain the MLflow train run). epochs : list[int], optional If provided, only artefacts whose filename encodes an epoch contained in this list are retained. Filenames are assumed to contain epoch as the second underscore-separated token (e.g. "epoch_4_5.png" -> epoch 4). image_numbers : list[int], optional If provided, only artefacts whose filename encodes an image number contained in this list are retained. Filenames are assumed to contain the image number as the third underscore-separated token (e.g. "epoch_4_5.png" -> image_number 5).

Returns

list[dict] A list of dictionaries, one per downloaded artefact, with keys: - 'filename' (str): artefact path from MLflow - 'image' (bytes): raw downloaded bytes - 'epoch' (int): parsed epoch number - 'image_number' (int): parsed image/sample number

Notes
  • Downloads are performed in parallel using joblib (threads).
  • The function assumes artefact filenames follow the pattern containing "epoch__.".
Source code in geostudio/backends/v2/gtune/client.py
def get_tuning_artefacts(self, tune_id: str, epochs: list = None, image_numbers: list = None):
    """
    Download fine‑tuning artefact images from an MLflow run referenced by a tune.

    This function:
    - Resolves the MLflow training run id for the given tune via gfm_client.get_tune(...)
    - Lists artefacts for that run from the MLflow server
    - Optionally filters artefact filenames by epoch and/or image number
    - Downloads matching artefacts in parallel and returns a list of records

    Parameters
    ----------
    tune_id : str
        Identifier of the tune (used to lookup metrics that contain the MLflow train run).
    epochs : list[int], optional
        If provided, only artefacts whose filename encodes an epoch contained in this list
        are retained. Filenames are assumed to contain epoch as the second underscore-separated
        token (e.g. "epoch_4_5.png" -> epoch 4).
    image_numbers : list[int], optional
        If provided, only artefacts whose filename encodes an image number contained in this
        list are retained. Filenames are assumed to contain the image number as the third
        underscore-separated token (e.g. "epoch_4_5.png" -> image_number 5).

    Returns
    -------
    list[dict]
        A list of dictionaries, one per downloaded artefact, with keys:
        - 'filename' (str): artefact path from MLflow
        - 'image' (bytes): raw downloaded bytes
        - 'epoch' (int): parsed epoch number
        - 'image_number' (int): parsed image/sample number

    Notes
    -----
    - Downloads are performed in parallel using joblib (threads).
    - The function assumes artefact filenames follow the pattern containing
    "epoch_<epoch>_<image_number>.<ext>".
    """

    requests.packages.urllib3.disable_warnings()
    art_files, train_run_id = self.list_tuning_artefacts(tune_id)

    if epochs is not None:
        art_files = [X for X in art_files if int(X.split("_")[1]) in epochs]
    if image_numbers is not None:
        art_files = [X for X in art_files if int(X.split("_")[2].split(".")[0]) in image_numbers]

    print(f"Downloading {len(art_files)} artefacts...")

    ans = list(
        track(
            Parallel(n_jobs=10, prefer="threads")(
                delayed(self.get_training_image)(fn, train_run_id) for fn in art_files
            ),
            total=len(art_files),
        )
    )

    print("Downloaded all artefacts")
    img_dict = [
        {
            "filename": art_files[X],
            "image": ans[X],
            "epoch": int(art_files[X].split("_")[1]),
            "image_number": int(art_files[X].split("_")[2].split(".")[0]),
        }
        for X in range(0, len(art_files))
    ]

    return img_dict

list_tune_templates

list_tune_templates(output: str = 'json')

Lists tune templates studio.

Parameters:

Name Type Description Default
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary containing the list of tune templates in the studio

Source code in geostudio/backends/v2/gtune/client.py
def list_tune_templates(self, output: str = "json"):
    """
    Lists tune templates studio.

    Args:
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: A dictionary containing the list of tune templates in the studio
    """
    response = self.http_get(f"{self.api_version}/tune-templates", output=output, data_field="results")
    return response

create_task

create_task(data: TaskIn, output: str = 'json')

Creates a new task using the provided data.

Parameters:

Name Type Description Default
data TaskIn

The data required to create a new task.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The response from the server containing the details of the newly created task.

Source code in geostudio/backends/v2/gtune/client.py
def create_task(self, data: TaskIn, output: str = "json"):
    """
    Creates a new task using the provided data.

    Args:
        data (TaskIn): The data required to create a new task.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The response from the server containing the details of the newly created task.
    """
    response = self.http_post(f"{self.api_version}/tune-templates", data=data, output=output)
    return response

get_task

get_task(task_id: str, output: str = 'json')

Retrieves a task by its ID.

Parameters:

Name Type Description Default
task_id str

The ID of the task to retrieve.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The response from the server containing the task details.

Source code in geostudio/backends/v2/gtune/client.py
def get_task(self, task_id: str, output: str = "json"):
    """
    Retrieves a task by its ID.

    Args:
        task_id (str): The ID of the task to retrieve.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: The response from the server containing the task details.
    """
    response = self.http_get(f"{self.api_version}/tune-templates/{task_id}", output=output, data_field="results")
    return response

delete_task

delete_task(task_id, output: str = 'json')

Deletes a task with the given task_id.

Parameters:

Name Type Description Default
task_id str

The unique identifier of the task to be deleted.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

Message of successfully deleted task.

Source code in geostudio/backends/v2/gtune/client.py
def delete_task(self, task_id, output: str = "json"):
    """
    Deletes a task with the given task_id.

    Args:
        task_id (str): The unique identifier of the task to be deleted.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: Message of successfully deleted task.

    """
    response = self.http_delete(f"{self.api_version}/tune-templates/{task_id}", output=output)
    return response

get_task_template

get_task_template(task_id: str, output: str = 'text')

Retrieves the task template yaml for the selected task

Parameters:

Name Type Description Default
task_id str

The ID of the task to retrieve.

required
output str

The format of the response. Can either be "cell", "text" or "file".

'text'

Returns:

Name Type Description
dict

The response from the server containing the task template yaml.

Source code in geostudio/backends/v2/gtune/client.py
def get_task_template(self, task_id: str, output: str = "text"):
    """
    Retrieves the task template yaml for the selected task

    Args:
        task_id (str): The ID of the task to retrieve.
        output (str, optional): The format of the response. Can either be "cell", "text" or "file".

    Returns:
        dict: The response from the server containing the task template yaml.
    """
    response = self.http_get(
        f"{self.api_version}/tune-templates/{task_id}/template", output="json", data_field="results"
    )
    if output == "text":
        return response["reason"]
    elif output == "cell":
        create_new_cell(f"ty = '''{response['reason']}''' ")
    elif output == "file":
        with open(task_id + ".yaml", "w") as fp:
            fp.write(response["reason"])

update_task

update_task(
    task_id: str, file_path: str, output: str = "json"
)

Updates a task's content with a yaml file config

Parameters:

Name Type Description Default
task_id str

The ID of the task to upload.

required
file_path str

The path to the file containing the new template.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

Message of successful task upload

Source code in geostudio/backends/v2/gtune/client.py
def update_task(self, task_id: str, file_path: str, output: str = "json"):
    """
    Updates a task's content with a yaml file config

    Args:
        task_id (str): The ID of the task to upload.
        file_path (str): The path to the file containing the new template.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: Message of successful task upload
    """

    response = self.http_put_file(
        f"{self.api_version}/tune-templates/{task_id}/template", file_path=file_path, output=output
    )
    return response

update_task_schema

update_task_schema(
    task_id: str, task_schema: Any, output: str = "json"
)

Update the JSONSchema of a task.

Parameters:

Name Type Description Default
task_id str

The ID of the task to update.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

Message of successful task update

Source code in geostudio/backends/v2/gtune/client.py
def update_task_schema(self, task_id: str, task_schema: Any, output: str = "json"):
    """
    Update the JSONSchema of a task.

    Args:
        task_id (str): The ID of the task to update.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: Message of successful task update
    """

    response = self.http_put(f"{self.api_version}/tune-templates/{task_id}/schema", data=task_schema, output=output)
    return response

get_task_param_defaults

get_task_param_defaults(task_id: str)

Retrieves the default parameter values for a given task.

Parameters:

Name Type Description Default
task_id str

The unique identifier of the task.

required

Returns:

Name Type Description
dict

A dictionary containing the default parameter values for the task. The keys are the parameter names and the values are the default values.

Source code in geostudio/backends/v2/gtune/client.py
def get_task_param_defaults(self, task_id: str):
    """
    Retrieves the default parameter values for a given task.

    Args:
        task_id (str): The unique identifier of the task.

    Returns:
        dict: A dictionary containing the default parameter values for the task.
            The keys are the parameter names and the values are the default values.
    """
    task_meta = self.get_task(task_id)
    defaults_dict = {}
    for k in task_meta["model_params"]["properties"].keys():
        if "properties" in task_meta["model_params"]["properties"][k]:
            defaults_dict[k] = task_meta["model_params"]["properties"][k]["default"]
    return defaults_dict

check_task_content

check_task_content(
    task_id: str,
    dataset_id: str,
    base_model_id: Any,
    output: str = "text",
)

Checks that the the task renders correctly

Parameters:

Name Type Description Default
task_id str

The ID of the task to check.

required
output str

The format of the returned template. Can be "text", "cell", or "file". Defaults to "text".

'text'

Returns:

Name Type Description
dict

Message of task content

Source code in geostudio/backends/v2/gtune/client.py
def check_task_content(self, task_id: str, dataset_id: str, base_model_id: Any, output: str = "text"):
    """
    Checks that the the task renders correctly

    Args:
        task_id (str): The ID of the task to check.
        output (str, optional): The format of the returned template. Can be "text", "cell", or "file". Defaults to "text".

    Returns:
        dict: Message of task content
    """
    params = {"dataset_id": dataset_id, "base_model": base_model_id}
    response = self.http_get(
        f"{self.api_version}/tune-templates/{task_id}/test-render", params=params, output="json"
    )
    if output == "text":
        return response["reason"]
    elif output == "cell":
        create_new_cell(f"ty = '''{response['reason']}''' ")
    elif output == "file":
        with open(task_id + ".yaml", "w") as fp:
            fp.write(response["reason"])

render_template

render_template(
    task_id: str, dataset_id: str, output: str = "text"
)

Checks that the the user defined task renders correctly.

Parameters:

Name Type Description Default
task_id str

The ID of the task to check.

required
dataset_id str

The ID of the dataset associated with the task.

required
output str

The format of the returned template. Can be "text", "cell", or "file". Defaults to "text".

'text'

Returns:

Name Type Description
dict

The rendered template in the specified output format.

Source code in geostudio/backends/v2/gtune/client.py
def render_template(self, task_id: str, dataset_id: str, output: str = "text"):
    """
    Checks that the the user defined task renders correctly.

    Args:
        task_id (str): The ID of the task to check.
        dataset_id (str): The ID of the dataset associated with the task.
        output (str, optional): The format of the returned template. Can be "text", "cell", or "file". Defaults to "text".

    Returns:
        dict: The rendered template in the specified output format.
    """
    t_params = {"dataset_id": dataset_id}
    response = self.http_get(
        f"{self.api_version}/tune-templates/{task_id}/test-render-user-defined-task", params=t_params, output="json"
    )
    if output == "text":
        return response["reason"]
    elif output == "cell":
        create_new_cell(f"ty = '''{response['reason']}''' ")
    elif output == "file":
        with open(task_id + ".yaml", "w") as fp:
            fp.write(response["reason"])

list_datasets

list_datasets(output: str = 'json')

Lists all datasets available in the studio.

Parameters:

Name Type Description Default
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary containing a list of datasets found in the dataset factory

Source code in geostudio/backends/v2/gtune/client.py
def list_datasets(self, output: str = "json"):
    """
    Lists all datasets available in the studio.

    Parameters:
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: A dictionary containing a list of datasets found in the dataset factory
    """
    response = self.http_get(f"{self.api_version}/datasets", output=output, data_field="results")
    return response

pre_scan_dataset

pre_scan_dataset(
    data: PreScanDatasetIn, output: str = "json"
)

Scans a new dataset - checks accessibility of the dataset URL, ensures corresponding data and label files are present, and extracts bands and their descriptions from the dataset.

Parameters:

Name Type Description Default
data PreScanDatasetIn

Link to the dataset to scan

required

Returns:

Name Type Description
dict

A dictionary containing the scan results.

Source code in geostudio/backends/v2/gtune/client.py
def pre_scan_dataset(self, data: PreScanDatasetIn, output: str = "json"):
    """
    Scans a new dataset - checks accessibility of the dataset URL, ensures corresponding data and label files are present, and extracts bands and their descriptions from the dataset.

    Args:
        data (PreScanDatasetIn): Link to the dataset to scan

    Returns:
        dict: A dictionary containing the scan results.
    """
    payload = json.loads(PreScanDatasetIn(**data).model_dump_json())
    response = self.http_post(f"{self.api_version}/datasets/pre-scan", data=payload, output=output)
    return response

get_sample_images

get_sample_images(dataset_id: str, output: str = 'json')

Retrieves a sample of images from a specified dataset.

Parameters:

Name Type Description Default
dataset_id str

The unique identifier of the dataset.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary containing the sample data in the requested format.

Source code in geostudio/backends/v2/gtune/client.py
def get_sample_images(self, dataset_id: str, output: str = "json"):
    """
    Retrieves a sample of images from a specified dataset.

    Args:
        dataset_id (str): The unique identifier of the dataset.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: A dictionary containing the sample data in the requested format.
    """
    response = self.http_get(f"{self.api_version}/datasets/{dataset_id}/sample", output=output)
    return response

update_dataset

update_dataset(
    dataset_id: str,
    data: DatasetUpdateIn,
    output: str = "json",
)

Update a dataset metadata in the database

Parameters:

Name Type Description Default
dataset_id str

The unique identifier of the dataset to be updated.

required
data DatasetUpdateIn

A dictionary containing the data to update for the dataset.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary of the updated dataset.

Source code in geostudio/backends/v2/gtune/client.py
def update_dataset(self, dataset_id: str, data: DatasetUpdateIn, output: str = "json"):
    """
    Update a dataset metadata in the database

    Args:
        dataset_id (str): The unique identifier of the dataset to be updated.
        data (DatasetUpdateIn): A dictionary containing the data to update for the dataset.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: A dictionary of the updated dataset.
    """
    payload = json.loads(DatasetUpdateIn(**data).model_dump_json())
    response = self.http_patch(f"{self.api_version}/datasets/{dataset_id}", data=payload, output=output)
    return response

get_dataset

get_dataset(dataset_id: str, output: str = 'json')

Retrieves a dataset from the studio.

Parameters:

Name Type Description Default
dataset_id str

The unique identifier of the dataset to retrieve.

required
output str

The format of the response. Default is "json".

'json'

Returns:

Name Type Description
dict

Information about the dataset found.

Source code in geostudio/backends/v2/gtune/client.py
def get_dataset(self, dataset_id: str, output: str = "json"):
    """
    Retrieves a dataset from the studio.

    Parameters:
        dataset_id (str): The unique identifier of the dataset to retrieve.
        output (str, optional): The format of the response. Default is "json".

    Returns:
        dict: Information about the dataset found.
    """
    response = self.http_get(f"{self.api_version}/datasets/{dataset_id}", output=output)
    return response

delete_dataset

delete_dataset(dataset_id: str, output: str = 'json')

Deletes a dataset with the given ID.

Parameters:

Name Type Description Default
dataset_id str

The ID of the dataset to delete.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary with a message after dataset is deleted

Source code in geostudio/backends/v2/gtune/client.py
def delete_dataset(self, dataset_id: str, output: str = "json"):
    """
    Deletes a dataset with the given ID.

    Args:
        dataset_id (str): The ID of the dataset to delete.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: A dictionary with a message after dataset is deleted
    """
    response = self.http_delete(f"{self.api_version}/datasets/{dataset_id}", output=output)
    return response

onboard_dataset

onboard_dataset(
    data: DatasetOnboardIn, output: str = "json"
)

Onboards a new dataset to the Geospatial studio.

Parameters:

Name Type Description Default
data DatasetOnboardIn

The dataset information to be onboarded.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary containing information about the onboarded dataset.

Source code in geostudio/backends/v2/gtune/client.py
def onboard_dataset(self, data: DatasetOnboardIn, output: str = "json"):
    """
    Onboards a new dataset to the Geospatial studio.

    Args:
        data (DatasetOnboardIn): The dataset information to be onboarded.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: A dictionary containing information about the onboarded dataset.
    """
    payload = json.loads(DatasetOnboardIn(**data).model_dump_json())
    response = self.http_post(f"{self.api_version}/datasets/onboard", data=payload, output=output)
    return response

list_base_models

list_base_models(output: str = 'json')

Lists all available base foundation models.

Parameters:

Name Type Description Default
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary containing a list of base foundation models available in the studio

Source code in geostudio/backends/v2/gtune/client.py
def list_base_models(self, output: str = "json"):
    """
    Lists all available base foundation models.

    Parameters:
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: A dictionary containing a list of base foundation models available in the studio
    """
    response = self.http_get(f"{self.api_version}/base-models", output=output, data_field="results")
    return response

create_base_model

create_base_model(data: BaseModelsIn, output: str = 'json')

Create a base foundation model in the Studio.

Parameters:

Name Type Description Default
output str

The format of the response. Defaults to "json".

'json'
data BaseModelsIn

Parameters for creating the base model.

required

Returns:

Name Type Description
dict

A dictionary containing a list of base foundation models available in the studio

Source code in geostudio/backends/v2/gtune/client.py
def create_base_model(self, data: BaseModelsIn, output: str = "json"):
    """
    Create a base foundation model in the Studio.

    Parameters:
        output (str, optional): The format of the response. Defaults to "json".
        data (BaseModelsIn): Parameters for creating the base model.

    Returns:
        dict: A dictionary containing a list of base foundation models available in the studio
    """
    response = self.http_post(f"{self.api_version}/base-models", data=data, output=output, data_field="results")
    return response

get_base_model

get_base_model(base_id: str, output: str = 'json')

Get base foundation model by id.

Parameters:

Name Type Description Default
base_id str

Base model ID

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The Found base model

Source code in geostudio/backends/v2/gtune/client.py
def get_base_model(self, base_id: str, output: str = "json"):
    """
    Get base foundation model by id.

    Parameters:
        base_id (str): Base model ID
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: The Found base model
    """
    response = self.http_get(f"{self.api_version}/base-models/{base_id}", output=output, data_field="results")
    return response

update_base_model_params

update_base_model_params(
    base_id: str,
    data: BaseModelParamsIn,
    output: str = "json",
)

Update base foundation model params.

Parameters:

Name Type Description Default
base_id str

Base model ID.

required
data BaseModelParamsIn

Base model params to update.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

Updates Base model params

Source code in geostudio/backends/v2/gtune/client.py
def update_base_model_params(self, base_id: str, data: BaseModelParamsIn, output: str = "json"):
    """
    Update base foundation model params.

    Parameters:
        base_id (str): Base model ID.
        data (BaseModelParamsIn): Base model params to update.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: Updates Base model params
    """
    response = self.http_patch(
        f"{self.api_version}/base-models/{base_id}/model-params", data=data, output=output, data_field="results"
    )
    return response

poll_onboard_dataset_until_finished

poll_onboard_dataset_until_finished(
    dataset_id, poll_frequency=10
)

Polls the status of an onboard dataset until it finishes processing. Defaults to a minimum of 5seconds poll frequency.

Parameters:

Name Type Description Default
dataset_id str

The unique identifier of the dataset being onboarded.

required
poll_frequency int

The time interval in seconds between polls. Defaults to 5 seconds.

10

Returns:

Name Type Description
dict

The final status of the dataset, either "Succeeded" or "Failed".

Source code in geostudio/backends/v2/gtune/client.py
def poll_onboard_dataset_until_finished(self, dataset_id, poll_frequency=10):
    """
    Polls the status of an onboard dataset until it finishes processing.
    Defaults to a minimum of 5seconds poll frequency.

    Args:
        dataset_id (str): The unique identifier of the dataset being onboarded.
        poll_frequency (int, optional): The time interval in seconds between polls. Defaults to 5 seconds.

    Returns:
        dict: The final status of the dataset, either "Succeeded" or "Failed".
    """
    # Default to a minimum of 10 seconds poll frequency.
    poll_frequency = 10 if poll_frequency < 10 else poll_frequency
    finished = False

    while finished is False:
        r = self.get_dataset(dataset_id)
        status = r["status"]
        time_taken = (
            datetime.now(timezone.utc)
            - datetime.strptime(r["created_at"], "%Y-%m-%dT%H:%M:%S.%f%z").replace(tzinfo=timezone.utc)
        ).seconds

        if status == "Succeeded":
            print(status + " - " + str(time_taken) + " seconds")
            finished = True
            return r

        elif status == "Failed":
            print(status + " - " + str(time_taken) + " seconds")
            finished = True
            return r

        else:
            print(status + " - " + str(time_taken) + " seconds", end="\r")

        sleep(poll_frequency)

poll_finetuning_until_finished

poll_finetuning_until_finished(tune_id, poll_frequency=10)

Polls the status of a tune until it finishes or fails.

Parameters:

Name Type Description Default
tune_id str

The unique identifier of the tune to poll.

required
poll_frequency int

The time interval in seconds between polls. Defaults to 5 seconds.

10

Returns:

Name Type Description
dict

The final status of the tune, including details such as the number of epochs and any error messages if the tune failed.

Source code in geostudio/backends/v2/gtune/client.py
def poll_finetuning_until_finished(self, tune_id, poll_frequency=10):
    """
    Polls the status of a tune until it finishes or fails.

    Args:
        tune_id (str): The unique identifier of the tune to poll.
        poll_frequency (int, optional): The time interval in seconds between polls. Defaults to 5 seconds.

    Returns:
        dict: The final status of the tune, including details such as the number of epochs and any error messages if the tune failed.
    """
    # Default to a minimum of 10 seconds poll frequency.
    poll_frequency = 10 if poll_frequency < 10 else poll_frequency
    finished = False

    while finished is False:
        r = self.get_tune(tune_id)
        status = r["status"]
        time_taken = (
            datetime.now(timezone.utc)
            - datetime.strptime(r["created_at"], "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc)
        ).seconds

        try:
            m = self.get_tune_metrics(tune_id)
            m_epochs = m.get("epochs")
        except Exception:
            m_epochs = "Unknown"

        if status == "Finished":
            print(status + " - Epoch: " + str(m_epochs) + " - " + str(time_taken) + " seconds")
            finished = True
            return r

        elif status == "Failed":
            print(status + " - Epoch: " + str(m_epochs) + " - " + str(time_taken) + " seconds")
            print("Download the logs from the link below:")
            print(r["logs_presigned_url"])
            finished = True
            return r

        else:
            print(status + " - Epoch: " + str(m_epochs) + " - " + str(time_taken) + " seconds", end="\r")

        sleep(poll_frequency)