# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
schema for validation
currently qianfan schema only
"""
from abc import ABC, abstractmethod
from qianfan.dataset.table import Table
from qianfan.utils import log_error
[docs]class Schema(ABC):
[docs] @abstractmethod
def validate(self, table: Table) -> bool:
"""
validate a dataset.Table object
currently check field and type only, not content in table
Args:
table (Table): table need to be validated
Returns:
bool:whether table is valid
"""
[docs]class QianfanSchema(Schema):
def __init__(self) -> None:
"""
initialize a new Schema instance
"""
# 千帆使用,用于作为返回值,表示是否是带标注的数据
self.is_annotated: bool = False
[docs] def validate(self, table: Table) -> bool:
return self.is_annotated
# 无排序对话
[docs]class QianfanNonSortedConversation(QianfanSchema):
"""validator for non-sorted, conversational dataset"""
[docs] def validate(self, table: Table) -> bool:
"""
validate a table
Args:
table (Table): table need to be validated
Returns:
bool:whether table is valid
"""
if table.get_row_count() == 0:
log_error("no data in table")
return False
col_names = table.col_names()
# 本地单轮对话对接千帆的校验规则
if "prompt" not in col_names:
log_error("no prompt column in dataset column")
return False
if table.inner_table.column("prompt").null_count:
log_error("prompt column has empty data in dataset column")
return False
if "response" in col_names:
if table.inner_table.column("response").null_count:
log_error("response column has empty data in dataset column")
return False
response_list = table.col_list("response")["response"]
for index in range(len(response_list)):
response_record = response_list[index]
if not (
isinstance(response_record, list)
and len(response_record) == 1
and isinstance(response_record[0], list)
and len(response_record[0]) == 1
and isinstance(response_record[0][0], str)
and response_record[0][0]
):
log_error(
f"response illegal in dataset row {index}. response data:"
f" {response_record}"
)
return False
self.is_annotated = "response" in col_names
return True
# 有排序对话
[docs]class QianfanSortedConversation(QianfanSchema):
"""validator for sorted, conversational dataset"""
[docs] def validate(self, table: Table) -> bool:
"""
validate a table
Args:
table (Table): table need to be validated
Returns:
bool:whether table is valid
"""
if table.get_row_count() == 0:
log_error("no data in table")
return False
col_names = table.col_names()
# 本地单轮对话带排序对接千帆的校验规则
if "prompt" not in col_names:
log_error("no prompt column in dataset column")
return False
if table.inner_table.column("prompt").null_count:
log_error("prompt column has empty data in dataset column")
return False
if "response" in col_names:
if table.inner_table.column("response").null_count:
log_error("response column has empty data in dataset column")
return False
response_list = table.col_list("response")["response"]
for index in range(len(response_list)):
response_record = response_list[index]
if not (isinstance(response_record, list) and len(response_record) > 0):
log_error(
f"response records illegal in dataset row {index}. response"
f" data: {response_record}"
)
return False
for single_response_record in response_record:
if not (
isinstance(single_response_record, list)
and len(single_response_record) == 1
and isinstance(single_response_record[0], str)
and single_response_record[0]
):
log_error(
f"response illegal in dataset row {index}. response data:"
f" {response_record}"
)
return False
self.is_annotated = "response" in col_names
return True
# 泛文本对话
[docs]class QianfanGenericText(QianfanSchema):
"""validator for generic text dataset"""
[docs] def validate(self, table: Table) -> bool:
"""
validate a table
Args:
table (Table): table need to be validated
Returns:
bool:whether table is valid
"""
if table.get_row_count() == 0:
log_error("no data in table")
return False
col_names = table.col_names()
if len(col_names) != 1:
log_error(f"dataset has more than 1 column: {col_names}")
return False
if table.inner_table.column(col_names[0]).null_count:
log_error("empty row in dataset")
return False
return True
# 问答集
[docs]class QianfanQuerySet(QianfanSchema):
"""validator for query set dataset"""
[docs] def validate(self, table: Table) -> bool:
"""
validate a table
Args:
table (Table): table need to be validated
Returns:
bool:whether table is valid
"""
if table.get_row_count() == 0:
log_error("no data in table")
return False
col_names = table.col_names()
if "prompt" not in col_names:
log_error("no prompt column in dataset column")
return False
if table.inner_table.column("prompt").null_count:
log_error("prompt column has empty data in dataset column")
return False
return True
# 文生图
[docs]class QianfanText2Image(QianfanSchema):
"""validator for text to image dataset"""
[docs] def validate(self, table: Table) -> bool:
"""
validate a table
Args:
table (Table): table need to be validated
Returns:
bool:whether table is valid
"""
return False