Skip to content

Sample files

Terratorch HPO Segmentation Config

burnscars-iterate-hpo
    defaults:
      terratorch_task:
        model_args:
          backbone: prithvi_eo_v2_300
          backbone_pretrained: true
        model_factory: EncoderDecoderFactory
        optimizer: AdamW
      trainer_args:
        accelerator: gpu
        log_every_n_steps: 1
        logger:
          class_path: lightning.pytorch.loggers.mlflow.MLFlowLogger
          init_args:
            experiment_name: prithvi-eo2
            run_name: prithvi-eo2-burnscars
            save_dir: /working/mlflow
            tracking_uri: http://mlflow.local
        max_epochs: 4
    experiment_name: test_geotune-burnscars-config
    n_trials: 4
    optimization_space:
      lr:
        log: true
        max: 1e-3
        min: 1e-6
        type: real
      optimizer_hparams:
        weight_decay:
          max: 0.4
          min: 0
          type: real
    run_repetitions: 1
    save_models: false
    storage_uri: /data/output
    tasks:
    - datamodule:
        class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule
        init_args:
          batch_size: 4
          constant_scale: 1.0
          dataset_bands:
          - 0
          - 1
          - 2
          - 3
          - 4
          - 5
          img_grep: '*_merged.tif'
          label_grep: '*.mask.tif'
          means:
          - 0.052501344674454685
          - 0.07818710276892571
          - 0.09527748749364956
          - 0.21442636938609524
          - 0.23594506038412516
          - 0.1707575734110789
          no_data_replace: 0
          no_label_replace: -1
          num_classes: 3
          num_workers: 2
          output_bands:
          - 0
          - 1
          - 2
          - 3
          - 4
          - 5
          rgb_indices:
          - 0
          - 1
          - 2
          stds:
          - 0.02980933467336373
          - 0.03671671799312583
          - 0.054445047276513434
          - 0.06953507438690602
          - 0.09133826565407105
          - 0.08313280812483531
          test_data_root: /data/geodata-r8sa4rfthfs9pnqaxxdvra/training_data/HLS_L30/
          test_label_data_root: /data/geodata-r8sa4rfthfs9pnqaxxdvra/labels/
          test_split: /data/geodata-r8sa4rfthfs9pnqaxxdvra/split_files/test_data.txt
          test_transform:
          - class_path: ToTensorV2
          train_data_root: /data/geodata-r8sa4rfthfs9pnqaxxdvra/training_data/HLS_L30/
          train_label_data_root: /data/geodata-r8sa4rfthfs9pnqaxxdvra/labels/
          train_split: /data/geodata-r8sa4rfthfs9pnqaxxdvra/split_files/train_data.txt
          val_data_root: /data/geodata-r8sa4rfthfs9pnqaxxdvra/training_data/HLS_L30/
          val_label_data_root: /data/geodata-r8sa4rfthfs9pnqaxxdvra/labels/
          val_split: /data//geodata-r8sa4rfthfs9pnqaxxdvra/split_files/val_data.txt
      direction: max
      early_prune: false
      early_stop_patience: 10
      metric: val/loss
      name: geotune-burnscars-config
      terratorch_task:
        freeze_backbone: false
        freeze_decoder: false
        ignore_index: -1
        loss: ce
        model_args:
          backbone: prithvi_eo_v2_300
          backbone_bands:
          - 0
          - 1
          - 2
          - 3
          - 4
          - 5
          backbone_drop_path: 0.1
          backbone_pretrained: true
          decoder: UNetDecoder
          decoder_channels:
          - 512
          - 256
          - 128
          - 64
          head_dropout: 0.1
          necks:
          - name: SelectIndices
            indices:
            - 5
            - 11
            - 17
            - 23
          - name: ReshapeTokensToImage
          - name: LearnedInterpolateToPyramidal
          num_classes: 3
        model_factory: EncoderDecoderFactory
        plot_on_val: 2
        tiled_inference_parameters:
          average_patches: false
          h_crop: 224
          h_stride: 196
          w_crop: 224
          w_stride: 196
      type: segmentation

Finetuning Config file

