feat(shared): add config trailing slash sanitizer (#1870)

This commit is contained in:
Jonas L 2022-06-11 18:18:34 +02:00 committed by GitHub
parent 8ab103ab44
commit 3705747132
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 4 deletions

View File

@ -1,19 +1,23 @@
import sys
from os import environ
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from loguru import logger
# pylint: disable=no-name-in-module
from pydantic import AnyHttpUrl, BaseModel, ValidationError
from pydantic import AnyHttpUrl, BaseModel, ValidationError, validator
from pydantic.fields import ModelField
from pydantic.utils import deep_update
from yaml import YAMLError, safe_load
if TYPE_CHECKING:
from pydantic.typing import AnyClassMethod
DEFAULT_ENV_PREFIX = "LIBRETIME"
DEFAULT_CONFIG_FILEPATH = Path("/etc/libretime/config.yml")
# pylint: disable=too-few-public-methods
class BaseConfig(BaseModel):
"""
@ -103,16 +107,30 @@ class BaseConfig(BaseModel):
return {}
def no_trailing_slash_validator(key: str) -> "AnyClassMethod":
# pylint: disable=unused-argument
def strip_trailing_slash(cls: Any, value: str) -> str:
return value.rstrip("/")
return validator(key, pre=True, allow_reuse=True)(strip_trailing_slash)
# pylint: disable=too-few-public-methods
class GeneralConfig(BaseModel):
public_url: AnyHttpUrl
api_key: str
# Validators
_public_url_no_trailing_slash = no_trailing_slash_validator("public_url")
# pylint: disable=too-few-public-methods
class StorageConfig(BaseModel):
path: str = "/srv/libretime"
# Validators
_path_no_trailing_slash = no_trailing_slash_validator("path")
# pylint: disable=too-few-public-methods
class DatabaseConfig(BaseModel):

View File

@ -3,21 +3,31 @@ from pathlib import Path
from typing import List
from unittest import mock
from pydantic import BaseModel
from pydantic import AnyHttpUrl, BaseModel
from pytest import mark, raises
from libretime_shared.config import BaseConfig, DatabaseConfig, RabbitMQConfig
from libretime_shared.config import (
BaseConfig,
DatabaseConfig,
RabbitMQConfig,
no_trailing_slash_validator,
)
# pylint: disable=too-few-public-methods
class FixtureConfig(BaseConfig):
public_url: AnyHttpUrl
api_key: str
allowed_hosts: List[str] = []
database: DatabaseConfig
rabbitmq: RabbitMQConfig = RabbitMQConfig()
# Validators
_public_url_no_trailing_slash = no_trailing_slash_validator("public_url")
FIXTURE_CONFIG_RAW = """
public_url: http://libretime.example.com/
api_key: "f3bf04fc"
allowed_hosts:
- example.com
@ -49,6 +59,7 @@ def test_base_config(tmp_path: Path):
):
config = FixtureConfig(filepath=config_filepath)
assert config.public_url == "http://libretime.example.com"
assert config.api_key == "f3bf04fc"
assert config.allowed_hosts == ["example.com", "sub.example.com"]
assert config.database.host == "localhost"