Source code for qianfan.dataset.schema

# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
schema for validation
currently qianfan schema only
"""
import functools
from abc import ABC, abstractmethod
from typing import Any, Callable

from qianfan.dataset.consts import (
    LLMOutputColumnName,
    NewInputChatColumnName,
    NewInputPromptColumnName,
    OldReferenceColumnName,
)
from qianfan.dataset.table import Table
from qianfan.utils import log_error, log_info


def _data_format_converter(func: Callable) -> Callable:
    @functools.wraps(func)
    def inner(schema: Schema, table: Table, *args: Any, **kwargs: Any) -> bool:
        if table.is_dataset_packed():
            log_info("unpack dataset before validating")
            table.unpack()
            result = func(schema, table, *args, **kwargs)
            log_info("pack dataset after validation")
            table.pack()
            return result

        return func(schema, table, *args, **kwargs)

    return inner


[docs]class Schema(ABC):
[docs] @abstractmethod def validate(self, table: Table) -> bool: """ validate a dataset.Table object currently check field and type only, not content in table Args: table (Table): table need to be validated Returns: bool:whether table is valid """
[docs]class QianfanSchema(Schema): def __init__(self) -> None: """ initialize a new Schema instance """ # 千帆使用,用于作为返回值,表示是否是带标注的数据 self.is_annotated: bool = False
[docs] def validate(self, table: Table) -> bool: return self.is_annotated
# 无排序对话
[docs]class QianfanNonSortedConversation(QianfanSchema): """validator for non-sorted, conversational dataset"""
[docs] @_data_format_converter def validate(self, table: Table) -> bool: """ validate a table Args: table (Table): table need to be validated Returns: bool:whether table is valid """ if table.row_number() == 0: log_error("no data in table") return False col_names = table.col_names() # 本地单轮对话对接千帆的校验规则 if "prompt" not in col_names: log_error("no prompt column in dataset column") return False if table.inner_table.column("prompt").null_count: log_error("prompt column has empty data in dataset column") return False if "response" in col_names: if table.inner_table.column("response").null_count: log_error("response column has empty data in dataset column") return False response_list = table.col_list("response")["response"] for index in range(len(response_list)): response_record = response_list[index] if not ( isinstance(response_record, list) and len(response_record) == 1 and isinstance(response_record[0], list) and len(response_record[0]) == 1 and isinstance(response_record[0][0], str) and response_record[0][0] ): log_error( f"response illegal in dataset row {index}. response data:" f" {response_record}\n" "for accurate dataset format, please check" "https://cloud.baidu.com/doc/WENXINWORKSHOP/s/yliu6bqzw" "#%E6%9C%89%E6%A0%87%E6%B3%A8%E4%BF%A1%E6%81%AF" "-%E6%9C%AC%E5%9C%B0%E5%AF%BC%E5%85%A5" ) return False self.is_annotated = "response" in col_names return True
# 有排序对话
[docs]class QianfanSortedConversation(QianfanSchema): """validator for sorted, conversational dataset"""
[docs] @_data_format_converter def validate(self, table: Table) -> bool: """ validate a table Args: table (Table): table need to be validated Returns: bool:whether table is valid """ if table.row_number() == 0: log_error("no data in table") return False col_names = table.col_names() # 本地单轮对话带排序对接千帆的校验规则 if "prompt" not in col_names: log_error("no prompt column in dataset column") return False if table.inner_table.column("prompt").null_count: log_error("prompt column has empty data in dataset column") return False if "response" in col_names: if table.inner_table.column("response").null_count: log_error("response column has empty data in dataset column") return False response_list = table.col_list("response")["response"] for index in range(len(response_list)): response_record = response_list[index] if not (isinstance(response_record, list) and len(response_record) > 0): log_error( f"response records illegal in dataset row {index}. response" f" data: {response_record}" ) return False for single_response_record in response_record: if not ( isinstance(single_response_record, list) and len(single_response_record) == 1 and isinstance(single_response_record[0], str) and single_response_record[0] ): log_error( f"response illegal in dataset row {index}. response data:" f" {response_record}" "for accurate dataset format, please check" "https://cloud.baidu.com/doc/WENXINWORKSHOP/s/yliu6bqzw" "#%E6%9C%89%E6%A0%87%E6%B3%A8%E4%BF%A1%E6%81%AF" "-%E6%9C%AC%E5%9C%B0%E5%AF%BC%E5%85%A5" ) return False self.is_annotated = "response" in col_names return True
# 泛文本对话
[docs]class QianfanGenericText(QianfanSchema): """validator for generic text dataset"""
[docs] def validate(self, table: Table) -> bool: """ validate a table Args: table (Table): table need to be validated Returns: bool:whether table is valid """ if table.row_number() == 0: log_error("no data in table") return False col_names = table.col_names() if len(col_names) != 1: log_error(f"dataset has more than 1 column: {col_names}") return False if table.inner_table.column(col_names[0]).null_count: log_error("empty row in dataset") return False return True
# 问答集
[docs]class QianfanQuerySet(QianfanSchema): """validator for query set dataset"""
[docs] @_data_format_converter def validate(self, table: Table) -> bool: """ validate a table Args: table (Table): table need to be validated Returns: bool:whether table is valid """ if table.row_number() == 0: log_error("no data in table") return False col_names = table.col_names() if "prompt" not in col_names: log_error("no prompt column in dataset column") return False if table.inner_table.column("prompt").null_count: log_error("prompt column has empty data in dataset column") return False return True
# 文生图
[docs]class QianfanText2Image(QianfanSchema): """validator for text to image dataset"""
[docs] def validate(self, table: Table) -> bool: """ validate a table Args: table (Table): table need to be validated Returns: bool:whether table is valid """ return False
[docs]class EvaluationSchema(Schema): """validator for evaluation used"""
[docs] def validate(self, table: Table) -> bool: """ validate a table Args: table (Table): table need to be validated Returns: bool:whether table is valid """ if len(table) == 0: log_error("table is empty") return False col_names = table.col_names() for column in [OldReferenceColumnName, LLMOutputColumnName]: if column not in col_names: log_error(f"{column} not in dataset columns") return False if ( NewInputPromptColumnName in col_names and NewInputChatColumnName in col_names ): log_error( f"can't have both {NewInputChatColumnName} and" f" {NewInputPromptColumnName} simultaneously" ) return False if NewInputPromptColumnName in col_names: elem_type = table[0][NewInputPromptColumnName] if not isinstance(elem_type, str): log_error( f"element in column {NewInputPromptColumnName} isn't str, rather" f" {type(elem_type)}" ) return False return True if NewInputChatColumnName in col_names: elem_type = table[0][NewInputChatColumnName] if not isinstance(elem_type, str): log_error( f"element in column {NewInputChatColumnName} isn't str, rather" f" {type(elem_type)}" ) return False return True log_error( f"no neither {NewInputChatColumnName} or {NewInputPromptColumnName} found" ) return False