feat(shared): add config trailing slash sanitizer (#1870)
This commit is contained in:
parent
8ab103ab44
commit
3705747132
|
@ -1,19 +1,23 @@
|
||||||
import sys
|
import sys
|
||||||
from os import environ
|
from os import environ
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
# pylint: disable=no-name-in-module
|
# pylint: disable=no-name-in-module
|
||||||
from pydantic import AnyHttpUrl, BaseModel, ValidationError
|
from pydantic import AnyHttpUrl, BaseModel, ValidationError, validator
|
||||||
from pydantic.fields import ModelField
|
from pydantic.fields import ModelField
|
||||||
from pydantic.utils import deep_update
|
from pydantic.utils import deep_update
|
||||||
from yaml import YAMLError, safe_load
|
from yaml import YAMLError, safe_load
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pydantic.typing import AnyClassMethod
|
||||||
|
|
||||||
DEFAULT_ENV_PREFIX = "LIBRETIME"
|
DEFAULT_ENV_PREFIX = "LIBRETIME"
|
||||||
DEFAULT_CONFIG_FILEPATH = Path("/etc/libretime/config.yml")
|
DEFAULT_CONFIG_FILEPATH = Path("/etc/libretime/config.yml")
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
class BaseConfig(BaseModel):
|
class BaseConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -103,16 +107,30 @@ class BaseConfig(BaseModel):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def no_trailing_slash_validator(key: str) -> "AnyClassMethod":
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def strip_trailing_slash(cls: Any, value: str) -> str:
|
||||||
|
return value.rstrip("/")
|
||||||
|
|
||||||
|
return validator(key, pre=True, allow_reuse=True)(strip_trailing_slash)
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
class GeneralConfig(BaseModel):
|
class GeneralConfig(BaseModel):
|
||||||
public_url: AnyHttpUrl
|
public_url: AnyHttpUrl
|
||||||
api_key: str
|
api_key: str
|
||||||
|
|
||||||
|
# Validators
|
||||||
|
_public_url_no_trailing_slash = no_trailing_slash_validator("public_url")
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
class StorageConfig(BaseModel):
|
class StorageConfig(BaseModel):
|
||||||
path: str = "/srv/libretime"
|
path: str = "/srv/libretime"
|
||||||
|
|
||||||
|
# Validators
|
||||||
|
_path_no_trailing_slash = no_trailing_slash_validator("path")
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
class DatabaseConfig(BaseModel):
|
class DatabaseConfig(BaseModel):
|
||||||
|
|
|
@ -3,21 +3,31 @@ from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import AnyHttpUrl, BaseModel
|
||||||
from pytest import mark, raises
|
from pytest import mark, raises
|
||||||
|
|
||||||
from libretime_shared.config import BaseConfig, DatabaseConfig, RabbitMQConfig
|
from libretime_shared.config import (
|
||||||
|
BaseConfig,
|
||||||
|
DatabaseConfig,
|
||||||
|
RabbitMQConfig,
|
||||||
|
no_trailing_slash_validator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
class FixtureConfig(BaseConfig):
|
class FixtureConfig(BaseConfig):
|
||||||
|
public_url: AnyHttpUrl
|
||||||
api_key: str
|
api_key: str
|
||||||
allowed_hosts: List[str] = []
|
allowed_hosts: List[str] = []
|
||||||
database: DatabaseConfig
|
database: DatabaseConfig
|
||||||
rabbitmq: RabbitMQConfig = RabbitMQConfig()
|
rabbitmq: RabbitMQConfig = RabbitMQConfig()
|
||||||
|
|
||||||
|
# Validators
|
||||||
|
_public_url_no_trailing_slash = no_trailing_slash_validator("public_url")
|
||||||
|
|
||||||
|
|
||||||
FIXTURE_CONFIG_RAW = """
|
FIXTURE_CONFIG_RAW = """
|
||||||
|
public_url: http://libretime.example.com/
|
||||||
api_key: "f3bf04fc"
|
api_key: "f3bf04fc"
|
||||||
allowed_hosts:
|
allowed_hosts:
|
||||||
- example.com
|
- example.com
|
||||||
|
@ -49,6 +59,7 @@ def test_base_config(tmp_path: Path):
|
||||||
):
|
):
|
||||||
config = FixtureConfig(filepath=config_filepath)
|
config = FixtureConfig(filepath=config_filepath)
|
||||||
|
|
||||||
|
assert config.public_url == "http://libretime.example.com"
|
||||||
assert config.api_key == "f3bf04fc"
|
assert config.api_key == "f3bf04fc"
|
||||||
assert config.allowed_hosts == ["example.com", "sub.example.com"]
|
assert config.allowed_hosts == ["example.com", "sub.example.com"]
|
||||||
assert config.database.host == "localhost"
|
assert config.database.host == "localhost"
|
||||||
|
|
Loading…
Reference in New Issue