Source code for qianfan.resources.rate_limiter

# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
    Implementation of Rate Limiter
"""
import threading
import time
from types import TracebackType
from typing import Any, Optional, Type

from aiolimiter import AsyncLimiter

from qianfan.config import get_config


[docs]class RateLimiter: """ Implementation of Rate Limiter. They're different rules between synchronous and asynchronous method using, we recommend only use one of two method within single rate limiter at same time """ class _SyncLimiter: def __init__( self, query_per_period: float = 1, period_in_second: float = 1, **kwargs: Any ): """ initialize rate limiter Args: query_per_period (float): query times in one period, default to 1.0. period_in_second (float): time of period, default to 1.0. Raises: ValueError: A ValueError will be raised if `query_per_period` is smaller than 1 or `period_in_second` isn't positive """ if not (query_per_period >= 1 and period_in_second > 0): raise ValueError( "argument illegal with query_per_period {} and period_in_second {}" .format(query_per_period, period_in_second) ) self._query_per_period = query_per_period self._period_in_second = period_in_second self._query_per_second = query_per_period / period_in_second self._token_count = self._query_per_period self._last_leak_timestamp = time.time() self._sync_lock = threading.Lock() def _leak(self) -> None: timestamp = time.time() delta = timestamp - self._last_leak_timestamp self._last_leak_timestamp = timestamp self._token_count = min( self._query_per_period, self._token_count + delta * self._query_per_second, ) def __enter__(self) -> None: """ synchronous entrance of rate limiter """ with self._sync_lock: while True: self._leak() if self._token_count >= 1: self._token_count -= 1 return time.sleep((1 - self._token_count) / self._query_per_second) def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: """ exit """ return def _check_is_closed(self) -> bool: return self._is_closed def __init__(self, query_per_second: float = 0, **kwargs: Any): """ initialize rate limiter Args: query_per_second (float): query times in one second, default to 0, meaning rate limiter close. """ if query_per_second == 0: query_per_second = get_config().QPS_LIMIT self._is_closed = query_per_second <= 0 if self._check_is_closed(): return period_length: float if query_per_second > 1: query_per_period = query_per_second period_length = 1 else: query_per_period = 1 period_length = 1 / query_per_second self._async_limiter = AsyncLimiter(query_per_period, period_length) self._sync_limiter = self._SyncLimiter(query_per_period, period_length) def __enter__(self) -> None: """ synchronous entrance of rate limiter """ if self._check_is_closed(): return with self._sync_limiter: return def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: """ exit """ return async def __aenter__(self) -> None: """ asynchronous entrance of rate limiter """ if self._check_is_closed(): return async with self._async_limiter: return async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: """ exit """ return