# DDP pretrain config for the FineWeb ablation model on GCP.
#
# Requirements:
#   - Log into WandB (`wandb login`) or disable `enable_wandb`
#
# Usage:
#   oumi train -c configs/examples/fineweb_ablation_pretraining/ddp/train.yaml
#
# See Also:
#   - Documentation: https://oumi.ai/docs/en/latest/user_guides/train/train.html
#   - Config class: oumi.core.configs.TrainingConfig
#   - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py
#   - Other training configs: configs/**/*train.yaml

model:
  model_name: "HuggingFaceFW/ablation-model-fineweb-v1"
  model_max_length: 2048
  torch_dtype_str: "bfloat16"
  attn_implementation: "sdpa"
  tokenizer_pad_token: "<|endoftext|>"
  load_pretrained_weights: False
  trust_remote_code: True

data:
  train:
    datasets:
      - dataset_name: "HuggingFaceFW/fineweb-edu"
        subset: "sample-10BT"
        split: "train"
        dataset_kwargs:
          seq_length: 2048
      # Polaris copy of the dataset:
      # - dataset_name: "/eagle/community_ai/datasets/fineweb-edu/sample-10BT"
      #   subset: "default"
      #   split: "train"
    stream: True
    pack: True

training:
  trainer_type: "TRL_SFT" # or OUMI
  save_steps: 500
  # If gradient checkpointing is enabled: use 12 (~94% of 40GB VRAM)
  per_device_train_batch_size: 4
  gradient_accumulation_steps: 64

  enable_gradient_checkpointing: False
  gradient_checkpointing_kwargs:
    use_reentrant: False
  ddp_find_unused_parameters: False
  optimizer: "adafactor"
  compile: True

  lr_scheduler_type: "cosine_with_min_lr"
  lr_scheduler_kwargs:
    min_lr: 3.0e-5
  learning_rate: 3.0e-4
  warmup_steps: 500
  weight_decay: 0.1

  dataloader_num_workers: "auto"
  dataloader_prefetch_factor: 32

  logging_steps: 10
  log_model_summary: False
  output_dir: "output/fineweb.pt"
  include_performance_metrics: True
  enable_wandb: True
