Source code for qianfan.trainer.finetune

# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Union

from qianfan.config import get_config
from qianfan.errors import InvalidArgumentError
from qianfan.evaluation.evaluator import Evaluator
from qianfan.model.configs import DeployConfig
from qianfan.resources.console import consts as console_consts
from qianfan.trainer.actions import (
    DeployAction,
    EvaluateAction,
    LoadDataSetAction,
    ModelPublishAction,
    TrainAction,
    action_mapping,
)
from qianfan.trainer.base import (
    BaseAction,
    EventHandler,
    Pipeline,
    Trainer,
)
from qianfan.trainer.configs import (
    ModelInfo,
    ModelInfoMapping,
    TrainConfig,
)
from qianfan.trainer.consts import (
    TrainStatus,
)


[docs]class LLMFinetune(Trainer): """ Class implements the SFT training pipeline with several actions. Use `run()` to synchronously run the training pipeline until the model training is finished. """ def __init__( self, train_type: Optional[str] = None, dataset: Optional[Any] = None, train_config: Optional[Union[TrainConfig, str]] = None, deploy_config: Optional[DeployConfig] = None, event_handler: Optional[EventHandler] = None, eval_dataset: Optional[Any] = None, evaluators: Optional[List[Evaluator]] = None, dataset_bos_path: Optional[str] = None, previous_trainer: Optional[Trainer] = None, previous_task_id: Optional[str] = None, name: Optional[str] = None, **kwargs: Any, ) -> None: """ Initialization function for LLM fine-tuning. Parameters: train_type: str A string representing the model version type. like 'ERNIE-Bot-turbo-0725', 'ChatGLM2-6b' dataset: Dataset A dataset instance. train_config: TrainConfig An TrainConfig for fine-tuning training parameters. If not provided, default parameters of diverse models will be used. deploy_config: DeployConfig An DeployConfig for model service deployment parameters. Required if deployment is needed. event_handler: EventHandler An EventHandler instance for receive events during the training process base_model: An optional string representing the base model like 'ERNIE-Bot-turbo', 'ChatGLM2' which will be mapped from the model version type if not set. eval_dataset: Dataset An optional dataset instance for evaluation. evaluators: List[Evaluator] An list of evaluators for evaluation. bos_path: Optional[str]: An bos path for training, this will be ignored if dataset is provided. **kwargs: Any additional keyword arguments. for calling example: ``` sft_task = LLMFinetune( train_type="ERNIE-Bot-turbo-0725", dataset={"datasets": [{"type": 1, "id": ds_id}]}, train_config=TrainConfig(...), event_handler=eh, ) ``` """ # 设置name self.name = name if isinstance(train_config, str): train_config = TrainConfig.load(train_config) actions: List[BaseAction] = [] # 校验dataset if dataset is not None: self.load_data_action = LoadDataSetAction( dataset=dataset, dataset_template=console_consts.DataTemplateType.NonSortedConversation, event_handler=event_handler, **kwargs, ) elif dataset_bos_path: self.load_data_action = LoadDataSetAction( dataset=dataset_bos_path, event_handler=event_handler, **kwargs, ) else: raise InvalidArgumentError("either dataset or bos_path is required") actions.append(self.load_data_action) if previous_trainer: # init an increment training if hasattr(previous_trainer, "train_action"): self.train_action = TrainAction( train_config=train_config, task_id=previous_trainer.train_action.task_id, train_mode=console_consts.TrainMode.SFT, job_name=name, **kwargs, ) else: raise InvalidArgumentError( "invalid trainer input without previous train action" ) elif previous_task_id: self.train_action = TrainAction( train_config=train_config, task_id=previous_task_id, train_mode=console_consts.TrainMode.SFT, job_name=name, **kwargs, ) else: # init train action from base model self.train_action = TrainAction( train_config=train_config, train_type=train_type, train_mode=console_consts.TrainMode.SFT, event_handler=event_handler, job_name=name, **kwargs, ) actions.append(self.train_action) if not kwargs.get("model_not_publish"): self.model_publish = ModelPublishAction( event_handler=event_handler, **kwargs, ) actions.append(self.model_publish) if deploy_config is not None: self.deploy_action = DeployAction( deploy_config=deploy_config, event_handler=event_handler, ) actions.append(self.deploy_action) if eval_dataset is not None and evaluators is not None: self.eval_action = EvaluateAction( eval_dataset=eval_dataset, evaluators=evaluators, ) actions.append(self.eval_action) ppl = Pipeline( actions=actions, event_handler=event_handler, ) self.ppls = [ppl] self.result = [None]
[docs] def run(self, **kwargs: Any) -> Trainer: """_summary_ run a pipeline to run the fine-tune process. Parameters: **kwargs: Any additional keyword arguments. {"input": {}} could be specified if needed Raises: InvalidArgumentError: no pipeline bind to run. Returns: Trainer: self, for chain invocation. """ self.input: Any = kwargs.get("input") if len(self.ppls) != 1: raise InvalidArgumentError("invalid pipeline to run") kwargs["backoff_factor"] = kwargs.get( "backoff_factor", get_config().TRAINER_STATUS_POLLING_BACKOFF_FACTOR ) kwargs["retry_count"] = kwargs.get( "retry_count", get_config().TRAINER_STATUS_POLLING_RETRY_TIMES ) self.result[0] = self.ppls[0].exec(**kwargs) return self
@property def status(self) -> str: """ LLMFinetune status getter. Returns: str: status for LLMFinetune, mapping from state of actions in pipeline. """ if len(self.ppls) != 1: raise InvalidArgumentError("invalid pipeline to get status") action = self.ppls[0][str(self.ppls[0]._state)] if action is None: return TrainStatus.Unknown action_name = action.__class__.__name__ return action_mapping.get(action_name, {}).get( action.state, TrainStatus.Unknown )
[docs] def stop(self, **kwargs: Dict) -> Trainer: """ stop method of LLMFinetune. LLMFinetune will stop all actions in pipeline. In fact, LLMFinetune only take one pipeline, so it will be equal to stop first of `ppls`. Returns: Trainer: self, for chain invocation. """ for ppl in self.ppls: ppl.stop() return self
[docs] def resume(self, **kwargs: Dict) -> "LLMFinetune": """ LLMFinetune resume method. Returns: LLMFinetune: _description_ """ self.result[0] = self.ppls[0].resume(**kwargs) return self
@property def output(self) -> Any: return self.result[0]
[docs] @classmethod def train_type_list(cls) -> Dict[str, ModelInfo]: return ModelInfoMapping