seed_everything: 42
trainer:
  accelerator: cpu
  strategy: auto
  devices: auto
  num_nodes: 1
  precision: 16-mixed
  logger:
    class_path: TensorBoardLogger
    init_args:
      save_dir: tests/
      name: all_ecos_random
  callbacks:
    - class_path: RichProgressBar
    - class_path: LearningRateMonitor
      init_args:
        logging_interval: epoch
    - class_path: EarlyStopping
      init_args:
        monitor: val/loss
        patience: 100
  max_epochs: 1
  check_val_every_n_epoch: 1
  log_every_n_steps: 20
  enable_checkpointing: true
  default_root_dir: tests/
data:
  class_path: GenericNonGeoPixelwiseRegressionDataModule
  init_args:
    batch_size: 2
    num_workers: 4
    train_transform:
      - class_path: albumentations.HorizontalFlip
        init_args:
          p: 0.5      
      - class_path: albumentations.Rotate
        init_args:
          limit: 30
          border_mode: 0 # cv2.BORDER_CONSTANT
          value: 0
          # mask_value: 1
          p: 0.5
      - class_path: ToTensorV2
    dataset_bands:
      - 0
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
      - 1
      - 2
      - 3
      - 4
    output_bands:
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
    rgb_indices:
      - 2
      - 1
      - 0
    train_data_root: tests/resources/inputs
    train_label_data_root: tests/resources/inputs
    val_data_root: tests/resources/inputs
    val_label_data_root: tests/resources/inputs
    test_data_root: tests/resources/inputs
    test_label_data_root: tests/resources/inputs 
    img_grep: "regression*input*.tif"
    label_grep: "regression*label*.tif"
    means:
    - 547.36707
    - 898.5121
    - 1020.9082
    - 2665.5352
    - 2340.584
    - 1610.1407
    stds:
    - 411.4701
    - 558.54065
    - 815.94025
    - 812.4403
    - 1113.7145
    - 1067.641
    no_label_replace: -1
    no_data_replace: 0

model:
  class_path: terratorch.tasks.PixelwiseRegressionTask
  init_args:
    model_args:
      decoder: UperNetDecoder
      pretrained: false
      backbone: prithvi_eo_v2_600
      backbone_drop_path_rate: 0.3
      backbone_window_size: 8
      decoder_channels: 64
      num_frames: 1
      in_channels: 6
      bands:
        - BLUE
        - GREEN
        - RED
        - NIR_NARROW
        - SWIR_1
        - SWIR_2
      head_dropout: 0.5708022831486758
      head_final_act: torch.nn.ReLU
      head_learned_upscale_layers: 2
    loss: rmse
    ignore_index: -1
    freeze_backbone: true
    freeze_decoder: false
    model_factory: PrithviModelFactory
    tiled_inference_parameters:
       h_crop: 224
       h_stride: 192
       w_crop: 224
       w_stride: 192
       average_patches: true
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 0.00013524680528283027
    weight_decay: 0.047782217873995426
lr_scheduler:
  class_path: ReduceLROnPlateau
  init_args:
    monitor: val/loss

