The YAML configuration file: an overview#
If you are using the command-line interface (CLI) to run jobs using TerraTorch, so you must became familiar with YAML, the format used to configure all the workflow within the toolkit. Writing a YAML file is very similar to coding, because even if you are not direclty handling the classes and others structures defined inside a codebase, you need to know how they work, their input argments and their position in the pipeline. In this way, we could call it a "low-code" task. The YAML file used for TerraTorch has an almost closed format, since there are a few fixed fields that must be filled with limited sets of classes, which makes easier for new users to get a pre-existing YAML file and adapt it to their own purposes.
In the next sections, we describe each field of a YAML file used for Earth Observation Foundation Models (EOFM) and try to make it clearer for a new user. However, we will not go into detail, since the complementary documentation (Lightning, PyTorch, ...) must fill this gap. The example can be downloaded here.
Trainer#
In the section called trainer are defined all the arguments that must be directly sent to the Lightning
Trainer object. If you need a deeper explantion about this object, check the Lightning's documentation.
In the first lines we have:
acceleratorrefers to the kind of device is being used to run the experiment. We are usually more interested incpuandgpu, but if you setauto, it will automaticaly select allocate the GPU is that is availble or otherwise run on CPU.strategyis related to the kind of parallelism is available. As we have usually ran the experiments using a single device for finetuning or inference, we do not care about it and choose the optionautoby default.devicesindicates the list of available devices to use for the experiment. Leave it asautoif you are running with a single device.num_nodesis self-explanatory. We have mostly tested TerraTorch for a single-node jobs, so, it is better to set it as1for now.precisionis the kind of precision used for your model.16-mixedhave been an usual choice.
Just below this initial stage, we have logger:
tests/all_ecos_random. 
Others frameworks, as MLFlow are also supported. Check the Lightning documentation
about logging for a more complete description. 
The callbacks field:
  callbacks:
    - class_path: RichProgressBar
    - class_path: LearningRateMonitor
      init_args:
        logging_interval: epoch
    - class_path: EarlyStopping
      init_args:
        monitor: val/loss
        patience: 100
  max_epochs: 1
  check_val_every_n_epoch: 1
  log_every_n_steps: 20
  enable_checkpointing: true
  default_root_dir: tests/
max_epochs: the maximum number of epochs to train the model. Notice that, if you are using early-stopping, maybe the training will finish before achieving this number.check_val_every_n_epoch: the frequency to evaluate the model using the validation dataset. The validation is important to verify if the model is tending to overfit and can be used, for example, to define when update the learning rate, or to invoke the early-stopping.enable_checkpointing: it enables the checkpointing, the action of periodically saving the state of the model to a file.default_root_dir: the directory used to save the model checkpoints.
Datamodule#
In this section, we start direclty handling TerraTorch's built-in structures. The field data is expected to
receive a generic datamodule or any other datamodule compatible with Lightning
Datamodules, as those defined
in our collection of datamodules. 
In the beginning of the field we have:
It means that we have chosen the generic regression datamodule and we will pass all its required arguments belowinit_args and with one new level of identation. The best practice here is to check the documentation of the
datamodule class you are using (in our case, here)
and verify all the arguments it expects to receive ant then to fill the lines
with <argument_name>: <argument_value>. 
As the TerraTorch and Lightning modules were already imported in the CLI script (terratorch/cli_tools.py),
you do not need to provide the complete paths for them. Otherwise, if you are using a datamodule defined in an
external package, indicate the path to import the model, as package.datamodules.SomeDatamodule. 
Model#
The field model is, in fact, the configuration for task + model: 
model:
  class_path: terratorch.tasks.PixelwiseRegressionTask
  init_args:
    model_args:
      decoder: UperNetDecoder
      pretrained: false
      backbone: prithvi_eo_v2_600
      backbone_drop_path_rate: 0.3
      backbone_window_size: 8
      decoder_channels: 64
      num_frames: 1
      in_channels: 6
      bands:
        - BLUE
        - GREEN
        - RED
        - NIR_NARROW
        - SWIR_1
        - SWIR_2
      head_dropout: 0.5708022831486758
      head_final_act: torch.nn.ReLU
      head_learned_upscale_layers: 2
    loss: rmse
    ignore_index: -1
    freeze_backbone: true
    freeze_decoder: false
    model_factory: PrithviModelFactory
    tiled_inference_parameters:
       h_crop: 224
       h_stride: 192
       w_crop: 224
       w_stride: 192
       average_patches: true
model_args, which it is intended to receive all the necessary configuration to
instantiate the model itself, that means, the structure backbone + decoder + head. Inside model_args, it
is possible do define which arguments will be sent to each component by including a prefix to the argument
names, as backbone_<argument> or decoder_<other_argument>. Alternatively, it is possible to pass the
arguments using dictionaries backbone_kwargs, decoder_kwargs and head_kwargs. The same recommendation
made for the data field is repeated here, check the documentation of the task and
model classes (backbones, decoders and heads) you are using in
order to define which arguments to write for each subfield of model. 
Optimizer and Learning Rate Scheduler#
The last two fields of out example are the configuration of the optimizer and the lr scheduler. Those fields are mostly self-explanatory for users already familiar with machine learning:
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 0.00013524680528283027
    weight_decay: 0.047782217873995426
lr_scheduler:
  class_path: ReduceLROnPlateau
  init_args:
    monitor: val/loss