Source code for qianfan.resources.auth.oauth

# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
import time
from typing import Any, Dict, Optional, Tuple

from qianfan.config import get_config
from qianfan.consts import Consts
from qianfan.errors import InternalError, InvalidArgumentError
from qianfan.resources.http_client import HTTPClient
from qianfan.resources.typing import QfRequest, RetryConfig
from qianfan.utils import (
    AsyncLock,
    log_error,
    log_info,
    log_warn,
)
from qianfan.utils.helper import Singleton


def _masked_ak(ak: str) -> str:
    """
    mask ak, only display first 6 characters
    """
    return ak[:6] + "***"


[docs]class AuthManager(metaclass=Singleton): """ AuthManager is singleton to manage all access token in SDK """
[docs] class AccessToken: """ Access Token object """ token: Optional[str] lock: threading.Lock alock: AsyncLock refresh_at: float def __init__(self, access_token: Optional[str] = None): """ Init access token object """ self.token = access_token self.lock = threading.Lock() self.alock = AsyncLock() self.refresh_at = 0
_token_map: Dict[Tuple[str, str], AccessToken] def __init__(self) -> None: """ Init Auth manager """ self._token_map = {} self._client = HTTPClient() self._lock = threading.Lock() self._alock = AsyncLock() def _register(self, ak: str, sk: str, access_token: Optional[str] = None) -> bool: """ add `(ak, sk)` to manager and return whether provided `(ak, sk)` is existed this function is not thread safe !!! """ existed = True if (ak, sk) not in self._token_map: self._token_map[(ak, sk)] = AuthManager.AccessToken(access_token) existed = False else: # if user provide new access token for existed (ak, sk), update it if access_token is not None: self._token_map[(ak, sk)].token = access_token self._token_map[(ak, sk)].refresh_at = 0 return existed
[docs] def register(self, ak: str, sk: str, access_token: Optional[str] = None) -> None: """ add `(ak, sk)` to manager and update access token """ with self._lock: existed = self._register(ak, sk, access_token) if not existed and access_token is None: self.refresh_access_token(ak, sk)
[docs] async def aregister( self, ak: str, sk: str, access_token: Optional[str] = None ) -> None: """ async add `(ak, sk)` to manager and update access token """ async with self._alock: existed = self._register(ak, sk, access_token) if not existed and access_token is None: await self.arefresh_access_token(ak, sk)
def _get_access_token_object( self, ak: str, sk: str ) -> AccessToken: # pylint:disable=undefined-variable """ get access token object by `(ak, sk)` this function is not thread safe !!! """ obj = self._token_map.get((ak, sk), None) if obj is None: raise InternalError("provided ak and sk are not registered") return obj def _get_token_from_access_token_object( self, obj: AccessToken, ak: str = "", sk: str = "" ) -> str: """ get access token from access token object this function is not thread safe and should be protected by lock from obj !!! """ if obj.token is None: log_warn(f"access token is not available for ak `{_masked_ak(ak)}`") return "" return obj.token
[docs] def get_access_token(self, ak: str, sk: str) -> str: """ get access token by `(ak, sk)` """ with self._lock: obj = self._get_access_token_object(ak, sk) with obj.lock: return self._get_token_from_access_token_object(obj, ak, sk)
[docs] async def aget_access_token(self, ak: str, sk: str) -> str: """ async get access token by `(ak, sk)` """ async with self._alock: obj = self._get_access_token_object(ak, sk) async with obj.alock: return self._get_token_from_access_token_object(obj, ak, sk)
def _auth_request(self, ak: str, sk: str) -> QfRequest: """ generate auth request """ return QfRequest( method="POST", url="{}{}".format(get_config().BASE_URL, Consts.AuthAPI), query={ "grant_type": "client_credentials", "client_id": ak, "client_secret": sk, }, retry_config=RetryConfig(timeout=get_config().AUTH_TIMEOUT), ) def _update_access_token( self, obj: AccessToken, response: Dict[str, Any], ak: str = "", sk: str = "" ) -> None: """ update access token from response of auth request this function is not thread safe and should be protected by lock from obj !!! """ if "error" in response: log_error( "refresh access_token for ak `{}` failed, error description={}".format( _masked_ak(ak), response["error_description"] ) ) return obj.token = response["access_token"] obj.refresh_at = time.time() def _refresh_access_token_too_often(self, obj: AccessToken) -> bool: """ check if access token is refreshed too often """ if ( time.time() - obj.refresh_at < get_config().ACCESS_TOKEN_REFRESH_MIN_INTERVAL ): log_info("access_token is already refreshed, skip refresh.") return True return False
[docs] def refresh_access_token(self, ak: str, sk: str) -> None: """ refresh access token of `(ak, sk)` """ with self._lock: obj = self._get_access_token_object(ak, sk) with obj.lock: log_info(f"trying to refresh access_token for ak `{_masked_ak(ak)}`") # in case multiple threads try to refresh access token at the same time # the token should not be refreshed multiple times if self._refresh_access_token_too_often(obj): return try: resp = self._client.request(self._auth_request(ak, sk)) json_body = resp.json() self._update_access_token(obj, json_body, ak, sk) except Exception as e: log_error(f"refresh access token failed with exception {str(e)}") return log_info("sucessfully refresh access_token")
[docs] async def arefresh_access_token(self, ak: str, sk: str) -> None: """ async refresh access token of `(ak, sk)` """ async with self._alock: obj = self._get_access_token_object(ak, sk) async with obj.alock: log_info(f"trying to refresh access_token for ak `{_masked_ak(ak)}`") # in case multiple threads try to refresh access token at the same time # the token should not be refreshed multiple times if self._refresh_access_token_too_often(obj): return try: resp, session = await self._client.arequest(self._auth_request(ak, sk)) async with session: json_body = await resp.json() self._update_access_token(obj, json_body, ak, sk) except Exception as e: log_error(f"refresh access token failed with exception {str(e)}") return log_info(f"sucessfully refresh access_token for ak `{_masked_ak(ak)}`")
[docs]class Auth(object): """ object to maintain acccess token for open api call """ _ak: Optional[str] = None _sk: Optional[str] = None _access_token: Optional[str] = None _access_key: Optional[str] = None _secret_key: Optional[str] = None _registered: bool = False _console_ak_to_app_ak: Dict[Tuple[str, str], Tuple[str, str]] = {} """ (access_key, secret_key) -> (ak, sk) map which convert console ak/sk to qianfan ak/sk use as cache to avoid querying console ak/sk multple times """ def __init__(self, **kwargs: Any) -> None: """ recv `ak`, `sk` and `access_token` from kwargs if the args does not contain the arguments, env variable will be used when `ak` and `sk` are provided, `access_token` will be set automatically """ if get_config().ENABLE_PRIVATE: return self._ak = kwargs.get("ak", None) or get_config().AK self._sk = kwargs.get("sk", None) or get_config().SK self._access_token = ( kwargs.get("access_token", None) or get_config().ACCESS_TOKEN ) self._access_key = kwargs.get("access_key", None) or get_config().ACCESS_KEY self._secret_key = kwargs.get("secret_key", None) or get_config().SECRET_KEY if not self._credential_available(): raise InvalidArgumentError( "no enough credential found, any one of (access_key, secret_key)," " (ak, sk), access_token must be provided" ) if ( self._access_token is None and (self._ak is None or self._sk is None) and (self._access_key is not None and self._secret_key is not None) ): self._registered = True def _register(self) -> None: """ register the access token to manager, so that it can be refreshed automatically """ if not self._registered: if self._access_token is None: # if access_token is not provided, both ak and sk should be provided if self._ak is None or self._sk is None: raise InvalidArgumentError( "both ak and sk must be provided, otherwise access_token should" " be provided" ) AuthManager().register(self._ak, self._sk, self._access_token) else: # if access_token is provided if not (self._ak is None or self._sk is None): # only register to manager when both ak and sk are provided AuthManager().register(self._ak, self._sk, self._access_token) self._registered = True async def _aregister(self) -> None: """ register the access token to manager, so that it can be refreshed automatically """ if not self._registered: if self._access_token is None: # if access_token is not provided, both ak and sk should be provided if self._ak is None or self._sk is None: raise InvalidArgumentError( "both ak and sk must be provided, otherwise access_token should" " be provided" ) await AuthManager().aregister(self._ak, self._sk, self._access_token) else: # if access_token is provided if not (self._ak is None or self._sk is None): # only register to manager when both ak and sk are provided await AuthManager().aregister( self._ak, self._sk, self._access_token ) self._registered = True
[docs] def refresh_access_token(self) -> None: """ refresh `access_token` using `ak` and `sk` """ if self._ak is None or self._sk is None: log_warn("AK or SK is not set, refresh access_token will not work.") return self._register() AuthManager().refresh_access_token(self._ak, self._sk) self._access_token = None
[docs] async def arefresh_access_token(self) -> None: """ refresh `access_token` using `ak` and `sk` """ if self._ak is None or self._sk is None: log_warn("AK or SK is not set, refresh access_token will not work.") return await self._aregister() await AuthManager().arefresh_access_token(self._ak, self._sk) self._access_token = None
def _credential_available(self) -> bool: if self._access_token is not None: return True if self._ak is not None and self._sk is not None: return True if self._access_key is not None and self._secret_key is not None: return True return False
[docs] def access_token(self) -> str: """ get current `access_token` """ if self._access_token is not None and (self._ak is None or self._sk is None): return self._access_token self._register() if self._ak is None or self._sk is None: # use access_key and secret_key to auth # so no access_token here return "" return AuthManager().get_access_token(self._ak, self._sk)
[docs] async def a_access_token(self) -> str: """ get current `access_token` """ if self._access_token is not None and (self._ak is None or self._sk is None): return self._access_token await self._aregister() if self._ak is None or self._sk is None: # use access_key and secret_key to auth # so no access_token here return "" return await AuthManager().aget_access_token(self._ak, self._sk)