Source code for qianfan.dataset.table

# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
wrapper for pyarrow.Table
"""

from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import pyarrow
from pyarrow import Table as PyarrowTable
from pydantic import BaseModel
from typing_extensions import Self

from qianfan.dataset.process_interface import (
    Appendable,
    Listable,
    Processable,
)


class _PyarrowRowManipulator(BaseModel, Appendable, Listable, Processable):
    """handler for processing of pyarrow table row"""

    class Config:
        arbitrary_types_allowed = True

    table: PyarrowTable

    def append(self, elem: Union[List[Dict], Tuple[Dict], Dict]) -> Self:
        """
        append an element to pyarrow table

        Args:
            elem (Union[List[Dict], Tuple[Dict], Dict]): elements added to pyarrow table
        Returns:
            Self: a new pyarrow table
        """

        if isinstance(elem, (list, tuple)):
            if not elem:
                raise ValueError("element is empty")
            elif not isinstance(elem[0], dict):
                raise ValueError(
                    "element in sequence-like container cannot be instance of"
                    f" {type(elem[0])}"
                )
            else:
                tables = []
                for e in elem:
                    tables.append(
                        pyarrow.Table.from_pydict(
                            mapping={k: [v] for k, v in e.items()}
                        )
                    )
                return pyarrow.concat_tables([self.table, *tables], promote=True)
        elif isinstance(elem, dict):
            new_table = pyarrow.Table.from_pydict(
                mapping={k: [v] for k, v in elem.items()}
            )
            return pyarrow.concat_tables([self.table, new_table], promote=True)
        else:
            raise ValueError(f"element cannot be instance of {type(elem)}")
        return self.table

    def list(
        self, by: Optional[Union[slice, int, str, Sequence[int], Sequence[str]]] = None
    ) -> Any:
        """
        get element(s) from pyarrow table

        Args:
            by (Optional[Union[slice, int, Sequence[int]]]):
                index or indices for elements, default to None, in which case
                return a python list of pyarrow table row
        Returns:
            Any: pyarrow table row list
        """

        if isinstance(by, str) or (
            isinstance(by, (list, tuple)) and isinstance(by[0], str)
        ):
            raise ValueError("cannot get row from table by str")
        if by is None:
            return self.table.to_pylist()

        if isinstance(by, int):
            return self.table.take([by]).to_pylist()
        elif isinstance(by, (list, tuple)):
            return self.table.take(list(by)).to_pylist()
        elif isinstance(by, slice):
            return self.table.slice(
                offset=by.start, length=by.stop - by.start + 1
            ).to_pylist()
        else:
            raise ValueError(f"unsupported key type {type(by)} when get row from table")

    def map(self, op: Callable[[Any], Any]) -> Self:
        """
        map on pyarrow table's row

        Args:
            op (Callable[[Any], Any]): handler used to map

        Returns:
            Self: a new pyarrow table
        """

        # 构建出的新 table 会按照首行的 key 作为 columns
        new_table: List[Dict[str, Any]] = []
        for row_index in range(self.table.num_rows):
            origin_data = self.table.take([row_index]).to_pylist()[0]
            input_dict = {key: val for key, val in origin_data.items()}
            returned_data = op(input_dict)
            if not returned_data:
                raise ValueError("cant make data empty")
            if not isinstance(returned_data, dict):
                raise ValueError("returned value isn't dict")
            if input_dict.keys() != returned_data.keys():
                raise ValueError("cant modify column name in map")

            new_table.append(returned_data)

        return pyarrow.Table.from_pylist(new_table)

    def filter(self, op: Callable[[Any], bool]) -> Self:
        """
        filter on pyarrow table's row

        Args:
            op (Callable[[Any], bool]): handler used to filter

        Returns:
            Self: a new pyarrow table
        """

        selection_masks: List[bool] = []
        for row_index in range(self.table.num_rows):
            origin_data = self.table.take([row_index]).to_pylist()[0]
            input_dict = {key: val for key, val in origin_data.items()}
            flag = op(input_dict)
            if flag is None:
                raise ValueError("cant return None")
            if not isinstance(flag, bool):
                raise ValueError("returned value isn't bool")

            selection_masks.append(flag)

        return self.table.filter(mask=selection_masks)

    def delete(self, index: Union[int, str]) -> Self:
        """
        delete an element from pyarrow table

        Args:
            index (Union[int, str]): element index to delete

        Returns:
            Self: a new pyarrow table
        """

        if isinstance(index, str):
            raise ValueError("cannot delete row by str")
        table_length = self.table.num_rows
        if index < 0 or index >= table_length:
            raise OverflowError(f"index overflow, table length is {table_length}")
        if index == 0:
            return self.table.slice(1)
        elif index == table_length - 1:
            return self.table.slice(0, table_length - 1)
        return pyarrow.concat_tables(
            [self.table.slice(0, index), self.table.slice(index + 1)]
        )


class _PyarrowColumnManipulator(BaseModel, Appendable, Listable, Processable):
    """handler for processing of pyarrow table column"""

    class Config:
        arbitrary_types_allowed = True

    table: PyarrowTable

    def append(self, elem: Dict[str, List]) -> Self:
        """
        append a row to pyarrow table

        Args:
            elem (Dict[str, List]): dict containing element added to pyarrow table
                must has column name "name" and column data list "data"
        Returns:
            Self: a new pyarrow table
        """

        if not isinstance(elem, dict):
            raise ValueError(f"element appended must be dict, not {type(elem)}")
        if "name" not in elem:
            raise ValueError("no name has been provided")
        if "data" not in elem:
            raise ValueError("no data has been provided")
        if not isinstance(elem["name"], str):
            raise TypeError(f"name isn't str, rather than {type(elem['name'])}")
        if not isinstance(elem["data"], list):
            raise TypeError(f"data isn't list, rather than {type(elem['data'])}")
        if not elem["data"]:
            raise ValueError("data can't be empty")
        if len(elem["data"]) != self.table.num_rows:
            raise ValueError(
                f"the length of data need to be {self.table.num_rows}, rather than"
                f" {len(elem['data'])}"
            )
        return self.table.append_column(elem["name"], [elem["data"]])

    def list(
        self, by: Optional[Union[slice, int, str, Sequence[int], Sequence[str]]] = None
    ) -> Any:
        """
        get column(s) from pyarrow table

        Args:
            by (Optional[Union[int, str, Sequence[int], Sequence[str]]]):
                index or indices for columns, default to None, in which case
                return a python list of pyarrow table column
        Returns:
            Any: pyarrow table column list
        """

        if by is None:
            return self.table.to_pydict()

        if isinstance(by, slice):
            raise ValueError("cannot get column by slice")
        if isinstance(by, (int, str)):
            indices: Any = [by]
        else:
            indices = by
        if isinstance(indices[0], str) and not set(indices).issubset(
            set(self.table.column_names)
        ):
            raise ValueError("contain not existed column name")
        return self.table.select(list(indices)).to_pydict()

    def map(self, op: Callable[[Any], Any]) -> Self:
        """
        map on pyarrow table's column

        Args:
            op (Callable[[Any], Any]): handler used to map

        Returns:
            Self: a new pyarrow table
        """

        new_columns: Dict[str, List[Any]] = {}
        for i in range(self.table.num_columns):
            column = self.table.select([i]).to_pydict()
            ret_column = op(column)
            new_columns.update(ret_column)

        return pyarrow.Table.from_pydict(new_columns)

    def filter(self, op: Callable[[Any], bool]) -> Self:
        """
        filter on pyarrow table's column

        Args:
            op (Callable[[Any], bool]): handler used to filter

        Returns:
            Self: a new pyarrow table
        """

        dropped_column_name = []
        for i in range(self.table.num_columns):
            column = self.table.select([i]).to_pydict()
            if not op(column):
                dropped_column_name += list(column.keys())

        return self.table.drop_columns(dropped_column_name)

    def delete(self, index: Union[int, str]) -> Self:
        """
        delete an column from pyarrow table

        Args:
            index (str): column name to delete

        Returns:
            Self: a new pyarrow table
        """

        if isinstance(index, int):
            raise ValueError("cannot delete column by int")
        return self.table.drop_columns(index)


[docs]class Table(BaseModel, Appendable, Listable, Processable): """ dataset representation on memory inherited from pyarrow.Table,implementing interface in process_interface.py """
[docs] class Config: arbitrary_types_allowed = True
inner_table: PyarrowTable def _row_op(self) -> _PyarrowRowManipulator: return _PyarrowRowManipulator(table=self.inner_table) def _col_op(self) -> _PyarrowColumnManipulator: return _PyarrowColumnManipulator(table=self.inner_table) # 直接调用 Table 对象的接口方法都默认是在行上做处理
[docs] def map(self, op: Callable[[Any], Any]) -> Self: """ map on pyarrow table's row Args: op (Callable[[Any], Any]): handler used to map Returns: Self: Table itself """ manipulator = self._row_op() self.inner_table = manipulator.map(op) # noqa return self
[docs] def filter(self, op: Callable[[Any], bool]) -> Self: """ filter on pyarrow table's row Args: op (Callable[[Any], bool]): handler used to filter Returns: Self: Table itself """ manipulator = self._row_op() self.inner_table = manipulator.filter(op) return self
[docs] def delete(self, index: Union[int, str]) -> Self: """ delete an element from pyarrow table Args: index (Union[int, str]): element index to delete Returns: Self: Table itself """ manipulator = self._row_op() self.inner_table = manipulator.delete(index) return self
[docs] def append(self, elem: Any) -> Self: """ append an element to pyarrow table Args: elem (Union[List[Dict], Tuple[Dict], Dict]): elements added to pyarrow table Returns: Self: Table itself """ manipulator = self._row_op() self.inner_table = manipulator.append(elem) return self
[docs] def list( self, by: Optional[Union[slice, int, str, Sequence[int], Sequence[str]]] = None ) -> Any: """ get element(s) from pyarrow table Args: by (Optional[Union[slice, int, Sequence[int]]]): index or indices for elements, default to None, in which case return a python list of pyarrow table row Returns: Any: pyarrow table row list """ manipulator = self._row_op() return manipulator.list(by)
[docs] def col_map(self, op: Callable[[Any], Any]) -> Self: """ map on pyarrow table's column Args: op (Callable[[Any], Any]): handler used to map Returns: Self: Table itself """ manipulator = self._col_op() self.inner_table = manipulator.map(op) # noqa return self
[docs] def col_filter(self, op: Callable[[Any], bool]) -> Self: """ filter on pyarrow table's column Args: op (Callable[[Any], bool]): handler used to filter Returns: Self: Table itself """ manipulator = self._col_op() self.inner_table = manipulator.filter(op) return self
[docs] def col_delete(self, index: Union[int, str]) -> Self: """ delete an column from pyarrow table Args: index (str): column name to delete Returns: Self: Table itself """ manipulator = self._col_op() self.inner_table = manipulator.delete(index) return self
[docs] def col_append(self, elem: Any) -> Self: """ append a row to pyarrow table Args: elem (Dict[str, List]): dict containing element added to pyarrow table must has column name "name" and column data list "data" Returns: Self: Table itself """ manipulator = self._col_op() self.inner_table = manipulator.append(elem) return self
[docs] def col_list( self, by: Optional[Union[slice, int, str, Sequence[int], Sequence[str]]] = None ) -> Any: """ get column(s) from pyarrow table Args: by (Optional[Union[int, str, Sequence[int], Sequence[str]]]): index or indices for columns, default to None, in which case return a python list of pyarrow table column Returns: Any: pyarrow table column list """ manipulator = self._col_op() return manipulator.list(by)
[docs] def col_names(self) -> List[str]: """ get column name list Returns: List[str]: column name list """ return self.inner_table.column_names
# 重写 get 和 del 的魔法方法 def __getitem__(self, key: Any) -> Any: if isinstance(key, str) or ( isinstance(key, Sequence) and isinstance(key[0], str) ): return self.col_list(key) return self.list(key) def __delitem__(self, key: Any) -> None: if isinstance(key, str): self.col_delete(key) elif isinstance(key, int): self.delete(key) else: raise ValueError(f"Unsupported key type {type(key)}")
[docs] def get_row_count(self) -> int: """ get pyarrow table row count。 Returns: int: row count。 """ return self.inner_table.num_rows
[docs] def get_column_count(self) -> int: """ get pyarrow table column count。 Returns: int: column count。 """ return self.inner_table.num_columns
[docs] def to_pylist(self) -> List: """ convert a pyarrow table to list Returns: List: a list """ return self.inner_table.to_pylist()
[docs] def to_pydict(self) -> Dict: """ convert a pyarrow table to dict Returns: Dict: a dict """ return self.inner_table.to_pydict()