feat: enhance libretime shared (#1491)

* feat: allow custom delimiter in BaseConfig

* feat: add default config filepath constant
This commit is contained in:
Jonas L 2022-01-06 15:25:43 +01:00 committed by GitHub
parent 3a615cafa0
commit 48dea6e5d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 5 deletions

View File

@ -11,7 +11,7 @@ from pydantic.fields import ModelField
from yaml import YAMLError, safe_load from yaml import YAMLError, safe_load
DEFAULT_ENV_PREFIX = "LIBRETIME" DEFAULT_ENV_PREFIX = "LIBRETIME"
DEFAULT_CONFIG_FILEPATH = Path("/etc/libretime/config.yml")
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class BaseConfig(BaseModel): class BaseConfig(BaseModel):
@ -27,10 +27,11 @@ class BaseConfig(BaseModel):
self, self,
*, *,
env_prefix: str = DEFAULT_ENV_PREFIX, env_prefix: str = DEFAULT_ENV_PREFIX,
env_delimiter: str = "_",
filepath: Optional[Path] = None, filepath: Optional[Path] = None,
) -> None: ) -> None:
file_values = self._load_file_values(filepath) file_values = self._load_file_values(filepath)
env_values = self._load_env_values(env_prefix) env_values = self._load_env_values(env_prefix, env_delimiter)
try: try:
super().__init__( super().__init__(
@ -43,22 +44,27 @@ class BaseConfig(BaseModel):
logger.critical(error) logger.critical(error)
sys.exit(1) sys.exit(1)
def _load_env_values(self, env_prefix: str) -> Dict[str, Any]: def _load_env_values(self, env_prefix: str, env_delimiter: str) -> Dict[str, Any]:
return self._get_fields_from_env(env_prefix, self.__fields__) return self._get_fields_from_env(env_prefix, env_delimiter, self.__fields__)
def _get_fields_from_env( def _get_fields_from_env(
self, self,
env_prefix: str, env_prefix: str,
env_delimiter: str,
fields: Dict[str, ModelField], fields: Dict[str, ModelField],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
result: Dict[str, Any] = {} result: Dict[str, Any] = {}
if env_prefix != "":
env_prefix += env_delimiter
for field in fields.values(): for field in fields.values():
env_name = (env_prefix + "_" + field.name).upper() env_name = (env_prefix + field.name).upper()
if field.is_complex(): if field.is_complex():
result[field.name] = self._get_fields_from_env( result[field.name] = self._get_fields_from_env(
env_name, env_name,
env_delimiter,
field.type_.__fields__, field.type_.__fields__,
) )
else: else: