Source code for qianfan.trainer.actions

# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Union, cast

from qianfan import resources as api
from qianfan.config import get_config
from qianfan.dataset import BosDataSource, Dataset, QianfanDataSource
from qianfan.errors import InternalError, InvalidArgumentError
from qianfan.evaluation import EvaluationManager
from qianfan.evaluation.evaluator import Evaluator, LocalEvaluator, QianfanEvaluator
from qianfan.model import Model, Service
from qianfan.model.configs import DeployConfig
from qianfan.resources.console import consts as console_consts
from qianfan.trainer.base import (
    ActionState,
    BaseAction,
    with_event,
)
from qianfan.trainer.configs import (
    DefaultPostPretrainTrainConfigMapping,
    DefaultTrainConfigMapping,
    PeftType,
    TrainConfig,
    TrainLimit,
    get_model_info,
)
from qianfan.trainer.consts import ServiceStatus, TrainStatus
from qianfan.utils import (
    bos_uploader,
    log_debug,
    log_error,
    log_info,
    log_warn,
    utils,
)
from qianfan.utils.bos_uploader import is_valid_bos_path
from qianfan.utils.utils import first_lower_case, snake_to_camel


[docs]class LoadDataSetAction(BaseAction[Dict[str, Any], Dict[str, Any]]): """LoadDataSetAction Action for dataset's loading, invokes the dataset's save method to guarantee the dataset is loaded in Qianfan platform. Sample: ``` load_action = LoadDataSetAction(dataset=Dataset(id=1)) load_action.exec() ``` input: none output: ``` {"datasets" : [{"id": 1, "name": "test_dataset"}]} ``` """ from qianfan.dataset.dataset import Dataset dataset: Optional[Dataset] = None bos_path: Optional[str] = None def __init__( self, dataset: Optional[Union[Dataset, str]] = None, dataset_template: Optional[console_consts.DataTemplateType] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) if dataset is None: raise InvalidArgumentError("dataset must be set") if isinstance(dataset, str): if not is_valid_bos_path(dataset): raise InvalidArgumentError(f"invalid bos_path {dataset}") self.bos_path = dataset elif isinstance(dataset.inner_data_source_cache, QianfanDataSource): qf_data_src = cast(QianfanDataSource, dataset.inner_data_source_cache) if ( dataset_template is not None and qf_data_src.template_type != dataset_template ): raise InvalidArgumentError( f"dataset must be `{dataset_template}` template." ) self.dataset = dataset elif isinstance(dataset.inner_data_source_cache, BosDataSource): self.dataset = dataset else: raise InvalidArgumentError( "dataset must be either implemented with QianfanDataSource or" " BosDataSource or a bos path" ) @with_event def exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: return self._exec(input, **kwargs) def _exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: """ Load dataset implementation, may called by exec and resume. """ if self.bos_path is not None: if not self.bos_path.endswith("/"): bos_path = f'{Path(f"/{self.bos_path}").parent}' log_warn( f"input bos_path {self.bos_path} is a file, auto_convert to dir:" f" {bos_path}" ) else: bos_path = self.bos_path return { "datasets": { "sourceType": ( console_consts.TrainDatasetSourceType.PrivateBos.value ), "versions": [{"versionBosUri": bos_path}], } } from qianfan.dataset.data_source import BosDataSource, QianfanDataSource if self.dataset is None: raise InvalidArgumentError("dataset or bos_path must be set") if self.dataset.inner_data_source_cache is None: raise InvalidArgumentError("invalid dataset") if isinstance(self.dataset.inner_data_source_cache, QianfanDataSource): log_debug("[load_dataset_action] prepare train-set") qf_data_src = cast(QianfanDataSource, self.dataset.inner_data_source_cache) is_released = qf_data_src.release_dataset(**kwargs) if not is_released: log_error("[load_dataset_action] dataset not released") raise InvalidArgumentError("dataset must be released") log_debug("[load_dataset_action] dataset loaded successfully") self.qf_dataset_id = qf_data_src.id return { "datasets": { "sourceType": console_consts.TrainDatasetSourceType.Platform.value, "versions": [ { "versionId": qf_data_src.id, } ], } } elif isinstance(self.dataset.inner_data_source_cache, BosDataSource): log_debug("[load_dataset_action] prepare train-set in BOS") bos_data_src = cast(BosDataSource, self.dataset.inner_data_source_cache) return { "datasets": { "sourceType": ( console_consts.TrainDatasetSourceType.PrivateBos.value ), "versions": [ { "versionBosUri": bos_uploader.generate_bos_file_parent_path( bos_data_src.bucket, bos_data_src.bos_file_path ) } ], } } else: raise InvalidArgumentError("dataset must be set") @with_event def resume(self, **kwargs: Dict) -> Dict[str, Any]: """ resume method for load dataset action. Returns: Dict[str, Any]: datasets metainfo including dataset_id and dataset_type. """ if self.qf_dataset_id: log_debug("[load_dataset_action] dataset loading already done") return { "datasets": [ { "id": self.qf_dataset_id, "type": console_consts.TrainDatasetType.Platform.value, } ] } log_debug("[load_dataset_action] dataset loading resumed") return self._exec(**kwargs)
[docs]class TrainAction( BaseAction[Dict[str, Any], Dict[str, Any]], ): """ Class for Train Action, Synchronous invocation of the training API, taking a dataset metadata dict as input and producing a model metadata as output. Concretely, `exec` is called for running. Note: this action is not involved with model publishing, please use use `ModelPublishAction` for publishing model. Sample: Input: ``` {'datasets': {"sourceType": ( console_consts.TrainDatasetSourceType.PrivateBos.value ), "versions": [ { "versionBosUri": bos_uploader.generate_bos_file_parent_path( bos_data_src.bucket, bos_data_src.bos_file_path ) } ]} ``` Output: ``` {'task_id': "task-ddd", 'job_id': "job-xxxx"} Sample code: ``` """ task_id: Optional[str] = None """train task id""" job_id: Optional[str] = None """train job id""" train_type: Optional[str] = "" """train_type""" is_incr: bool = False """if it's incremental train or not""" train_config: Optional[TrainConfig] = None """train config""" train_mode: console_consts.TrainMode """train mode""" job_name: str = "" """train task name""" task_description: Optional[str] = None """train task description""" job_description: Optional[str] = None """train job description""" _input: Optional[Dict[str, Any]] = None """train input""" result: Optional[Dict[str, Any]] = None """"train result""" def __init__( self, train_mode: console_consts.TrainMode, train_type: Optional[str] = None, train_config: Optional[TrainConfig] = None, task_id: Optional[str] = None, job_id: Optional[str] = None, peft_type: PeftType = PeftType.ALL, job_name: Optional[str] = None, task_description: Optional[str] = None, job_description: Optional[str] = None, **kwargs: Any, ) -> None: """ Parameters: train_mode (Optional[console_consts.TrainMode], optional): train mode, e.g. `SFT`, `PostPretrain`. Defaults to None. train_type (Optional[str], optional): train_type, must be specified when it's not increment training like 'ERNIE-Bot-turbo-0725' train_config (Optional[TrainConfig], optional): train_config, e.g. `epoch=10, batch_size=32`. base_model (Optional[str], optional): base_mode, like 'ERNIE-Bot-turbo'. Defaults to None. task_id (Optional[int], optional): used in incr train, model train task_id. Defaults to None. job_id (Optional[int], optional): used in incr train, mod train job_id. Defaults to None. job_name (Optional[str], optional): train task name. Defaults to None. task_description (Optional[str], optional): train task description. Defaults to None. job_description (Optional[str], optional): train job description. Defaults to None. """ super().__init__(**kwargs) self.task_id = task_id self.job_id = job_id self.train_mode = train_mode if self.task_id is not None: # if incremental train pre_task_detail = api.FineTune.V2.task_detail(task_id=self.task_id) # 获取增量任务的训练model if pre_task_detail.get("result") is not None: self.train_type = pre_task_detail["result"]["model"] self.train_mode = train_mode self.is_incr = True else: if train_type is None: raise InvalidArgumentError("train_type must be specified") # 从基础模型开始训练 self.train_type = train_type model_info = get_model_info(train_mode, self.train_type) if model_info is None: log_warn(f"unknown train model type: {self.train_type} is not found") assert self.train_type is not None if train_config is None: train_config = self.get_default_train_config( self.train_type, self.train_mode, peft_type ) self.train_config = train_config self.validateTrainConfig(strict=kwargs.get("validate_strict", True)) self.job_name = self._generate_job_name(job_name, self.train_type) self.task_description = task_description self.job_description = job_description def _generate_job_name( self, job_name: Optional[str], train_type: Optional[str] ) -> str: if job_name is not None: return job_name model_info = ( get_model_info(self.train_mode, train_type) if train_type is not None else None ) return ( f"job_{utils.generate_letter_num_random_id()}" if model_info is None else f"{model_info.short_name}_{utils.generate_letter_num_random_id(5)}" )
[docs] def validateTrainConfig(self, strict: bool = True) -> None: """ validate train_config with ModelInfo Limits Raises: InvalidArgumentError: _description_ """ if self.train_config is None: raise InvalidArgumentError("none train_config") else: assert self.train_type train_type_model_info = get_model_info(self.train_mode, self.train_type) if train_type_model_info is None: return if ( self.train_config.peft_type not in train_type_model_info.support_peft_types ): log_warn( f"[train_action] train_type {self.train_type}, peft_type" f" {self.train_config.peft_type} not found, it may be not supported" ) if strict: raise InvalidArgumentError( f"[train_action] train_type {self.train_type}, peft_type" f" {self.train_config.peft_type} not found, it may be not" " supported" ) else: assert train_type_model_info res = False if ( train_type_model_info.specific_peft_types_params_limit is not None and self.train_config.peft_type in train_type_model_info.specific_peft_types_params_limit ): res = self._validate_train_config( train_type_model_info.specific_peft_types_params_limit[ self.train_config.peft_type ] | train_type_model_info.common_params_limit, ) else: res = self._validate_train_config( train_type_model_info.common_params_limit ) if not res and strict: raise InvalidArgumentError( "invalid train_config, please check the config" )
def _validate_train_config(self, train_limit: TrainLimit) -> bool: """ validate train_config with a specific train_limit Args: train_limit (TrainLimit): _description_ Raises: InvalidArgumentError: _description_ """ if self.train_config is None: raise InvalidArgumentError("validate train_config is none") return self.train_config.validate_config(train_limit) @with_event def exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: """ exec method for train action Args: input (Dict[str, Any], optional): input with dataset meta: for example: Input: ``` {'datasets':[{'type': 1, 'id': 111}]} ``` Raises: InvalidArgumentError: invalid dataset input Returns: Dict[str, Any]: train task output with task_id and job_id for example: Output: ``` {'task_id': 47923, 'job_id': 33512} ``` """ # for resume if self._input is None: self._input = input return self._exec(self._input, **kwargs) def _exec(self, input: Dict[str, Any] = {}, **kwargs: Any) -> Dict[str, Any]: # 校验数据集 ds_config = input.get("datasets") if ds_config is None: raise InvalidArgumentError("train set must be set") assert isinstance(ds_config, dict) assert self.train_config ds_config["splitRatio"] = self.train_config.trainset_rate if self.job_id is None: # request for create model train task assert self.train_type is not None resp = api.FineTune.V2.create_job( name=self.job_name, description=self.task_description, model=self.train_type, train_mode=self.train_mode, **kwargs, ) self.job_id = str(resp["result"]["jobId"]) log_debug( f"[train_action] create {self.train_mode} train job: {self.job_id}" ) assert self.train_config is not None hyper_params_dict = { **self.train_config.dict(exclude={"peft_type", "trainset_rate", "extras"}), **self.train_config.extras, } hyper_params_dict = { first_lower_case(snake_to_camel(key)): value for key, value in hyper_params_dict.items() if value is not None } ds_config = input["datasets"] log_debug(f"train with ds_config: { ds_config}") log_debug(f"train with hyper_params: { hyper_params_dict}") if self.is_incr: # 增量训练 kwargs["incrementTaskId"] = self.task_id log_info(f"train with incrementTaskId: { self.task_id}") assert self.train_config.peft_type is not None create_task_resp = api.FineTune.V2.create_task( job_id=self.job_id, params_scale=self.train_config.peft_type, hyper_params=hyper_params_dict, dataset_config=ds_config, **kwargs, ) self.task_id = str(create_task_resp["result"]["taskId"]) log_debug(f"[train_action] create {self.train_mode} train task: {self.task_id}") # 获取job状态,是否训练完成 self._wait_model_trained(**kwargs) self.result = {**input, "task_id": self.task_id, "job_id": self.job_id} assert self.result is not None return self.result def _wait_model_trained(self, **kwargs: Dict) -> None: if self.task_id is None: raise InvalidArgumentError("task_id must not be None") while True: job_status_resp = api.FineTune.V2.task_detail( task_id=self.task_id, **kwargs, ) job_status = job_status_resp["result"]["runStatus"] job_progress = int(job_status_resp["result"]["runProgress"][:-1]) log_info( "[train_action] fine-tune running..." f" job_name:{self.job_name} current status: {job_status}," f" {job_progress}% check train task log in" f" https://console.bce.baidu.com/qianfan/train/sft/{self.job_id}/{self.task_id}/detail/traininglog" ) if job_progress >= 50: log_info(f" check vdl report in {job_status_resp['result']['vdlLink']}") self.action_event(ActionState.Running, "train running", job_status_resp) if job_status == console_consts.TrainStatus.Finish: break elif job_status in [ console_consts.TrainStatus.Fail, console_consts.TrainStatus.Stop, ]: log_error( "[train_action] fine-tune job" f" {self.job_id}/{self.task_id} has ended," f" {job_status_resp}" ) raise InternalError( f"fine-tune job {self.job_id}/{self.task_id} has ended with" f" status: {job_status}" ) else: time.sleep(get_config().TRAIN_STATUS_POLLING_INTERVAL) log_info( "[train_action] fine-tune job has ended:" f" {self.job_id}/{self.task_id} with status: {job_status}" ) @with_event def resume(self, **kwargs: Dict) -> Dict[str, Any]: """ resume method for train action Parameters: **kwargs (Dict): input args for action resume """ if self.result is not None: log_warn("[train_action] already done") return self.result self.action_event(ActionState.Running, "train resume") if self.task_id is not None and self.job_id is not None: log_info( f"[train_action] resume from created job {self.task_id}/{self.job_id}" ) self._wait_model_trained(**kwargs) self.result = {"task_id": self.task_id, "job_id": self.job_id} return self.result else: if self._input is None: self._input = {} return self._exec(self._input, **kwargs)
[docs] def stop(self, **kwargs: Dict) -> None: """ stop method for train action Parameters: **kwargs (Dict): input args for action stop """ if self.task_id is None or self.job_id is None: log_warn("[train_action] task_id or job_id not set, training not started") return api.FineTune.stop_job(self.task_id, self.job_id) log_debug(f"train job {self.task_id}/{self.job_id} stopped")
[docs] def get_default_train_config( self, model_type: str, train_mode: console_consts.TrainMode, peft_type: PeftType ) -> TrainConfig: if train_mode == console_consts.TrainMode.PostPretrain: model_info = DefaultPostPretrainTrainConfigMapping.get( model_type, # DefaultTrainConfigMapping[get_config().DEFAULT_FINE_TUNE_TRAIN_TYPE], ) else: model_info = DefaultTrainConfigMapping.get( model_type, # DefaultTrainConfigMapping[get_config().DEFAULT_FINE_TUNE_TRAIN_TYPE], ) if model_info is None: raise InvalidArgumentError( f"can not find default config for {model_type} in {peft_type}" ) train_config = model_info[peft_type] train_config.peft_type = peft_type return train_config
[docs]class ModelPublishAction(BaseAction[Dict[str, Any], Dict[str, Any]]): """ Class for Model publish action, Commonly used after `TrainAction`. Sample: Input: ``` {'task_id': 47923, 'job_id': 33512} ``` Output: ``` {'task_id': 47923, 'job_id': 33512, 'model_id': "xxx", 'model_version_id': "aaa"} ``` """ task_id: Optional[str] = None """task id""" job_id: Optional[str] = None """job id""" result: Optional[Dict[str, Any]] = None """result of model publish action""" model: Optional[Model] = None """model object""" @with_event def exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: if self.task_id == "" or self.job_id == "": raise InvalidArgumentError("task_id or job_id must be set") self.task_id = input.get("task_id", "") self.job_id = input.get("job_id", "") self.model = Model(task_id=self.task_id, job_id=self.job_id) return self._exec(input, **kwargs) def _exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: if self.model is None: raise InvalidArgumentError("model must be set when in model publish._exec") log_debug( f"[model_publish_action] start model publish task:, {self.task_id}," f" {self.job_id}" ) try: self.action_event( ActionState.Running, "model publish", { "task_id": self.task_id, "job_id": self.job_id, }, ) self.model.publish(name=input.get("name", ""), **kwargs) log_debug( f"[model publish] model: {self.task_id}_{self.job_id} has been" " published." ) self.result = { "task_id": self.task_id, "job_id": self.job_id, "model_id": self.model.id, "model_version_id": self.model.version_id, "model": self.model, } return self.result except Exception as e: log_error(f"[model_publish_action] model publish error: {e}") raise e @with_event def resume(self, **kwargs: Dict) -> Dict[str, Any]: # raise NotImplementedError("ModelPublishAction.resume() is not implemented") if self.result is not None: return self.result if self.model is not None: return self._exec() if self.task_id is None or self.job_id is None: raise InvalidArgumentError("task_id or job_id not set, resume failed") self.model = Model(task_id=self.task_id, job_id=self.job_id) return self._exec(**kwargs)
[docs]class DeployAction(BaseAction[Dict[str, Any], Dict[str, Any]]): """DeployAction Action for model service deployment. A TrainConfig must be supplied when instance initialized. Sample: ``` deploy_config = DeployConfig(replicas=1, pool_type=1) deploy_action = DeployAction(deploy_config=deploy_config) output = deploy_action.exec(input) ``` input: {'task_id': 47923, 'job_id': 33512, 'model_id': "xx", 'model_version_id': "xxx"} output: ``` {'task_id': 47923, 'job_id': 33512, 'model_id': "xx", 'model_version_id': "xxx", 'service_id': 164, 'service_endpoint': 'xbiimimv_xxx'} ``` """ deploy_config: Optional[DeployConfig] """deploy config include replicas and so on""" model_id: Optional[int] """model id""" model_id_str: Optional[str] """model str id""" model_version_id: Optional[int] """model version id""" model_version_id_str: Optional[str] """model version str id """ _input: Optional[Dict[str, Any]] = None """input of action""" result: Optional[Dict[str, Any]] = None """result of action""" def __init__(self, deploy_config: Optional[DeployConfig] = None, **kwargs: Any): """ Parameters: deploy_config (Optional[DeployConfig], optional): deploy config include replicas and so on. Defaults to None. """ super().__init__(**kwargs) self.deploy_config = deploy_config @with_event def exec(self, input: Dict[str, Any] = {}, **kwargs: Any) -> Dict[str, Any]: # for resume if self._input is None: self._input = input if self.deploy_config is None: raise InvalidArgumentError("deploy_config must be set") if input.get("model") is None: self.model_id = input.get("model_id") self.model_version_id = input.get("model_version_id") # TODO 迁移成str id if self.model_id is None or self.model_version_id is None: raise InvalidArgumentError("model_id or model_version_id must be set") self.model = Model(self.model_id, self.model_version_id, auto_complete=True) self.model.auto_complete_info() else: self.model = cast(Model, input.get("model")) if self.model is None: raise InvalidArgumentError( "must input with model or model id and version id" ) self.model.auto_complete_info() self.model_id = self.model.old_id self.model_version_id = self.model.old_version_id # 自动补全 return self._exec(**kwargs) def _exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: if self.deploy_config is None: raise InvalidArgumentError("deploy_config must be set in deploy._exec") log_debug( f"[deploy_action] try deploy model {self.model.id}_{self.model.version_id}" ) self.action_event( ActionState.Running, "ready to deploy", { "model": self.model, }, ) # deploy model self.model.deploy(self.deploy_config, **kwargs) if self.model.service is not None: log_debug( "[deploy_action] model" f" {self.model_id}_{self.model_version_id} deployed successfully with" " service:" f" {self.model.service.id} endpoint:{self.model.service.endpoint}" ) return { **input, "service_id": self.model.service.id, "service_endpoint": self.model.service.endpoint, "service": self.model.service, "model": self.model, } else: raise InternalError("model.service is not available") @with_event def resume(self, **kwargs: Dict) -> Dict[str, Any]: """ resume method for deploy action Parameters: **kwargs (Dict): input args for action resume """ if self.model_id is not None and self.model_version_id is not None: self.model = Model(self.model_id_str, self.model_version_id_str) self.model.auto_complete_info() elif self.model is None: raise InvalidArgumentError( "either (model_id and version_id) or model must be set" ) return self._exec()
[docs]class EvaluateAction(BaseAction[Dict[str, Any], Dict[str, Any]]): """EvaluateAction Action for evaluate models or services. Sample: input: ``` {'model_id': 47923, 'model_version_id': 33512} ``` output: ``` {'eval_res': EvaluationResult ...} ``` """ eval_manager: Optional[EvaluationManager] = None """evaluation manager for evaluate models or services.""" eval_dataset: Optional[Dataset] = None _input: Optional[Dict[str, Any]] = None """input of action""" result: Optional[Dict[str, Any]] = None """result of action""" def __init__( self, eval_dataset: Dataset, evaluators: List[Evaluator], **kwargs: Any ): """ init method for evaluate action Parameters: eval_dataset Dataset: dataset for evaluation, use Dataset.load() to create. evaluators List[Evaluator]: evaluators for evaluation, include local and qianfan remote evaluators. Specifically, qianfan_evaluators are only available for Model. """ super().__init__(**kwargs) self.eval_dataset = eval_dataset local_evaluators = [ eval for eval in evaluators if isinstance(eval, LocalEvaluator) ] qianfan_evaluators = [ eval for eval in evaluators if isinstance(eval, QianfanEvaluator) ] self.eval_manager = EvaluationManager( local_evaluators=local_evaluators if len(local_evaluators) > 0 else None, qianfan_evaluators=( qianfan_evaluators if len(qianfan_evaluators) > 0 else None ), ) @with_event def exec(self, input: Dict[str, Any] = {}, **kwargs: Any) -> Dict[str, Any]: """ exec evaluation Parameters: input (Dict[str, Any], optional): input dict with model/service info. Defaults to {}. Returns: Dict[str, Any]: output the result with the original input """ self._input = input log_info(f"[evaluation_action] begin to do evaluation, input: {self._input}") llm = self._parse_from_input(self._input) res = self._exec(llm, **kwargs) self.result = {"eval_res": res, **input} return self.result def _parse_from_input(self, input: Dict[str, Any] = {}) -> Union[Model, Service]: """ Parses and returns the model or service object based on the input parameters. Parameters: input (Dict[str, Any], optional): . Defaults to {}. Returns: Union[Model, Service]: parsed model or service object """ if input.get("service"): llm = input.get("service") elif input.get("model"): llm = input.get("model") elif input.get("model_id") and input.get("model_version_id"): llm = Model(input["model_id"], input["model_version_id"]) else: log_error(f"[evaluation_action] invalid llm input error {self._input}") raise InvalidArgumentError( "model or service must be set in evaluation action" ) assert isinstance(llm, (Model, Service)) return llm def _exec(self, llm: Union[Model, Service], **kwargs: Dict) -> Any: """ accept a llm model/service to do evaluation Parameters: llm (Union[Model, Service]): model to do evaluation Returns: Any: evaluation result object """ assert self.eval_manager is not None if self.eval_dataset is None: raise InvalidArgumentError("eval_dataset must be set") self.action_event( ActionState.Running, "ready to evaluate", { "llm": llm, "dataset": self.eval_dataset, }, ) log_info("[evaluation_action] running evaluation...") return self.eval_manager.eval([llm], self.eval_dataset, **kwargs) @with_event def resume(self, **kwargs: Dict) -> Dict[str, Any]: """ resume method for eval action Parameters: **kwargs (Dict): input args for action resume """ if self._input is None: log_error( "[evaluation_action] previous input not found, call run() instead." ) raise ValueError("input not found") llm = self._parse_from_input(self._input) res = self._exec(llm, **kwargs) self.result = {"eval_res": res, **self._input} return self.result
action_mapping: Dict[str, Dict[str, Any]] = { LoadDataSetAction.__class__.__name__: { ActionState.Preceding: TrainStatus.DatasetLoading, ActionState.Running: TrainStatus.DatasetLoading, ActionState.Done: TrainStatus.DatasetLoaded, ActionState.Error: TrainStatus.DatasetLoadFailed, ActionState.Stopped: TrainStatus.DatasetLoadStopped, }, TrainAction.__class__.__name__: { ActionState.Preceding: TrainStatus.TrainCreated, ActionState.Running: TrainStatus.Training, ActionState.Done: TrainStatus.TrainFinished, ActionState.Error: TrainStatus.TrainFailed, ActionState.Stopped: TrainStatus.TrainStopped, }, ModelPublishAction.__class__.__name__: { ActionState.Preceding: TrainStatus.ModelPublishing, ActionState.Running: TrainStatus.ModelPublishing, ActionState.Done: TrainStatus.ModelPublished, ActionState.Error: TrainStatus.ModelPublishFailed, ActionState.Stopped: TrainStatus.ModelPublishFailed, }, DeployAction.__class__.__name__: { ActionState.Preceding: ServiceStatus.Created, ActionState.Running: ServiceStatus.Deploying, ActionState.Done: ServiceStatus.Deployed, ActionState.Error: ServiceStatus.DeployFailed, ActionState.Stopped: ServiceStatus.DeployStopped, }, EvaluateAction.__class__.__name__: { ActionState.Preceding: TrainStatus.EvaluationCreated, ActionState.Running: TrainStatus.EvaluationRunning, ActionState.Done: TrainStatus.EvaluationFinished, ActionState.Error: TrainStatus.EvaluationFailed, ActionState.Stopped: TrainStatus.EvaluationStopped, }, }