Skip to content

Inference

geostudio.backends.v2.ginference.client

Client

Client(
    api_config: GeoFmSettings = None,
    session: Session = None,
    api_token: str = None,
    api_key: str = None,
    api_key_file: str = None,
    geostudio_config_file: str = None,
    *args,
    **kwargs
)

Bases: BaseClient

A client for interacting with the Geospatial Studio inference API endpoints

Source code in geostudio/backends/base_client.py
def __init__(
    self,
    api_config: GeoFmSettings = None,
    session: requests.Session = None,
    api_token: str = None,
    api_key: str = None,
    api_key_file: str = None,
    geostudio_config_file: str = None,
    *args,
    **kwargs,
):
    """
    Initializes the GeoFmClient with the provided configuration.

    Args:
        api_config (GeoFmSettings, optional): The configuration settings for the GeoFm API. Defaults to None.
        session (requests.Session, optional): A pre-configured requests session. Defaults to None.
        api_token (str, optional): The API token for authentication. Defaults to None.
        api_key (str, optional): The API key for authentication. Defaults to None.
        api_key_file (str, optional): The path to the file containing the API key. Defaults to None.
        geostudio_config_file (str): The file path to the geostudio config path containing api_key + base_urls.
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Raises:
        GeoFMException: If no API token, API key, or API key file is provided.

    Attributes:
        api_config (GeoFmSettings): The configuration settings for the GeoFm API.
        session (requests.Session): A pre-configured requests session.
        logger (logging.Logger): The logger instance for logging messages.
    """
    self.api_config = api_config or GeoFmSettings()

    if api_token:
        print("Using api_token")
        api_token = api_token or GeoFmSettings.GEOFM_API_TOKEN
        self.session = gfm_session(access_token=api_token)
    elif api_key:
        print("Using api_key from sdk command")
        self.session = gfm_session(api_key=api_key)
    elif api_key_file:
        if not os.path.isfile(api_key_file):
            raise GeoFMException("Config file does not exist, Please provide a valid config file.")
        print("Using api_key from file")
        self.session = gfm_session(api_key_file=api_key_file)
    elif geostudio_config_file:
        if not os.path.isfile(geostudio_config_file):
            raise GeoFMException("Config file does not exist, Please provide a valid config file.")
        print("Using api key and base urls from geostudio config file")
        geostudio_config_file_values = dotenv_values(geostudio_config_file)
        settings.BASE_GATEWAY_API_URL = geostudio_config_file_values.get("BASE_GATEWAY_API_URL", "")
        settings.BASE_STUDIO_UI_URL = geostudio_config_file_values.get("BASE_STUDIO_UI_URL", "")
        settings.GEOSTUDIO_API_KEY = geostudio_config_file_values.get("GEOSTUDIO_API_KEY", None)
        self.session = gfm_session(api_key=settings.GEOSTUDIO_API_KEY)
    else:
        raise GeoFMException("Missing APIToken. Add `GEOFM_API_TOKEN` to env variables.")

    # else:
    #     self.session = session or gfm_session(
    #         client_id=self.api_config.ISV_CLIENT_ID,
    #         client_secret=self.api_config.ISV_CLIENT_SECRET,
    #         well_known_url=self.api_config.ISV_WELL_KNOWN,
    #         userinfo_endpoint=self.api_config.ISV_USER_ENDPOINT,
    #     )
    self.logger = logging.getLogger()

create_model

create_model(data: ModelCreateInput, output: str = 'json')

Creates a new model using the provided data.

Parameters:

