# Config for fine-tuning Llama 3.8B model with long context.
#
# Usage (for 8xA100-40GB GPUs):
#   accelerate launch \
#     --config_file configs/recipes/llama3_1/sft/8b_full/accelerate.yaml \
#     --use_fsdp \
#     --num_processes 8 \
#     --dynamo_backend inductor \
#     --mixed_precision no \
#     -m oumi train \
#     -c "configs/recipes/llama3_1/sft/8b_full/longctx_train.yaml"
#
# See Also:
#   - Documentation: https://oumi.ai/docs/en/latest/user_guides/train/train.html
#   - Config class: oumi.core.configs.TrainingConfig
#   - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py
#   - Other training configs: configs/**/*train.yaml

model:
  model_name: "meta-llama/Llama-3.1-8B-Instruct"
  model_max_length: 32_768
  torch_dtype_str: "bfloat16" # Train in BF16 to save memory
  attn_implementation: "sdpa" # pytorch native flash-attention, helps reduce vram
  load_pretrained_weights: True
  trust_remote_code: True
  enable_liger_kernel: True # Helps reduce required vram

training:
  trainer_type: "TRL_SFT"
  learning_rate: 2.0e-05
  output_dir: "output/llama8b.longctx.fft"

  compile: True # Enabling compilation helps reduce required vram. Might cause issues with checkpointing depending on accelerate version.
  per_device_train_batch_size: 1 # Keeping low to maximize context length
  gradient_accumulation_steps: 1 # Update as needed for your desired effective batch size
  optimizer: "adamw_torch_fused" # Use this for up to 32K context

  logging_steps: 5 # This value is used for debugging. Customize as needed.
  max_steps: 100 # This value is used for debugging. Customize as needed.
  save_steps: 0 # This value is used for debugging. Customize as needed.
  save_final_model: False # This value is used for debugging. Customize as needed.
  include_performance_metrics: True

data:
  train:
    datasets:
      - dataset_name: "HuggingFaceFW/fineweb-edu"
        subset: "sample-10BT"
        split: "train"
        dataset_kwargs:
          seq_length: 32_768
    stream: True
    pack: True
