# QLoRA config for Deepseek R1 Distill Llama 3.3 70B.
#
# Requirements:
#   - Log into WandB (`wandb login`) or disable `enable_wandb`
#
# Usage:
#   oumi train -c configs/recipes/deepseek_r1/sft/distill_llama_70b/qlora_train.yaml
#
# See Also:
#   - Documentation: https://oumi.ai/docs/en/latest/user_guides/train/train.html
#   - Config class: oumi.core.configs.TrainingConfig
#   - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py
#   - Other training configs: configs/**/*train.yaml

model:
  model_name: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
  model_max_length: 2048
  torch_dtype_str: "bfloat16"
  attn_implementation: "sdpa"
  trust_remote_code: True

data:
  train:
    datasets:
      - dataset_name: "yahma/alpaca-cleaned" # 51,760 examples

training:
  trainer_type: "TRL_SFT"
  use_peft: True
  save_steps: 200
  num_train_epochs: 1
  per_device_train_batch_size: 2
  gradient_accumulation_steps: 1

  enable_gradient_checkpointing: True
  gradient_checkpointing_kwargs:
    use_reentrant: False
  ddp_find_unused_parameters: False
  optimizer: "adamw_torch_fused"
  learning_rate: 3.0e-04
  lr_scheduler_type: "cosine"
  warmup_steps: 100
  weight_decay: 0.01
  compile: False

  dataloader_num_workers: "auto"
  dataloader_prefetch_factor: 32

  logging_steps: 100
  log_model_summary: False
  empty_device_cache_steps: 50
  output_dir: "output/deepseek_r1_llama70b.qlora"
  include_performance_metrics: True
  enable_wandb: True

peft:
  q_lora: True
  # https://github.com/pytorch/torchtune/blob/37337f71677da69f0967a9cde34b96ad7fec3cb6/torchtune/modules/peft/lora.py#L95
  bnb_4bit_quant_type: "nf4"
  # Must use a float type for quantized data storage. See:
  # https://huggingface.co/docs/bitsandbytes/main/en/fsdp_qlora#quantized-data-storage.
  bnb_4bit_quant_storage: "bfloat16"
  bnb_4bit_compute_dtype: "bfloat16"
  lora_r: 16
  lora_alpha: 32
  lora_target_modules:
    - "q_proj"
    - "k_proj"
    - "v_proj"
    - "o_proj"
    - "gate_proj"
    - "up_proj"
    - "down_proj"

fsdp:
  enable_fsdp: True
  forward_prefetch: True
  auto_wrap_policy: "TRANSFORMER_BASED_WRAP"
  transformer_layer_cls: "LlamaDecoderLayer"
