# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Qianfan API Requestor
"""
import copy
import json
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Iterator,
Optional,
TypeVar,
Union,
)
from urllib.parse import urlparse
import qianfan.errors as errors
from qianfan.config import get_config
from qianfan.consts import APIErrorCode, Consts
from qianfan.resources.auth.iam import iam_sign
from qianfan.resources.auth.oauth import Auth
from qianfan.resources.requestor.base import (
BaseAPIRequestor,
_async_check_if_status_code_is_200,
_check_if_status_code_is_200,
_with_latency,
)
from qianfan.resources.typing import QfRequest, QfResponse, RetryConfig
from qianfan.utils.logging import log_error, log_info
_T = TypeVar("_T")
[docs]class QfAPIRequestor(BaseAPIRequestor):
"""
object to manage Qianfan API requests
"""
def __init__(self, **kwargs: Any) -> None:
"""
`ak`, `sk` and `access_token` can be provided in kwargs.
"""
super().__init__(**kwargs)
self._auth = Auth(**kwargs)
def _retry_if_token_expired(self, func: Callable[..., _T]) -> Callable[..., _T]:
"""
this is a wrapper to deal with token expired error
"""
token_refreshed = False
def retry_wrapper(*args: Any, **kwargs: Any) -> _T:
nonlocal token_refreshed
# if token is refreshed, token expired exception will not be dealt with
if not token_refreshed:
try:
return func(*args)
except errors.AccessTokenExpiredError:
# refresh token and set token_refreshed flag
self._auth.refresh_access_token()
token_refreshed = True
# then fallthrough and try again
return func(*args, **kwargs)
return retry_wrapper
@_with_latency
def _request_stream(
self,
request: QfRequest,
data_postprocess: Callable[[QfResponse], QfResponse] = lambda x: x,
) -> Iterator[QfResponse]:
"""
stream sync request
"""
with self._rate_limiter:
responses = self._client.request_stream(request)
event = ""
token_refreshed = False
while True:
try:
body, resp = next(responses)
except StopIteration:
break
_check_if_status_code_is_200(resp)
body_str = body.decode("utf-8")
if body_str == "":
continue
if body_str.startswith(Consts.STREAM_RESPONSE_EVENT_PREFIX):
# event indicator for the type of data
event = body_str[len(Consts.STREAM_RESPONSE_EVENT_PREFIX) :]
continue
elif not body_str.startswith(Consts.STREAM_RESPONSE_PREFIX):
try:
# the response might be error message in json format
json_body = json.loads(body_str)
self._check_error(json_body)
except errors.AccessTokenExpiredError:
if not token_refreshed:
token_refreshed = True
self._auth.refresh_access_token()
self._add_access_token(request)
with self._rate_limiter:
responses = self._client.request_stream(request)
continue
raise
except json.JSONDecodeError:
# the response is not json format, ignore and raise InternalError
pass
raise errors.RequestError(
f"got unexpected stream response from server: {body_str}"
)
body_str = body_str[len(Consts.STREAM_RESPONSE_PREFIX) :]
json_body = json.loads(body_str)
if event != "":
json_body["_event"] = event
event = ""
parsed = self._parse_response(json_body, resp)
parsed.request = QfRequest.from_requests(resp.request)
parsed.request.json_body = copy.deepcopy(request.json_body)
yield data_postprocess(parsed)
def _async_retry_if_token_expired(
self, func: Callable[..., Awaitable[_T]]
) -> Callable[..., Awaitable[_T]]:
"""
this is a wrapper to deal with token expired error
"""
token_refreshed = False
async def retry_wrapper(*args: Any, **kwargs: Any) -> _T:
nonlocal token_refreshed
# if token is refreshed, token expired exception will not be dealt with
if not token_refreshed:
try:
return await func(*args)
except errors.AccessTokenExpiredError:
# refresh token and set token_refreshed flag
await self._auth.arefresh_access_token()
token_refreshed = True
# then fallthrough and try again
return await func(*args, **kwargs)
return retry_wrapper
@_with_latency
async def _async_request_stream(
self,
request: QfRequest,
data_postprocess: Callable[[QfResponse], QfResponse] = lambda x: x,
) -> AsyncIterator[QfResponse]:
"""
async stream request
"""
async with self._rate_limiter:
responses = self._client.arequest_stream(request)
token_refreshed = False
async for body, resp in responses:
_async_check_if_status_code_is_200(resp)
body_str = body.decode("utf-8")
if body_str.strip() == "":
continue
if not body_str.startswith(Consts.STREAM_RESPONSE_PREFIX):
try:
# the response might be error message in json format
json_body: Dict[str, Any] = json.loads(body_str)
self._check_error(json_body)
except json.JSONDecodeError:
# the response is not json format, ignore and raise RequestError
pass
except errors.AccessTokenExpiredError:
if not token_refreshed:
token_refreshed = True
await self._auth.arefresh_access_token()
await self._async_add_access_token(request)
async with self._rate_limiter:
responses = self._client.arequest_stream(request)
continue
raise
raise errors.RequestError(
f"got unexpected stream response from server: {body_str}"
)
body_str = body_str[len(Consts.STREAM_RESPONSE_PREFIX) :]
json_body = json.loads(body_str)
parsed = self._parse_async_response(json_body, resp)
parsed.request = QfRequest.from_aiohttp(resp.request_info)
parsed.request.json_body = copy.deepcopy(request.json_body)
yield data_postprocess(parsed)
def _check_error(self, body: Dict[str, Any]) -> None:
"""
check whether error_code in response body
if there is an APITokenExpired error,
raise AccessTokenExpiredError
"""
if "error_code" in body:
req_id = body.get("id", "")
error_code = body["error_code"]
err_msg = body.get("error_msg", "no error message found in response body")
log_error(
f"api request req_id: {req_id} failed with error code: {error_code},"
f" err msg: {err_msg}, please check"
" https://cloud.baidu.com/doc/WENXINWORKSHOP/s/tlmyncueh"
)
if error_code in {
APIErrorCode.APITokenExpired.value,
APIErrorCode.APITokenInvalid.value,
}:
raise errors.AccessTokenExpiredError
raise errors.APIError(error_code, err_msg, req_id)
[docs] def llm(
self,
endpoint: str,
header: Dict[str, Any] = {},
query: Dict[str, Any] = {},
body: Dict[str, Any] = {},
stream: bool = False,
data_postprocess: Callable[[QfResponse], QfResponse] = lambda x: x,
retry_config: RetryConfig = RetryConfig(),
) -> Union[QfResponse, Iterator[QfResponse]]:
"""
llm related api request
"""
log_info(f"requesting llm api endpoint: {endpoint}")
@self._retry_if_token_expired
def _helper() -> Union[QfResponse, Iterator[QfResponse]]:
req = self._base_llm_request(
endpoint,
header=header,
query=query,
body=body,
retry_config=retry_config,
)
req = self._add_access_token(req)
if stream:
return self._request_stream(req, data_postprocess=data_postprocess)
return self._request(req, data_postprocess=data_postprocess)
return self._with_retry(retry_config, _helper)
[docs] async def async_llm(
self,
endpoint: str,
header: Dict[str, Any] = {},
query: Dict[str, Any] = {},
body: Dict[str, Any] = {},
stream: bool = False,
data_postprocess: Callable[[QfResponse], QfResponse] = lambda x: x,
retry_config: RetryConfig = RetryConfig(),
) -> Union[QfResponse, AsyncIterator[QfResponse]]:
"""
llm related api request
"""
log_info(f"async requesting llm api endpoint: {endpoint}")
@self._async_retry_if_token_expired
async def _helper() -> Union[QfResponse, AsyncIterator[QfResponse]]:
req = self._base_llm_request(
endpoint,
header=header,
query=query,
body=body,
retry_config=retry_config,
)
req = await self._async_add_access_token(req)
if stream:
return self._async_request_stream(
req, data_postprocess=data_postprocess
)
return await self._async_request(req, data_postprocess=data_postprocess)
return await self._async_with_retry(retry_config, _helper)
def _base_llm_request(
self,
endpoint: str,
header: Dict[str, Any] = {},
query: Dict[str, Any] = {},
body: Dict[str, Any] = {},
retry_config: RetryConfig = RetryConfig(),
) -> QfRequest:
"""
create base llm QfRequest from provided args
"""
req = QfRequest(method="POST", url=self._llm_api_url(endpoint))
req.headers = header
req.query = query
req.json_body = body
req.retry_config = retry_config
return req
@staticmethod
def _sign(request: QfRequest, ak: str, sk: str) -> None:
"""
sign the request
"""
url = request.url
parsed_uri = urlparse(request.url)
host = parsed_uri.netloc
request.url = parsed_uri.path
request.headers = {
"Content-Type": "application/json",
"Host": host,
**request.headers,
}
iam_sign(ak, sk, request)
request.url = url
def _add_access_token(
self, req: QfRequest, auth: Optional[Auth] = None
) -> QfRequest:
"""
add access token to QfRequest
"""
if auth is None:
auth = self._auth
access_token = auth.access_token()
if access_token == "":
# use IAM auth
access_key = auth._access_key
secret_key = auth._secret_key
if access_key is None or secret_key is None:
raise errors.AccessTokenExpiredError
self._sign(req, access_key, secret_key)
else:
# use openapi auth
req.query["access_token"] = access_token
return req
async def _async_add_access_token(
self, req: QfRequest, auth: Optional[Auth] = None
) -> QfRequest:
"""
async add access token to QfRequest
"""
if auth is None:
auth = self._auth
access_token = await auth.a_access_token()
if access_token == "":
# use IAM auth
access_key = auth._access_key
secret_key = auth._secret_key
if access_key is None or secret_key is None:
raise errors.AccessTokenExpiredError
self._sign(req, access_key, secret_key)
else:
# use openapi auth
req.query["access_token"] = access_token
return req
def _llm_api_url(self, endpoint: str) -> str:
"""
convert endpoint to llm api url
"""
return "{}{}{}".format(
get_config().BASE_URL,
Consts.ModelAPIPrefix,
endpoint,
)
def _request_api(self, req: QfRequest, auth: Optional[Auth] = None) -> QfResponse:
"""
request api with auth and retry
"""
@self._retry_if_token_expired
def _helper() -> QfResponse:
self._add_access_token(req, auth)
return self._request(req)
return self._with_retry(req.retry_config, _helper)
def _async_request_api(
self, req: QfRequest, auth: Optional[Auth] = None
) -> Awaitable[QfResponse]:
"""
async request api with auth and retry
"""
@self._async_retry_if_token_expired
async def _helper() -> QfResponse:
await self._async_add_access_token(req, auth)
return await self._async_request(req)
return self._async_with_retry(req.retry_config, _helper)
[docs]def create_api_requestor(*args: Any, **kwargs: Any) -> QfAPIRequestor:
if get_config().ENABLE_PRIVATE:
return PrivateAPIRequestor(**kwargs)
return QfAPIRequestor(**kwargs)
[docs]class PrivateAPIRequestor(QfAPIRequestor):
"""
qianfan private api requestor
"""
def __init__(self, **kwargs: Any) -> None:
"""
`ak`, `sk` and `access_token` can be provided in kwargs.
"""
super().__init__(**kwargs)
self._ak = kwargs.get("ak", None) or get_config().AK
self._sk = kwargs.get("sk", None) or get_config().SK
self._access_code = kwargs.get("access_code", None) or get_config().ACCESS_CODE
def _base_llm_request(
self,
endpoint: str,
header: Dict[str, Any] = {},
query: Dict[str, Any] = {},
body: Dict[str, Any] = {},
retry_config: RetryConfig = RetryConfig(),
) -> QfRequest:
"""
create base llm QfRequest from provided args
"""
req = QfRequest(
method="POST",
url="{}{}".format(
Consts.ModelAPIPrefix,
endpoint,
),
)
req.headers = header
req.query = query
req.json_body = body
req.retry_config = retry_config
return req
[docs] def llm(
self,
endpoint: str,
header: Dict[str, Any] = {},
query: Dict[str, Any] = {},
body: Dict[str, Any] = {},
stream: bool = False,
data_postprocess: Callable[[QfResponse], QfResponse] = lambda x: x,
retry_config: RetryConfig = RetryConfig(),
) -> Union[QfResponse, Iterator[QfResponse]]:
"""
llm related api request
"""
log_info(f"requesting llm api endpoint: {endpoint}")
def _helper() -> Union[QfResponse, Iterator[QfResponse]]:
req = self._base_llm_request(
endpoint,
header=header,
query=query,
body=body,
retry_config=retry_config,
)
parsed_uri = urlparse(get_config().BASE_URL)
host = parsed_uri.netloc
req.headers["content-type"] = "application/json;"
req.headers["Host"] = host
if self._access_code != "" and self._access_code is not None:
req.headers["Authorization"] = "ACCESSCODE {}".format(self._access_code)
elif self._ak != "" and self._sk != "":
iam_sign(str(self._ak), str(self._sk), req)
req.url = get_config().BASE_URL + req.url
if stream:
return self._request_stream(req, data_postprocess=data_postprocess)
return self._request(req, data_postprocess=data_postprocess)
return self._with_retry(retry_config, _helper)