Name Type Description Default
data ModelCreateInput`

The input data required to create a new model.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The response containing the created model Metadata.

Source code in geostudio/backends/v2/ginference/client.py
def create_model(self, data: ModelCreateInput, output: str = "json"):
    """
    Creates a new model using the provided data.

    Args:
        data (ModelCreateInput`): The input data required to create a new model.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The response containing the created model Metadata.
    """
    payload = json.loads(ModelCreateInput(**data).model_dump_json())
    response = self.http_post(f"{self.api_version}/models", data=payload, output=output, data_field="results")
    return response

list_models

list_models(output: str = 'json')

Lists all available models.

Parameters:

Name Type Description Default
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary containing the list of models.

Source code in geostudio/backends/v2/ginference/client.py
def list_models(self, output: str = "json"):
    """
    Lists all available models.

    Args:
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: A dictionary containing the list of models.
    """
    response = self.http_get(
        endpoint=f"{self.api_version}/models?limit=1000&skip=0", output=output, data_field="results"
    )
    return response

update_model

update_model(
    model_id: UUID,
    data: ModelUpdateInput,
    output: str = "json",
)

Updates metadata of a specified model.

Parameters:

Name Type Description Default
model_id UUID

The unique identifier of the model to be updated.

required
data dict

A dictionary containing the new metadata for the model.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The response from the server containing the updated metadata.

Source code in geostudio/backends/v2/ginference/client.py
def update_model(self, model_id: UUID, data: ModelUpdateInput, output: str = "json"):
    """
    Updates metadata of a specified model.

    Args:
        model_id (UUID): The unique identifier of the model to be updated.
        data (dict): A dictionary containing the new metadata for the model.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: The response from the server containing the updated metadata.
    """
    payload = json.loads(ModelUpdateInput(**data).model_dump_json())
    response = self.http_patch(
        f"{self.api_version}/models/{model_id}", data=payload, output=output, data_field="results"
    )
    return response

deploy_model

deploy_model(
    model_id: str,
    data: ModelOnboardingInputSchema,
    output="json",
)

Deploys a model

Parameters:

Name Type Description Default
model_id str

The unique identifier of the model to be deployed

required
data ModelOnboardingInputSchema

Urls to the model checkpoint and configs

required
Source code in geostudio/backends/v2/ginference/client.py
def deploy_model(self, model_id: str, data: ModelOnboardingInputSchema, output="json"):
    """
    Deploys a model

    Args:
        model_id (str): The unique identifier of the model to be deployed
        data (ModelOnboardingInputSchema): Urls to the model checkpoint and configs

    """
    payload = json.loads(ModelOnboardingInputSchema(**data).model_dump_json())
    response = self.http_post(
        f"{self.api_version}/models/{model_id}/deploy", data=payload, output=output, data_field="results"
    )
    return response

get_model

get_model(model_id: UUID, output: str = 'json')

Retrieves a model's information using its ID.

Parameters:

Name Type Description Default
model_id UUID

The unique identifier of the model to retrieve.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The model's status and information

Source code in geostudio/backends/v2/ginference/client.py
def get_model(self, model_id: UUID, output: str = "json"):
    """
    Retrieves a model's information using its ID.

    Parameters:
        model_id (UUID): The unique identifier of the model to retrieve.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The model's status and information
    """
    response = self.http_get(f"{self.api_version}/models/{model_id}", output=output, data_field="results")
    return response

delete_model

delete_model(model_id: str, output: str = 'json')

Deletes a specified model using its ID.

Parameters:

Name Type Description Default
model_id str

The ID of the model to be deleted.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The response from the server after deleting the model.

Source code in geostudio/backends/v2/ginference/client.py
def delete_model(self, model_id: str, output: str = "json"):
    """
    Deletes a specified model using its ID.

    Args:
        model_id (str): The ID of the model to be deleted.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: The response from the server after deleting the model.
    """
    response = self.http_delete(f"{self.api_version}/models/{model_id}", output=output)
    return response

submit_inference

submit_inference(
    data: InferenceCreateInput, output: str = "json"
)

Submits an inference task to the server.

Parameters:

Name Type Description Default
data InferenceCreateInput

The input data for the inference task.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The server's response containing the results of the inference task.

Source code in geostudio/backends/v2/ginference/client.py
def submit_inference(self, data: InferenceCreateInput, output: str = "json"):
    """
    Submits an inference task to the server.

    Args:
        data (InferenceCreateInput): The input data for the inference task.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The server's response containing the results of the inference task.
    """
    payload = json.loads(InferenceCreateInput(**data).model_dump_json())
    response = self.http_post(f"{self.api_version}/inference", data=payload, output=output, data_field="results")
    return response

list_inferences

list_inferences(output: str = 'json')

Lists inferences submitted to the Studio. Limit to most recent 10.

Parameters:

Name Type Description Default
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary containing a list of inference tasks submitted to the studio

Source code in geostudio/backends/v2/ginference/client.py
def list_inferences(self, output: str = "json"):
    """
    Lists inferences submitted to the Studio. Limit to most recent 10.

    Args:
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: A dictionary containing a list of inference tasks submitted to the studio
    """
    response = self.http_get(f"{self.api_version}/inference?limit=10&skip=0", output=output, data_field="results")
    return response

get_inference

get_inference(inference_id: UUID, output: str = 'json')

Retrieves the inference with the given inference ID.

Parameters:

Name Type Description Default
inference_id UUID

The unique identifier of the inference task.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The inference task data in the specified output format.

Source code in geostudio/backends/v2/ginference/client.py
def get_inference(self, inference_id: UUID, output: str = "json"):
    """
    Retrieves the inference with the given inference ID.

    Args:
        inference_id (uuid.UUID): The unique identifier of the inference task.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The inference task data in the specified output format.
    """
    response = self.http_get(f"{self.api_version}/inference/{inference_id}", output=output, data_field="results")
    return response

delete_inference

delete_inference(inference_id: UUID, output: str = 'json')

Deletes an inference using its ID.

Parameters:

Name Type Description Default
inference_id UUID

The ID of the inference to be deleted.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The response from the server after deleting the inference.

Source code in geostudio/backends/v2/ginference/client.py
def delete_inference(self, inference_id: UUID, output: str = "json"):
    """
    Deletes an inference using its ID.

    Args:
        inference_id (uuid.UUID): The ID of the inference to be deleted.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: The response from the server after deleting the inference.
    """
    response = self.http_delete(f"{self.api_version}/inference/{inference_id}", output=output)
    return response

get_inference_tasks

get_inference_tasks(
    inference_id: UUID, output: str = "json"
)

Retrieves the tasks associated with an inference.

Parameters:

Name Type Description Default
inference_id UUID

The unique identifier of the inference.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The inference task data in the specified output format.

Source code in geostudio/backends/v2/ginference/client.py
def get_inference_tasks(self, inference_id: UUID, output: str = "json"):
    """
    Retrieves the tasks associated with an inference.

    Args:
        inference_id (uuid.UUID): The unique identifier of the inference.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The inference task data in the specified output format.
    """
    response = self.http_get(
        f"{self.api_version}/inference/{inference_id}/tasks", output=output, data_field="results"
    )
    return response

get_task_output_url

get_task_output_url(task_id: UUID, output: str = 'json')

Retrieves the output url for a specific inference task.

Parameters:

Name Type Description Default
task_id UUID

The unique identifier of the task.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The inference task data in the specified output format.

Source code in geostudio/backends/v2/ginference/client.py
def get_task_output_url(self, task_id: UUID, output: str = "json"):
    """
    Retrieves the output url for a specific inference task.

    Args:
        task_id (UUID): The unique identifier of the task.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The inference task data in the specified output format.
    """
    response = self.http_get(f"{self.api_version}/tasks/{task_id}/output", output=output, data_field="results")
    return response

get_task_step_logs

get_task_step_logs(
    task_id: UUID, step_id: str, output: str = "json"
)

Retrieves the logs for a specific step of an inference task.

Parameters:

Name Type Description Default
task_id UUID

The unique identifier of the task.

required
output str

The desired output format. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The inference task data in the specified output format.

Source code in geostudio/backends/v2/ginference/client.py
def get_task_step_logs(self, task_id: UUID, step_id: str, output: str = "json"):
    """
    Retrieves the logs for a specific step of an inference task.

    Args:
        task_id (UUID): The unique identifier of the task.
        output (str, optional): The desired output format. Defaults to "json".

    Returns:
        dict: The inference task data in the specified output format.
    """
    response = self.http_get(
        f"{self.api_version}/tasks/{task_id}/logs/{step_id}", output=output, data_field="results"
    )
    return response

check_data_availability

check_data_availability(
    datasource: str,
    data: DataAdvisorIn,
    output: str = "json",
)

Query data-advisor service to check data availability before running an inference.

Parameters:

Name Type Description Default
data dict

A dictionary containing the necessary parameters for the data availability check.

required
output str

The desired output format. Default is "json".

'json'

Returns:

Name Type Description
dict

The response from the server containing the data availability information.

Source code in geostudio/backends/v2/ginference/client.py
def check_data_availability(self, datasource: str, data: DataAdvisorIn, output: str = "json"):
    """
    Query data-advisor service to check data availability before running an inference.

    Args:
        data (dict): A dictionary containing the necessary parameters for the data availability check.
        output (str, optional): The desired output format. Default is "json".

    Returns:
        dict: The response from the server containing the data availability information.
    """
    payload = json.loads(DataAdvisorIn(**data).model_dump_json())
    response = self.http_post(
        f"{self.api_version}/data-advice/{datasource}", data=payload, output=output, data_field="results"
    )
    return response

list_datasource_collections

list_datasource_collections(
    datasource: str, output: str = "json"
)

Query data-advisor to list collections available for a specific data source

Source code in geostudio/backends/v2/ginference/client.py
def list_datasource_collections(self, datasource: str, output: str = "json"):
    """
    Query data-advisor to list collections available for a specific data source
    """
    response = self.http_get(f"{self.api_version}/data-advice/{datasource}", output=output, data_field="results")
    return response

list_datasource

list_datasource(
    connector: str = None,
    collection: str = None,
    limit: int = 25,
    skip: int = 0,
    output: str = "json",
)

Lists all data sources available in the studio.

Parameters:

Name Type Description Default
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

A dictionary containing a list of data sources available through the studio

Source code in geostudio/backends/v2/ginference/client.py
def list_datasource(
    self, connector: str = None, collection: str = None, limit: int = 25, skip: int = 0, output: str = "json"
):
    """
    Lists all data sources available in the studio.

    Args:
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: A dictionary containing a list of data sources available through the studio
    """
    url = [f"limit={limit}", f"skip={skip}"]

    if connector:
        url.append(f"connector={connector}")
    if collection:
        url.append(f"collection={collection}")
    response = self.http_get(
        f"{self.api_version}/data-sources?{'&'.join(url)}", output=output, data_field="results"
    )
    return response

get_datasource

get_datasource(datasource_id: UUID, output: str = 'json')

Retrieves a specific data source's information.

Parameters:

Name Type Description Default
datasource_id UUID

The unique identifier of the data source to retrieve.

required
output str

The format of the response. Defaults to "json".

'json'

Returns:

Name Type Description
dict

The response from the server containing the data source details.

Source code in geostudio/backends/v2/ginference/client.py
def get_datasource(self, datasource_id: UUID, output: str = "json"):
    """
    Retrieves a specific data source's information.

    Args:
        datasource_id (UUID): The unique identifier of the data source to retrieve.
        output (str, optional): The format of the response. Defaults to "json".

    Returns:
        dict: The response from the server containing the data source details.
    """
    results = self.list_datasource()["results"]
    data_source = list(filter(lambda x: x["id"] == datasource_id, results))
    return data_source
get_fileshare_links(object_name: str)

Generate presigned urls for sharing files i.e uploading and downloading files.

Parameters:

Name Type Description Default
object_name str

The name of the object (file) for which to generate upload links.

required

Returns:

Name Type Description
dict

A dictionary containing the upload links.

Source code in geostudio/backends/v2/ginference/client.py
def get_fileshare_links(self, object_name: str):
    """
    Generate presigned urls for sharing files i.e uploading and downloading files.

    Args:
        object_name (str): The name of the object (file) for which to generate upload links.

    Returns:
        dict: A dictionary containing the upload links.
    """
    print("Going to generate the upload link")
    response = self.http_get(
        f"{self.api_version}/file-share?object_name={object_name}", output="json", data_field="results"
    )
    return response

upload_file_to_url

upload_file_to_url(upload_url: str, filepath: str)

Uploads a file to a specified URL using a PUT request.

Parameters:

Name Type Description Default
upload_url str

The URL to which the file will be uploaded.

required
filepath str

The path to the file that will be uploaded.

required

Returns:

Type Description

requests.Response: The response from the server after the file upload.

Source code in geostudio/backends/v2/ginference/client.py
def upload_file_to_url(self, upload_url: str, filepath: str):
    """
    Uploads a file to a specified URL using a PUT request.

    Args:
        upload_url (str): The URL to which the file will be uploaded.
        filepath (str): The path to the file that will be uploaded.

    Returns:
        requests.Response: The response from the server after the file upload.
    """

    print("Going to upload the file to the url.")

    fields = {}
    path = Path(filepath)
    total_size = path.stat().st_size
    filename = path.name

    with Progress(
        TextColumn("[bold black]{task.description}"),
        BarColumn(),
        DownloadColumn(),
        TransferSpeedColumn(),
        TimeRemainingColumn(),
    ) as progress_bar:
        task = progress_bar.add_task(filename, total=total_size)
        with open(filepath, "rb") as f:
            fields["file"] = ("filename", f)
            e = MultipartEncoder(fields=fields)
            last_bytes = 0

            def monitor_callback(monitor):
                nonlocal last_bytes
                bytes_diff = monitor.bytes_read - last_bytes
                progress_bar.update(task, advance=bytes_diff)
                last_bytes = monitor.bytes_read

            m = MultipartEncoderMonitor(e, monitor_callback)

            headers = {"Content-Type": m.content_type}
            response = requests.put(upload_url, data=m, headers=headers)
    return response

upload_file

upload_file(filename: str)

Streamlines :py:meth:get_upload_links and :py:meth:upload_file_to_url. Uploads a file to a specified location using the provided upload links.

Source code in geostudio/backends/v2/ginference/client.py
def upload_file(self, filename: str):
    """
    Streamlines :py:meth:`get_upload_links` and :py:meth:`upload_file_to_url`.
    Uploads a file to a specified location using the provided upload links.
    """
    links = self.get_fileshare_links(object_name=filename.split("/")[-1])
    # print(links)
    upload_url = links.get("upload_url", None)
    if upload_url:
        self.upload_file_to_url(links["upload_url"], filename)
    return links

create_download_presigned_url

create_download_presigned_url(
    bucket_name: str,
    object_key: str,
    endpoint_url: str,
    region_name: str,
    service_name: str,
    aws_access_key_id: str = None,
    aws_secret_access_key: str = None,
    expiration: int = 3600,
    **kwargs
)

Function to create presigned url to download object from bucket

Parameters

bucket_name : str The bucket name in the instance object_key : str Object path to pre-sign endpoint_url: str s3 Endpoint i.e https://s3.us-east.cloud-object-storage.appdomain.cloud region_name: str Region where bucket lives. i.e us-east service_name: str service to connect to i.e s3 aws_access_key_id: str AWS Access key to the instance aws_secret_access_key: str AWS secret access key to the instance expiration : int, optional Expiration duration in seconds, by default 3600

Returns

str Presigned download url

Source code in geostudio/backends/v2/ginference/client.py
def create_download_presigned_url(
    self,
    bucket_name: str,
    object_key: str,
    endpoint_url: str,
    region_name: str,
    service_name: str,
    aws_access_key_id: str = None,
    aws_secret_access_key: str = None,
    expiration: int = 3600,
    **kwargs,
):
    """Function to create presigned url to download object from bucket

    Parameters
    ----------
    bucket_name : str
        The bucket name in the instance
    object_key : str
        Object path to pre-sign
    endpoint_url: str
        s3 Endpoint i.e https://s3.us-east.cloud-object-storage.appdomain.cloud
    region_name: str
        Region where bucket lives. i.e us-east
    service_name: str
        service to connect to i.e s3
    aws_access_key_id: str
        AWS Access key to the instance
    aws_secret_access_key: str
        AWS secret access key to the instance
    expiration : int, optional
        Expiration duration in seconds, by default 3600

    Returns
    -------
    str
        Presigned download url
    """

    s3_client = boto3.client(
        service_name,
        region_name=region_name,
        endpoint_url=endpoint_url,
        aws_secret_access_key=aws_secret_access_key,
        aws_access_key_id=aws_access_key_id,
        **kwargs,
    )
    try:
        download_url = s3_client.generate_presigned_url(
            ClientMethod="get_object",
            Params={"Bucket": bucket_name, "Key": object_key},
            ExpiresIn=expiration,
        )
    except ClientError as e:
        print(f"Error creating presigned URL: {e}")
        return None

    return download_url

create_upload_presigned_url

create_upload_presigned_url(
    bucket_name: str,
    object_key: str,
    endpoint_url: str,
    region_name: str,
    service_name: str,
    aws_access_key_id: str = None,
    aws_secret_access_key: str = None,
    expiration: int = 3600,
    **kwargs
)

Function to create presigned url to upload object from bucket

Parameters

bucket_name : str The bucket name in the instance object_key : str Object path to pre-sign endpoint_url: str s3 Endpoint i.e https://s3.us-east.cloud-object-storage.appdomain.cloud region_name: str Region where bucket lives. i.e us-east service_name: str service to connect to i.e s3 aws_access_key_id: str AWS Access key to the instance aws_secret_access_key: str AWS secret access key to the instance expiration : int, optional Expiration duration in seconds, by default 3600

Returns

str Presigned upload url

Source code in geostudio/backends/v2/ginference/client.py
def create_upload_presigned_url(
    self,
    bucket_name: str,
    object_key: str,
    endpoint_url: str,
    region_name: str,
    service_name: str,
    aws_access_key_id: str = None,
    aws_secret_access_key: str = None,
    expiration: int = 3600,
    **kwargs,
):
    """Function to create presigned url to upload object from bucket

    Parameters
    ----------
    bucket_name : str
        The bucket name in the instance
    object_key : str
        Object path to pre-sign
    endpoint_url: str
        s3 Endpoint i.e https://s3.us-east.cloud-object-storage.appdomain.cloud
    region_name: str
        Region where bucket lives. i.e us-east
    service_name: str
        service to connect to i.e s3
    aws_access_key_id: str
        AWS Access key to the instance
    aws_secret_access_key: str
        AWS secret access key to the instance
    expiration : int, optional
        Expiration duration in seconds, by default 3600

    Returns
    -------
    str
        Presigned upload url
    """

    s3_client = boto3.client(
        service_name,
        region_name=region_name,
        endpoint_url=endpoint_url,
        aws_secret_access_key=aws_secret_access_key,
        aws_access_key_id=aws_access_key_id,
        **kwargs,
    )
    try:
        upload_url = s3_client.generate_presigned_url(
            ClientMethod="put_object",
            Params={"Bucket": bucket_name, "Key": object_key},
            ExpiresIn=expiration,
        )
    except ClientError as e:
        print(f"Error creating presigned URL: {e}")
        return None

    return upload_url

poll_inference_until_finished

poll_inference_until_finished(
    inference_id, poll_frequency=10
)

Polls the status of an inference task until it is completed or failed. Defaults to a minimum of 5seconds poll frequency.

Parameters:

Name Type Description Default
inference_id str

The unique identifier of the inference task.

required
poll_frequency int

The time interval in seconds between polls. Defaults to 5 seconds.

10

Returns:

Name Type Description
dict

The response from the inference task when it is completed or failed.

Source code in geostudio/backends/v2/ginference/client.py
def poll_inference_until_finished(self, inference_id, poll_frequency=10):
    """
    Polls the status of an inference task until it is completed or failed.
    Defaults to a minimum of 5seconds poll frequency.

    Args:
        inference_id (str): The unique identifier of the inference task.
        poll_frequency (int, optional): The time interval in seconds between polls. Defaults to 5 seconds.

    Returns:
        dict: The response from the inference task when it is completed or failed.
    """
    poll_frequency = 10 if poll_frequency < 10 else poll_frequency
    finished = False

    while finished is False:
        r = self.get_inference(inference_id)
        status = r["status"]
        time_taken = (
            datetime.now(timezone.utc)
            - datetime.strptime(r["created_at"], "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc)
        ).seconds

        if "COMPLETED" in status:
            print(status + " - " + str(time_taken) + " seconds")
            finished = True
            return r

        elif status == "FAILED":
            print(status + " - " + str(time_taken) + " seconds")
            finished = True
            return r

        elif status == "STOPPED":
            print(status + " - " + str(time_taken) + " seconds")
            finished = True
            return r

        else:
            print(status + " - " + str(time_taken) + " seconds", end="\r")

        sleep(poll_frequency)