# Full fine-tune config for Llama 3.2 11B Vision Instruct.
#
# Requirements:
#   - Log into WandB (`wandb login`) or disable `enable_wandb`
#   - Log into HF: `hf auth login`
#   - Request access to Llama 3.2: https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct
#
# Usage:
#   oumi train -c configs/recipes/vision/llama3_2_vision/sft/11b_full/train.yaml
#
# See Also:
#   - Documentation: https://oumi.ai/docs/en/latest/user_guides/train/train.html
#   - Config class: oumi.core.configs.TrainingConfig
#   - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py
#   - Other training configs: configs/**/*train.yaml

model:
  model_name: "meta-llama/Llama-3.2-11B-Vision-Instruct"
  torch_dtype_str: "bfloat16"
  model_max_length: 1024
  attn_implementation: "sdpa"
  chat_template: "llama3-instruct"
  freeze_layers:
    - "vision_model"

data:
  train:
    collator_name: "vision_language_with_padding"
    use_torchdata: True
    datasets:
      - dataset_name: "merve/vqav2-small"
        split: "validation"
        shuffle: True
        seed: 42
        transform_num_workers: "auto"
        dataset_kwargs:
          processor_name: "meta-llama/Llama-3.2-11B-Vision-Instruct"
          # limit: 4096 # Uncomment to limit dataset size!
          return_tensors: True
      # - dataset_name: "HuggingFaceH4/llava-instruct-mix-vsft"
      #   split: "train"
      #   shuffle: True
      #   seed: 42
      #   transform_num_workers: "auto"
      #   dataset_kwargs:
      #     processor_name: "meta-llama/Llama-3.2-11B-Vision-Instruct"
      #     return_tensors: True
      # - dataset_name: vision_language_jsonl
      #   dataset_path: "training.jsonl"  # See notebook for example how to generate this file
      #   dataset_kwargs:
      #     data_column: "messages"
      #     processor_name: "meta-llama/Llama-3.2-11B-Vision-Instruct"

training:
  output_dir: "output/vlm_finetuned"
  trainer_type: "TRL_SFT"
  # TODO: OPE-875 - Re-enable. Currently broken at `transformers==4.48.2`.
  # GitHub issue: https://github.com/huggingface/transformers/issues/36040.
  enable_gradient_checkpointing: False
  per_device_train_batch_size: 2
  gradient_accumulation_steps: 4
  max_steps: 20

  gradient_checkpointing_kwargs:
    # Reentrant docs: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
    use_reentrant: False
  ddp_find_unused_parameters: False
  empty_device_cache_steps: 1
  compile: False

  optimizer: "adamw_torch_fused"
  learning_rate: 2e-5
  warmup_ratio: 0.03
  weight_decay: 0.0
  lr_scheduler_type: "cosine"

  logging_steps: 5
  save_steps: 0
  save_final_model: True
  dataloader_num_workers: "auto"
  dataloader_prefetch_factor: 16
  include_performance_metrics: True
  log_model_summary: False
  enable_wandb: True

fsdp:
  enable_fsdp: True
  sharding_strategy: "HYBRID_SHARD"
  forward_prefetch: True
  auto_wrap_policy: "TRANSFORMER_BASED_WRAP"
  transformer_layer_cls: "transformers.models.mllama.modeling_mllama.MllamaSelfAttentionDecoderLayer,transformers.models.mllama.modeling_mllama.MllamaCrossAttentionDecoderLayer,transformers.models.mllama.modeling_mllama.MllamaVisionEncoderLayer"
