Source code for qianfan.trainer.model

# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import time
from typing import Any, Dict, Iterator, Optional, Union

from qianfan import resources as api
from qianfan.config import get_config
from qianfan.errors import InternalError, InvalidArgumentError
from qianfan.resources import (
    ChatCompletion,
    Completion,
    Embedding,
    QfResponse,
    Text2Image,
)
from qianfan.resources.console import consts as console_const
from qianfan.trainer.base import ExecuteSerializable
from qianfan.trainer.configs import DeployConfig
from qianfan.trainer.consts import ServiceType
from qianfan.utils import log_warn


[docs]class Model( ExecuteSerializable[Dict, Union[QfResponse, Iterator[QfResponse]]], ): id: Optional[int] """remote model id""" version_id: Optional[int] """remote model version id""" name: str """model name""" service: Optional["Service"] = None """model service""" task_id: Optional[int] """train tkas id""" job_id: Optional[int] """train job id""" def __init__( self, id: Optional[int] = None, version_id: Optional[int] = None, task_id: Optional[int] = None, job_id: Optional[int] = None, ): """ Class for model in qianfan, which is deployable by using deploy() to get a custom model service. Parameters: id (Optional[int], optional): qianfan model remote id. Defaults to None. version_id (Optional[int], optional): model version id. Defaults to None. task_id (Optional[int], optional): model train task id. Defaults to None. job_id (Optional[int], optional): model train job id. Defaults to None. """ self.id = id self.version_id = version_id self.task_id = task_id self.job_id = job_id
[docs] def exec( self, input: Optional[Dict] = None, **kwargs: Dict ) -> Union[QfResponse, Iterator[QfResponse]]: """ model execution, for different model service type, please input a dict with different keys. Concretely, take `input={"messages": [{"role": "user", "content": "hello world"}]}` as input, when the model is a chat io Model. Parameters: input (Optional[Dict], optional): input data . Defaults to None. Raises: InternalError: model with no service deployed is unable to call exec Returns: Union[QfResponse, Iterator[QfResponse]]: output data """ if self.service is None: raise InternalError( "model not deployed, call `model_deploy()` to instantiate a service" ) return self.service.exec(input, **kwargs)
[docs] def deploy(self, deploy_config: DeployConfig) -> "Service": """ model deploy Parameters: deploy_config (DeployConfig): model service deploy config Returns: Service: model service instance """ self.service = model_deploy(self, deploy_config) return self.service
[docs] def publish(self, name: str = "") -> "Model": """ model publish, before deploying a model, it should be published. Parameters: name (str, optional): model name. Defaults to "m_{task_id}{job_id}". """ if self.version_id: # already released model_detail_resp = api.Model.detail(model_version_id=self.version_id) self.id = model_detail_resp["result"]["modelId"] self.task_id = model_detail_resp["result"]["sourceExtra"][ "trainSourceExtra" ]["taskId"] self.job_id = model_detail_resp["result"]["sourceExtra"][ "trainSourceExtra" ]["runId"] if model_detail_resp["result"]["state"] != console_const.ModelState.Ready: self._wait_for_publish() if self.id: list_resp = api.Model.list(self.id) if len(list_resp["result"]["modelVersionList"]) == 0: raise InvalidArgumentError( "not model version matched, please train and publish first" ) self.version_id = list_resp["result"]["modelVersionList"][0][ "modelVersionId" ] if self.version_id is None: raise InvalidArgumentError("model version id not found") model_detail_resp = api.Model.detail(model_version_id=self.version_id) self.task_id = model_detail_resp["result"]["sourceExtra"][ "trainSourceExtra" ]["taskId"] self.job_id = model_detail_resp["result"]["sourceExtra"][ "trainSourceExtra" ]["runId"] if model_detail_resp["result"]["state"] != console_const.ModelState.Ready: self._wait_for_publish() # 发布模型 self.model_name = name if name != "" else f"m_{self.task_id}{self.job_id}" model_publish_resp = api.Model.publish( is_new=True, model_name=self.model_name, version_meta={"taskId": self.task_id, "iterationId": self.job_id}, ) self.id = model_publish_resp["result"]["modelId"] if self.task_id is None or self.job_id is None: raise InvalidArgumentError("task id or job id not found") while True: job_status_resp = api.FineTune.get_job( task_id=self.task_id, job_id=self.job_id ) job_status = job_status_resp["result"]["trainStatus"] if job_status != console_const.TrainStatus.Running: break time.sleep(get_config().TRAIN_STATUS_POLLING_INTERVAL) if self.id is None: raise InvalidArgumentError("model id not found") # 获取模型版本信息: model_list_resp = api.Model.list(model_id=self.id) model_version_list = model_list_resp["result"]["modelVersionList"] if model_version_list is None or len(model_version_list) == 0: raise InvalidArgumentError("not model version matched") self.version_id = model_version_list[0]["modelVersionId"] if self.version_id is None: raise InvalidArgumentError("model version id not found") self._wait_for_publish() return self
def _wait_for_publish(self) -> None: """ call a polling loop to wait until the model is published. Raises: InternalError: _description_ """ # 获取模型版本详情 if self.version_id is None: raise InvalidArgumentError("model version id not found") while True: model_detail_info = api.Model.detail(model_version_id=self.version_id) model_version_state = model_detail_info["result"]["state"] if model_version_state == console_const.ModelState.Ready: break elif model_version_state == console_const.ModelState.Fail: raise InternalError("model published failed") time.sleep(get_config().MODEL_PUBLISH_STATUS_POLLING_INTERVAL)
[docs] def dumps(self) -> Optional[bytes]: """ Serialize the model to bytes. Returns: Optional[bytes]: bytes of this model """ return pickle.dumps(self)
[docs] def loads(self, data: bytes) -> Any: """ load model instance from bytes Parameters: data (bytes): bytes of this model Returns: Any: model instance """ return pickle.loads(data)
[docs]class Service(ExecuteSerializable[Dict, Union[QfResponse, Iterator[QfResponse]]]): id: Optional[int] """remote service id""" model: Optional[Model] """service model instance""" deploy_config: Optional[DeployConfig] """service deploy config""" endpoint: Optional[str] """service endpoint to call""" service_type: Optional[ServiceType] """service type, for user use service as a execution must specify""" # service type may get from model ioModel def __init__( self, id: Optional[int] = None, endpoint: Optional[str] = None, model: Optional[Model] = None, deploy_config: Optional[DeployConfig] = None, service_type: Optional[ServiceType] = None, ) -> None: """ Class for model in qianfan, which is deployable by using deploy() to get a custom model service. Parameters: id (Optional[int], optional): qianfan service id. Defaults to None. endpoint (Optional[str], optional): qianfan service endpoint. Defaults to None. model (Optional[Model], optional): service's corresponding model. Defaults to None. deploy_config (Optional[DeployConfig], optional): service deploy config. Defaults to None. service_type (Optional[ServiceType], optional): service type, for user use service as a execution must specify, Defaults to None. """ self.id = id self.endpoint = endpoint self.model = model self.deploy_config = deploy_config self.service_type = service_type if self.endpoint is not None and self.service_type is None: log_warn("service type should be specified when endpoint passed in") @property def status(self) -> console_const.ServiceStatus: """ get the service status Raises: InternalError: id not found Returns: console_const.ServiceStatus """ if self.id is None: raise InternalError("service id not found") resp = api.Service.get(id=self.id) return resp["result"]["serviceStatus"]
[docs] def exec( self, input: Optional[Dict] = None, **kwargs: Dict ) -> Union[QfResponse, Iterator[QfResponse]]: """ exec Parameters: input (Optional[Union[str, List[str], List[dict]]], optional): input of execution of service. Defaults to None. **kwargs: additional args Dict Raises: InternalError: unsupported service type Returns: Union[str, List[str], List[dict]]: output """ if input is None: raise InvalidArgumentError("input is none") if self.endpoint is not None and self.service_type is None: raise InvalidArgumentError( "service type must be specified when endpoint passed in" ) if self.status != console_const.ServiceStatus.Done: raise InternalError("service is not ready") if self.service_type == ServiceType.Chat: return ChatCompletion().do(endpoint=self.endpoint, **input) elif self.service_type == ServiceType.Completion: return Completion().do(endpoint=self.endpoint, **input) elif self.service_type == ServiceType.Embedding: return Embedding().do(endpoint=self.endpoint, **input) elif self.service_type == ServiceType.Text2Image: return Text2Image().do(endpoint=self.endpoint, **input) else: raise InvalidArgumentError(f"unsupported service type {self.service_type}")
[docs] def dumps(self) -> Optional[bytes]: """ serialize the model instance to bytes Returns: Optional[bytes]: bytes of the model instance """ return pickle.dumps(self)
[docs] def loads(self, data: bytes) -> Any: """ load service instance from bytes Parameters: data (bytes): bytes of model instance Returns: Any: model instance """ return pickle.loads(data)
[docs]def model_deploy(model: Model, deploy_config: DeployConfig) -> Service: """ model deployment implement, a polling loop will be called after deploy task created. Parameters: model (Model): model to deploy deploy_config (DeployConfig): service deploy config, mainly including replicas and pool type. Returns: Service: deployed service with endpoint to call """ svc = Service( model=model, deploy_config=deploy_config, service_type=deploy_config.service_type, ) if model.id is None or model.version_id is None: raise InvalidArgumentError("model id | model version id not found") svc_publish_resp = api.Service.create( model_id=model.id, model_version_id=model.version_id, iteration_id=model.version_id, name=f"svc{model.id}{model.version_id}", uri=( deploy_config.endpoint_prefix if deploy_config != "" else f"ep{model.id}{model.version_id}" ), replicas=deploy_config.replicas, pool_type=deploy_config.pool_type, ) svc.id = svc_publish_resp["result"]["serviceId"] if svc.id is None: raise InternalError("service id not found") # 资源付费完成后,serviceStatus会变成Deploying,查看模型服务状态 while True: resp = api.Service.get(id=svc.id) svc_status = resp["result"]["serviceStatus"] if svc_status != console_const.ServiceStatus.Deploying.value: sft_model_endpoint = resp["result"]["uri"] break time.sleep(get_config().DEPLOY_STATUS_POLLING_INTERVAL) svc.endpoint = sft_model_endpoint return svc