floods-finetuning-config-file with prithvi
    # lightning.pytorch==2.2.0.post0
    seed_everything: 0
    trainer:
      accelerator: auto
      strategy: auto
      devices: auto
      num_nodes: 1
      precision: 16-mixed

      fast_dev_run: false
      max_epochs: 200
      min_epochs: null
      max_steps: -1
      min_steps: null
      max_time: null
      limit_train_batches: null
      limit_val_batches: null
      limit_test_batches: null
      limit_predict_batches: null
      overfit_batches: 0.0
      val_check_interval: null
      check_val_every_n_epoch: 1
      num_sanity_val_steps: null
      log_every_n_steps: 50
      enable_checkpointing: true
      enable_progress_bar: null
      enable_model_summary: null
      accumulate_grad_batches: 1
      gradient_clip_val: null
      gradient_clip_algorithm: null
      deterministic: null
      benchmark: null
      inference_mode: true
      use_distributed_sampler: true
      profiler: null
      detect_anomaly: false
      barebones: false
      plugins: null
      sync_batchnorm: false
      reload_dataloaders_every_n_epochs: 0
    predict_output_dir: null
    data:
      class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule
      init_args:
        batch_size: 4
        num_workers: 8
        train_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2GeodnHand6Bands
        val_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2GeodnHand6Bands
        test_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/S2GeodnHand6Bands
        img_grep: "*_S2GeodnHand.tif"
        label_grep: "*_LabelHand.tif"
        means:
          - 0.107582
          - 0.13471393
          - 0.12520133
          - 0.3236181
          - 0.2341743
          - 0.15878009
        stds:
          - 0.07145836
          - 0.06783548
          - 0.07323416
          - 0.09489725
          - 0.07938496
          - 0.07089546
        num_classes: 2
        predict_data_root: null
        train_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
        val_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
        test_label_data_root: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
        train_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_train_data_S2_geodn.txt
        val_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_valid_data_S2_geodn.txt
        test_split: /dccstor/geofm-finetuning/flood_mapping/sen1floods11/splits/splits/flood_handlabeled/flood_test_data_S2_geodn.txt
        ignore_split_file_extensions: true
        allow_substring_split_file: true
        output_bands:
          - BLUE
          - GREEN
          - RED
          - NIR_NARROW
          - SWIR_1
          - SWIR_2
        dataset_bands:
          - RED
          - GREEN
          - BLUE
          - NIR_NARROW
          - SWIR_1
          - SWIR_2
        predict_dataset_bands:
          - BLUE
          - GREEN
          - RED
          - NIR_NARROW
          - SWIR_1
          - SWIR_2
        constant_scale: 0.0001
        rgb_indices:
          - 2
          - 1
          - 0
        train_transform: null
        val_transform: null
        test_transform: null
        expand_temporal_dimension: false
        reduce_zero_label: false
        no_data_replace: 0.0
        drop_last: true
    model:
      class_path: terratorch.tasks.SemanticSegmentationTask
      init_args:
        model_args:
          decoder: UperNetDecoder
          pretrained: false
          backbone: prithvi_swin_B
          backbone_drop_path_rate: 0.3
          backbone_window_size: 7
          decoder_channels: 256
          in_channels: 6
          bands:
            - BLUE
            - GREEN
            - RED
            - NIR_NARROW
            - SWIR_1
            - SWIR_2
          num_frames: 1
          num_classes: 2
          head_dropout: 0.1
          head_channel_list:
            - 256
        model_factory: PrithviModelFactory
        loss: ce
        aux_heads: null
        aux_loss: null
        class_weights:
          - 0.3
          - 0.7
        ignore_index: -1
        lr: 0.001
        optimizer: torch.optim.Adam
        optimizer_hparams: null
        scheduler: null
        scheduler_hparams: null
        freeze_backbone: false
        freeze_decoder: false
        plot_on_val: 2
        class_names: null
        tiled_inference_parameters: null
    optimizer:
      class_path: torch.optim.AdamW
      init_args:
        lr: 6.0e-05
        betas:
          - 0.9
          - 0.999
        eps: 1.0e-08
        weight_decay: 0.05
        amsgrad: false
        maximize: false
        foreach: null
        capturable: false
        differentiable: false
        fused: null

Sample file to create user defined tuning templates

This sample file demonstrates how to create a user-defined tuning template. To use this template when submitting a task:

  1. Download or copy the YAML configuration below.
  2. Encode the file content using Base64.
  3. Insert the encoded content into the "content" field of your task JSON payload.
