From 3705747132bae0d9c97e41b100593950f0367bd9 Mon Sep 17 00:00:00 2001 From: Jonas L Date: Sat, 11 Jun 2022 18:18:34 +0200 Subject: [PATCH] feat(shared): add config trailing slash sanitizer (#1870) --- shared/libretime_shared/config.py | 22 ++++++++++++++++++++-- shared/tests/config_test.py | 15 +++++++++++++-- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/shared/libretime_shared/config.py b/shared/libretime_shared/config.py index c1a2c6cc3..96d9234a0 100644 --- a/shared/libretime_shared/config.py +++ b/shared/libretime_shared/config.py @@ -1,19 +1,23 @@ import sys from os import environ from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from loguru import logger # pylint: disable=no-name-in-module -from pydantic import AnyHttpUrl, BaseModel, ValidationError +from pydantic import AnyHttpUrl, BaseModel, ValidationError, validator from pydantic.fields import ModelField from pydantic.utils import deep_update from yaml import YAMLError, safe_load +if TYPE_CHECKING: + from pydantic.typing import AnyClassMethod + DEFAULT_ENV_PREFIX = "LIBRETIME" DEFAULT_CONFIG_FILEPATH = Path("/etc/libretime/config.yml") + # pylint: disable=too-few-public-methods class BaseConfig(BaseModel): """ @@ -103,16 +107,30 @@ class BaseConfig(BaseModel): return {} +def no_trailing_slash_validator(key: str) -> "AnyClassMethod": + # pylint: disable=unused-argument + def strip_trailing_slash(cls: Any, value: str) -> str: + return value.rstrip("/") + + return validator(key, pre=True, allow_reuse=True)(strip_trailing_slash) + + # pylint: disable=too-few-public-methods class GeneralConfig(BaseModel): public_url: AnyHttpUrl api_key: str + # Validators + _public_url_no_trailing_slash = no_trailing_slash_validator("public_url") + # pylint: disable=too-few-public-methods class StorageConfig(BaseModel): path: str = "/srv/libretime" + # Validators + _path_no_trailing_slash = no_trailing_slash_validator("path") + # pylint: disable=too-few-public-methods class DatabaseConfig(BaseModel): diff --git a/shared/tests/config_test.py b/shared/tests/config_test.py index f83f811fd..e86dc2fe9 100644 --- a/shared/tests/config_test.py +++ b/shared/tests/config_test.py @@ -3,21 +3,31 @@ from pathlib import Path from typing import List from unittest import mock -from pydantic import BaseModel +from pydantic import AnyHttpUrl, BaseModel from pytest import mark, raises -from libretime_shared.config import BaseConfig, DatabaseConfig, RabbitMQConfig +from libretime_shared.config import ( + BaseConfig, + DatabaseConfig, + RabbitMQConfig, + no_trailing_slash_validator, +) # pylint: disable=too-few-public-methods class FixtureConfig(BaseConfig): + public_url: AnyHttpUrl api_key: str allowed_hosts: List[str] = [] database: DatabaseConfig rabbitmq: RabbitMQConfig = RabbitMQConfig() + # Validators + _public_url_no_trailing_slash = no_trailing_slash_validator("public_url") + FIXTURE_CONFIG_RAW = """ +public_url: http://libretime.example.com/ api_key: "f3bf04fc" allowed_hosts: - example.com @@ -49,6 +59,7 @@ def test_base_config(tmp_path: Path): ): config = FixtureConfig(filepath=config_filepath) + assert config.public_url == "http://libretime.example.com" assert config.api_key == "f3bf04fc" assert config.allowed_hosts == ["example.com", "sub.example.com"] assert config.database.host == "localhost"