# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import time
from typing import Any, Dict, Iterator, Optional, Union
from qianfan import resources as api
from qianfan.config import get_config
from qianfan.dataset import Dataset, QianfanDataSource
from qianfan.errors import InternalError, InvalidArgumentError, QianfanError
from qianfan.resources import (
ChatCompletion,
Completion,
Embedding,
QfResponse,
Text2Image,
)
from qianfan.resources.console import consts as console_const
from qianfan.resources.console.model import Model as ResourceModel
from qianfan.trainer.base import ExecuteSerializable
from qianfan.trainer.configs import DeployConfig
from qianfan.trainer.consts import ServiceType
from qianfan.utils import log_debug, log_error, log_info, log_warn
from qianfan.utils.utils import generate_letter_num_random_id
[docs]class Model(
ExecuteSerializable[Dict, Union[QfResponse, Iterator[QfResponse]]],
):
id: Optional[int]
"""remote model id"""
version_id: Optional[int]
"""remote model version id"""
name: Optional[str] = None
"""model name"""
service: Optional["Service"] = None
"""model service"""
task_id: Optional[int]
"""train tkas id"""
job_id: Optional[int]
"""train job id"""
def __init__(
self,
id: Optional[int] = None,
version_id: Optional[int] = None,
task_id: Optional[int] = None,
job_id: Optional[int] = None,
name: Optional[str] = None,
):
"""
Class for model in qianfan, which is deployable by using deploy() to
get a custom model service.
Parameters:
id (Optional[int], optional):
qianfan model remote id. Defaults to None.
version_id (Optional[int], optional):
model version id. Defaults to None.
task_id (Optional[int], optional):
model train task id. Defaults to None.
job_id (Optional[int], optional):
model train job id. Defaults to None.
"""
self.id = id
self.version_id = version_id
self.task_id = task_id
self.job_id = job_id
self.name = name
[docs] def exec(
self, input: Optional[Dict] = None, **kwargs: Dict
) -> Union[QfResponse, Iterator[QfResponse]]:
"""
model execution, for different model service type, please input
a dict with different keys.
Concretely, take
`input={"messages": [{"role": "user",
"content": "hello world"}]}`
as input, when the model is a chat io Model.
Parameters:
input (Optional[Dict], optional):
input data . Defaults to None.
Raises:
InternalError: model with no service deployed is unable to call exec
Returns:
Union[QfResponse, Iterator[QfResponse]]:
output data
"""
if self.service is None:
raise InternalError(
"model not deployed, call `model_deploy()` to instantiate a service"
)
return self.service.exec(input, **kwargs)
[docs] def deploy(self, deploy_config: DeployConfig, **kwargs: Any) -> "Service":
"""
model deploy
Parameters:
deploy_config (DeployConfig):
model service deploy config
Returns:
Service: model service instance
"""
if self.service is None:
self.service = model_deploy(self, deploy_config, **kwargs)
return self.service
log_info("model service already existed")
return self.service
[docs] def publish(self, name: str = "", **kwargs: Any) -> "Model":
"""
model publish, before deploying a model, it should be published.
Parameters:
name str:
model name. Defaults to "m_{task_id}{job_id}".
"""
if self.version_id:
# already released
model_detail_resp = api.Model.detail(
model_version_id=self.version_id, **kwargs
)
self.id = model_detail_resp["result"]["modelId"]
self.task_id = model_detail_resp["result"]["sourceExtra"][
"trainSourceExtra"
]["taskId"]
self.job_id = model_detail_resp["result"]["sourceExtra"][
"trainSourceExtra"
]["runId"]
log_info(f"check model {self.id}/{self.version_id} published...")
if model_detail_resp["result"]["state"] != console_const.ModelState.Ready:
self._wait_for_publish(**kwargs)
if self.id:
list_resp = api.Model.list(self.id, **kwargs)
if len(list_resp["result"]["modelVersionList"]) == 0:
raise InvalidArgumentError(
"not model version matched, please train and publish first"
)
log_info("model publish get the first version in model list as default")
self.version_id = list_resp["result"]["modelVersionList"][0][
"modelVersionId"
]
if self.version_id is None:
raise InvalidArgumentError("model version id not found")
model_detail_resp = api.Model.detail(
model_version_id=self.version_id, **kwargs
)
self.task_id = model_detail_resp["result"]["sourceExtra"][
"trainSourceExtra"
]["taskId"]
self.job_id = model_detail_resp["result"]["sourceExtra"][
"trainSourceExtra"
]["runId"]
if model_detail_resp["result"]["state"] != console_const.ModelState.Ready:
self._wait_for_publish(**kwargs)
# 发布模型
self.model_name = name if name != "" else f"m_{self.task_id}_{self.job_id}"
model_publish_resp = api.Model.publish(
is_new=True,
model_name=self.model_name,
version_meta={"taskId": self.task_id, "iterationId": self.job_id},
**kwargs,
)
log_info(
f"check train job: {self.task_id}/{self.job_id} status before publishing"
" model"
)
self.id = model_publish_resp["result"]["modelId"]
if self.task_id is None or self.job_id is None:
raise InvalidArgumentError("task id or job id not found")
# 判断训练任务已经训练完成
while True:
job_status_resp = api.FineTune.get_job(
task_id=self.task_id,
job_id=self.job_id,
**kwargs,
)
job_status = job_status_resp["result"]["trainStatus"]
log_info(f"model publishing keep polling, current status {job_status}")
if job_status == console_const.TrainStatus.Running:
time.sleep(get_config().TRAIN_STATUS_POLLING_INTERVAL)
continue
elif job_status == console_const.TrainStatus.Finish:
break
else:
raise InvalidArgumentError("invalid train task job to publish model")
if self.id is None:
raise InvalidArgumentError("model id not found")
# 获取模型版本信息:
model_list_resp = api.Model.list(model_id=self.id, **kwargs)
model_version_list = model_list_resp["result"]["modelVersionList"]
if model_version_list is None or len(model_version_list) == 0:
raise InvalidArgumentError("not model version matched")
self.version_id = model_version_list[0]["modelVersionId"]
if self.version_id is None:
raise InvalidArgumentError("model version id not found")
self._wait_for_publish(**kwargs)
return self
def _wait_for_publish(self, **kwargs: Any) -> None:
"""
call a polling loop to wait until the model is published.
Raises:
InternalError: _description_
"""
# 获取模型版本详情
if self.version_id is None:
raise InvalidArgumentError("model version id not found")
log_info("model ready to publish")
while True:
model_detail_info = api.Model.detail(
model_version_id=self.version_id, **kwargs
)
model_version_state = model_detail_info["result"]["state"]
log_info(f"check model publish status: {model_version_state}")
if model_version_state == console_const.ModelState.Ready:
log_info(f"model {self.id}/{self.version_id} published successfully")
break
elif model_version_state == console_const.ModelState.Fail:
raise InternalError(
"model published failed, check error msg and retry."
f" {model_detail_info}"
)
time.sleep(get_config().MODEL_PUBLISH_STATUS_POLLING_INTERVAL)
[docs] def dumps(self) -> Optional[bytes]:
"""
Serialize the model to bytes.
Returns:
Optional[bytes]:
bytes of this model
"""
return pickle.dumps(self)
[docs] def loads(self, data: bytes) -> Any:
"""
load model instance from bytes
Parameters:
data (bytes):
bytes of this model
Returns:
Any: model instance
"""
return pickle.loads(data)
[docs] def batch_run_on_qianfan(self, dataset: Dataset, **kwargs: Any) -> Dataset:
"""
create batch run using specific dataset on qianfan
by evaluation ability of platform
Parameters:
dataset (Dataset):
A dataset instance which indicates a dataset on qianfan platform
**kwargs (Any):
Arbitrary keyword arguments
Returns:
Dataset: batch result contained in dataset
"""
if not dataset.is_dataset_located_in_qianfan():
err_msg = "can't start a batch run task on non-qianfan dataset"
log_error(err_msg)
raise ValueError(err_msg)
qianfan_data_source = dataset.inner_data_source_cache
assert isinstance(qianfan_data_source, QianfanDataSource)
log_info("start to create evaluation task in model")
resp = ResourceModel.create_evaluation_task(
name=f"model_run_{generate_letter_num_random_id()}",
version_info=[
{
"modelId": self.id,
"modelVersionId": self.version_id,
}
],
dataset_id=qianfan_data_source.id,
eval_config={
"evalMode": "manual",
"evaluationDimension": [
{"dimension": "满意度"},
],
},
dataset_name=qianfan_data_source.name,
**kwargs,
).body
eval_id = resp["result"]["evalId"]
log_debug(f"create evaluation task in model response: {resp}")
log_info(f"start to polling status of evaluation task {eval_id}")
while True:
eval_info = ResourceModel.get_evaluation_info(eval_id)
eval_state = eval_info["result"]["state"]
log_debug(f"current evaluation task info: {eval_info}")
log_info(f"current eval_state: {eval_state}")
if eval_state not in [
console_const.EvaluationTaskStatus.Pending.value,
console_const.EvaluationTaskStatus.Doing.value,
]:
break
time.sleep(30)
if eval_state not in [
console_const.EvaluationTaskStatus.DoingWithManualBegin,
console_const.EvaluationTaskStatus.Done,
]:
err_msg = f"can't finish evaluation task and failed with state {eval_state}"
log_error(err_msg)
raise QianfanError(err_msg)
result_dataset_id = eval_info["result"]["evalStandardConf"]["resultDatasetId"]
log_info(f"get result dataset id {result_dataset_id}")
return Dataset.load(qianfan_dataset_id=result_dataset_id, **kwargs)
[docs]class Service(ExecuteSerializable[Dict, Union[QfResponse, Iterator[QfResponse]]]):
id: Optional[int]
"""remote service id"""
model: Optional[Model]
"""service model instance"""
deploy_config: Optional[DeployConfig]
"""service deploy config"""
endpoint: Optional[str]
"""service endpoint to call"""
service_type: Optional[ServiceType]
"""service type, for user use service as a execution must specify"""
# service type may get from model ioModel
def __init__(
self,
id: Optional[int] = None,
endpoint: Optional[str] = None,
model: Optional[Union[Model, str]] = None,
deploy_config: Optional[DeployConfig] = None,
service_type: Optional[ServiceType] = None,
) -> None:
"""
Class for model in qianfan, which is deployable by using deploy() to
get a custom model service.
Parameters:
id (Optional[int], optional):
qianfan service id. Defaults to None.
endpoint (Optional[str], optional):
qianfan service endpoint. Defaults to None.
model (Optional[Model], optional):
service's corresponding model. Defaults to None.
deploy_config (Optional[DeployConfig], optional):
service deploy config. Defaults to None.
service_type (Optional[ServiceType], optional):
service type, for user use service as a execution must specify,
Defaults to None.
"""
self.id = id
self.service_type = service_type
if self.service_type is None:
log_warn("service type should be specified before exec")
if endpoint is not None:
self.model = None
self.endpoint = endpoint
elif isinstance(model, str):
self.model = Model(name=model)
self.endpoint = None
elif isinstance(model, Model):
# need to deploy
self.model = model
self.endpoint = None
else:
raise InvalidArgumentError("invalid model service")
self.deploy_config = deploy_config
self.service_type = service_type
# if self.endpoint is not None and self.service_type is None:
@property
def status(self) -> str:
"""
get the service status
Raises:
InternalError: id not found
Returns:
console_const.ServiceStatus
"""
if self.id is None:
return ""
else:
resp = api.Service.get(
id=self.id,
retry_count=get_config().TRAINER_STATUS_POLLING_RETRY_TIMES,
backoff_factor=get_config().TRAINER_STATUS_POLLING_BACKOFF_FACTOR,
)
return resp["result"]["serviceStatus"]
[docs] def exec(
self, input: Optional[Dict] = None, **kwargs: Dict
) -> Union[QfResponse, Iterator[QfResponse]]:
"""
exec
Parameters:
input (Optional[Union[str, List[str], List[dict]]], optional):
input of execution of service. Defaults to None.
**kwargs: additional args Dict
Raises:
InternalError: unsupported service type
Returns:
Union[str, List[str], List[dict]]:
output
"""
if input is None:
raise InvalidArgumentError("input is none")
return self.get_res().do(**{**input, **kwargs})
[docs] def get_res(self) -> Union[ChatCompletion, Completion, Embedding, Text2Image]:
"""
convert to the specific model resources. e.g.
`ChatCompletion`, `Completion`, `Embeddings`,
`Text2Image`
Returns:
Union[ChatCompletion, Completion, Embedding, Text2Image]:
resource object
"""
if self.endpoint is not None and self.service_type is None:
raise InvalidArgumentError(
"service type must be specified when endpoint passed in"
)
svc_status = self.status
if svc_status != console_const.ServiceStatus.Done:
log_warn("service status unknown, service could be unavailable.")
if self.service_type == ServiceType.Chat:
return ChatCompletion(
model=(self.model.name if self.model is not None else None),
endpoint=self.endpoint,
)
elif self.service_type == ServiceType.Completion:
return Completion(
model=(self.model.name if self.model is not None else None),
endpoint=self.endpoint,
)
elif self.service_type == ServiceType.Embedding:
return Embedding(
model=(self.model.name if self.model is not None else None),
endpoint=self.endpoint,
)
elif self.service_type == ServiceType.Text2Image:
return Text2Image(
model=(self.model.name if self.model is not None else None),
endpoint=self.endpoint,
)
else:
raise InvalidArgumentError(f"unsupported service type {self.service_type}")
[docs] def deploy(self, **kwargs: Any) -> "Service":
if self.model is None:
raise InvalidArgumentError("model not found")
model = self.model
if model.id is None or model.version_id is None:
raise InvalidArgumentError("model id | model version id not found")
if self.deploy_config is None:
raise InvalidArgumentError("deploy config not found")
log_info(f"ready to deploy service with model {model.id}/{model.version_id}")
svc_publish_resp = api.Service.create(
model_id=model.id,
model_version_id=model.version_id,
name=(
self.deploy_config.name
if self.deploy_config.name != ""
else f"svc{model.id}_{model.version_id}"
),
uri=(
self.deploy_config.endpoint_prefix
if self.deploy_config.endpoint_prefix != ""
else f"ep{model.id}_{model.version_id}"
),
replicas=self.deploy_config.replicas,
pool_type=self.deploy_config.pool_type,
**kwargs,
)
self.id = svc_publish_resp["result"]["serviceId"]
if self.id is None:
log_error("create service error", svc_publish_resp)
raise InternalError("service id not found")
# 资源付费完成后,serviceStatus会变成Deploying,查看模型服务状态
while True:
resp = api.Service.get(id=self.id, **kwargs)
svc_status = resp["result"]["serviceStatus"]
if svc_status in [
console_const.ServiceStatus.Deploying.value,
console_const.ServiceStatus.New.value,
]:
log_info(
"please check web console"
" `https://console.bce.baidu.com/qianfan/ais/console/onlineService`,for"
" service deployment payment."
)
elif svc_status == console_const.ServiceStatus.Done:
sft_model_endpoint = resp["result"]["uri"]
log_info(
f"service {self.id} has been deployed in `/{sft_model_endpoint}` "
)
break
else:
log_error(f"service {self.id} has been ended in {svc_status}")
break
time.sleep(get_config().DEPLOY_STATUS_POLLING_INTERVAL)
self.endpoint = sft_model_endpoint
return self
[docs] def dumps(self) -> Optional[bytes]:
"""
serialize the model instance to bytes
Returns:
Optional[bytes]:
bytes of the model instance
"""
return pickle.dumps(self)
[docs] def loads(self, data: bytes) -> Any:
"""
load service instance from bytes
Parameters:
data (bytes):
bytes of model instance
Returns:
Any: model instance
"""
return pickle.loads(data)
[docs]def model_deploy(model: Model, deploy_config: DeployConfig, **kwargs: Any) -> Service:
"""
model deployment implement, a polling loop will be called after
deploy task created.
Parameters:
model (Model):
model to deploy
deploy_config (DeployConfig):
service deploy config, mainly including replicas
and pool type.
Returns:
Service: deployed service with endpoint to call
"""
svc = Service(
model=model,
deploy_config=deploy_config,
service_type=deploy_config.service_type,
)
svc.deploy(**kwargs)
return svc