# Sample GRPO training config for the Qwen2-0.5B-Instruct model.
#
# Requirements:
#   - Log into WandB (`wandb login`) or disable `enable_wandb`
#
# Usage:
#   oumi train -c configs/examples/grpo_tldr/train.yaml
#
# See Also:
#   - Documentation: https://oumi.ai/docs/en/latest/user_guides/train/train.html
#   - Config class: oumi.core.configs.TrainingConfig
#   - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py
#   - Other training configs: configs/**/*train.yaml

model:
  model_name: "Qwen/Qwen2-0.5B-Instruct"
  model_max_length: 8192
  torch_dtype_str: "bfloat16"
  attn_implementation: "sdpa"

data:
  train:
    datasets:
      - dataset_name: "trl-lib/tldr"
        split: "train"

training:
  trainer_type: "TRL_GRPO"
  save_steps: 500
  per_device_train_batch_size: 2 # 1 for GPUs with 24GB VRAM
  gradient_accumulation_steps: 1

  reward_functions: ["soft_20tokens_completions"]

  enable_gradient_checkpointing: False
  gradient_checkpointing_kwargs:
    use_reentrant: False
  ddp_find_unused_parameters: False
  optimizer: "adamw_torch"
  empty_device_cache_steps: 1
  compile: False # Enabling compilation makes training slower overall

  grpo:
    use_vllm: True
    vllm_gpu_memory_utilization: 0.4

  dataloader_num_workers: "auto"
  dataloader_prefetch_factor: 32

  max_steps: 500
  logging_steps: 10
  log_model_summary: False
  output_dir: "output/tldr.grpo"
  # Cannot enable due to trl bug. See https://github.com/huggingface/trl/issues/3081.
  include_performance_metrics: False
  enable_wandb: True

  trainer_kwargs:
    vllm_mode: "colocate"
