Source code for qianfan.trainer.base

# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import pickle
from abc import ABC, abstractmethod
from threading import Lock
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Sequence,
    TypeVar,
    Union,
    cast,
)

from qianfan.common.runnable.base import ExecuteSerializable
from qianfan.errors import InternalError, InvalidArgumentError
from qianfan.trainer.consts import ActionState
from qianfan.trainer.event import Event, EventHandler, dispatch_event
from qianfan.utils import log_debug, log_error, utils

Input = TypeVar("Input")
Output = TypeVar("Output")


[docs]class BaseAction(ExecuteSerializable[Input, Output], ABC): """ BaseAction is a reusable, atomic operation components that can be freely orchestrated for use in Pipelines. """ def __init__( self, id: Optional[str] = None, name: Optional[str] = None, event_handler: Optional[EventHandler] = None, **kwargs: Dict[str, Any], ) -> None: """ init method Parameters: id (Optional[str], optional): action id for identify action. Defaults to None. name (Optional[str], optional): action name. Defaults to None. event_handler (Optional[EventHandler], optional): event_handler implements for action state track. Defaults to None. """ self.id = id if id is not None else utils.generate_letter_num_random_id() self.name = name if name is not None else f"action_{self.id}" self.state = ActionState.Preceding self.event_dispatcher = event_handler
[docs] def dumps(self) -> Optional[bytes]: """ dumps action input bytes Returns: serialized bytes action data """ return pickle.dumps(self)
[docs] def loads(self, data: bytes) -> Any: """ loads Parameters: data (bytes): load Returns: Any: action instance """ return pickle.loads(data)
[docs] @abstractmethod def exec(self, input: Optional[Input] = None, **kwargs: Dict) -> Output: """ exec is a abstract method for execute action. Parameters: input (Optional[Input], optional): input. Defaults to None. Returns: Output: output """ ...
[docs] @abstractmethod def resume(self, **kwargs: Dict) -> Output: """ Action resume from last input, sub-class should implement this method with their own resuming logic. BaseAction don not support last input storage, because it's not different from actions in their each action state. """ ...
[docs] def stop(self, **kwargs: Dict) -> None: """ Action stop method, sub-class should implement this method with their own stop logic. """ self.action_event(ActionState.Stopped)
[docs] def action_error_event(self, e: Exception) -> None: """ dispatch action error event Parameters: e (Exception): action runtime error """ dispatch_event( self.event_dispatcher, Event( self.__class__, self.id, ActionState.Error, ( f"action_error: action_type[{self.__class__.__name__}]" f" action_id[{self.id}], msg:{str(e)}" ), {"error": str(e)}, ), )
[docs] def action_event(self, state: ActionState, msg: str = "", data: Any = None) -> None: """ dispatch action event Parameters: state (ActionState): action state msg (str, optional): action custom dfscription. Defaults to "". data (Any, optional): action custom data. Defaults to None. """ dispatch_event( self.event_dispatcher, Event( self.__class__, self.id, state, ( f"action_event: action_type[{self.__class__.__name__}]" f" action_id[{self.id}], msg:{msg}" ), data, ), )
[docs] @classmethod def action_type(cls) -> str: return "base"
[docs]def with_event(func: Callable[..., Any]) -> Callable[..., Any]: """ decorator for action state tracking with event. """ def wrapper(self: BaseAction, **kwargs: Any) -> Any: """ method wrapper """ try: log_debug(f"action[{self.__class__.__name__}][{self.id}] Preceding") self.action_event(ActionState.Preceding, "", {}) resp = func(self, **kwargs) self.action_event(ActionState.Done, "", resp) log_debug(f"action[{self.__class__.__name__}][{self.id}] Done") return resp except Exception as e: log_error(f"action[{self.__class__.__name__}][{self.id}] error {e}") self.action_error_event(e) # return {"error": e} raise e return wrapper
[docs]class Pipeline(BaseAction[Dict[str, Any], Dict[str, Any]]): """ Pipeline is a sequentially executed chain composed of multiple actions, and users can customize the action chain according to their needs. At any given moment, the Pipeline retains the id of the currently executing action, allowing users to retrieve information about the action currently in progress. By registering an EventHandler, user can listen to events generated during the Pipeline running process. """ def __init__( self, actions: Sequence[BaseAction], post_actions: Sequence[BaseAction] = [], event_handler: Optional[EventHandler] = None, **kwargs: Any, ) -> None: """ Parameters: actions Sequence[BaseAction]: The actions to be executed in the pipeline. post_actions: Sequence[BaseAction]: The actions to be executed after the pipeline is completed. event_handler: Optional[EventHandler] event_handler to receive events. kwargs (Any): Additional keyword arguments. ``` ppl = Pipeline( actions=actions, ) ``` """ super().__init__(event_handler=event_handler, **kwargs) self.actions: Dict[str, BaseAction] = {} self.seq: List[str] = [] for action in actions: if action.id in self.actions: raise ValueError(f"action id {action.id} is duplicated") self.actions[action.id] = action self.seq.append(action.id) self.post_actions = post_actions self._state: str = "" self._sync_lock = Lock() self._stop: bool = False self._last_output: Optional[Dict[str, Any]] = None
[docs] @with_event def exec( self, input: Optional[Dict[str, Any]] = None, **kwargs: Dict ) -> Dict[str, Any]: """ Parameters: input: Optional[Dict[str, Any]] input of the pipeline. kwargs: additional keyword arguments. Return: Dict[str, Any]: The output of the pipeline. """ return self.exec_from(input, 0, **kwargs)
[docs] def exec_from( self, input: Optional[Dict[str, Any]] = None, start: Optional[Union[int, str]] = 0, **kwargs: Dict, ) -> Dict[str, Any]: if isinstance(start, str): start_idx = self.seq.index(start) elif isinstance(start, int): start_idx = start else: raise InvalidArgumentError( "pipeline start must be index of sequence or key of action" ) output: Dict[str, Any] = copy.deepcopy(input) if input is not None else {} for i, k in enumerate(self.seq): if self._stop: break if i < start_idx: continue if self.event_dispatcher is not None: self.action_event( ActionState.Running, "pipeline running", {"action": k} ) self._state = k output = self.actions[k].exec(input=output, **kwargs) if output.get("error") is not None: raise InternalError(cast(str, output.get("error"))) for next in self.post_actions: next.exec(copy.deepcopy(output), **kwargs) return output
def __getitem__(self, key: str) -> Optional[BaseAction]: """ get action by key, which is the action id. Args: key (str): action id generate when action was created. Returns: Optional[BaseAction]: action with the given id if exists, otherwise None. """ return self.actions.get(key)
[docs] @with_event def resume(self, **kwargs: Dict) -> Dict[str, Any]: """ resume pipeline running from last stopped or failed action. """ self._stop = False last_output = self.actions[self._state].resume(**kwargs) if self.seq[-1] == self._state: # last node return directly return last_output idx = self.seq.index(self._state) + 1 return self.exec_from(last_output, idx, **kwargs)
[docs] def stop(self, **kwargs: Dict) -> None: """ stop pipeline running, only stop the actions not running. """ with self._sync_lock: self._stop = True action = self.actions.get(self._state) if action is None: raise InternalError("unknown action to stop") else: action.stop() return super().stop()
[docs] def register_event_handler( self, event_handler: EventHandler, action_id: Optional[str] = None ) -> None: """ Register the event handler to specific the action. Args: event_handler (EventHandler): The event handler instance. """ self.event_dispatcher = event_handler for id, action in self.actions.items(): if action_id is None and id == action_id: action.event_dispatcher == event_handler break else: action.event_dispatcher = event_handler
[docs]class Trainer(ABC): """ Base Trainer class, which focus on one step call to run the whole training process. which define the basic 3 methods to operate training. - run() run the specific training process like fine-tuning - resume() resume from the stopped, failed - stop() stop the training process """ ppls: List[Pipeline] = [] """ Pipelines for training, there may be multiple pipelines in the training process. """ result: List[Any] = [] """pipeline running results, which may be an error or an object"""
[docs] @abstractmethod def run(self, **kwargs: Dict) -> "Trainer": """ Trainer abstract method. For the diverse instance subclasses, Override this method to implement the specific training process. Returns: Trainer: Trainer instance """ ...
[docs] @abstractmethod def stop(self, **kwargs: Dict) -> "Trainer": """ Trainer abstract method. Subclasses implement it to support an more controllable usage in the concrete situations. Returns: Trainer: Trainer instance """ return self
[docs] @abstractmethod def resume(self, **kwargs: Dict) -> "Trainer": """ Counter to stop method. User can resume the training process by calling resume() method. Returns: Trainer: Trainer instance """ return self
@property def status(self) -> str: """ Trainer status。Implements different status for different process like fine-tuning, RLHF, PreTrain and so on. """ return ""
[docs] def get_evaluate_result(self) -> Any: """ Receive the evaluate result from the pipeline. [coming soon]. """ raise NotImplementedError("trainer get_evaluate_result")
[docs] def get_log(self) -> Any: """ Receive the training log during the pipeline execution. [coming soon]. """ raise NotImplementedError("trainer get_log")
[docs] def register_event_handler( self, event_handler: EventHandler, ppl_id: Optional[str] = None ) -> None: """ Register the event handler to specific the ppls. Args: event_handler (EventHandler): The event handler instance. """ for ppl in self.ppls: if ppl_id is None and ppl.id == ppl_id: ppl.register_event_handler(event_handler) break else: ppl.register_event_handler(event_handler)
@property def actions(self) -> Dict[str, BaseAction]: """ Get the available actions for trainer. Returns: List[str]: The list of action names. """ return self.ppls[0].actions @property @abstractmethod def output(self) -> Any: ...