Skip to content

Inference

geostudio.backends.v2.ginference.client

Client

Client(
    api_config: GeoFmSettings = None,
    session: Session = None,
    api_token: str = None,
    api_key: str = None,
    api_key_file: str = None,
    geostudio_config_file: str = None,
    *args,
    **kwargs
)

Bases: BaseClient

A client for interacting with the Geospatial Studio inference API endpoints

Source code in geostudio/backends/base_client.py
def __init__(
    self,
    api_config: GeoFmSettings = None,
    session: requests.Session = None,
    api_token: str = None,
    api_key: str = None,
    api_key_file: str = None,
    geostudio_config_file: str = None,
    *args,
    **kwargs,
):
    """
    Initializes the GeoFmClient with the provided configuration.

    Args:
        api_config (GeoFmSettings, optional): The configuration settings for the GeoFm API. Defaults to None.
        session (requests.Session, optional): A pre-configured requests session. Defaults to None.
        api_token (str, optional): The API token for authentication. Defaults to None.
        api_key (str, optional): The API key for authentication. Defaults to None.
        api_key_file (str, optional): The path to the file containing the API key. Defaults to None.
        geostudio_config_file (str): The file path to the geostudio config path containing api_key + base_urls.
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Raises:
        GeoFMException: If no API token, API key, or API key file is provided.

    Attributes:
        api_config (GeoFmSettings): The configuration settings for the GeoFm API.
        session (requests.Session): A pre-configured requests session.
        logger (logging.Logger): The logger instance for logging messages.
    """
    self.api_config = api_config or GeoFmSettings()

    if api_token:
        print("Using api_token")
        api_token = api_token or GeoFmSettings.GEOFM_API_TOKEN
        self.session = gfm_session(access_token=api_token)
    elif api_key:
        print("Using api_key from sdk command")
        self.session = gfm_session(api_key=api_key)
    elif api_key_file:
        if not os.path.isfile(api_key_file):
            raise GeoFMException("Config file does not exist, Please provide a valid config file.")
        print("Using api_key from file")
        self.session = gfm_session(api_key_file=api_key_file)
    elif geostudio_config_file:
        if not os.path.isfile(geostudio_config_file):
            raise GeoFMException("Config file does not exist, Please provide a valid config file.")
        print("Using api key and base urls from geostudio config file")
        geostudio_config_file_values = dotenv_values(geostudio_config_file)
        settings.BASE_GATEWAY_API_URL = geostudio_config_file_values.get("BASE_GATEWAY_API_URL", "")
        settings.BASE_STUDIO_UI_URL = geostudio_config_file_values.get("BASE_STUDIO_UI_URL", "")
        settings.GEOSTUDIO_API_KEY = geostudio_config_file_values.get("GEOSTUDIO_API_KEY", None)
        self.session = gfm_session(api_key=settings.GEOSTUDIO_API_KEY)
    else:
        raise GeoFMException("Missing APIToken. Add `GEOFM_API_TOKEN` to env variables.")

    # else:
    #     self.session = session or gfm_session(
    #         client_id=self.api_config.ISV_CLIENT_ID,
    #         client_secret=self.api_config.ISV_CLIENT_SECRET,
    #         well_known_url=self.api_config.ISV_WELL_KNOWN,
    #         userinfo_endpoint=self.api_config.ISV_USER_ENDPOINT,
    #     )
    self.logger = logging.getLogger()

create_model

create_model(data: ModelCreateInput, output: str = 'json')

Creates a new model using the provided data.

Parameters:

