# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from collections.abc import Callable
import oumi.core.constants as constants
from oumi.core.collators.text_collator_with_padding import TextCollatorWithPadding
from oumi.core.collators.text_completions_collator_with_padding import (
TextCompletionsCollatorWithPadding,
)
from oumi.core.collators.vision_language_collator_with_padding import (
VisionLanguageCollatorWithPadding,
)
from oumi.core.collators.vision_language_sft_collator import VisionLanguageSftCollator
from oumi.core.configs import DatasetSplit, TrainingConfig
from oumi.core.configs.internal.supported_models import (
find_internal_model_config,
)
from oumi.core.configs.params.data_params import TrainTarget
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
from oumi.utils.logging import logger
_VERY_LARGE_INTEGER = int(1e30)
_SENTINEL_SYS = "<<__S__>>"
_SENTINEL_USER = "<<__U__>>"
_SENTINEL_ASST = "<<__A__>>"
_FIX_HINT = (
"Fix: provide response_template (and end_of_turn_template for "
"all_assistant_turns) in collator_kwargs."
)
def _detect_eot_template(
tokenizer: "BaseTokenizer",
after_text: str,
between_text: str,
) -> tuple[list[int], str]:
"""Detect end-of-turn token IDs and template string.
Compares token-ID prefixes of the text after the last assistant turn
(end-of-sequence) with the text between assistant turns (mid-conversation).
Primary: longest common token-ID prefix.
Fallback: first token of between_text (for models like GPT OSS
that use different mid-conversation vs end-of-sequence tokens).
Returns:
(eot_ids, end_of_turn_template)
"""
after_ids = tokenizer.encode(after_text, add_special_tokens=False)
between_ids = tokenizer.encode(between_text, add_special_tokens=False)
prefix_len = 0
for a, b in zip(after_ids, between_ids):
if a != b:
break
prefix_len += 1
eot_ids = after_ids[:prefix_len]
if not eot_ids and between_ids:
eot_ids = between_ids[:1]
eot_decoded = tokenizer.decode(eot_ids, skip_special_tokens=False)
assert isinstance(eot_decoded, str)
return eot_ids, eot_decoded
def _detect_response_template(
tokenizer: "BaseTokenizer",
header_text: str,
eot_ids: list[int],
) -> str:
"""Detect the assistant response header from the user-to-assistant boundary.
Strips the leading end-of-turn prefix (which belongs to the previous
turn, not the response header) and any ``<think>`` blocks injected
by reasoning-model chat templates (e.g. Qwen3).
Returns:
response_template string
"""
resp_ids = tokenizer.encode(header_text, add_special_tokens=False)
eot_len = len(eot_ids)
if eot_len > 0 and resp_ids[:eot_len] == eot_ids:
resp_ids = resp_ids[eot_len:]
resp_decoded = tokenizer.decode(resp_ids, skip_special_tokens=False)
assert isinstance(resp_decoded, str)
response_template = resp_decoded
if "<think>" in response_template:
idx = response_template.index("<think>")
stripped = response_template[:idx].rstrip()
if stripped:
logger.info(
"Stripped <think> block from auto-detected response_template: %r -> %r",
response_template,
stripped,
)
response_template = stripped
else:
raise ValueError(
f"Extracted response_template is only a <think> block.\n{_FIX_HINT}"
)
response_template = response_template.rstrip("\n")
return response_template
[docs]
def resolve_collator_templates(
tokenizer: "BaseTokenizer",
) -> tuple[str, str]:
"""Auto-detect response_template and end_of_turn_template.
Applies the chat template to a known test conversation, then finds
the assistant boundary strings in the rendered output.
Returns:
(response_template, end_of_turn_template)
Raises:
ValueError: If templates cannot be extracted.
"""
msgs_with_sys = [
{"role": "system", "content": _SENTINEL_SYS},
{"role": "user", "content": _SENTINEL_USER},
{"role": "assistant", "content": _SENTINEL_ASST},
{"role": "user", "content": _SENTINEL_USER},
{"role": "assistant", "content": _SENTINEL_ASST},
]
msgs_no_sys = msgs_with_sys[1:]
rendered = None
for msgs in (msgs_with_sys, msgs_no_sys):
try:
rendered = tokenizer.apply_chat_template(
msgs, tokenize=False, add_generation_prompt=False
)
break
except Exception:
continue
if rendered is None:
raise ValueError(
f"Tokenizer has no chat template or it failed to render.\n{_FIX_HINT}"
)
if not isinstance(rendered, str):
raise ValueError(
f"Chat template returned a non-string type ({type(rendered).__name__}).\n"
f"{_FIX_HINT}"
)
# Locate boundaries around the second turn pair
# to avoid system-prompt effects on the first turn.
try:
first_asst = rendered.index(_SENTINEL_ASST)
first_asst_end = first_asst + len(_SENTINEL_ASST)
second_user = rendered.index(_SENTINEL_USER, first_asst_end)
second_user_end = second_user + len(_SENTINEL_USER)
second_asst = rendered.index(_SENTINEL_ASST, second_user_end)
second_asst_end = second_asst + len(_SENTINEL_ASST)
except ValueError:
raise ValueError(
"Could not locate assistant turn boundaries in the rendered "
f"chat template.\n{_FIX_HINT}"
)
eot_ids, end_of_turn_template = _detect_eot_template(
tokenizer,
after_text=rendered[second_asst_end:],
between_text=rendered[first_asst_end:second_user],
)
response_template = _detect_response_template(
tokenizer,
header_text=rendered[second_user_end:second_asst],
eot_ids=eot_ids,
)
if not response_template.strip():
raise ValueError(f"Extracted response_template is empty.\n{_FIX_HINT}")
if not end_of_turn_template.strip():
raise ValueError(f"Extracted end_of_turn_template is empty.\n{_FIX_HINT}")
return response_template, end_of_turn_template
[docs]
def build_data_collator(
collator_name: str,
tokenizer: BaseTokenizer,
*,
max_length: int | None,
label_ignore_index: int | None = constants.LABEL_IGNORE_INDEX,
debug: bool = False,
**kwargs,
) -> Callable:
"""Builds a data collator based on the given collator name.
Args:
collator_name: The name of the collator to build.
Supported values are:
- "text_with_padding": Uses `TextCollatorWithPadding`.
- "text_completions_only_with_padding": Uses
`TextCompletionsCollatorWithPadding`. Supports optional
``end_of_turn_template`` for tool-aware span-based masking.
- "vision_language_with_padding": Uses `VisionLanguageCollatorWithPadding`.
- "vision_language_sft": Uses `VisionLanguageSftCollator`.
tokenizer: A tokenizer.
max_length: An optional maximum sequence length.
label_ignore_index: If set, then label values of tokens that shouldn't
contribute to the loss computation will be replaced by this special value.
For example, this can be `PAD`, or image tokens.
PyTorch convention is to use -100 as the `ignore_index` label. Refer to
the `ignore_index` parameter of `torch.nn.CrossEntropyLoss()`
for more details.
debug: If True, logs a single example for debugging purposes.
**kwargs: Additional keyword arguments to pass to the collator constructor.
Returns:
Callable: The data collator function or class.
Raises:
ValueError: If an unsupported collator name is provided.
"""
if not collator_name:
raise ValueError("Empty data collator name.")
enable_truncation: bool = False
if max_length is not None and max_length > 0:
enable_truncation = True
if (
tokenizer.model_max_length is not None
and tokenizer.model_max_length < _VERY_LARGE_INTEGER
and max_length != tokenizer.model_max_length
):
logger.warning(
f"Data collator's maximum length: ({max_length}) is "
+ (
"greater than"
if max_length > tokenizer.model_max_length
else "less than"
)
+ f" tokenizer's model maximum length ({tokenizer.model_max_length})"
)
if collator_name == "text_with_padding":
return TextCollatorWithPadding(
tokenizer=tokenizer,
max_length=max_length,
truncation=enable_truncation,
label_ignore_index=label_ignore_index,
debug=debug,
**kwargs,
)
elif collator_name == "vision_language_with_padding":
return VisionLanguageCollatorWithPadding(
tokenizer=tokenizer,
max_length=max_length,
truncation=enable_truncation,
label_ignore_index=label_ignore_index,
debug=debug,
**kwargs,
)
elif collator_name == "vision_language_sft":
processor_name = kwargs.pop("processor_name", None)
if not processor_name:
raise ValueError(f"Empty processor_name for '{collator_name}'")
processor_kwargs = kwargs.pop("processor_kwargs", None)
return VisionLanguageSftCollator(
tokenizer=tokenizer,
processor_name=processor_name,
processor_kwargs=processor_kwargs,
max_length=max_length,
truncation=enable_truncation,
label_ignore_index=label_ignore_index,
**kwargs,
)
elif collator_name == "text_completions_only_with_padding":
if not kwargs.get("response_template"):
raise ValueError(
"'text_completions_only_with_padding' requires a response_template.\n"
"Fix: set train_target in your data config (auto-resolves templates "
"from the tokenizer), or provide response_template in collator_kwargs."
)
if not kwargs.get("train_target"):
raise ValueError(
"'text_completions_only_with_padding' requires a train_target.\n"
"Fix: set train_target in your data config, or provide "
"train_target in collator_kwargs."
)
ignore_index = kwargs.pop(
"ignore_index",
label_ignore_index if label_ignore_index is not None else -100,
)
return TextCompletionsCollatorWithPadding(
tokenizer=tokenizer,
debug=debug,
ignore_index=ignore_index,
**kwargs,
)
raise ValueError(f"Unknown data collator name: '{collator_name}'")
[docs]
def build_collator_from_config(
config: TrainingConfig, tokenizer: BaseTokenizer | None, debug: bool = False
) -> Callable | None:
"""Creates data collator if specified in config."""
train_split = config.data.get_split(DatasetSplit.TRAIN)
if not train_split.collator_name:
return None
collator_name: str = train_split.collator_name
if tokenizer is None:
raise ValueError(
"Tokenizer must be provided if collator is specified! "
f"collator: '{collator_name}'"
)
model_config = find_internal_model_config(config.model)
label_ignore_index: int | None = (
config.training.label_ignore_index
if config.training.label_ignore_index is not None
else (
model_config.label_ignore_index
if model_config is not None
else constants.LABEL_IGNORE_INDEX
)
)
collator_kwargs = {}
if (
collator_name in ("vision_language_with_padding", "vision_language_sft")
and model_config is not None
and model_config.visual_config is not None
):
collator_kwargs["allow_multi_image_inputs"] = (
model_config.visual_config.supports_multiple_images
)
if collator_name == "vision_language_with_padding":
collator_kwargs["main_image_feature"] = (
model_config.visual_config.main_image_feature
)
if collator_name == "vision_language_sft":
processor_name = collator_kwargs.get(
"processor_name", config.model.tokenizer_name or config.model.model_name
)
if not processor_name:
raise ValueError(f"Processor name must be provided for '{collator_name}'!")
collator_kwargs["processor_name"] = processor_name
collator_kwargs["processor_kwargs"] = config.model.processor_kwargs
collator_kwargs["trust_remote_code"] = collator_kwargs.get(
"trust_remote_code", config.model.trust_remote_code
)
# --- Resolve train_target and templates ---
config_collator_kwargs = train_split.collator_kwargs or {}
if collator_name == "text_completions_only_with_padding":
if train_split.train_target is not None:
# Path 1: train_target is set, auto-detect templates from
# the tokenizer's chat template. Falls back to user-provided
# response_template in collator_kwargs if auto-detection fails.
collator_kwargs["train_target"] = train_split.train_target.value
try:
response_template, end_of_turn_template = resolve_collator_templates(
tokenizer
)
collator_kwargs["response_template"] = response_template
if train_split.train_target == TrainTarget.ALL_ASSISTANT_TURNS:
collator_kwargs["end_of_turn_template"] = end_of_turn_template
except ValueError:
if config_collator_kwargs.get("response_template") is None:
raise
if (
train_split.train_target == TrainTarget.ALL_ASSISTANT_TURNS
and "end_of_turn_template" not in collator_kwargs
and config_collator_kwargs.get("end_of_turn_template") is None
):
raise ValueError(
"train_target='all_assistant_turns' requires end_of_turn_template, "
"but auto-detection failed.\n"
"Fix: provide end_of_turn_template in collator_kwargs."
)
elif config_collator_kwargs.get("response_template") is not None:
# Path 2: train_target not set, templates provided manually
# via collator_kwargs. Infer train_target from which templates
# are present.
has_eot = config_collator_kwargs.get("end_of_turn_template") is not None
has_inst = config_collator_kwargs.get("instruction_template") is not None
if has_eot:
collator_kwargs["train_target"] = "all_assistant_turns"
elif has_inst:
warnings.warn(
"Instruction-based masking is deprecated.\n"
"Use train_target='all_assistant_turns'"
"or train_target='final_assistant_turn' instead.",
DeprecationWarning,
stacklevel=2,
)
collator_kwargs["train_target"] = "_legacy_instruction_response"
else:
collator_kwargs["train_target"] = "final_assistant_turn"
else:
raise ValueError(
"'text_completions_only_with_padding' collator requires"
" configuration.\n"
"Fix: set train_target in your data config, "
"or provide response_template in collator_kwargs."
)
# User-provided collator_kwargs override auto-resolved values
collator_kwargs.update(config_collator_kwargs)
return build_data_collator(
collator_name=collator_name,
tokenizer=tokenizer,
max_length=config.model.model_max_length,
label_ignore_index=label_ignore_index,
debug=debug,
**collator_kwargs,
)