Source code for qianfan.resources.requestor.base

# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
API Requestor for SDK
"""

import copy
import inspect
import json
import time
from typing import (
    Any,
    AsyncIterator,
    Awaitable,
    Callable,
    Dict,
    Iterator,
    Optional,
    TypeVar,
    Union,
)

import aiohttp
import requests
from tenacity import (
    retry,
    retry_if_exception,
    stop_after_attempt,
    wait_exponential_jitter,
)

import qianfan.errors as errors
from qianfan.resources.http_client import HTTPClient
from qianfan.resources.rate_limiter import RateLimiter
from qianfan.resources.typing import QfRequest, QfResponse, RetryConfig
from qianfan.utils.logging import log_error, log_trace, log_warn

_T = TypeVar("_T")


def _is_utf8_encoded_bytes(byte_str: bytes) -> bool:
    """check whether bytes object is utf8 encoded"""
    try:
        byte_str.decode("utf-8")
        return True
    except UnicodeDecodeError:
        return False


def _get_body_str(byte_str: Optional[Union[bytes, str]]) -> Optional[Union[bytes, str]]:
    """get utf8-decoded str"""
    if byte_str is None:
        return ""

    if (
        not byte_str
        or isinstance(byte_str, str)
        or not _is_utf8_encoded_bytes(byte_str)
    ):
        return byte_str

    return str(byte_str, encoding="utf8")


def _check_if_status_code_is_200(response: requests.Response) -> None:
    """
    check whether the status code of response is ok(200)
    if the status code is not 200, raise a `RequestError`
    """
    if response.status_code != 200:
        failed_msg = (
            f"http request url {response.url} failed "
            f"with http status code {response.status_code}\n"
        )
        if response.headers.get("X-Bce-Error-Code", ""):
            failed_msg += (
                f"error code from baidu: {response.headers['X-Bce-Error-Code']}\n"
            )

        if response.headers.get("X-Bce-Error-Message", ""):
            failed_msg += (
                f"error message from baidu: {response.headers['X-Bce-Error-Message']}\n"
            )

        request_body = _get_body_str(response.request.body)
        response_body = _get_body_str(response.content)

        failed_msg += (
            f"request headers: {response.request.headers}\n"
            f"request body: {request_body!r}\n"
            f"response headers: {response.headers}\n"
            f"response body: {response_body!r}"
        )

        log_error(failed_msg)
        raise errors.RequestError(failed_msg)


def _async_check_if_status_code_is_200(response: aiohttp.ClientResponse) -> None:
    """
    async check whether the status code of response is ok(200)
    if the status code is not 200, raise a `RequestError`
    """
    if response.status != 200:
        raise errors.RequestError(
            f"request failed with status code `{response.status}`, "
            f"headers: `{response.headers}`, "
            f"body: `{response.content}`"
        )


def _with_latency(func: Callable) -> Callable:
    """
    general decorator to add latency info into response
    """
    sign = inspect.signature(func)
    if inspect.iscoroutinefunction(func):
        return _async_latency(func)
    elif sign.return_annotation is Iterator[QfResponse]:
        return _stream_latency(func)
    elif sign.return_annotation is QfResponse:
        return _latency(func)
    elif sign.return_annotation is AsyncIterator[QfResponse]:
        return _async_stream_latency(func)
    return func


def _latency(func: Callable[..., QfResponse]) -> Callable[..., QfResponse]:
    """
    a decorator to add latency info into response
    """

    def wrapper(*args: Any, **kwargs: Any) -> QfResponse:
        start_time = time.perf_counter()
        resp = func(*args, **kwargs)
        resp.statistic["total_latency"] = time.perf_counter() - start_time
        return resp

    return wrapper


def _async_latency(
    func: Callable[..., Awaitable[QfResponse]]
) -> Callable[..., Awaitable[QfResponse]]:
    """
    a decorator to add latency info into async response
    """

    async def wrapper(*args: Any, **kwargs: Any) -> QfResponse:
        start_time = time.perf_counter()
        resp = await func(*args, **kwargs)
        resp.statistic["total_latency"] = time.perf_counter() - start_time
        return resp

    return wrapper


def _stream_latency(
    func: Callable[..., Iterator[QfResponse]]
) -> Callable[..., Iterator[QfResponse]]:
    """
    a decorator to add latency info into stream response
    """

    def wrapper(*args: Any, **kwargs: Any) -> Iterator[QfResponse]:
        start_time = time.perf_counter()
        first_token_latency: Optional[float] = None
        resp = func(*args, **kwargs)
        sse_block_receive_time = time.perf_counter()
        for r in resp:
            if first_token_latency is None:
                first_token_latency = time.perf_counter() - start_time
            r.statistic["request_latency"] = (
                time.perf_counter() - sse_block_receive_time
            )
            r.statistic["first_token_latency"] = first_token_latency
            r.statistic["total_latency"] = time.perf_counter() - start_time
            sse_block_receive_time = time.perf_counter()
            yield r

    return wrapper


def _async_stream_latency(
    func: Callable[..., AsyncIterator[QfResponse]]
) -> Callable[..., AsyncIterator[QfResponse]]:
    """
    a decorator to add latency info into async stream response
    """

    async def wrapper(*args: Any, **kwargs: Any) -> AsyncIterator[QfResponse]:
        start_time = time.perf_counter()
        first_token_latency: Optional[float] = None
        resp = func(*args, **kwargs)
        sse_block_receive_time = time.perf_counter()
        async for r in resp:
            if first_token_latency is None:
                first_token_latency = time.perf_counter() - start_time
            r.statistic["request_latency"] = (
                time.perf_counter() - sse_block_receive_time
            )
            r.statistic["first_token_latency"] = first_token_latency
            r.statistic["total_latency"] = time.perf_counter() - start_time
            sse_block_receive_time = time.perf_counter()
            yield r

    return wrapper


[docs]class BaseAPIRequestor(object): """ Base class of API Requestor """ def __init__(self, **kwargs: Any) -> None: """ `ak`, `sk` and `access_token` can be provided in kwargs. """ self._client = HTTPClient(**kwargs) self._rate_limiter = RateLimiter(**kwargs) @_with_latency def _request( self, request: QfRequest, data_postprocess: Callable[[QfResponse], QfResponse] = lambda x: x, ) -> QfResponse: """ simple sync request """ with self._rate_limiter: log_trace(f"raw request: {request}") response = self._client.request(request) _check_if_status_code_is_200(response) try: body = response.json() except requests.JSONDecodeError: raise errors.RequestError( f"Got invalid json response from server, body: {response.content!r}" ) resp = self._parse_response(body, response) resp.statistic["request_latency"] = response.elapsed.total_seconds() resp.request = QfRequest.from_requests(response.request) resp.request.json_body = copy.deepcopy(request.json_body) return data_postprocess(resp) @_with_latency async def _async_request( self, request: QfRequest, data_postprocess: Callable[[QfResponse], QfResponse] = lambda x: x, ) -> QfResponse: """ async request """ async with self._rate_limiter: response, session = await self._client.arequest(request) start = time.perf_counter() async with session: async with response: _async_check_if_status_code_is_200(response) try: body = await response.json() except json.JSONDecodeError: raise errors.RequestError( "Got invalid json response from server, body:" f" {response.content}" ) resp = self._parse_async_response(body, response) resp.statistic["request_latency"] = time.perf_counter() - start resp.request = QfRequest.from_aiohttp(response.request_info) resp.request.json_body = copy.deepcopy(request.json_body) return data_postprocess(resp) def _parse_response( self, body: Dict[str, Any], resp: requests.Response ) -> QfResponse: """ parse response to QfResponse """ self._check_error(body) qf_response = QfResponse( code=resp.status_code, headers=dict(resp.headers), body=body ) return qf_response def _parse_async_response( self, body: Dict[str, Any], resp: aiohttp.ClientResponse ) -> QfResponse: """ parse async response to QfResponse """ self._check_error(body) qf_response = QfResponse( code=resp.status, headers=dict(resp.headers), body=body ) return qf_response def _check_error(self, body: Dict[str, Any]) -> None: """ check whether there is error in response """ raise NotImplementedError def _with_retry( self, config: RetryConfig, func: Callable[..., _T], *args: Any ) -> _T: """ retry wrapper """ def predicate_api_err_code(result: Any) -> bool: if isinstance(result, errors.APIError): if result.error_code in config.retry_err_codes: log_warn( f"got error code {result.error_code} from server, retrying... " ) return True if isinstance(result, requests.RequestException): log_error(f"request exception: {result}, retrying...") return True return False @retry( wait=wait_exponential_jitter( jitter=config.jitter, max=config.max_wait_interval ), retry=retry_if_exception(predicate_api_err_code), stop=stop_after_attempt(config.retry_count), reraise=True, ) def _retry_wrapper(*args: Any) -> _T: return func(*args) return _retry_wrapper(*args) async def _async_with_retry( self, config: RetryConfig, func: Callable[..., Awaitable[_T]], *args: Any ) -> _T: """ async retry wrapper """ def predicate_api_err_code(result: Any) -> bool: if isinstance(result, errors.APIError): if result.error_code in config.retry_err_codes: log_warn( f"got error code {result.error_code} from server, retrying... " ) return True if isinstance(result, aiohttp.ClientError): log_error(f"request exception: {result}, retrying...") return True return False @retry( wait=wait_exponential_jitter( jitter=config.jitter, max=config.max_wait_interval ), retry=retry_if_exception(predicate_api_err_code), stop=stop_after_attempt(config.retry_count), reraise=True, ) async def _retry_wrapper(*args: Any) -> _T: return await func(*args) return await _retry_wrapper(*args)