feat(shared): load env config using jsonschema

The env loader is now capable of loading lists of objects, union types
or list of union types from the env variables.
They are some limitations: for example it doesn't support unions of
different shapes `list | dict` or `str | dict`.
This commit is contained in:
jo 2022-07-30 21:15:39 +02:00 committed by Kyle Robbertze
parent 6c449e3019
commit bcd877266f
3 changed files with 500 additions and 47 deletions

View File

@ -1,16 +1,14 @@
import sys
from os import environ
from itertools import zip_longest
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from loguru import logger
# pylint: disable=no-name-in-module
from pydantic import BaseModel, ValidationError
from pydantic.fields import ModelField
from pydantic.utils import deep_update
from yaml import YAMLError, safe_load
from ._env import EnvLoader
DEFAULT_ENV_PREFIX = "LIBRETIME"
DEFAULT_CONFIG_FILEPATH = Path("/etc/libretime/config.yml")
@ -36,54 +34,19 @@ class BaseConfig(BaseModel):
if filepath is not None:
filepath = Path(filepath)
file_values = self._load_file_values(filepath)
env_values = self._load_env_values(env_prefix, env_delimiter)
env_loader = EnvLoader(self.schema(), env_prefix, env_delimiter)
values = deep_merge_dict(
self._load_file_values(filepath),
env_loader.load(),
)
try:
super().__init__(**deep_update(file_values, env_values))
super().__init__(**values)
except ValidationError as error:
logger.critical(error)
sys.exit(1)
def _load_env_values(self, env_prefix: str, env_delimiter: str) -> Dict[str, Any]:
return self._get_fields_from_env(env_prefix, env_delimiter, self.__fields__)
def _get_fields_from_env(
self,
env_prefix: str,
env_delimiter: str,
fields: Dict[str, ModelField],
) -> Dict[str, Any]:
result: Dict[str, Any] = {}
if env_prefix != "":
env_prefix += env_delimiter
for field in fields.values():
env_name = (env_prefix + field.name).upper()
if field.is_complex():
children: Union[List[Any], Dict[str, Any]] = []
if field.sub_fields:
if env_name in environ:
children = [v.strip() for v in environ[env_name].split(",")]
else:
children = self._get_fields_from_env(
env_name,
env_delimiter,
field.type_.__fields__,
)
if len(children) != 0:
result[field.name] = children
else:
if env_name in environ:
result[field.name] = environ[env_name]
return result
def _load_file_values(
self,
filepath: Optional[Path] = None,
@ -102,3 +65,38 @@ class BaseConfig(BaseModel):
logger.error(f"config file '{filepath}' is not a valid yaml file: {error}")
return {}
def deep_merge_dict(base: Dict[str, Any], next_: Dict[str, Any]) -> Dict[str, Any]:
result = base.copy()
for key, value in next_.items():
if key in result:
if isinstance(result[key], dict) and isinstance(value, dict):
result[key] = deep_merge_dict(result[key], value)
continue
if isinstance(result[key], list) and isinstance(value, list):
result[key] = deep_merge_list(result[key], value)
continue
if value:
result[key] = value
return result
def deep_merge_list(base: List[Any], next_: List[Any]) -> List[Any]:
result: List[Any] = []
for base_item, next_item in zip_longest(base, next_):
if isinstance(base_item, list) and isinstance(next_item, list):
result.append(deep_merge_list(base_item, next_item))
continue
if isinstance(base_item, dict) and isinstance(next_item, dict):
result.append(deep_merge_dict(base_item, next_item))
continue
if next_item:
result.append(next_item)
return result

View File

@ -0,0 +1,266 @@
from collections import ChainMap
from functools import reduce
from operator import getitem
from os import environ
from typing import Any, Dict, List, Optional, TypeVar
__all__ = [
"EnvLoader",
]
def filter_env(env: Dict[str, str], prefix: str) -> Dict[str, str]:
"""
Filter a environment variables dict by key prefix.
Args:
env: Environment variables dict.
prefix: Environment variable key prefix.
Returns:
Environment variables dict.
"""
return {k: v for k, v in env.items() if k.startswith(prefix)}
def guess_env_array_indexes(env: Dict[str, str], prefix: str) -> List[int]:
"""
Guess environment variables indexes from the environment variables keys.
Args:
env: Environment variables dict.
prefix: Environment variable key prefix for all indexes.
Returns:
A list of indexes.
"""
prefix_len = len(prefix)
result = []
for env_name in filter_env(env, prefix):
if not env_name[prefix_len].isdigit():
continue
index_str = env_name[prefix_len:]
index_str = index_str.partition("_")[0]
result.append(int(index_str))
return result
T = TypeVar("T")
def index_dict_to_none_list(base: Dict[int, T]) -> List[Optional[T]]:
"""
Convert a dict to a list by associating the dict keys to the list
indexes and filling the missing indexes with None.
Args:
base: Dict to convert.
Returns:
Converted dict.
"""
if not base:
return []
result: List[Optional[T]] = [None] * (max(base.keys()) + 1)
for index, value in base.items():
result[index] = value
return result
# pylint: disable=too-few-public-methods
class EnvLoader:
schema: dict
env_prefix: str
env_delimiter: str
_env: Dict[str, str]
def __init__(
self,
schema: dict,
env_prefix: Optional[str] = None,
env_delimiter: str = "_",
) -> None:
self.schema = schema
self.env_prefix = env_prefix or ""
self.env_delimiter = env_delimiter
self._env = environ.copy()
if self.env_prefix:
self._env = filter_env(self._env, self.env_prefix)
def load(self) -> Dict[str, Any]:
if not self._env:
return {}
return self._get(self.env_prefix, self.schema)
def _resolve_ref(
self,
path: str,
) -> Dict[str, Any]:
_, *parts = path.split("/")
return reduce(getitem, parts, self.schema)
def _get_mapping(
self,
env_name: str,
*schemas: Dict[str, Any],
) -> Dict[str, Any]:
"""
Get a mapping of each subtypes with the data.
This helps resolve conflicts after we have all the data.
Args:
env_name: Environment variable name to get the data from.
Returns:
Mapping of each subtypes, with associated data as value.
"""
mapping: Dict[str, Any] = {}
for schema in schemas:
if "$ref" in schema:
schema = self._resolve_ref(schema["$ref"])
value = self._get(env_name, schema)
if not value:
continue
key = "title" if "title" in schema else "type"
mapping[schema[key]] = value
return mapping
# pylint: disable=too-many-return-statements
def _get(
self,
env_name: str,
schema: Dict[str, Any],
) -> Any:
"""
Get a value from the environment.
Args:
env_name: Environment variable name.
schema: Schema for the value we are retrieving.
Returns:
Value retrieved from the environment.
"""
if "$ref" in schema:
schema = self._resolve_ref(schema["$ref"])
if "type" in schema:
if schema["type"] in ("string", "integer", "boolean"):
return self._env.get(env_name, None)
if schema["type"] == "object":
return self._get_object(env_name, schema)
if schema["type"] == "array":
return self._get_array(env_name, schema)
# Get all the properties as we won't have typing conflicts
if "allOf" in schema:
all_of_mapping = self._get_mapping(env_name, *schema["allOf"])
# Merging all subtypes data together
return dict(ChainMap(*all_of_mapping.values()))
# Get all the properties and resolve conflicts after
if "anyOf" in schema:
any_of_mapping = self._get_mapping(env_name, *schema["anyOf"])
if any_of_mapping:
any_of_values = list(any_of_mapping.values())
# If all subtypes are primary types, return the first subtype data
if all(isinstance(value, str) for value in any_of_values):
return any_of_values[0]
# If all subtypes are dicts, merge the subtypes data in a single dict.
# Do not worry if subtypes share a field name, as the value is from a
# single environment variable and will have the same value.
if all(isinstance(value, dict) for value in any_of_values):
return dict(ChainMap(*any_of_values))
return None
raise ValueError(f"{env_name}: unhandled schema {schema}")
def _get_object(
self,
env_name: str,
schema: Dict[str, Any],
) -> Dict[str, Any]:
"""
Get an object from the environment.
Args:
env_name: Environment variable name.
schema: Schema for the value we are retrieving.
Returns:
Value retrieved from the environment.
"""
result: Dict[str, Any] = {}
if env_name != "":
env_name += self.env_delimiter
for child_key, child_schema in schema["properties"].items():
child_env_name = (env_name + child_key).upper()
value = self._get(child_env_name, child_schema)
if value:
result[child_key] = value
return result
# pylint: disable=too-many-branches
def _get_array(
self,
env_parent: str,
schema: Dict[str, Any],
) -> Optional[List[Any]]:
"""
Get an array from the environment.
Args:
env_name: Environment variable name.
schema: Schema for the value we are retrieving.
Returns:
Value retrieved from the environment.
"""
result: Dict[int, Any] = {}
schema_items = schema["items"]
if "$ref" in schema_items:
schema_items = self._resolve_ref(schema_items["$ref"])
# Found a environment variable without index suffix, try
# to extract CSV formatted array
if env_parent in self._env:
values = self._get(env_parent, schema_items)
if values:
for index, value in enumerate(values.split(",")):
result[index] = value.strip()
indexes = guess_env_array_indexes(self._env, env_parent + self.env_delimiter)
if indexes:
for index in indexes:
env_name = env_parent + self.env_delimiter + str(index)
value = self._get(env_name, schema_items)
if value:
result[index] = value
return index_dict_to_none_list(result)

View File

@ -0,0 +1,189 @@
# pylint: disable=protected-access
from os import environ
from typing import List, Union
from unittest import mock
import pytest
from pydantic import BaseModel
from libretime_shared.config import BaseConfig
from libretime_shared.config._env import EnvLoader
ENV_SCHEMA_OBJ_WITH_STR = {
"type": "object",
"properties": {"a_str": {"type": "string"}},
}
@pytest.mark.parametrize(
"env_parent, env, schema, expected",
[
(
"PRE",
{"PRE_A_STR": "found"},
{"a_str": {"type": "string"}},
{"a_str": "found"},
),
(
"PRE",
{"PRE_OBJ_A_STR": "found"},
{"obj": ENV_SCHEMA_OBJ_WITH_STR},
{"obj": {"a_str": "found"}},
),
(
"PRE",
{"PRE_ARR1": "one, two"},
{"arr1": {"type": "array", "items": {"type": "string"}}},
{"arr1": ["one", "two"]},
),
(
"PRE",
{
"PRE_ARR2_0_A_STR": "one",
"PRE_ARR2_1_A_STR": "two",
"PRE_ARR2_3_A_STR": "ten",
},
{"arr2": {"type": "array", "items": ENV_SCHEMA_OBJ_WITH_STR}},
{
"arr2": [
{"a_str": "one"},
{"a_str": "two"},
None,
{"a_str": "ten"},
]
},
),
],
)
def test_env_config_loader_get_object(
env_parent,
env,
schema,
expected,
):
with mock.patch.dict(environ, env):
loader = EnvLoader(schema={}, env_prefix="PRE")
result = loader._get_object(env_parent, {"properties": schema})
assert result == expected
class FirstChildConfig(BaseModel):
a_child_str: str
class SecondChildConfig(BaseModel):
a_child_str: str
a_child_int: int
# pylint: disable=too-few-public-methods
class FixtureConfig(BaseConfig):
a_str: str
a_list_of_str: List[str]
a_obj: FirstChildConfig
a_obj_with_default: FirstChildConfig = FirstChildConfig(a_child_str="default")
a_list_of_obj: List[FirstChildConfig]
a_union_str_or_int: Union[str, int]
a_union_obj: Union[FirstChildConfig, SecondChildConfig]
a_list_of_union_str_or_int: List[Union[str, int]]
a_list_of_union_obj: List[Union[FirstChildConfig, SecondChildConfig]]
ENV_SCHEMA = FixtureConfig.schema()
@pytest.mark.parametrize(
"env_name, env, schema, expected",
[
(
"PRE_A_STR",
{"PRE_A_STR": "found"},
ENV_SCHEMA["properties"]["a_str"],
"found",
),
(
"PRE_A_LIST_OF_STR",
{"PRE_A_LIST_OF_STR": "one, two"},
ENV_SCHEMA["properties"]["a_list_of_str"],
["one", "two"],
),
(
"PRE_A_OBJ",
{"PRE_A_OBJ_A_CHILD_STR": "found"},
ENV_SCHEMA["properties"]["a_obj"],
{"a_child_str": "found"},
),
],
)
def test_env_config_loader_get(
env_name,
env,
schema,
expected,
):
with mock.patch.dict(environ, env):
loader = EnvLoader(schema=ENV_SCHEMA, env_prefix="PRE")
result = loader._get(env_name, schema)
assert result == expected
def test_env_config_loader_load_empty():
with mock.patch.dict(environ, {}):
loader = EnvLoader(schema=ENV_SCHEMA, env_prefix="PRE")
result = loader.load()
assert not result
def test_env_config_loader_load():
with mock.patch.dict(
environ,
{
"PRE_A_STR": "found",
"PRE_A_LIST_OF_STR": "one, two",
"PRE_A_OBJ": "invalid",
"PRE_A_OBJ_A_CHILD_STR": "found",
"PRE_A_OBJ_WITH_DEFAULT_A_CHILD_STR": "found",
"PRE_A_LIST_OF_OBJ": "invalid",
"PRE_A_LIST_OF_OBJ_0_A_CHILD_STR": "found",
"PRE_A_LIST_OF_OBJ_1_A_CHILD_STR": "found",
"PRE_A_LIST_OF_OBJ_3_A_CHILD_STR": "found",
"PRE_A_LIST_OF_OBJ_INVALID": "invalid",
"PRE_A_UNION_STR_OR_INT": "found",
"PRE_A_UNION_OBJ_A_CHILD_STR": "found",
"PRE_A_UNION_OBJ_A_CHILD_INT": "found",
"PRE_A_LIST_OF_UNION_STR_OR_INT": "one, two, 3",
"PRE_A_LIST_OF_UNION_STR_OR_INT_3": "4",
"PRE_A_LIST_OF_UNION_OBJ": "invalid",
"PRE_A_LIST_OF_UNION_OBJ_0_A_CHILD_STR": "found",
"PRE_A_LIST_OF_UNION_OBJ_1_A_CHILD_STR": "found",
"PRE_A_LIST_OF_UNION_OBJ_1_A_CHILD_INT": "found",
"PRE_A_LIST_OF_UNION_OBJ_3_A_CHILD_INT": "found",
"PRE_A_LIST_OF_UNION_OBJ_INVALID": "invalid",
},
):
loader = EnvLoader(schema=ENV_SCHEMA, env_prefix="PRE")
result = loader.load()
assert result == {
"a_str": "found",
"a_list_of_str": ["one", "two"],
"a_obj": {"a_child_str": "found"},
"a_obj_with_default": {"a_child_str": "found"},
"a_list_of_obj": [
{"a_child_str": "found"},
{"a_child_str": "found"},
None,
{"a_child_str": "found"},
],
"a_union_str_or_int": "found",
"a_union_obj": {
"a_child_str": "found",
"a_child_int": "found",
},
"a_list_of_union_str_or_int": ["one", "two", "3", "4"],
"a_list_of_union_obj": [
{"a_child_str": "found"},
{"a_child_str": "found", "a_child_int": "found"},
None,
{"a_child_int": "found"},
],
}