Source code for qianfan.utils.utils

# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import importlib.util
import os
import re
import secrets
import string
import threading
import uuid as uuid_lib
from threading import current_thread
from types import TracebackType
from typing import Dict, List, Optional, Type

from qianfan.errors import InvalidArgumentError
from qianfan.utils import log_info

thread_local = threading.local()


def _get_value_from_dict_or_var_or_env(
    dictionary: Dict, key: str, value: Optional[str], env_key: str
) -> Optional[str]:
    """
    Attempt to retrieve a value from the `dictionary` using the `key`.
    If the `key` is not found, try to obtain a value from the environment variable
    using `env_key`.
    If still not found, return None

    Args:
        dictionary (Dict): the dict to search
        key (str): the key of the value in dictionary
        env_key (str): the name of the environment variable

    Returns:
        the value of key, or None if not found
    """
    if key in dictionary:
        return dictionary[key]
    if value is not None:
        return value
    env_value = os.environ.get(env_key)
    return env_value


def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
    """
    if `key` is in `src` dict, set the value of `key` in dst with src[key]

    Args:
        src (Dict): the source dict
        dst (Dict): the destination dict
        key (str): the key to be found in src dict

    Returns:
        None
    """
    if key in src:
        dst[key] = src[key]


def _get_from_env_or_default(env_name: str, default_value: str) -> str:
    """
    Get value from environment variable or return default value

    Args:
        env_name (str): the name of the environment variable
        default_value: the default value if the env var does not exist

    Return:
        str, the value of the environment variable or default_value
    """
    env_value = os.environ.get(env_name)
    if env_value is None:
        return default_value
    return env_value


def _strtobool(val: str) -> bool:
    """
    convert string to bool
    """
    val = val.lower()
    if val in ("y", "yes", "t", "true", "on", "1"):
        return True
    elif val in ("n", "no", "f", "false", "off", "0"):
        return False
    else:
        raise InvalidArgumentError(f"invalid boolean value `{val}")


def _none_if_empty(val: str) -> Optional[str]:
    """
    convert string to None if it is empty or "None"
    """
    if val == "" or val.lower() == "none":
        return None
    return val


[docs]class AsyncLock: """ wrapper of asyncio.Lock """ def __init__(self) -> None: self._lock = None try: self._lock = asyncio.Lock() except RuntimeError: event_loop_checked = getattr(thread_local, "event_loop_checked", None) # only log once for each thread if event_loop_checked is None: log_info( f"no event loop in thread `{current_thread().name}`, async feature" " won't be available. Please make sure the object is initialized" " in the thread with event loop." ) thread_local.event_loop_checked = True async def __aenter__(self) -> None: if self._lock is None: raise InvalidArgumentError( "no event loop found in current thread, please make sure the event loop" " is available when the object is initialized" ) await self._lock.__aenter__() async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: if self._lock is None: raise InvalidArgumentError( "no event loop found in current thread, please make sure the event loop" " is available when the object is initialized" ) await self._lock.__aexit__(exc_type, exc_val, exc_tb)
[docs]def uuid() -> str: return str(uuid_lib.uuid4()).replace("-", "")
[docs]def generate_letter_num_random_id(len: int = 10) -> str: return "".join( secrets.choice(string.ascii_letters + string.digits) for _ in range(len) )
[docs]def assert_package_installed(package_name: str) -> None: if not check_package_installed(package_name): raise ImportError( f"Unable to import {package_name} package, " f"please install it using 'pip install {package_name}'." )
[docs]def check_package_installed(package_name: str) -> bool: return importlib.util.find_spec(package_name) is not None
[docs]def camel_to_snake(name: str) -> str: name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
[docs]def snake_to_camel(name: str) -> str: return "".join([x.capitalize() for x in name.split("_")])
[docs]def first_lower_case(name: str) -> str: return name[:1].lower() + name[1:]
[docs]def remove_suffix(name: str, suffix: str) -> str: if name.endswith(suffix): return name[: -len(suffix)] return name
[docs]def remove_suffix_list(name: str, suffix_list: List[str]) -> str: for suffix in suffix_list: if name.endswith(suffix): return name[: -len(suffix)] return name