feat(shared): add config trailing slash sanitizer (#1870)

This commit is contained in:
Jonas L 2022-06-11 18:18:34 +02:00 committed by GitHub
parent 8ab103ab44
commit 3705747132
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 4 deletions

View File

@ -1,19 +1,23 @@
import sys import sys
from os import environ from os import environ
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from loguru import logger from loguru import logger
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
from pydantic import AnyHttpUrl, BaseModel, ValidationError from pydantic import AnyHttpUrl, BaseModel, ValidationError, validator
from pydantic.fields import ModelField from pydantic.fields import ModelField
from pydantic.utils import deep_update from pydantic.utils import deep_update
from yaml import YAMLError, safe_load from yaml import YAMLError, safe_load
if TYPE_CHECKING:
from pydantic.typing import AnyClassMethod
DEFAULT_ENV_PREFIX = "LIBRETIME" DEFAULT_ENV_PREFIX = "LIBRETIME"
DEFAULT_CONFIG_FILEPATH = Path("/etc/libretime/config.yml") DEFAULT_CONFIG_FILEPATH = Path("/etc/libretime/config.yml")
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class BaseConfig(BaseModel): class BaseConfig(BaseModel):
""" """
@ -103,16 +107,30 @@ class BaseConfig(BaseModel):
return {} return {}
def no_trailing_slash_validator(key: str) -> "AnyClassMethod":
# pylint: disable=unused-argument
def strip_trailing_slash(cls: Any, value: str) -> str:
return value.rstrip("/")
return validator(key, pre=True, allow_reuse=True)(strip_trailing_slash)
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class GeneralConfig(BaseModel): class GeneralConfig(BaseModel):
public_url: AnyHttpUrl public_url: AnyHttpUrl
api_key: str api_key: str
# Validators
_public_url_no_trailing_slash = no_trailing_slash_validator("public_url")
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class StorageConfig(BaseModel): class StorageConfig(BaseModel):
path: str = "/srv/libretime" path: str = "/srv/libretime"
# Validators
_path_no_trailing_slash = no_trailing_slash_validator("path")
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class DatabaseConfig(BaseModel): class DatabaseConfig(BaseModel):

View File

@ -3,21 +3,31 @@ from pathlib import Path
from typing import List from typing import List
from unittest import mock from unittest import mock
from pydantic import BaseModel from pydantic import AnyHttpUrl, BaseModel
from pytest import mark, raises from pytest import mark, raises
from libretime_shared.config import BaseConfig, DatabaseConfig, RabbitMQConfig from libretime_shared.config import (
BaseConfig,
DatabaseConfig,
RabbitMQConfig,
no_trailing_slash_validator,
)
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class FixtureConfig(BaseConfig): class FixtureConfig(BaseConfig):
public_url: AnyHttpUrl
api_key: str api_key: str
allowed_hosts: List[str] = [] allowed_hosts: List[str] = []
database: DatabaseConfig database: DatabaseConfig
rabbitmq: RabbitMQConfig = RabbitMQConfig() rabbitmq: RabbitMQConfig = RabbitMQConfig()
# Validators
_public_url_no_trailing_slash = no_trailing_slash_validator("public_url")
FIXTURE_CONFIG_RAW = """ FIXTURE_CONFIG_RAW = """
public_url: http://libretime.example.com/
api_key: "f3bf04fc" api_key: "f3bf04fc"
allowed_hosts: allowed_hosts:
- example.com - example.com
@ -49,6 +59,7 @@ def test_base_config(tmp_path: Path):
): ):
config = FixtureConfig(filepath=config_filepath) config = FixtureConfig(filepath=config_filepath)
assert config.public_url == "http://libretime.example.com"
assert config.api_key == "f3bf04fc" assert config.api_key == "f3bf04fc"
assert config.allowed_hosts == ["example.com", "sub.example.com"] assert config.allowed_hosts == ["example.com", "sub.example.com"]
assert config.database.host == "localhost" assert config.database.host == "localhost"