Name Type Description Default
data ModelCreateInput`

The input data required to create a new model.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The response containing the created model Metadata.

Source code in geostudio/backends/v2/ginference/client.py
def create_model(self, data: ModelCreateInput, output: str = "json"):
    """
    Creates a new model using the provided data.

    Args:
        data (ModelCreateInput`): The input data required to create a new model.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The response containing the created model Metadata.
    """
    payload = json.loads(ModelCreateInput(**data).model_dump_json())
    response = self.http_post(f"{self.api_version}/models", data=payload, output=output, data_field="results")
    return response

list_models

list_models(output: str = 'json')

Lists all available models.

Parameters:

Name Type Description Default
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary containing the list of models.

Source code in geostudio/backends/v2/ginference/client.py
def list_models(self, output: str = "json"):
    """
    Lists all available models.

    Args:
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: A dictionary containing the list of models.
    """
    response = self.http_get(
        endpoint=f"{self.api_version}/models?limit=1000&skip=0", output=output, data_field="results"
    )
    return response

update_model

update_model(
    model_id: UUID,
    data: ModelUpdateInput,
    output: str = "json",
)

Updates metadata of a specified model.

Parameters:

Name Type Description Default
model_id UUID

The unique identifier of the model to be updated.

required
data dict

A dictionary containing the new metadata for the model.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The response from the server containing the updated metadata.

Source code in geostudio/backends/v2/ginference/client.py
def update_model(self, model_id: UUID, data: ModelUpdateInput, output: str = "json"):
    """
    Updates metadata of a specified model.

    Args:
        model_id (UUID): The unique identifier of the model to be updated.
        data (dict): A dictionary containing the new metadata for the model.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: The response from the server containing the updated metadata.
    """
    payload = json.loads(ModelUpdateInput(**data).model_dump_json())
    response = self.http_patch(
        f"{self.api_version}/models/{model_id}", data=payload, output=output, data_field="results"
    )
    return response

deploy_model

deploy_model(
    model_id: str,
    data: ModelOnboardingInputSchema,
    output="json",
)

Deploys a model

Parameters:

Name Type Description Default
model_id str

The unique identifier of the model to be deployed

required
data ModelOnboardingInputSchema

Urls to the model checkpoint and configs

required
Source code in geostudio/backends/v2/ginference/client.py
def deploy_model(self, model_id: str, data: ModelOnboardingInputSchema, output="json"):
    """
    Deploys a model

    Args:
        model_id (str): The unique identifier of the model to be deployed
        data (ModelOnboardingInputSchema): Urls to the model checkpoint and configs

    """
    payload = json.loads(ModelOnboardingInputSchema(**data).model_dump_json())
    response = self.http_post(
        f"{self.api_version}/models/{model_id}/deploy", data=payload, output=output, data_field="results"
    )
    return response

get_model

get_model(model_id: UUID, output: str = 'json')

Retrieves a model's information using its ID.

Parameters:

Name Type Description Default
model_id UUID

The unique identifier of the model to retrieve.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The model's status and information

Source code in geostudio/backends/v2/ginference/client.py
def get_model(self, model_id: UUID, output: str = "json"):
    """
    Retrieves a model's information using its ID.

    Parameters:
        model_id (UUID): The unique identifier of the model to retrieve.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The model's status and information
    """
    response = self.http_get(f"{self.api_version}/models/{model_id}", output=output, data_field="results")
    return response

delete_model

delete_model(model_id: str, output: str = 'json')

Deletes a specified model using its ID.

Parameters:

Name Type Description Default
model_id str

The ID of the model to be deleted.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The response from the server after deleting the model.

Source code in geostudio/backends/v2/ginference/client.py
def delete_model(self, model_id: str, output: str = "json"):
    """
    Deletes a specified model using its ID.

    Args:
        model_id (str): The ID of the model to be deleted.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: The response from the server after deleting the model.
    """
    response = self.http_delete(f"{self.api_version}/models/{model_id}", output=output)
    return response

submit_inference

submit_inference(
    data: InferenceCreateInput, output: str = "json"
)

Submits an inference task to the server.

Parameters:

Name Type Description Default
data InferenceCreateInput

The input data for the inference task.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The server's response containing the results of the inference task.

Source code in geostudio/backends/v2/ginference/client.py
def submit_inference(self, data: InferenceCreateInput, output: str = "json"):
    """
    Submits an inference task to the server.

    Args:
        data (InferenceCreateInput): The input data for the inference task.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The server's response containing the results of the inference task.
    """
    payload = json.loads(InferenceCreateInput(**data).model_dump_json())
    response = self.http_post(f"{self.api_version}/inference", data=payload, output=output, data_field="results")
    return response

list_inferences

list_inferences(output: str = 'json')

Lists inferences submitted to the Studio. Limit to most recent 10.

Parameters:

Name Type Description Default
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary containing a list of inference tasks submitted to the studio

Source code in geostudio/backends/v2/ginference/client.py
def list_inferences(self, output: str = "json"):
    """
    Lists inferences submitted to the Studio. Limit to most recent 10.

    Args:
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: A dictionary containing a list of inference tasks submitted to the studio
    """
    response = self.http_get(f"{self.api_version}/inference?limit=10&skip=0", output=output, data_field="results")
    return response

get_inference

get_inference(inference_id: UUID, output: str = 'json')

Retrieves the inference with the given inference ID.

Parameters:

Name Type Description Default
inference_id UUID

The unique identifier of the inference task.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The inference task data in the specified output format.

Source code in geostudio/backends/v2/ginference/client.py
def get_inference(self, inference_id: UUID, output: str = "json"):
    """
    Retrieves the inference with the given inference ID.

    Args:
        inference_id (uuid.UUID): The unique identifier of the inference task.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The inference task data in the specified output format.
    """
    response = self.http_get(f"{self.api_version}/inference/{inference_id}", output=output, data_field="results")
    return response

delete_inference

delete_inference(inference_id: UUID, output: str = 'json')

Deletes an inference using its ID.

Parameters:

Name Type Description Default
inference_id UUID

The ID of the inference to be deleted.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The response from the server after deleting the inference.

Source code in geostudio/backends/v2/ginference/client.py
def delete_inference(self, inference_id: UUID, output: str = "json"):
    """
    Deletes an inference using its ID.

    Args:
        inference_id (uuid.UUID): The ID of the inference to be deleted.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: The response from the server after deleting the inference.
    """
    response = self.http_delete(f"{self.api_version}/inference/{inference_id}", output=output)
    return response

get_inference_tasks

get_inference_tasks(
    inference_id: UUID, output: str = "json"
)

Retrieves the tasks associated with an inference.

Parameters:

Name Type Description Default
inference_id UUID

The unique identifier of the inference.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The inference task data in the specified output format.

Source code in geostudio/backends/v2/ginference/client.py
def get_inference_tasks(self, inference_id: UUID, output: str = "json"):
    """
    Retrieves the tasks associated with an inference.

    Args:
        inference_id (uuid.UUID): The unique identifier of the inference.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The inference task data in the specified output format.
    """
    response = self.http_get(
        f"{self.api_version}/inference/{inference_id}/tasks", output=output, data_field="results"
    )
    return response

get_task_output_url

get_task_output_url(task_id: UUID, output: str = 'json')

Retrieves the output url for a specific inference task.

Parameters:

Name Type Description Default
task_id UUID

The unique identifier of the task.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The inference task data in the specified output format.

Source code in geostudio/backends/v2/ginference/client.py
def get_task_output_url(self, task_id: UUID, output: str = "json"):
    """
    Retrieves the output url for a specific inference task.

    Args:
        task_id (UUID): The unique identifier of the task.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The inference task data in the specified output format.
    """
    response = self.http_get(f"{self.api_version}/tasks/{task_id}/output", output=output, data_field="results")
    return response

get_task_step_logs

get_task_step_logs(
    task_id: UUID, step_id: str, output: str = "json"
)

Retrieves the logs for a specific step of an inference task.

Parameters:

Name Type Description Default
task_id UUID

The unique identifier of the task.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The inference task data in the specified output format.

Source code in geostudio/backends/v2/ginference/client.py
def get_task_step_logs(self, task_id: UUID, step_id: str, output: str = "json"):
    """
    Retrieves the logs for a specific step of an inference task.

    Args:
        task_id (UUID): The unique identifier of the task.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The inference task data in the specified output format.
    """
    response = self.http_get(
        f"{self.api_version}/tasks/{task_id}/logs/{step_id}", output=output, data_field="results"
    )
    return response

check_data_availability

check_data_availability(
    datasource: str,
    data: DataAdvisorIn,
    output: str = "json",
)

Query data-advisor service to check data availability before running an inference.

Parameters:

Name Type Description Default
data dict

A dictionary containing the necessary parameters for the data availability check.

required
output str

The desired output format. Default is "json".

'json'

Returns:

Name Type Description
dict

The response from the server containing the data availability information.

Source code in geostudio/backends/v2/ginference/client.py
def check_data_availability(self, datasource: str, data: DataAdvisorIn, output: str = "json"):
    """
    Query data-advisor service to check data availability before running an inference.

    Args:
        data (dict): A dictionary containing the necessary parameters for the data availability check.
        output (str, optional): The desired output format. Default is "json".

    Returns:
        dict: The response from the server containing the data availability information.
    """
    payload = json.loads(DataAdvisorIn(**data).model_dump_json())
    response = self.http_post(
        f"{self.api_version}/data-advice/{datasource}", data=payload, output=output, data_field="results"
    )
    return response

list_datasource_collections

list_datasource_collections(
    datasource: str, output: str = "json"
)

Query data-advisor to list collections available for a specific data source

Source code in geostudio/backends/v2/ginference/client.py
def list_datasource_collections(self, datasource: str, output: str = "json"):
    """
    Query data-advisor to list collections available for a specific data source
    """
    response = self.http_get(f"{self.api_version}/data-advice/{datasource}", output=output, data_field="results")
    return response

list_datasource

list_datasource(
    connector: str = None,
    collection: str = None,
    limit: int = 25,
    skip: int = 0,
    output: str = "json",
)

Lists all data sources available in the studio.

Parameters:

Name Type Description Default
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary containing a list of data sources available through the studio

Source code in geostudio/backends/v2/ginference/client.py
def list_datasource(
    self, connector: str = None, collection: str = None, limit: int = 25, skip: int = 0, output: str = "json"
):
    """
    Lists all data sources available in the studio.

    Args:
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: A dictionary containing a list of data sources available through the studio
    """
    url = [f"limit={limit}", f"skip={skip}"]

    if connector:
        url.append(f"connector={connector}")
    if collection:
        url.append(f"collection={collection}")
    response = self.http_get(
        f"{self.api_version}/data-sources?{'&'.join(url)}", output=output, data_field="results"
    )
    return response

get_datasource

get_datasource(datasource_id: UUID, output: str = 'json')

Retrieves a specific data source's information.

Parameters:

Name Type Description Default
datasource_id UUID

The unique identifier of the data source to retrieve.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The response from the server containing the data source details.

Source code in geostudio/backends/v2/ginference/client.py
def get_datasource(self, datasource_id: UUID, output: str = "json"):
    """
    Retrieves a specific data source's information.

    Args:
        datasource_id (UUID): The unique identifier of the data source to retrieve.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: The response from the server containing the data source details.
    """
    results = self.list_datasource()["results"]
    data_source = list(filter(lambda x: x["id"] == datasource_id, results))
    return data_source
get_fileshare_links(object_name: str)

Generate presigned urls for sharing files i.e uploading and downloading files.

Parameters:

Name Type Description Default
object_name str

The name of the object (file) for which to generate upload links.

required

Returns:

Name Type Description
dict

A dictionary containing the upload links.

Source code in geostudio/backends/v2/ginference/client.py
def get_fileshare_links(self, object_name: str):
    """
    Generate presigned urls for sharing files i.e uploading and downloading files.

    Args:
        object_name (str): The name of the object (file) for which to generate upload links.

    Returns:
        dict: A dictionary containing the upload links.
    """
    print("Going to generate the upload link")
    response = self.http_get(
        f"{self.api_version}/file-share?object_name={object_name}", output="json", data_field="results"
    )
    return response

upload_file_to_url

upload_file_to_url(upload_url: str, filepath: str)

Uploads a file to a specified URL using a PUT request with a rich progress bar.

Parameters:

Name Type Description Default
upload_url str

The URL to which the file will be uploaded.

required
filepath str

The path to the file that will be uploaded.

required

Returns:

Type Description

requests.Response: The response from the server after the file upload.

Source code in geostudio/backends/v2/ginference/client.py
def upload_file_to_url(self, upload_url: str, filepath: str):
    """
    Uploads a file to a specified URL using a PUT request with a rich progress bar.

    Args:
        upload_url (str): The URL to which the file will be uploaded.
        filepath (str): The path to the file that will be uploaded.

    Returns:
        requests.Response: The response from the server after the file upload.
    """

    print("Going to upload the file to the url.")

    path = Path(filepath)
    total_size = path.stat().st_size
    filename = path.name
    content_type, _ = mimetypes.guess_type(filename)
    headers = {"Content-Type": content_type or "application/octet-stream", "Content-Length": str(total_size)}

    with Progress(
        TextColumn("[bold black]{task.description}"),
        BarColumn(),
        DownloadColumn(),
        TransferSpeedColumn(),
        TimeRemainingColumn(),
    ) as progress:
        task_id = progress.add_task(description=filename, total=total_size)

        with open(filepath, "rb") as f:
            wrapped_file = RichProgressWrapper(f, progress, task_id)
            response = requests.put(upload_url, data=wrapped_file, headers=headers, verify=False)
            return response

upload_file

upload_file(filename: str)

Streamlines :py:meth:get_upload_links and :py:meth:upload_file_to_url. Uploads a file to a specified location using the provided upload links.

Source code in geostudio/backends/v2/ginference/client.py
def upload_file(self, filename: str):
    """
    Streamlines :py:meth:`get_upload_links` and :py:meth:`upload_file_to_url`.
    Uploads a file to a specified location using the provided upload links.
    """
    links = self.get_fileshare_links(object_name=filename.split("/")[-1])
    # print(links)
    upload_url = links.get("upload_url", None)
    if upload_url:
        response = self.upload_file_to_url(links["upload_url"], filename)
        if response.status_code == 200:
            print(f"\n[bold green]✓ Verified:[/bold green] {filename} uploaded.")
        else:
            print(f"\n[bold red]✗ Failed:[/bold red] Status {response.status_code}")
    return links

poll_inference_until_finished

poll_inference_until_finished(
    inference_id, poll_frequency=10
)

Polls the status of an inference task until it is completed or failed. Defaults to a minimum of 5seconds poll frequency.

Parameters:

Name Type Description Default
inference_id str

The unique identifier of the inference task.

required
poll_frequency int

The time interval in seconds between polls. Defaults to 5 seconds.

10

Returns:

Name Type Description
dict

The response from the inference task when it is completed or failed.

Source code in geostudio/backends/v2/ginference/client.py
def poll_inference_until_finished(self, inference_id, poll_frequency=10):
    """
    Polls the status of an inference task until it is completed or failed.
    Defaults to a minimum of 5seconds poll frequency.

    Args:
        inference_id (str): The unique identifier of the inference task.
        poll_frequency (int, optional): The time interval in seconds between polls. Defaults to 5 seconds.

    Returns:
        dict: The response from the inference task when it is completed or failed.
    """
    poll_frequency = 10 if poll_frequency < 10 else poll_frequency
    finished = False

    while finished is False:
        r = self.get_inference(inference_id)
        status = r["status"]
        time_taken = (
            datetime.now(timezone.utc)
            - datetime.strptime(r["created_at"], "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc)
        ).seconds

        if "COMPLETED" in status:
            print(status + " - " + str(time_taken) + " seconds")
            finished = True
            return r

        elif status == "FAILED":
            print(status + " - " + str(time_taken) + " seconds")
            finished = True
            return r

        elif status == "STOPPED":
            print(status + " - " + str(time_taken) + " seconds")
            finished = True
            return r

        else:
            print(status + " - " + str(time_taken) + " seconds", end="\r")

        sleep(poll_frequency)