sample-convnext-config
    # lightning.pytorch==2.1.1
    seed_everything: 0
    trainer:
      accelerator: auto
      strategy: auto
      devices: auto
      num_nodes: 1
      precision: 16-mixed

      callbacks:
        - class_path: RichProgressBar
        - class_path: LearningRateMonitor
          init_args:
            logging_interval: epoch
        # - class_path: ModelCheckpoint
        #   init_args:
        #       mode: min
        #       monitor: val/loss
        #       filename: best-{epoch:02d}
        - class_path: EarlyStopping
          init_args:
            monitor: val/loss
            patience: 20
        # ---- Early stop if ----
        # ---- Early stop endif ----
      max_epochs: 100
      check_val_every_n_epoch: 1
      log_every_n_steps: 5
      enable_checkpointing: false
      default_root_dir: logs/

    data:
      class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule
      init_args:
        batch_size: 16
        num_workers: 16
        no_label_replace: -1
        no_data_replace: 0
        constant_scale: 1.0
        dataset_bands:
          - 'RED'
          - 'GREEN'
          - 'BLUE'

        output_bands:
          - 'RED'
          - 'GREEN'
          - 'BLUE'

        rgb_indices:
          - 0
          - 1
          - 2

        train_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/
        train_label_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/
        val_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/
        val_label_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/
        test_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/
        test_label_data_root: AerialImageDatasetTiledMergedFixedLabels_sample/
        img_grep: "*train.tif"
        label_grep: "*label.tif"
        train_split: train.txt
        val_split: val.txt
        test_split: test.txt
        # constant_scale: 0.0039
        # means: [0.485, 0.456, 0.406]
        # stds: [0.229, 0.224, 0.225]
        means:
          - 104.24203383423682
          - 109.92963788132441
          - 100.98120642006803

        stds:
          - 51.593745217159935
          - 47.218880227273814
          - 45.45813303733705

        check_stackability: false

        num_classes: 2

        train_transform:
          - class_path: albumentations.D4
          - class_path: ToTensorV2
        val_transform:
          - class_path: ToTensorV2
        test_transform:
          - class_path: ToTensorV2
    model:
      class_path: terratorch.tasks.SemanticSegmentationTask
      init_args:
        model_factory: EncoderDecoderFactory
        model_args:
          backbone: timm_convnext_large.fb_in22k
          num_classes: 2
          backbone_pretrained: true
          necks:
          # - name: SelectIndices
          #   indices: [1,2,3,4]
          decoder: UNetDecoder
          decoder_channels: [512, 256, 128, 64]
          head_channel_list: [256]
          head_dropout: 0.1
        loss: dice
        # loss: ce
        plot_on_val: 2
        ignore_index: -1
        freeze_backbone: false
        freeze_decoder: false

        # tiled_inference_parameters:
        #   h_crop: 224
        #   h_stride: 198
        #   w_crop: 224
        #   w_stride: 198
        #   average_patches: True

    optimizer:
      class_path: torch.optim.AdamW
      init_args:
        lr: 3e-05
        # betas:
        # - 0.9
        # - 0.999
        # eps: 1.0e-08
        # weight_decay: 0.05
        # amsgrad: false
        # maximize: false
        # capturable: false
        # differentiable: false
        # ---- Optimizer stop if ----
    # lr_scheduler:
    #   class_path: CosineAnnealingLR
    #   init_args:
    #     T_max: 20

    # lr_scheduler_interval: step
    # lr_scheduler:
    #   class_path: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
    #   init_args:
    #     T_0: 1000   # first cycle: 1000 steps
    #     T_mult: 2   # cycles: 1000, 2000, 4000, ... (fits well in 10k)
    #     eta_min: 1.0e-6
    lr_scheduler:
      class_path: lightning.pytorch.cli.ReduceLROnPlateau
      init_args:
        monitor: val/loss
        mode: min
        factor: 0.5
        patience: 5
        threshold: 0.0001
        threshold_mode: rel
        cooldown: 0
        min_lr: 0.0
        eps: 1.0e-08

Example usage:

{
    "name": "user-new-task",
    "description": "Custom ConvNeXT tuning task",
    "purpose": "Other", // Do not change
    "content": "<BASE64_ENCODED_YAML_HERE>", // Paste your Base64-encoded YAML here
    "extra_info": {
        "runtime_image": "quay.io/geospatial-studio/terratorch:latest",
        "model_framework": "terratorch-v2"
    },
    "model_params": {},
    "dataset_id": "selected_dataset"
}
How to encode the YAML file to Base64

Using command line:

    base64 sample-convnext-config.yaml
Using Python:
    import base64

    def encode_file_to_base64(file_path):
        with open(file_path, "rb") as file:
            # Read the file in binary mode
            file_content = file.read()

            # Encode the content to base64
            base64_encoded = base64.b64encode(file_content)

            # Decode the Base64 bytes into a string (if needed)
            base64_string = base64_encoded.decode('utf-8')

        return base64_string

    encoded_content = encode_file_to_base64("../sample_files/sample-convnext-config.yaml")
    print(encoded_content)
Using online tools:

Visit base64encode.org and paste your YAML content.