Source code for pywrapid.webclient.web

#!/usr/bin/python3
"""
pywrapid web client base

This library is for educational purposes only.
Do no evil, do not break local or internation laws!
By using this code, you take full responisbillity for your actions.
The author have granted code access for educational purposes and is
not liable for any missuse.
"""
# __author__ = "Jonas Werme"
# __copyright__ = "Copyright (c) 2021 Jonas Werme"
# __credits__ = ["nsahq"]
# __license__ = "MIT"
# __version__ = "0.1.0"
# __maintainer__ = "Jonas Werme"
# __email__ = "jonas[dot]werme[at]hoofbite[dot]com"
# __status__ = "Prototype"


import logging
from datetime import datetime, timedelta
from enum import Enum
from time import time
from typing import Any, Optional, Type, Union
from urllib.parse import urlparse

import jwt
from requests import HTTPError, RequestException, Response, Timeout, TooManyRedirects, request

from pywrapid.config import ConfigSubSection, WrapidConfig
from pywrapid.utils import is_file_readable

from .exceptions import (
    ClientAuthenticationError,
    ClientAuthorizationError,
    ClientConnectionError,
    ClientError,
    ClientHTTPError,
    ClientTimeout,
    CredentialCertificateFileError,
    CredentialError,
    CredentialKeyFileError,
    CredentialURLError,
)

log = logging.getLogger(__name__)


class AuthorizationType(Enum):
    """Auth type enum"""

    NONE = 0
    BASIC = 1
    BEARER = 2
    JWT = 3
    OAUTH2 = 4


