From bcd877266f0c65e00d488f3a128cac8c0e5ee4fd Mon Sep 17 00:00:00 2001 From: jo Date: Sat, 30 Jul 2022 21:15:39 +0200 Subject: [PATCH] feat(shared): load env config using jsonschema The env loader is now capable of loading lists of objects, union types or list of union types from the env variables. They are some limitations: for example it doesn't support unions of different shapes `list | dict` or `str | dict`. --- shared/libretime_shared/config/_base.py | 92 ++++---- shared/libretime_shared/config/_env.py | 266 ++++++++++++++++++++++++ shared/tests/config/env_test.py | 189 +++++++++++++++++ 3 files changed, 500 insertions(+), 47 deletions(-) create mode 100644 shared/libretime_shared/config/_env.py create mode 100644 shared/tests/config/env_test.py diff --git a/shared/libretime_shared/config/_base.py b/shared/libretime_shared/config/_base.py index ac9ba09cf..40c0c652a 100644 --- a/shared/libretime_shared/config/_base.py +++ b/shared/libretime_shared/config/_base.py @@ -1,16 +1,14 @@ import sys -from os import environ +from itertools import zip_longest from pathlib import Path from typing import Any, Dict, List, Optional, Union from loguru import logger - -# pylint: disable=no-name-in-module from pydantic import BaseModel, ValidationError -from pydantic.fields import ModelField -from pydantic.utils import deep_update from yaml import YAMLError, safe_load +from ._env import EnvLoader + DEFAULT_ENV_PREFIX = "LIBRETIME" DEFAULT_CONFIG_FILEPATH = Path("/etc/libretime/config.yml") @@ -36,54 +34,19 @@ class BaseConfig(BaseModel): if filepath is not None: filepath = Path(filepath) - file_values = self._load_file_values(filepath) - env_values = self._load_env_values(env_prefix, env_delimiter) + env_loader = EnvLoader(self.schema(), env_prefix, env_delimiter) + + values = deep_merge_dict( + self._load_file_values(filepath), + env_loader.load(), + ) try: - super().__init__(**deep_update(file_values, env_values)) + super().__init__(**values) except ValidationError as error: logger.critical(error) sys.exit(1) - def _load_env_values(self, env_prefix: str, env_delimiter: str) -> Dict[str, Any]: - return self._get_fields_from_env(env_prefix, env_delimiter, self.__fields__) - - def _get_fields_from_env( - self, - env_prefix: str, - env_delimiter: str, - fields: Dict[str, ModelField], - ) -> Dict[str, Any]: - result: Dict[str, Any] = {} - - if env_prefix != "": - env_prefix += env_delimiter - - for field in fields.values(): - env_name = (env_prefix + field.name).upper() - - if field.is_complex(): - children: Union[List[Any], Dict[str, Any]] = [] - - if field.sub_fields: - if env_name in environ: - children = [v.strip() for v in environ[env_name].split(",")] - - else: - children = self._get_fields_from_env( - env_name, - env_delimiter, - field.type_.__fields__, - ) - - if len(children) != 0: - result[field.name] = children - else: - if env_name in environ: - result[field.name] = environ[env_name] - - return result - def _load_file_values( self, filepath: Optional[Path] = None, @@ -102,3 +65,38 @@ class BaseConfig(BaseModel): logger.error(f"config file '{filepath}' is not a valid yaml file: {error}") return {} + + +def deep_merge_dict(base: Dict[str, Any], next_: Dict[str, Any]) -> Dict[str, Any]: + result = base.copy() + for key, value in next_.items(): + if key in result: + if isinstance(result[key], dict) and isinstance(value, dict): + result[key] = deep_merge_dict(result[key], value) + continue + + if isinstance(result[key], list) and isinstance(value, list): + result[key] = deep_merge_list(result[key], value) + continue + + if value: + result[key] = value + + return result + + +def deep_merge_list(base: List[Any], next_: List[Any]) -> List[Any]: + result: List[Any] = [] + for base_item, next_item in zip_longest(base, next_): + if isinstance(base_item, list) and isinstance(next_item, list): + result.append(deep_merge_list(base_item, next_item)) + continue + + if isinstance(base_item, dict) and isinstance(next_item, dict): + result.append(deep_merge_dict(base_item, next_item)) + continue + + if next_item: + result.append(next_item) + + return result diff --git a/shared/libretime_shared/config/_env.py b/shared/libretime_shared/config/_env.py new file mode 100644 index 000000000..d8a42a649 --- /dev/null +++ b/shared/libretime_shared/config/_env.py @@ -0,0 +1,266 @@ +from collections import ChainMap +from functools import reduce +from operator import getitem +from os import environ +from typing import Any, Dict, List, Optional, TypeVar + +__all__ = [ + "EnvLoader", +] + + +def filter_env(env: Dict[str, str], prefix: str) -> Dict[str, str]: + """ + Filter a environment variables dict by key prefix. + + Args: + env: Environment variables dict. + prefix: Environment variable key prefix. + + Returns: + Environment variables dict. + """ + return {k: v for k, v in env.items() if k.startswith(prefix)} + + +def guess_env_array_indexes(env: Dict[str, str], prefix: str) -> List[int]: + """ + Guess environment variables indexes from the environment variables keys. + + Args: + env: Environment variables dict. + prefix: Environment variable key prefix for all indexes. + + Returns: + A list of indexes. + """ + prefix_len = len(prefix) + + result = [] + for env_name in filter_env(env, prefix): + if not env_name[prefix_len].isdigit(): + continue + + index_str = env_name[prefix_len:] + index_str = index_str.partition("_")[0] + result.append(int(index_str)) + + return result + + +T = TypeVar("T") + + +def index_dict_to_none_list(base: Dict[int, T]) -> List[Optional[T]]: + """ + Convert a dict to a list by associating the dict keys to the list + indexes and filling the missing indexes with None. + + Args: + base: Dict to convert. + + Returns: + Converted dict. + """ + if not base: + return [] + + result: List[Optional[T]] = [None] * (max(base.keys()) + 1) + + for index, value in base.items(): + result[index] = value + + return result + + +# pylint: disable=too-few-public-methods +class EnvLoader: + schema: dict + + env_prefix: str + env_delimiter: str + + _env: Dict[str, str] + + def __init__( + self, + schema: dict, + env_prefix: Optional[str] = None, + env_delimiter: str = "_", + ) -> None: + self.schema = schema + self.env_prefix = env_prefix or "" + self.env_delimiter = env_delimiter + + self._env = environ.copy() + if self.env_prefix: + self._env = filter_env(self._env, self.env_prefix) + + def load(self) -> Dict[str, Any]: + if not self._env: + return {} + + return self._get(self.env_prefix, self.schema) + + def _resolve_ref( + self, + path: str, + ) -> Dict[str, Any]: + _, *parts = path.split("/") + return reduce(getitem, parts, self.schema) + + def _get_mapping( + self, + env_name: str, + *schemas: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Get a mapping of each subtypes with the data. + + This helps resolve conflicts after we have all the data. + + Args: + env_name: Environment variable name to get the data from. + + Returns: + Mapping of each subtypes, with associated data as value. + """ + mapping: Dict[str, Any] = {} + + for schema in schemas: + if "$ref" in schema: + schema = self._resolve_ref(schema["$ref"]) + + value = self._get(env_name, schema) + if not value: + continue + + key = "title" if "title" in schema else "type" + mapping[schema[key]] = value + + return mapping + + # pylint: disable=too-many-return-statements + def _get( + self, + env_name: str, + schema: Dict[str, Any], + ) -> Any: + """ + Get a value from the environment. + + Args: + env_name: Environment variable name. + schema: Schema for the value we are retrieving. + + Returns: + Value retrieved from the environment. + """ + + if "$ref" in schema: + schema = self._resolve_ref(schema["$ref"]) + + if "type" in schema: + if schema["type"] in ("string", "integer", "boolean"): + return self._env.get(env_name, None) + + if schema["type"] == "object": + return self._get_object(env_name, schema) + + if schema["type"] == "array": + return self._get_array(env_name, schema) + + # Get all the properties as we won't have typing conflicts + if "allOf" in schema: + all_of_mapping = self._get_mapping(env_name, *schema["allOf"]) + # Merging all subtypes data together + return dict(ChainMap(*all_of_mapping.values())) + + # Get all the properties and resolve conflicts after + if "anyOf" in schema: + any_of_mapping = self._get_mapping(env_name, *schema["anyOf"]) + if any_of_mapping: + any_of_values = list(any_of_mapping.values()) + + # If all subtypes are primary types, return the first subtype data + if all(isinstance(value, str) for value in any_of_values): + return any_of_values[0] + + # If all subtypes are dicts, merge the subtypes data in a single dict. + # Do not worry if subtypes share a field name, as the value is from a + # single environment variable and will have the same value. + if all(isinstance(value, dict) for value in any_of_values): + return dict(ChainMap(*any_of_values)) + + return None + + raise ValueError(f"{env_name}: unhandled schema {schema}") + + def _get_object( + self, + env_name: str, + schema: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Get an object from the environment. + + Args: + env_name: Environment variable name. + schema: Schema for the value we are retrieving. + + Returns: + Value retrieved from the environment. + """ + result: Dict[str, Any] = {} + + if env_name != "": + env_name += self.env_delimiter + + for child_key, child_schema in schema["properties"].items(): + child_env_name = (env_name + child_key).upper() + + value = self._get(child_env_name, child_schema) + if value: + result[child_key] = value + + return result + + # pylint: disable=too-many-branches + def _get_array( + self, + env_parent: str, + schema: Dict[str, Any], + ) -> Optional[List[Any]]: + """ + Get an array from the environment. + + Args: + env_name: Environment variable name. + schema: Schema for the value we are retrieving. + + Returns: + Value retrieved from the environment. + """ + result: Dict[int, Any] = {} + + schema_items = schema["items"] + if "$ref" in schema_items: + schema_items = self._resolve_ref(schema_items["$ref"]) + + # Found a environment variable without index suffix, try + # to extract CSV formatted array + if env_parent in self._env: + values = self._get(env_parent, schema_items) + if values: + for index, value in enumerate(values.split(",")): + result[index] = value.strip() + + indexes = guess_env_array_indexes(self._env, env_parent + self.env_delimiter) + if indexes: + for index in indexes: + env_name = env_parent + self.env_delimiter + str(index) + value = self._get(env_name, schema_items) + if value: + result[index] = value + + return index_dict_to_none_list(result) diff --git a/shared/tests/config/env_test.py b/shared/tests/config/env_test.py new file mode 100644 index 000000000..bd5fc62a3 --- /dev/null +++ b/shared/tests/config/env_test.py @@ -0,0 +1,189 @@ +# pylint: disable=protected-access +from os import environ +from typing import List, Union +from unittest import mock + +import pytest +from pydantic import BaseModel + +from libretime_shared.config import BaseConfig +from libretime_shared.config._env import EnvLoader + +ENV_SCHEMA_OBJ_WITH_STR = { + "type": "object", + "properties": {"a_str": {"type": "string"}}, +} + + +@pytest.mark.parametrize( + "env_parent, env, schema, expected", + [ + ( + "PRE", + {"PRE_A_STR": "found"}, + {"a_str": {"type": "string"}}, + {"a_str": "found"}, + ), + ( + "PRE", + {"PRE_OBJ_A_STR": "found"}, + {"obj": ENV_SCHEMA_OBJ_WITH_STR}, + {"obj": {"a_str": "found"}}, + ), + ( + "PRE", + {"PRE_ARR1": "one, two"}, + {"arr1": {"type": "array", "items": {"type": "string"}}}, + {"arr1": ["one", "two"]}, + ), + ( + "PRE", + { + "PRE_ARR2_0_A_STR": "one", + "PRE_ARR2_1_A_STR": "two", + "PRE_ARR2_3_A_STR": "ten", + }, + {"arr2": {"type": "array", "items": ENV_SCHEMA_OBJ_WITH_STR}}, + { + "arr2": [ + {"a_str": "one"}, + {"a_str": "two"}, + None, + {"a_str": "ten"}, + ] + }, + ), + ], +) +def test_env_config_loader_get_object( + env_parent, + env, + schema, + expected, +): + with mock.patch.dict(environ, env): + loader = EnvLoader(schema={}, env_prefix="PRE") + result = loader._get_object(env_parent, {"properties": schema}) + assert result == expected + + +class FirstChildConfig(BaseModel): + a_child_str: str + + +class SecondChildConfig(BaseModel): + a_child_str: str + a_child_int: int + + +# pylint: disable=too-few-public-methods +class FixtureConfig(BaseConfig): + a_str: str + a_list_of_str: List[str] + a_obj: FirstChildConfig + a_obj_with_default: FirstChildConfig = FirstChildConfig(a_child_str="default") + a_list_of_obj: List[FirstChildConfig] + a_union_str_or_int: Union[str, int] + a_union_obj: Union[FirstChildConfig, SecondChildConfig] + a_list_of_union_str_or_int: List[Union[str, int]] + a_list_of_union_obj: List[Union[FirstChildConfig, SecondChildConfig]] + + +ENV_SCHEMA = FixtureConfig.schema() + + +@pytest.mark.parametrize( + "env_name, env, schema, expected", + [ + ( + "PRE_A_STR", + {"PRE_A_STR": "found"}, + ENV_SCHEMA["properties"]["a_str"], + "found", + ), + ( + "PRE_A_LIST_OF_STR", + {"PRE_A_LIST_OF_STR": "one, two"}, + ENV_SCHEMA["properties"]["a_list_of_str"], + ["one", "two"], + ), + ( + "PRE_A_OBJ", + {"PRE_A_OBJ_A_CHILD_STR": "found"}, + ENV_SCHEMA["properties"]["a_obj"], + {"a_child_str": "found"}, + ), + ], +) +def test_env_config_loader_get( + env_name, + env, + schema, + expected, +): + with mock.patch.dict(environ, env): + loader = EnvLoader(schema=ENV_SCHEMA, env_prefix="PRE") + result = loader._get(env_name, schema) + assert result == expected + + +def test_env_config_loader_load_empty(): + with mock.patch.dict(environ, {}): + loader = EnvLoader(schema=ENV_SCHEMA, env_prefix="PRE") + result = loader.load() + assert not result + + +def test_env_config_loader_load(): + with mock.patch.dict( + environ, + { + "PRE_A_STR": "found", + "PRE_A_LIST_OF_STR": "one, two", + "PRE_A_OBJ": "invalid", + "PRE_A_OBJ_A_CHILD_STR": "found", + "PRE_A_OBJ_WITH_DEFAULT_A_CHILD_STR": "found", + "PRE_A_LIST_OF_OBJ": "invalid", + "PRE_A_LIST_OF_OBJ_0_A_CHILD_STR": "found", + "PRE_A_LIST_OF_OBJ_1_A_CHILD_STR": "found", + "PRE_A_LIST_OF_OBJ_3_A_CHILD_STR": "found", + "PRE_A_LIST_OF_OBJ_INVALID": "invalid", + "PRE_A_UNION_STR_OR_INT": "found", + "PRE_A_UNION_OBJ_A_CHILD_STR": "found", + "PRE_A_UNION_OBJ_A_CHILD_INT": "found", + "PRE_A_LIST_OF_UNION_STR_OR_INT": "one, two, 3", + "PRE_A_LIST_OF_UNION_STR_OR_INT_3": "4", + "PRE_A_LIST_OF_UNION_OBJ": "invalid", + "PRE_A_LIST_OF_UNION_OBJ_0_A_CHILD_STR": "found", + "PRE_A_LIST_OF_UNION_OBJ_1_A_CHILD_STR": "found", + "PRE_A_LIST_OF_UNION_OBJ_1_A_CHILD_INT": "found", + "PRE_A_LIST_OF_UNION_OBJ_3_A_CHILD_INT": "found", + "PRE_A_LIST_OF_UNION_OBJ_INVALID": "invalid", + }, + ): + loader = EnvLoader(schema=ENV_SCHEMA, env_prefix="PRE") + result = loader.load() + assert result == { + "a_str": "found", + "a_list_of_str": ["one", "two"], + "a_obj": {"a_child_str": "found"}, + "a_obj_with_default": {"a_child_str": "found"}, + "a_list_of_obj": [ + {"a_child_str": "found"}, + {"a_child_str": "found"}, + None, + {"a_child_str": "found"}, + ], + "a_union_str_or_int": "found", + "a_union_obj": { + "a_child_str": "found", + "a_child_int": "found", + }, + "a_list_of_union_str_or_int": ["one", "two", "3", "4"], + "a_list_of_union_obj": [ + {"a_child_str": "found"}, + {"a_child_str": "found", "a_child_int": "found"}, + None, + {"a_child_int": "found"}, + ], + }