Source code for oumi.analyze.analyzers.quality

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data quality analyzer implementation."""

import re

from pydantic import BaseModel, Field

from oumi.analyze.base import ConversationAnalyzer
from oumi.core.registry import register_sample_analyzer
from oumi.core.types.conversation import Conversation, Message, Role

__all__ = ["DataQualityMetrics", "DataQualityAnalyzer"]

# Invalid serialization patterns: (regex, display_name)
_INVALID_VALUE_PATTERNS = [
    (re.compile(r"\bNaN\b"), "NaN"),
    (re.compile(r"\bnan\b"), "nan"),
    (re.compile(r"\bnull\b"), "null"),
    (re.compile(r"\bNone\b"), "None"),
    (re.compile(r"\bundefined\b"), "undefined"),
]


[docs] class DataQualityMetrics(BaseModel): """Result model for data quality checks on a conversation. Example: >>> result = DataQualityMetrics( ... has_non_alternating_turns=False, ... has_no_user_message=False, ... has_system_message_not_at_start=False, ... has_empty_turns=False, ... empty_turn_count=0, ... has_invalid_values=False, ... invalid_value_patterns=[], ... ) >>> print(result.has_non_alternating_turns) False """ has_non_alternating_turns: bool = Field( description=( "True if non-system messages do NOT strictly alternate between " "user and assistant roles (i.e. consecutive same-role messages exist)" ) ) has_no_user_message: bool = Field( description=( "True if the conversation contains no user message " "(including empty conversations)" ) ) has_system_message_not_at_start: bool = Field( description=( "True if any system message appears after position 0 in the conversation" ) ) has_empty_turns: bool = Field( description="True if any message has empty or whitespace-only content" ) empty_turn_count: int = Field( description="Number of messages with empty or whitespace-only content" ) has_invalid_values: bool = Field( description=( "True if any message contains values serialized as strings " "(e.g. 'NaN', 'null', 'None', 'undefined')" ) ) invalid_value_patterns: list[str] = Field( description="List of invalid value patterns found across all messages" )
[docs] @register_sample_analyzer("quality") class DataQualityAnalyzer(ConversationAnalyzer[DataQualityMetrics]): """Analyzer for basic data quality checks on conversations. Checks for five common data quality issues without requiring an LLM: - Non-alternating user/assistant message patterns - Missing user messages - System messages not at the start of the conversation - Empty or whitespace-only turns - Values serialized as strings (NaN, null, None, undefined) Example: >>> from oumi.analyze.analyzers.quality import DataQualityAnalyzer >>> from oumi.core.types.conversation import Conversation, Message, Role >>> >>> analyzer = DataQualityAnalyzer() >>> conversation = Conversation(messages=[ ... Message(role=Role.USER, content="Hello"), ... Message(role=Role.ASSISTANT, content="Hi there!"), ... ]) >>> result = analyzer.analyze(conversation) >>> print(result.has_non_alternating_turns) False """ _result_model = DataQualityMetrics
[docs] @classmethod def get_config_schema(cls) -> dict: """Get JSON schema for DataQualityAnalyzer configuration.""" return {"properties": {}}
[docs] def analyze(self, conversation: Conversation) -> DataQualityMetrics: """Analyze data quality for a conversation. Args: conversation: The conversation to analyze. Returns: DataQualityMetrics with the quality check results. """ messages = conversation.messages # 1. Non-alternating turns check (ignoring system messages) roles = [m.role.value for m in messages] non_system = [r for r in roles if r != "system"] has_non_alternating = False for i in range(1, len(non_system)): if non_system[i] == non_system[i - 1]: has_non_alternating = True break # 2. No user message check has_no_user = not any(m.role == Role.USER for m in messages) # 3. System message not at position 0 check has_system_not_at_start = any(m.role == Role.SYSTEM for m in messages[1:]) # 4. Empty turns check def _text(m: Message) -> str: return DataQualityAnalyzer.get_text_content(m) empty_count = sum(1 for m in messages if not _text(m).strip()) # 5. Invalid serialized values check patterns_found: set[str] = set() for message in messages: content = _text(message) for pattern, name in _INVALID_VALUE_PATTERNS: if pattern.search(content): patterns_found.add(name) return DataQualityMetrics( has_non_alternating_turns=has_non_alternating, has_no_user_message=has_no_user, has_system_message_not_at_start=has_system_not_at_start, has_empty_turns=empty_count > 0, empty_turn_count=empty_count, has_invalid_values=len(patterns_found) > 0, invalid_value_patterns=sorted(patterns_found), )