[docs] class WebCredentials: """Credential base class""" def __init__(self) -> None: """Init class for web credentials""" self._options: dict = {} self._config: dict = {} self.type = type(self).__name__ @property def options(self) -> dict: """Getter for options""" return self._options @property def config(self) -> dict: """Getter for options""" return self._config def _unify_configuration( self, params: dict, config: Union[Type[WrapidConfig], dict, None] ) -> ConfigSubSection: """Unify params and config values Produces wrapid config (ConfigSubSection) object after validating standard stuff. Args: params (dict): Configuration parameters, e.g. locals() config (WrapidConfig|dict): Configuration object, dict or None """ params = { k: v for k, v in params.items() if k != "self" and k != "kwargs" and k != "config" and k != "__class__" and v } if params and config: raise CredentialError( "Multiple configuration options used, " "use a WrapidConfig derivative/dict OR passed parameters" ) if params: config = params if isinstance(config, dict): config = ConfigSubSection({"dummy_key": config}, "dummy_key") # type: ignore if config and isinstance(config, WrapidConfig): return config raise CredentialError("Config parameter must be of type dict or a WrapidConfig derivative") # def import_dependencies(self, dependencies: list) -> bool: # """Import dependencies for credential type""" # missing_dependencies = [] # for dependency in dependencies: # if not is_module_available(dependency): # log.error("Dependency %s is not available", dependency) # missing_dependencies.append(dependency) # continue # if not is_module_already_imported(dependency): # setattr(self, dependency, import_module(dependency)) # if missing_dependencies: # raise DependencyError(f"Missing credential dependencies: {missing_dependencies}") # log.debug("All credential dependencies (%s) are available", dependencies) # return True
[docs] def validate_url(self, url: str = "", raise_on_fail: bool = False) -> bool: """Validate URL strings Args: url (str, optional): The URL to validate. Defaults to "". raise_on_fail (bool, optional): Raise error if validation fails. Defaults to False. Raises: CredentialURLError: _description_ Returns: bool: _description_ """ if not url or not urlparse(url): log.error("URL validation failed for URL: %s", self._config[url]) if raise_on_fail: raise CredentialURLError("Validation failed for URL") return False return True
[docs] class BasicAuthCredentials(WebCredentials): """Credential class for basic auth""" def __init__( # pylint: disable=unused-argument self, username: str, password: str, login_url: str = "", config: Union[Type[WrapidConfig], dict, None] = None, **kwargs: dict[str, Any], ) -> None: wrapid_config = self._unify_configuration({**locals(), **kwargs}, config) # type: ignore super().__init__() required_keys = ["username", "password"] wrapid_config.validate_keys(expected_keys=required_keys) # type: ignore self._config = dict(wrapid_config.cfg) if "login_url" in self._config: self.validate_url(str(self._config.get("login_url"))) self._options = {"auth": (username, password)}
[docs] class X509Credentials(WebCredentials): """Credential class for x509 auth""" def __init__( # pylint: disable=unused-argument, too-many-arguments self, cert_file: str = "", key_file: str = "", login_url: str = "", jwt_key: str = "", access_token_timeout: int = 0, token_expiry_offset: int = 0, config: Union[Type[WrapidConfig], dict, None] = None, **kwargs: dict[str, Any], ) -> None: wrapid_config = self._unify_configuration({**locals(), **kwargs}, config) super().__init__() required_keys = ["login_url", "key_file", "cert_file"] wrapid_config.validate_keys(required_keys) self._config = dict(wrapid_config.cfg) self.cert_file = self._config.get("cert_file", "") self.key_file = self._config.get("key_file", "") self.login_url = self._config.get("login_url", "") if not is_file_readable(self.cert_file): log.error("Certificate file validation failed for %s", self.cert_file) raise CredentialCertificateFileError("Certificate file error") if not is_file_readable(self.key_file): log.error("Key file validation failed for %s", self.key_file) raise CredentialKeyFileError("Key file error") log.debug( "Loading x509 web credentials: cert_file=%s, key_file=%s, login_url=%s", self.cert_file, self.key_file, self.login_url, ) self._options = {"cert": (self.cert_file, self.key_file)}
class OAuth2Credentials(WebCredentials): """Credential class for OAauth2 authentication""" def __init__( # pylint: disable=unused-argument, too-many-arguments # nosec self, login_url: str = "", token_url: str = "", redirect_uri: str = "", auth_data: Optional[dict] = None, legacy_auth: Optional[Union[BasicAuthCredentials, dict]] = None, refresh_token_timeout: int = 0, access_token_timeout: int = 0, token_expiry_offset: int = 0, config: Union[Type[WrapidConfig], dict, None] = None, **kwargs: dict[str, Any], ) -> None: # self.import_dependencies(["requests_oauthlib"]) wrapid_config = self._unify_configuration({**locals(), **kwargs}, config) super().__init__() required_keys = ["login_url", "auth_data"] self._config = dict(wrapid_config.cfg) for url in ["login_url", "token_url", "redirect_uri"]: use_url = "" if url in self._config: self.validate_url(url) use_url = str(self._config.get(url)) if url not in required_keys: required_keys.append(url) setattr(self, url, use_url) wrapid_config.validate_keys(required_keys) self.legacy_auth = self._config.get("legacy_auth", None) log.debug( "Loading OAuth2 web credentials: login_url=%s, " "token_url=%s, redirect_uri=%s, legacy_auth=%s", self.login_url, # type: ignore # pylint: disable=no-member self.token_url, # type: ignore # pylint: disable=no-member self.redirect_uri, # type: ignore # pylint: disable=no-member bool(self.legacy_auth), ) if self.legacy_auth: if isinstance(self.legacy_auth, dict): self.legacy_auth = BasicAuthCredentials( username=self.legacy_auth.get("username", ""), password=self.legacy_auth.get("password", ""), ) self._options = self.legacy_auth._options # type: ignore self.credential_body = self._config.pop("auth_data")
[docs] class WebClient: # pylint: disable=too-many-instance-attributes, too-many-arguments """Web Client base Generic web client class as base for creating application specific clients or to be used directly as a general use web client. Wraps the request library and adds generic exceptions. Passes web calls transparently to requests, meaning you can use any requests option you see fit, such as proxy settings etc by passing them as key word arguments. If a configuration section named client_options is passed to the client, these options will be set for the web communication. Passed arguments will have precedence over configuration items. The client allows you to mix and match authetication types with authorization types to fit strange combinations used in some APIs. Can be used with a wrapid config or straight up dict config for use in clients extending this class. Allows raise of exception on non-2xx responses (optional). """
[docs] def __init__( self, authorization_type: AuthorizationType = AuthorizationType.NONE, credentials: Optional[Type[WebCredentials]] = None, dict_config: Optional[dict] = None, wrapid_config: Optional[Type[WrapidConfig]] = None, ): """Init function for web client class Args: authorization_type (AuthorizationType (ENUM), optional): wrapid authorization type to use for clients communication. credentials (Type[WebCredentials], optional): wrapid credentials object to use for clients communication. dict_config (dict, optional, mutually exlusive with wrapid_conf): dict object to store in the clients config parameter. wrapid_config (Type[WrapidConfig], optional, mutually exlusive with dict_config): wrapid configuration object to store configuration in the clients config parameter from. Raises: ClientException """ # Initialize all instance variables with proper type hints self._config: dict = {} self._authorization_type: AuthorizationType = authorization_type self._credential_options: dict = {} self._credential_config: dict = {} self._login_url: str = "" self._credential_body: dict = {} self._access_token_expiry: datetime = datetime.now() self._refresh_token_expiry: datetime = datetime.now() self._access_token: str = "" # nosec self._refresh_token: str = "" # nosec if wrapid_config and dict_config: raise ClientError( "Initiation error: dict_config and wrapid_config are mutually exclusive" ) if wrapid_config: self._config = wrapid_config.cfg elif dict_config: self._config = dict_config if credentials: self._credential_options = {**credentials.options} # type: ignore[dict-item] self._login_url = credentials.config.get("login_url", "") # type: ignore[attr-defined] self._credential_config = credentials.config # type: ignore[assignment] if isinstance(credentials, OAuth2Credentials): self._credential_body = credentials.credential_body try: AuthorizationType(authorization_type).name except ValueError as error: raise ClientAuthorizationError(error) from error log.debug( "Initiating new client with authorization type %s and credential type %s", AuthorizationType(authorization_type).name, type(credentials).__name__, )
[docs] def _unpack_jwt(self, token: str) -> dict: """Decodes and unpacks JWT tokens content Does not include signature verification or encrypted parts Args: token (str): Token to unpack Returns: dict: Unpacked JWT token """ additionals = {} if self._credential_options.get("jwt_secret", None): additionals["key"] = self._config["jwt_secret"] if self._credential_options.get("jwt_algorithms", None): additionals["algorithms"] = self._credential_options.get("jwt_algorithms") return jwt.decode(token, options={"verify_signature": False}, **additionals)
[docs] def session_expired(self) -> bool: """Check if our session has expired Returns: bool: True if token is expired, False if still valid """ if not self._access_token_expiry or not self._access_token: return True time_offset = datetime.now() + timedelta( seconds=self._config.get("token_expiry_offset", 10) ) # Offset to avoid ms/ns race condition if time_offset < self._access_token_expiry: return False return True
[docs] def generate_session(self, method: str = "POST", **options: Any) -> None: """Authenticate and generate new token Args: method (str, optional): HTTP Method to use. Defaults to "POST". Raises: ClientAuthenticationError """ # Merge credential options (like auth) with any provided options login_options = { **self._credential_options, **self._config.get("auth_options", {}), **options, } if self._credential_body: login_options["data"] = self._credential_body response = self.call( method, str(self._login_url), raise_for_status=False, skip_authentication=True, **login_options, ) if response.status_code > 299 or response.status_code < 200: log.error( "Unable to generate new session: [%s] %s @ %s", response.status_code, response.content, self._login_url, ) raise ClientAuthenticationError( f"Unable to generate new session: [{response.status_code}] {response.content!r}" ) self._parse_authentication_data(response)
def _parse_authentication_data( # pylint: disable=too-many-branches self, response: Response ) -> None: # Custom headers or custom bodies are common locations of bearer tokens. # We need to make this more dynamic later. Adding response Authorization header # and a few more for now # Typically used for custom x509 auth but also common for basic auth and custom # implementations if "Authorization" in response.headers: self._set_access_token(response.headers["Authorization"]) # Custom configured header if self._config.get("access_token_header", ""): if self._config["access_token_header"] in response.headers: self._set_access_token(response.headers[self._config["access_token_header"]]) else: log.error( "Unable to find access token header %s in response headers: %s", self._config["access_token_header"], response.headers, ) raise ClientAuthenticationError("Unable to find configured access token header") # Manage JWT data extraction if self._authorization_type == AuthorizationType.OAUTH2: try: auth_response_data = response.json() # Oauth2 implementations differ vastly. Some gives only access_tokens, # some give both at authentication, some give both at every refresh # some give only refresh_tokens for offline scopes. Spliting ifs to handle all. if "access_token" in auth_response_data: self._set_access_token(auth_response_data["access_token"]) if self._config.get("access_token_timeout", 0) == 0: expiry = time() + auth_response_data.get("expires_in") else: expiry = time() + self._config.get("access_token_timeout", 0) self._set_access_token_expiry(expiry) if "refresh_token" in auth_response_data: self._set_refresh_token(auth_response_data["refresh_token"]) expiry = time() + self._config.get("refresh_token_timeout", 84600) self._set_refresh_token_expiry(expiry) except ValueError as error: raise ClientAuthenticationError(error) from error except Exception as error: raise ClientAuthenticationError(error) from error if self._authorization_type in [ AuthorizationType.JWT, AuthorizationType.BEARER, ]: self._set_access_token_expiry(self._get_jwt_expiry(self._access_token)) if self._access_token_expiry < datetime.now(): raise ClientAuthorizationError("JWT Access token expired or could not be set") def _get_jwt_expiry(self, token: str) -> float: """Get JWT token expiry time Args: token (str): JWT token Returns: float: Expiry time as unix timestamp """ jwt_data = self._unpack_jwt(token) log.debug("JWT data: %s", jwt_data) expiry = time() for exp in ["exp", "expiresIn", "expires_in", "expires"]: # standard = exp if exp in jwt_data: if jwt_data[exp] < 44640: # Handle poor expiry implementations with offset expiry = expiry + jwt_data[exp] else: expiry = jwt_data[exp] break return expiry # flake8: noqa: C901
[docs] def call( self, method: str, url: str, raise_for_status: bool = False, skip_authentication: bool = False, **options: Any, ) -> Response: """Send web request to the target url Args: method (str): Method of the HTTP request url (str): URL of the request raise_for_status (bool): Raise for non 2xx repsonses skip_authentication (bool): Skip authentication and skip token refresh controls **options (dict): request options Raises: ClientHTTPError ClientTimeout ClientConnectionError ClientException ClientAuthenticationError Returns: Response: requests.Response object """ if not skip_authentication and self.session_expired(): login_options = {**self._credential_options, **self._config.get("auth_options", {})} self.generate_session(**login_options) if self._access_token and self._authorization_type != AuthorizationType.NONE: if "headers" not in options: options["headers"] = {"Authorization": f"Bearer {self._access_token}"} else: if "Authorization" not in options["headers"]: options["headers"] = { "Authorization": f"Bearer {self._access_token}", **options["headers"], } if "client_options" in self.get_config: options = {**self.get_config["client_options"], **options} try: response = request(method, url, **options) if raise_for_status: response.raise_for_status() except HTTPError as error: raise ClientHTTPError(error) from error except Timeout as error: raise ClientTimeout(error) from error except TooManyRedirects as error: raise ClientConnectionError(error) from error except RequestException as error: raise ClientError(error) from error return response
@property def get_config(self) -> dict: """Get current configuration Returns: configuration {dict} -- Dict representation of configuration""" return self._config def _set_refresh_token(self, refresh_token: str) -> None: """Set refresh token""" self._refresh_token = refresh_token log.debug("Refresh token set to: %s", self._refresh_token) def _set_refresh_token_expiry(self, refresh_expiry: float) -> None: """Set refresh token expiry time""" self._refresh_token_expiry = datetime.fromtimestamp(refresh_expiry) log.debug("Refresh token expiry set to: %s", self._refresh_token_expiry) def _set_access_token(self, access_token: str) -> None: """Set access token""" # Following 4 lines can be simplified with token.removeprefix("Bearer ") # for python version > 3.9 # Keeping for backwards compatibillity for now bearer = "Bearer " if access_token.startswith(bearer): access_token = access_token[len(bearer) :] self._access_token = access_token log.debug("Access token set to: %s", self._access_token) def _set_access_token_expiry(self, access_expiry: float) -> None: """Set access token expiry time""" self._access_token_expiry = datetime.fromtimestamp(access_expiry) log.debug("Access token expiry set to: %s", self._access_token_expiry)