diff --git a/analyzer/libretime_analyzer/main.py b/analyzer/libretime_analyzer/main.py index aeb3b23bb..640fe4138 100644 --- a/analyzer/libretime_analyzer/main.py +++ b/analyzer/libretime_analyzer/main.py @@ -34,7 +34,7 @@ def cli( Run analyzer. """ setup_logger(level_from_name(log_level), log_filepath) - config = Config(filepath=config_filepath) + config = Config(config_filepath) # Start up the StatusReporter process StatusReporter.start_thread(retry_queue_filepath) diff --git a/api-client/libretime_api_client/v1.py b/api-client/libretime_api_client/v1.py index c4286abe5..593508323 100644 --- a/api-client/libretime_api_client/v1.py +++ b/api-client/libretime_api_client/v1.py @@ -75,7 +75,7 @@ class ApiClient: def __init__(self, logger=None, config_path="/etc/libretime/config.yml"): self.logger = logger or logging - config = Config(filepath=config_path) + config = Config(config_path) self.base_url = config.general.public_url self.api_key = config.general.api_key diff --git a/api/libretime_api/settings/prod.py b/api/libretime_api/settings/prod.py index cde7640e7..25312db67 100644 --- a/api/libretime_api/settings/prod.py +++ b/api/libretime_api/settings/prod.py @@ -22,7 +22,7 @@ from ._schema import Config LIBRETIME_LOG_FILEPATH = getenv("LIBRETIME_LOG_FILEPATH") LIBRETIME_CONFIG_FILEPATH = getenv("LIBRETIME_CONFIG_FILEPATH") -CONFIG = Config(filepath=LIBRETIME_CONFIG_FILEPATH) +CONFIG = Config(LIBRETIME_CONFIG_FILEPATH) SECRET_KEY = CONFIG.general.api_key diff --git a/playout/libretime_playout/main.py b/playout/libretime_playout/main.py index 0eedfa79c..f8d16de9d 100644 --- a/playout/libretime_playout/main.py +++ b/playout/libretime_playout/main.py @@ -106,7 +106,7 @@ def cli(log_level: str, log_filepath: Optional[Path], config_filepath: Optional[ Run playout. """ setup_logger(level_from_name(log_level), log_filepath) - config = Config(filepath=config_filepath) + config = Config(config_filepath) try: for dir_path in [CACHE_DIR, RECORD_DIR]: diff --git a/shared/README.md b/shared/README.md index 4658ae956..0fba73a36 100644 --- a/shared/README.md +++ b/shared/README.md @@ -29,7 +29,7 @@ class Config(BaseConfig): rabbitmq: RabbitMQConfig analyzer: AnalyzerConfig -config = Config(filepath="/etc/libretime/config.yml") +config = Config("/etc/libretime/config.yml") ``` > Don't instantiate a sub model if it has a required field, otherwise the `Config` class import will raise a `ValidationError`. diff --git a/shared/libretime_shared/config/_base.py b/shared/libretime_shared/config/_base.py index 40c0c652a..da0abfd97 100644 --- a/shared/libretime_shared/config/_base.py +++ b/shared/libretime_shared/config/_base.py @@ -24,20 +24,23 @@ class BaseConfig(BaseModel): :returns: configuration class """ + # pylint: disable=no-self-argument def __init__( - self, + _self, + _filepath: Optional[Union[Path, str]] = None, *, - env_prefix: str = DEFAULT_ENV_PREFIX, - env_delimiter: str = "_", - filepath: Optional[Union[Path, str]] = None, + _env_prefix: str = DEFAULT_ENV_PREFIX, + _env_delimiter: str = "_", + **kwargs: Any, ) -> None: - if filepath is not None: - filepath = Path(filepath) + if _filepath is not None: + _filepath = Path(_filepath) - env_loader = EnvLoader(self.schema(), env_prefix, env_delimiter) + env_loader = EnvLoader(_self.schema(), _env_prefix, _env_delimiter) values = deep_merge_dict( - self._load_file_values(filepath), + kwargs, + _self._load_file_values(_filepath), env_loader.load(), ) @@ -67,36 +70,39 @@ class BaseConfig(BaseModel): return {} -def deep_merge_dict(base: Dict[str, Any], next_: Dict[str, Any]) -> Dict[str, Any]: +def deep_merge_dict(base: Dict[str, Any], *elements: Dict[str, Any]) -> Dict[str, Any]: result = base.copy() - for key, value in next_.items(): - if key in result: - if isinstance(result[key], dict) and isinstance(value, dict): - result[key] = deep_merge_dict(result[key], value) - continue - if isinstance(result[key], list) and isinstance(value, list): - result[key] = deep_merge_list(result[key], value) - continue + for element in elements: + for key, value in element.items(): + if key in result: + if isinstance(result[key], dict) and isinstance(value, dict): + result[key] = deep_merge_dict(result[key], value) + continue - if value: - result[key] = value + if isinstance(result[key], list) and isinstance(value, list): + result[key] = deep_merge_list(result[key], value) + continue + + if value: + result[key] = value return result -def deep_merge_list(base: List[Any], next_: List[Any]) -> List[Any]: +def deep_merge_list(base: List[Any], *elements: List[Any]) -> List[Any]: result: List[Any] = [] - for base_item, next_item in zip_longest(base, next_): - if isinstance(base_item, list) and isinstance(next_item, list): - result.append(deep_merge_list(base_item, next_item)) - continue + for element in elements: + for base_item, next_item in zip_longest(base, element): + if isinstance(base_item, list) and isinstance(next_item, list): + result.append(deep_merge_list(base_item, next_item)) + continue - if isinstance(base_item, dict) and isinstance(next_item, dict): - result.append(deep_merge_dict(base_item, next_item)) - continue + if isinstance(base_item, dict) and isinstance(next_item, dict): + result.append(deep_merge_dict(base_item, next_item)) + continue - if next_item: - result.append(next_item) + if next_item: + result.append(next_item) return result diff --git a/shared/tests/config/base_test.py b/shared/tests/config/base_test.py index e86dc2fe9..d380c6ca1 100644 --- a/shared/tests/config/base_test.py +++ b/shared/tests/config/base_test.py @@ -57,7 +57,7 @@ def test_base_config(tmp_path: Path): "WRONGPREFIX_API_KEY": "invalid", }, ): - config = FixtureConfig(filepath=config_filepath) + config = FixtureConfig(config_filepath) assert config.public_url == "http://libretime.example.com" assert config.api_key == "f3bf04fc" @@ -69,7 +69,7 @@ def test_base_config(tmp_path: Path): # Optional model: loading default values (rabbitmq) with mock.patch.dict(environ, {}): - config = FixtureConfig(filepath=config_filepath) + config = FixtureConfig(config_filepath) assert config.allowed_hosts == ["example.com", "sub.example.com"] assert config.rabbitmq.host == "localhost" assert config.rabbitmq.port == 5672 @@ -82,7 +82,7 @@ def test_base_config(tmp_path: Path): "LIBRETIME_ALLOWED_HOSTS": "example.com, changed.example.com", }, ): - config = FixtureConfig(filepath=config_filepath) + config = FixtureConfig(config_filepath) assert config.allowed_hosts == ["example.com", "changed.example.com"] assert config.rabbitmq.host == "changed" assert config.rabbitmq.port == 5672 @@ -111,19 +111,19 @@ def test_base_config_required_submodel(tmp_path: Path): # With config file with mock.patch.dict(environ, {}): - config = FixtureWithRequiredSubmodelConfig(filepath=config_filepath) + config = FixtureWithRequiredSubmodelConfig(config_filepath) assert config.required.api_key == "test_key" assert config.required.with_default == "original" # With env variables with mock.patch.dict(environ, {"LIBRETIME_REQUIRED_API_KEY": "test_key"}): - config = FixtureWithRequiredSubmodelConfig(filepath=None) + config = FixtureWithRequiredSubmodelConfig(None) assert config.required.api_key == "test_key" assert config.required.with_default == "original" # With env variables override with mock.patch.dict(environ, {"LIBRETIME_REQUIRED_API_KEY": "changed"}): - config = FixtureWithRequiredSubmodelConfig(filepath=config_filepath) + config = FixtureWithRequiredSubmodelConfig(config_filepath) assert config.required.api_key == "changed" assert config.required.with_default == "original" @@ -135,14 +135,29 @@ def test_base_config_required_submodel(tmp_path: Path): "LIBRETIME_REQUIRED_WITH_DEFAULT": "changed", }, ): - config = FixtureWithRequiredSubmodelConfig(filepath=config_filepath) + config = FixtureWithRequiredSubmodelConfig(config_filepath) assert config.required.api_key == "changed" assert config.required.with_default == "changed" # Raise validation error with mock.patch.dict(environ, {}): with raises(SystemExit): - FixtureWithRequiredSubmodelConfig(filepath=None) + FixtureWithRequiredSubmodelConfig(None) + + +def test_base_config_from_init() -> None: + class FromInitFixtureConfig(BaseConfig): + found: str + override: str + + with mock.patch.dict(environ, {"LIBRETIME_OVERRIDE": "changed"}): + config = FromInitFixtureConfig( + found="changed", + override="invalid", + ) + + assert config.found == "changed" + assert config.override == "changed" FIXTURE_CONFIG_RAW_MISSING = """ @@ -169,4 +184,4 @@ def test_load_config_error(tmp_path: Path, raw, exception): with raises(exception): with mock.patch.dict(environ, {}): - FixtureConfig(filepath=config_filepath) + FixtureConfig(config_filepath) diff --git a/worker/libretime_worker/config.py b/worker/libretime_worker/config.py index 2c49682b7..d2a20ff8c 100644 --- a/worker/libretime_worker/config.py +++ b/worker/libretime_worker/config.py @@ -11,7 +11,7 @@ class Config(BaseConfig): LIBRETIME_CONFIG_FILEPATH = getenv("LIBRETIME_CONFIG_FILEPATH") -config = Config(filepath=LIBRETIME_CONFIG_FILEPATH) +config = Config(LIBRETIME_CONFIG_FILEPATH) # Celery amqp settings BROKER_URL = config.rabbitmq.url