# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Union
from qianfan.config import get_config
from qianfan.dataset import Dataset
from qianfan.errors import InvalidArgumentError
from qianfan.resources.console import consts as console_consts
from qianfan.trainer.actions import (
LoadDataSetAction,
TrainAction,
action_mapping,
)
from qianfan.trainer.base import (
BaseAction,
EventHandler,
Pipeline,
Trainer,
)
from qianfan.trainer.configs import (
ModelInfo,
PostPreTrainModelInfoMapping,
TrainConfig,
)
from qianfan.trainer.consts import (
TrainStatus,
)
[docs]class PostPreTrain(Trainer):
"""
Class implements the PostPreTrain training pipeline with several actions.
Use `run()` to synchronously run the training pipeline until the
model training pipeline is finished.
"""
def __init__(
self,
train_type: str,
dataset: Optional[Union[Dataset, str]] = None,
train_config: Optional[Union[TrainConfig, str]] = None,
event_handler: Optional[EventHandler] = None,
**kwargs: Any,
) -> None:
"""
Initialization function for LLM fine-tuning.
Parameters:
train_type: str
A string representing the model version type.
like 'ERNIE-Bot-turbo-0725', 'ChatGLM2-6b'
dataset: Optional[Union[Dataset, str]] = None,
A post_pretrain dataset instance and an bos path.
or an bos path for post pretrain
train_config: TrainConfig
An TrainConfig for post pretrain training parameters.
If not provided, default parameters of diverse
models will be used.
event_handler: EventHandler
An EventHandler instance for receive events during
the training process
**kwargs: Any additional keyword arguments.
for calling example:
```
ds = Dataset.load(qianfan_dataset_id="", ...)
sft_task = PostPreTrain(
train_type="ERNIE-Bot-turbo-0725",
dataset=ds,
train_config=TrainConfig(...),
event_handler=eh,
)
```
"""
# 校验train_type
if train_type is None or train_type == "":
raise InvalidArgumentError("train_type is empty")
if isinstance(train_config, str):
train_config = TrainConfig.load(train_config)
actions: List[BaseAction] = []
# 初始化load action
self.load_data_action = LoadDataSetAction(
dataset,
console_consts.DataTemplateType.GenericText,
event_handler=event_handler,
**kwargs,
)
actions.append(self.load_data_action)
# 初始化train action
self.train_action = TrainAction(
train_config=train_config,
train_type=train_type,
train_mode=console_consts.TrainMode.PostPretrain,
event_handler=event_handler,
**kwargs,
)
actions.append(self.train_action)
ppl = Pipeline(
actions=actions,
event_handler=event_handler,
)
self.ppls = [ppl]
self.result = [None]
[docs] def run(self, **kwargs: Any) -> Trainer:
"""_summary_
run a pipeline to run the fine-tune process.
Parameters:
**kwargs:
Any additional keyword arguments.
{"input": {}} could be specified if needed
Raises:
InvalidArgumentError: no pipeline bind
to run.
Returns:
Trainer:
self, for chain invocation.
"""
self.input: Any = kwargs.get("input")
if len(self.ppls) != 1:
raise InvalidArgumentError("invalid pipeline to run")
kwargs["backoff_factor"] = kwargs.get(
"backoff_factor", get_config().TRAINER_STATUS_POLLING_BACKOFF_FACTOR
)
kwargs["retry_count"] = kwargs.get(
"retry_count", get_config().TRAINER_STATUS_POLLING_RETRY_TIMES
)
self.result[0] = self.ppls[0].exec(**kwargs)
return self
@property
def status(self) -> str:
"""
PostPreTrain status getter.
Returns:
str: status for PostPreTrain, mapping from state of actions in pipeline.
"""
if len(self.ppls) != 1:
raise InvalidArgumentError("invalid pipeline to get status")
action = self.ppls[0][str(self.ppls[0]._state)]
if action is None:
return TrainStatus.Unknown
action_name = action.__class__.__name__
return action_mapping.get(action_name, {}).get(
action.state, TrainStatus.Unknown
)
[docs] def stop(self, **kwargs: Dict) -> Trainer:
"""
stop method of PostPreTrain. PostPreTrain will stop
all actions in pipeline. In fact, PostPreTrain only take one
pipeline, so it will be equal to stop first of `ppls`.
Returns:
Trainer:
self, for chain invocation.
"""
for ppl in self.ppls:
ppl.stop()
return self
[docs] def resume(self, **kwargs: Dict) -> "PostPreTrain":
"""
PostPreTrain resume method.
Returns:
PostPreTrain: _description_
"""
self.result[0] = self.ppls[0].resume(**kwargs)
return self
@property
def output(self) -> Any:
return self.result[0]
[docs] @classmethod
def train_type_list(cls) -> Dict[str, ModelInfo]:
return PostPreTrainModelInfoMapping