Source code for qianfan.trainer.model

# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import time
from typing import Any, Dict, Iterator, Optional, Union

from qianfan import resources as api
from qianfan.config import get_config
from qianfan.dataset import Dataset, QianfanDataSource
from qianfan.errors import InternalError, InvalidArgumentError, QianfanError
from qianfan.resources import (
    ChatCompletion,
    Completion,
    Embedding,
    QfResponse,
    Text2Image,
)
from qianfan.resources.console import consts as console_const
from qianfan.resources.console.model import Model as ResourceModel
from qianfan.trainer.base import ExecuteSerializable
from qianfan.trainer.configs import DeployConfig
from qianfan.trainer.consts import ServiceType
from qianfan.utils import log_debug, log_error, log_info, log_warn
from qianfan.utils.utils import generate_letter_num_random_id


[docs]class Model( ExecuteSerializable[Dict, Union[QfResponse, Iterator[QfResponse]]], ): id: Optional[int] """remote model id""" version_id: Optional[int] """remote model version id""" name: Optional[str] = None """model name""" service: Optional["Service"] = None """model service""" task_id: Optional[int] """train tkas id""" job_id: Optional[int] """train job id""" def __init__( self, id: Optional[int] = None, version_id: Optional[int] = None, task_id: Optional[int] = None, job_id: Optional[int] = None, name: Optional[str] = None, ): """ Class for model in qianfan, which is deployable by using deploy() to get a custom model service. Parameters: id (Optional[int], optional): qianfan model remote id. Defaults to None. version_id (Optional[int], optional): model version id. Defaults to None. task_id (Optional[int], optional): model train task id. Defaults to None. job_id (Optional[int], optional): model train job id. Defaults to None. """ self.id = id self.version_id = version_id self.task_id = task_id self.job_id = job_id self.name = name
[docs] def exec( self, input: Optional[Dict] = None, **kwargs: Dict ) -> Union[QfResponse, Iterator[QfResponse]]: """ model execution, for different model service type, please input a dict with different keys. Concretely, take `input={"messages": [{"role": "user", "content": "hello world"}]}` as input, when the model is a chat io Model. Parameters: input (Optional[Dict], optional): input data . Defaults to None. Raises: InternalError: model with no service deployed is unable to call exec Returns: Union[QfResponse, Iterator[QfResponse]]: output data """ if self.service is None: raise InternalError( "model not deployed, call `model_deploy()` to instantiate a service" ) return self.service.exec(input, **kwargs)
[docs] def deploy(self, deploy_config: DeployConfig, **kwargs: Any) -> "Service": """ model deploy Parameters: deploy_config (DeployConfig): model service deploy config Returns: Service: model service instance """ if self.service is None: self.service = model_deploy(self, deploy_config, **kwargs) return self.service log_info("model service already existed") return self.service
[docs] def publish(self, name: str = "", **kwargs: Any) -> "Model": """ model publish, before deploying a model, it should be published. Parameters: name str: model name. Defaults to "m_{task_id}{job_id}". """ if self.version_id: # already released model_detail_resp = api.Model.detail( model_version_id=self.version_id, **kwargs ) self.id = model_detail_resp["result"]["modelId"] self.task_id = model_detail_resp["result"]["sourceExtra"][ "trainSourceExtra" ]["taskId"] self.job_id = model_detail_resp["result"]["sourceExtra"][ "trainSourceExtra" ]["runId"] log_info(f"check model {self.id}/{self.version_id} published...") if model_detail_resp["result"]["state"] != console_const.ModelState.Ready: self._wait_for_publish(**kwargs) if self.id: list_resp = api.Model.list(self.id, **kwargs) if len(list_resp["result"]["modelVersionList"]) == 0: raise InvalidArgumentError( "not model version matched, please train and publish first" ) log_info("model publish get the first version in model list as default") self.version_id = list_resp["result"]["modelVersionList"][0][ "modelVersionId" ] if self.version_id is None: raise InvalidArgumentError("model version id not found") model_detail_resp = api.Model.detail( model_version_id=self.version_id, **kwargs ) self.task_id = model_detail_resp["result"]["sourceExtra"][ "trainSourceExtra" ]["taskId"] self.job_id = model_detail_resp["result"]["sourceExtra"][ "trainSourceExtra" ]["runId"] if model_detail_resp["result"]["state"] != console_const.ModelState.Ready: self._wait_for_publish(**kwargs) # 发布模型 self.model_name = name if name != "" else f"m_{self.task_id}_{self.job_id}" model_publish_resp = api.Model.publish( is_new=True, model_name=self.model_name, version_meta={"taskId": self.task_id, "iterationId": self.job_id}, **kwargs, ) log_info( f"check train job: {self.task_id}/{self.job_id} status before publishing" " model" ) self.id = model_publish_resp["result"]["modelId"] if self.task_id is None or self.job_id is None: raise InvalidArgumentError("task id or job id not found") # 判断训练任务已经训练完成 while True: job_status_resp = api.FineTune.get_job( task_id=self.task_id, job_id=self.job_id, **kwargs, ) job_status = job_status_resp["result"]["trainStatus"] log_info(f"model publishing keep polling, current status {job_status}") if job_status == console_const.TrainStatus.Running: time.sleep(get_config().TRAIN_STATUS_POLLING_INTERVAL) continue elif job_status == console_const.TrainStatus.Finish: break else: raise InvalidArgumentError("invalid train task job to publish model") if self.id is None: raise InvalidArgumentError("model id not found") # 获取模型版本信息: model_list_resp = api.Model.list(model_id=self.id, **kwargs) model_version_list = model_list_resp["result"]["modelVersionList"] if model_version_list is None or len(model_version_list) == 0: raise InvalidArgumentError("not model version matched") self.version_id = model_version_list[0]["modelVersionId"] if self.version_id is None: raise InvalidArgumentError("model version id not found") self._wait_for_publish(**kwargs) return self
def _wait_for_publish(self, **kwargs: Any) -> None: """ call a polling loop to wait until the model is published. Raises: InternalError: _description_ """ # 获取模型版本详情 if self.version_id is None: raise InvalidArgumentError("model version id not found") log_info("model ready to publish") while True: model_detail_info = api.Model.detail( model_version_id=self.version_id, **kwargs ) model_version_state = model_detail_info["result"]["state"] log_info(f"check model publish status: {model_version_state}") if model_version_state == console_const.ModelState.Ready: log_info(f"model {self.id}/{self.version_id} published successfully") break elif model_version_state == console_const.ModelState.Fail: raise InternalError( "model published failed, check error msg and retry." f" {model_detail_info}" ) time.sleep(get_config().MODEL_PUBLISH_STATUS_POLLING_INTERVAL)
[docs] def dumps(self) -> Optional[bytes]: """ Serialize the model to bytes. Returns: Optional[bytes]: bytes of this model """ return pickle.dumps(self)
[docs] def loads(self, data: bytes) -> Any: """ load model instance from bytes Parameters: data (bytes): bytes of this model Returns: Any: model instance """ return pickle.loads(data)
[docs] def batch_run_on_qianfan(self, dataset: Dataset, **kwargs: Any) -> Dataset: """ create batch run using specific dataset on qianfan by evaluation ability of platform Parameters: dataset (Dataset): A dataset instance which indicates a dataset on qianfan platform **kwargs (Any): Arbitrary keyword arguments Returns: Dataset: batch result contained in dataset """ if not dataset.is_dataset_located_in_qianfan(): err_msg = "can't start a batch run task on non-qianfan dataset" log_error(err_msg) raise ValueError(err_msg) qianfan_data_source = dataset.inner_data_source_cache assert isinstance(qianfan_data_source, QianfanDataSource) log_info("start to create evaluation task in model") resp = ResourceModel.create_evaluation_task( name=f"model_run_{generate_letter_num_random_id()}", version_info=[ { "modelId": self.id, "modelVersionId": self.version_id, } ], dataset_id=qianfan_data_source.id, eval_config={ "evalMode": "manual", "evaluationDimension": [ {"dimension": "满意度"}, ], }, dataset_name=qianfan_data_source.name, **kwargs, ).body eval_id = resp["result"]["evalId"] log_debug(f"create evaluation task in model response: {resp}") log_info(f"start to polling status of evaluation task {eval_id}") while True: eval_info = ResourceModel.get_evaluation_info(eval_id) eval_state = eval_info["result"]["state"] log_debug(f"current evaluation task info: {eval_info}") log_info(f"current eval_state: {eval_state}") if eval_state not in [ console_const.EvaluationTaskStatus.Pending.value, console_const.EvaluationTaskStatus.Doing.value, ]: break time.sleep(30) if eval_state not in [ console_const.EvaluationTaskStatus.DoingWithManualBegin, console_const.EvaluationTaskStatus.Done, ]: err_msg = f"can't finish evaluation task and failed with state {eval_state}" log_error(err_msg) raise QianfanError(err_msg) result_dataset_id = eval_info["result"]["evalStandardConf"]["resultDatasetId"] log_info(f"get result dataset id {result_dataset_id}") return Dataset.load(qianfan_dataset_id=result_dataset_id, **kwargs)
[docs]class Service(ExecuteSerializable[Dict, Union[QfResponse, Iterator[QfResponse]]]): id: Optional[int] """remote service id""" model: Optional[Model] """service model instance""" deploy_config: Optional[DeployConfig] """service deploy config""" endpoint: Optional[str] """service endpoint to call""" service_type: Optional[ServiceType] """service type, for user use service as a execution must specify""" # service type may get from model ioModel def __init__( self, id: Optional[int] = None, endpoint: Optional[str] = None, model: Optional[Union[Model, str]] = None, deploy_config: Optional[DeployConfig] = None, service_type: Optional[ServiceType] = None, ) -> None: """ Class for model in qianfan, which is deployable by using deploy() to get a custom model service. Parameters: id (Optional[int], optional): qianfan service id. Defaults to None. endpoint (Optional[str], optional): qianfan service endpoint. Defaults to None. model (Optional[Model], optional): service's corresponding model. Defaults to None. deploy_config (Optional[DeployConfig], optional): service deploy config. Defaults to None. service_type (Optional[ServiceType], optional): service type, for user use service as a execution must specify, Defaults to None. """ self.id = id self.service_type = service_type if self.service_type is None: log_warn("service type should be specified before exec") if endpoint is not None: self.model = None self.endpoint = endpoint elif isinstance(model, str): self.model = Model(name=model) self.endpoint = None elif isinstance(model, Model): # need to deploy self.model = model self.endpoint = None else: raise InvalidArgumentError("invalid model service") self.deploy_config = deploy_config self.service_type = service_type # if self.endpoint is not None and self.service_type is None: @property def status(self) -> str: """ get the service status Raises: InternalError: id not found Returns: console_const.ServiceStatus """ if self.id is None: return "" else: resp = api.Service.get( id=self.id, retry_count=get_config().TRAINER_STATUS_POLLING_RETRY_TIMES, backoff_factor=get_config().TRAINER_STATUS_POLLING_BACKOFF_FACTOR, ) return resp["result"]["serviceStatus"]
[docs] def exec( self, input: Optional[Dict] = None, **kwargs: Dict ) -> Union[QfResponse, Iterator[QfResponse]]: """ exec Parameters: input (Optional[Union[str, List[str], List[dict]]], optional): input of execution of service. Defaults to None. **kwargs: additional args Dict Raises: InternalError: unsupported service type Returns: Union[str, List[str], List[dict]]: output """ if input is None: raise InvalidArgumentError("input is none") return self.get_res().do(**{**input, **kwargs})
[docs] def get_res(self) -> Union[ChatCompletion, Completion, Embedding, Text2Image]: """ convert to the specific model resources. e.g. `ChatCompletion`, `Completion`, `Embeddings`, `Text2Image` Returns: Union[ChatCompletion, Completion, Embedding, Text2Image]: resource object """ if self.endpoint is not None and self.service_type is None: raise InvalidArgumentError( "service type must be specified when endpoint passed in" ) svc_status = self.status if svc_status != console_const.ServiceStatus.Done: log_warn("service status unknown, service could be unavailable.") if self.service_type == ServiceType.Chat: return ChatCompletion( model=(self.model.name if self.model is not None else None), endpoint=self.endpoint, ) elif self.service_type == ServiceType.Completion: return Completion( model=(self.model.name if self.model is not None else None), endpoint=self.endpoint, ) elif self.service_type == ServiceType.Embedding: return Embedding( model=(self.model.name if self.model is not None else None), endpoint=self.endpoint, ) elif self.service_type == ServiceType.Text2Image: return Text2Image( model=(self.model.name if self.model is not None else None), endpoint=self.endpoint, ) else: raise InvalidArgumentError(f"unsupported service type {self.service_type}")
[docs] def deploy(self, **kwargs: Any) -> "Service": if self.model is None: raise InvalidArgumentError("model not found") model = self.model if model.id is None or model.version_id is None: raise InvalidArgumentError("model id | model version id not found") if self.deploy_config is None: raise InvalidArgumentError("deploy config not found") log_info(f"ready to deploy service with model {model.id}/{model.version_id}") svc_publish_resp = api.Service.create( model_id=model.id, model_version_id=model.version_id, name=( self.deploy_config.name if self.deploy_config.name != "" else f"svc{model.id}_{model.version_id}" ), uri=( self.deploy_config.endpoint_prefix if self.deploy_config.endpoint_prefix != "" else f"ep{model.id}_{model.version_id}" ), replicas=self.deploy_config.replicas, pool_type=self.deploy_config.pool_type, **kwargs, ) self.id = svc_publish_resp["result"]["serviceId"] if self.id is None: log_error("create service error", svc_publish_resp) raise InternalError("service id not found") # 资源付费完成后,serviceStatus会变成Deploying,查看模型服务状态 while True: resp = api.Service.get(id=self.id, **kwargs) svc_status = resp["result"]["serviceStatus"] if svc_status in [ console_const.ServiceStatus.Deploying.value, console_const.ServiceStatus.New.value, ]: log_info( "please check web console" " `https://console.bce.baidu.com/qianfan/ais/console/onlineService`,for" " service deployment payment." ) elif svc_status == console_const.ServiceStatus.Done: sft_model_endpoint = resp["result"]["uri"] log_info( f"service {self.id} has been deployed in `/{sft_model_endpoint}` " ) break else: log_error(f"service {self.id} has been ended in {svc_status}") break time.sleep(get_config().DEPLOY_STATUS_POLLING_INTERVAL) self.endpoint = sft_model_endpoint return self
[docs] def dumps(self) -> Optional[bytes]: """ serialize the model instance to bytes Returns: Optional[bytes]: bytes of the model instance """ return pickle.dumps(self)
[docs] def loads(self, data: bytes) -> Any: """ load service instance from bytes Parameters: data (bytes): bytes of model instance Returns: Any: model instance """ return pickle.loads(data)
[docs]def model_deploy(model: Model, deploy_config: DeployConfig, **kwargs: Any) -> Service: """ model deployment implement, a polling loop will be called after deploy task created. Parameters: model (Model): model to deploy deploy_config (DeployConfig): service deploy config, mainly including replicas and pool type. Returns: Service: deployed service with endpoint to call """ svc = Service( model=model, deploy_config=deploy_config, service_type=deploy_config.service_type, ) svc.deploy(**kwargs) return svc