# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
schema for validation
currently qianfan schema only
"""
import functools
from abc import ABC, abstractmethod
from typing import Any, Callable
from qianfan.dataset.table import Table
from qianfan.utils import log_error, log_info
def _data_format_converter(func: Callable) -> Callable:
@functools.wraps(func)
def inner(schema: Schema, table: Table, *args: Any, **kwargs: Any) -> bool:
if table.is_dataset_packed():
log_info("unpack dataset before validating")
table.unpack()
result = func(schema, table, *args, **kwargs)
log_info("pack dataset after validation")
table.pack()
return result
return func(schema, table, *args, **kwargs)
return inner
[docs]class Schema(ABC):
[docs] @abstractmethod
def validate(self, table: Table) -> bool:
"""
validate a dataset.Table object
currently check field and type only, not content in table
Args:
table (Table): table need to be validated
Returns:
bool:whether table is valid
"""
[docs]class QianfanSchema(Schema):
def __init__(self) -> None:
"""
initialize a new Schema instance
"""
# 千帆使用,用于作为返回值,表示是否是带标注的数据
self.is_annotated: bool = False
[docs] def validate(self, table: Table) -> bool:
return self.is_annotated
# 无排序对话
[docs]class QianfanNonSortedConversation(QianfanSchema):
"""validator for non-sorted, conversational dataset"""
[docs] @_data_format_converter
def validate(self, table: Table) -> bool:
"""
validate a table
Args:
table (Table): table need to be validated
Returns:
bool:whether table is valid
"""
if table.row_number() == 0:
log_error("no data in table")
return False
col_names = table.col_names()
# 本地单轮对话对接千帆的校验规则
if "prompt" not in col_names:
log_error("no prompt column in dataset column")
return False
if table.inner_table.column("prompt").null_count:
log_error("prompt column has empty data in dataset column")
return False
if "response" in col_names:
if table.inner_table.column("response").null_count:
log_error("response column has empty data in dataset column")
return False
response_list = table.col_list("response")["response"]
for index in range(len(response_list)):
response_record = response_list[index]
if not (
isinstance(response_record, list)
and len(response_record) == 1
and isinstance(response_record[0], list)
and len(response_record[0]) == 1
and isinstance(response_record[0][0], str)
and response_record[0][0]
):
log_error(
f"response illegal in dataset row {index}. response data:"
f" {response_record}\n"
"for accurate dataset format, please check"
"https://cloud.baidu.com/doc/WENXINWORKSHOP/s/yliu6bqzw"
"#%E6%9C%89%E6%A0%87%E6%B3%A8%E4%BF%A1%E6%81%AF"
"-%E6%9C%AC%E5%9C%B0%E5%AF%BC%E5%85%A5"
)
return False
self.is_annotated = "response" in col_names
return True
# 有排序对话
[docs]class QianfanSortedConversation(QianfanSchema):
"""validator for sorted, conversational dataset"""
[docs] @_data_format_converter
def validate(self, table: Table) -> bool:
"""
validate a table
Args:
table (Table): table need to be validated
Returns:
bool:whether table is valid
"""
if table.row_number() == 0:
log_error("no data in table")
return False
col_names = table.col_names()
# 本地单轮对话带排序对接千帆的校验规则
if "prompt" not in col_names:
log_error("no prompt column in dataset column")
return False
if table.inner_table.column("prompt").null_count:
log_error("prompt column has empty data in dataset column")
return False
if "response" in col_names:
if table.inner_table.column("response").null_count:
log_error("response column has empty data in dataset column")
return False
response_list = table.col_list("response")["response"]
for index in range(len(response_list)):
response_record = response_list[index]
if not (isinstance(response_record, list) and len(response_record) > 0):
log_error(
f"response records illegal in dataset row {index}. response"
f" data: {response_record}"
)
return False
for single_response_record in response_record:
if not (
isinstance(single_response_record, list)
and len(single_response_record) == 1
and isinstance(single_response_record[0], str)
and single_response_record[0]
):
log_error(
f"response illegal in dataset row {index}. response data:"
f" {response_record}"
"for accurate dataset format, please check"
"https://cloud.baidu.com/doc/WENXINWORKSHOP/s/yliu6bqzw"
"#%E6%9C%89%E6%A0%87%E6%B3%A8%E4%BF%A1%E6%81%AF"
"-%E6%9C%AC%E5%9C%B0%E5%AF%BC%E5%85%A5"
)
return False
self.is_annotated = "response" in col_names
return True
# 泛文本对话
[docs]class QianfanGenericText(QianfanSchema):
"""validator for generic text dataset"""
[docs] def validate(self, table: Table) -> bool:
"""
validate a table
Args:
table (Table): table need to be validated
Returns:
bool:whether table is valid
"""
if table.row_number() == 0:
log_error("no data in table")
return False
col_names = table.col_names()
if len(col_names) != 1:
log_error(f"dataset has more than 1 column: {col_names}")
return False
if table.inner_table.column(col_names[0]).null_count:
log_error("empty row in dataset")
return False
return True
# 问答集
[docs]class QianfanQuerySet(QianfanSchema):
"""validator for query set dataset"""
[docs] @_data_format_converter
def validate(self, table: Table) -> bool:
"""
validate a table
Args:
table (Table): table need to be validated
Returns:
bool:whether table is valid
"""
if table.row_number() == 0:
log_error("no data in table")
return False
col_names = table.col_names()
if "prompt" not in col_names:
log_error("no prompt column in dataset column")
return False
if table.inner_table.column("prompt").null_count:
log_error("prompt column has empty data in dataset column")
return False
return True
# 文生图
[docs]class QianfanText2Image(QianfanSchema):
"""validator for text to image dataset"""
[docs] def validate(self, table: Table) -> bool:
"""
validate a table
Args:
table (Table): table need to be validated
Returns:
bool:whether table is valid
"""
return False