diff --git a/playout/libretime_playout/config.py b/playout/libretime_playout/config.py index 2690c4240..4877c1b89 100644 --- a/playout/libretime_playout/config.py +++ b/playout/libretime_playout/config.py @@ -7,7 +7,7 @@ from libretime_shared.config import ( RabbitMQConfig, StreamConfig, ) -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, model_validator CACHE_DIR = Path.cwd() / "scheduler" RECORD_DIR = Path.cwd() / "recorder" @@ -37,18 +37,21 @@ class LiquidsoapConfig(BaseModel): harbor_ssl_private_key: Optional[str] = None harbor_ssl_password: Optional[str] = None - @root_validator - @classmethod - def _validate_harbor_ssl(cls, values: dict): - harbor_ssl_certificate = values.get("harbor_ssl_certificate") - harbor_ssl_private_key = values.get("harbor_ssl_private_key") - if harbor_ssl_certificate is not None and harbor_ssl_private_key is None: + @model_validator(mode="after") + def _validate_harbor_ssl(self): + if ( + self.harbor_ssl_certificate is not None + and self.harbor_ssl_private_key is None + ): raise ValueError("missing 'harbor_ssl_private_key' value") - if harbor_ssl_certificate is None and harbor_ssl_private_key is not None: + if ( + self.harbor_ssl_certificate is None + and self.harbor_ssl_private_key is not None + ): raise ValueError("missing 'harbor_ssl_certificate' value") - return values + return self class Config(BaseConfig): diff --git a/shared/libretime_shared/config/__init__.py b/shared/libretime_shared/config/__init__.py index c4a4510fb..05f89eee6 100644 --- a/shared/libretime_shared/config/__init__.py +++ b/shared/libretime_shared/config/__init__.py @@ -1,4 +1,5 @@ from ._base import DEFAULT_CONFIG_FILEPATH, DEFAULT_ENV_PREFIX, BaseConfig +from ._fields import AnyHttpUrlStr, AnyUrlStr, StrNoLeadingSlash, StrNoTrailingSlash from ._models import ( AudioChannels, AudioFormat, @@ -11,6 +12,4 @@ from ._models import ( StorageConfig, StreamConfig, SystemOutput, - no_leading_slash_validator, - no_trailing_slash_validator, ) diff --git a/shared/libretime_shared/config/_base.py b/shared/libretime_shared/config/_base.py index eb6d3da9e..95a5bd775 100644 --- a/shared/libretime_shared/config/_base.py +++ b/shared/libretime_shared/config/_base.py @@ -38,7 +38,7 @@ class BaseConfig(BaseModel): if _filepath is not None: _filepath = Path(_filepath) - env_loader = EnvLoader(_self.schema(), _env_prefix, _env_delimiter) + env_loader = EnvLoader(_self.model_json_schema(), _env_prefix, _env_delimiter) values = deep_merge_dict( kwargs, diff --git a/shared/libretime_shared/config/_env.py b/shared/libretime_shared/config/_env.py index 03f94b657..85df5fc1e 100644 --- a/shared/libretime_shared/config/_env.py +++ b/shared/libretime_shared/config/_env.py @@ -140,7 +140,7 @@ class EnvLoader: return mapping - # pylint: disable=too-many-return-statements + # pylint: disable=too-many-return-statements,too-many-branches def _get( self, env_name: str, @@ -160,7 +160,13 @@ class EnvLoader: if "$ref" in schema: schema = self._resolve_ref(schema["$ref"]) + if "const" in schema: + return self._env.get(env_name, None) + if "type" in schema: + if schema["type"] == "null": + return None + if schema["type"] in ("string", "integer", "boolean"): return self._env.get(env_name, None) diff --git a/shared/libretime_shared/config/_fields.py b/shared/libretime_shared/config/_fields.py new file mode 100644 index 000000000..fe3677783 --- /dev/null +++ b/shared/libretime_shared/config/_fields.py @@ -0,0 +1,69 @@ +from typing import Any, Optional + +from pydantic import ( + AfterValidator, + AnyHttpUrl, + AnyUrl, + GetCoreSchemaHandler, + GetJsonSchemaHandler, + TypeAdapter, +) +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import Url +from pydantic_core.core_schema import CoreSchema, no_info_after_validator_function +from typing_extensions import Annotated + +StrNoTrailingSlash = Annotated[str, AfterValidator(lambda x: str(x).rstrip("/"))] +StrNoLeadingSlash = Annotated[str, AfterValidator(lambda x: str(x).lstrip("/"))] + + +class AnyUrlStr(str): + _type_adapter = TypeAdapter(AnyUrl) + obj: Url + + @classmethod + def __get_pydantic_core_schema__( + cls, + _: Any, + handler: GetCoreSchemaHandler, + ) -> CoreSchema: + return no_info_after_validator_function(cls, handler(str)) + + @classmethod + def __get_pydantic_json_schema__( + cls, + core_schema: CoreSchema, + handler: GetJsonSchemaHandler, + ) -> JsonSchemaValue: + field_schema = handler(core_schema) + field_schema.update(format="uri") + return field_schema + + def __new__(cls, value: str) -> "AnyUrlStr": + url_obj = cls._type_adapter.validate_strings(value) + self = str.__new__(cls, str(url_obj).rstrip("/")) + self.obj = url_obj + return self + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().__repr__()})" + + @property + def scheme(self) -> str: + return self.obj.scheme + + @property + def host(self) -> Optional[str]: + return self.obj.host + + @property + def port(self) -> Optional[int]: + return self.obj.port + + @property + def path(self) -> Optional[str]: + return self.obj.path + + +class AnyHttpUrlStr(AnyUrlStr): + _type_adapter = TypeAdapter(AnyHttpUrl) diff --git a/shared/libretime_shared/config/_models.py b/shared/libretime_shared/config/_models.py index 13e70927a..51a5961f5 100644 --- a/shared/libretime_shared/config/_models.py +++ b/shared/libretime_shared/config/_models.py @@ -1,59 +1,33 @@ import sys from enum import Enum -from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union +from typing import List, Literal, Optional, Union -# pylint: disable=no-name-in-module -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator from typing_extensions import Annotated +from ._fields import AnyHttpUrlStr, AnyUrlStr, StrNoLeadingSlash, StrNoTrailingSlash + if sys.version_info < (3, 9): from backports.zoneinfo import ZoneInfo, ZoneInfoNotFoundError else: from zoneinfo import ZoneInfo, ZoneInfoNotFoundError -if TYPE_CHECKING: - from pydantic.typing import AnyClassMethod - - -def no_trailing_slash_validator(key: str) -> "AnyClassMethod": - # pylint: disable=unused-argument - def strip_trailing_slash(cls: Any, value: Any) -> Any: - if isinstance(value, str): - return value.rstrip("/") - return value - - return validator(key, pre=True, allow_reuse=True)(strip_trailing_slash) - - -def no_leading_slash_validator(key: str) -> "AnyClassMethod": - # pylint: disable=unused-argument - def strip_leading_slash(cls: Any, value: Any) -> Any: - if isinstance(value, str): - return value.lstrip("/") - return value - - return validator(key, pre=True, allow_reuse=True)(strip_leading_slash) - - # GeneralConfig ######################################################################################## # pylint: disable=too-few-public-methods class GeneralConfig(BaseModel): - public_url: AnyHttpUrl + public_url: AnyHttpUrlStr api_key: str secret_key: Optional[str] = None timezone: str = "UTC" - allowed_cors_origins: List[AnyHttpUrl] = [] + allowed_cors_origins: List[AnyHttpUrlStr] = [] - # Validators - _public_url_no_trailing_slash = no_trailing_slash_validator("public_url") - - @validator("timezone") + @field_validator("timezone") @classmethod def _validate_timezone(cls, value: str) -> str: try: @@ -70,10 +44,7 @@ class GeneralConfig(BaseModel): # pylint: disable=too-few-public-methods class StorageConfig(BaseModel): - path: str = "/srv/libretime" - - # Validators - _path_no_trailing_slash = no_trailing_slash_validator("path") + path: StrNoTrailingSlash = "/srv/libretime" # DatabaseConfig @@ -122,7 +93,7 @@ class RabbitMQConfig(BaseModel): class BaseInput(BaseModel): enabled: bool = True - public_url: Optional[AnyUrl] = None + public_url: Optional[AnyUrlStr] = None class InputKind(str, Enum): @@ -131,12 +102,10 @@ class InputKind(str, Enum): class HarborInput(BaseInput): kind: Literal[InputKind.HARBOR] = InputKind.HARBOR - mount: str + mount: StrNoLeadingSlash port: int secure: bool = False - _mount_no_leading_slash = no_leading_slash_validator("mount") - class MainHarborInput(HarborInput): mount: str = "main" @@ -162,7 +131,7 @@ class BaseAudio(BaseModel): channels: AudioChannels = AudioChannels.STEREO bitrate: int - @validator("bitrate") + @field_validator("bitrate") @classmethod def _validate_bitrate(cls, value: int) -> int: # Once the liquidsoap script generation supports it, fine tune @@ -200,11 +169,11 @@ class AudioOpus(BaseAudio): class IcecastOutput(BaseModel): kind: Literal["icecast"] = "icecast" enabled: bool = False - public_url: Optional[AnyUrl] = None + public_url: Optional[AnyUrlStr] = None host: str = "localhost" port: int = 8000 - mount: str + mount: StrNoLeadingSlash source_user: str = "source" source_password: str admin_user: str = "admin" @@ -222,13 +191,11 @@ class IcecastOutput(BaseModel): mobile: bool = False - _mount_no_leading_slash = no_leading_slash_validator("mount") - class ShoutcastOutput(BaseModel): kind: Literal["shoutcast"] = "shoutcast" enabled: bool = False - public_url: Optional[AnyUrl] = None + public_url: Optional[AnyUrlStr] = None host: str = "localhost" port: int = 8000 @@ -264,9 +231,9 @@ class SystemOutput(BaseModel): # pylint: disable=too-few-public-methods class Outputs(BaseModel): - icecast: List[IcecastOutput] = Field([], max_items=3) - shoutcast: List[ShoutcastOutput] = Field([], max_items=1) - system: List[SystemOutput] = Field([], max_items=1) + icecast: List[IcecastOutput] = Field([], max_length=3) + shoutcast: List[ShoutcastOutput] = Field([], max_length=1) + system: List[SystemOutput] = Field([], max_length=1) @property def merged(self) -> List[Union[IcecastOutput, ShoutcastOutput]]: diff --git a/shared/requirements.txt b/shared/requirements.txt index 41fb43e04..651a8c2a6 100644 --- a/shared/requirements.txt +++ b/shared/requirements.txt @@ -2,5 +2,5 @@ # This file is auto-generated by tools/extract_requirements.py. backports.zoneinfo>=0.2.1,<0.3;python_version<'3.9' click>=8.0.4,<8.2 -pydantic>=1.7.4,<1.11 +pydantic>=2.5.0,<2.6 pyyaml>=5.3.1,<6.1 diff --git a/shared/setup.py b/shared/setup.py index d777700b4..4c97be351 100644 --- a/shared/setup.py +++ b/shared/setup.py @@ -14,7 +14,7 @@ setup( install_requires=[ "backports.zoneinfo>=0.2.1,<0.3;python_version<'3.9'", "click>=8.0.4,<8.2", - "pydantic>=1.7.4,<1.11", + "pydantic>=2.5.0,<2.6", "pyyaml>=5.3.1,<6.1", ], extras_require={ diff --git a/shared/tests/config/base_test.py b/shared/tests/config/base_test.py index fc3d346e6..93e9b900f 100644 --- a/shared/tests/config/base_test.py +++ b/shared/tests/config/base_test.py @@ -3,17 +3,17 @@ from pathlib import Path from typing import List, Union from unittest import mock -from pydantic import AnyHttpUrl, BaseModel, Field +from pydantic import BaseModel, Field from pytest import mark, raises from typing_extensions import Annotated from libretime_shared.config import ( + AnyHttpUrlStr, BaseConfig, DatabaseConfig, IcecastOutput, RabbitMQConfig, ShoutcastOutput, - no_trailing_slash_validator, ) AnyOutput = Annotated[ @@ -24,15 +24,357 @@ AnyOutput = Annotated[ # pylint: disable=too-few-public-methods class FixtureConfig(BaseConfig): - public_url: AnyHttpUrl + public_url: AnyHttpUrlStr api_key: str allowed_hosts: List[str] = [] database: DatabaseConfig rabbitmq: RabbitMQConfig = RabbitMQConfig() outputs: List[AnyOutput] - # Validators - _public_url_no_trailing_slash = no_trailing_slash_validator("public_url") + +FIXTURE_CONFIG_JSON_SCHEMA = { + "$defs": { + "AudioAAC": { + "properties": { + "channels": { + "allOf": [{"$ref": "#/$defs/AudioChannels"}], + "default": "stereo", + }, + "bitrate": {"title": "Bitrate", "type": "integer"}, + "format": {"const": "aac", "default": "aac", "title": "Format"}, + }, + "required": ["bitrate"], + "title": "AudioAAC", + "type": "object", + }, + "AudioChannels": { + "enum": ["stereo", "mono"], + "title": "AudioChannels", + "type": "string", + }, + "AudioMP3": { + "properties": { + "channels": { + "allOf": [{"$ref": "#/$defs/AudioChannels"}], + "default": "stereo", + }, + "bitrate": {"title": "Bitrate", "type": "integer"}, + "format": {"const": "mp3", "default": "mp3", "title": "Format"}, + }, + "required": ["bitrate"], + "title": "AudioMP3", + "type": "object", + }, + "AudioOGG": { + "properties": { + "channels": { + "allOf": [{"$ref": "#/$defs/AudioChannels"}], + "default": "stereo", + }, + "bitrate": {"title": "Bitrate", "type": "integer"}, + "format": {"const": "ogg", "default": "ogg", "title": "Format"}, + "enable_metadata": { + "anyOf": [{"type": "boolean"}, {"type": "null"}], + "default": False, + "title": "Enable Metadata", + }, + }, + "required": ["bitrate"], + "title": "AudioOGG", + "type": "object", + }, + "AudioOpus": { + "properties": { + "channels": { + "allOf": [{"$ref": "#/$defs/AudioChannels"}], + "default": "stereo", + }, + "bitrate": {"title": "Bitrate", "type": "integer"}, + "format": { + "const": "opus", + "default": "opus", + "title": "Format", + }, + }, + "required": ["bitrate"], + "title": "AudioOpus", + "type": "object", + }, + "DatabaseConfig": { + "properties": { + "host": { + "default": "localhost", + "title": "Host", + "type": "string", + }, + "port": {"default": 5432, "title": "Port", "type": "integer"}, + "name": { + "default": "libretime", + "title": "Name", + "type": "string", + }, + "user": { + "default": "libretime", + "title": "User", + "type": "string", + }, + "password": { + "default": "libretime", + "title": "Password", + "type": "string", + }, + }, + "title": "DatabaseConfig", + "type": "object", + }, + "IcecastOutput": { + "properties": { + "kind": { + "const": "icecast", + "default": "icecast", + "title": "Kind", + }, + "enabled": { + "default": False, + "title": "Enabled", + "type": "boolean", + }, + "public_url": { + "anyOf": [ + {"type": "string", "format": "uri"}, + {"type": "null"}, + ], + "default": None, + "title": "Public Url", + }, + "host": { + "default": "localhost", + "title": "Host", + "type": "string", + }, + "port": {"default": 8000, "title": "Port", "type": "integer"}, + "mount": {"title": "Mount", "type": "string"}, + "source_user": { + "default": "source", + "title": "Source User", + "type": "string", + }, + "source_password": { + "title": "Source Password", + "type": "string", + }, + "admin_user": { + "default": "admin", + "title": "Admin User", + "type": "string", + }, + "admin_password": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "title": "Admin Password", + }, + "audio": { + "discriminator": { + "mapping": { + "aac": "#/$defs/AudioAAC", + "mp3": "#/$defs/AudioMP3", + "ogg": "#/$defs/AudioOGG", + "opus": "#/$defs/AudioOpus", + }, + "propertyName": "format", + }, + "oneOf": [ + {"$ref": "#/$defs/AudioAAC"}, + {"$ref": "#/$defs/AudioMP3"}, + {"$ref": "#/$defs/AudioOGG"}, + {"$ref": "#/$defs/AudioOpus"}, + ], + "title": "Audio", + }, + "name": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "title": "Name", + }, + "description": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "title": "Description", + }, + "website": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "title": "Website", + }, + "genre": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "title": "Genre", + }, + "mobile": { + "default": False, + "title": "Mobile", + "type": "boolean", + }, + }, + "required": ["mount", "source_password", "audio"], + "title": "IcecastOutput", + "type": "object", + }, + "RabbitMQConfig": { + "properties": { + "host": { + "default": "localhost", + "title": "Host", + "type": "string", + }, + "port": {"default": 5672, "title": "Port", "type": "integer"}, + "user": { + "default": "libretime", + "title": "User", + "type": "string", + }, + "password": { + "default": "libretime", + "title": "Password", + "type": "string", + }, + "vhost": { + "default": "/libretime", + "title": "Vhost", + "type": "string", + }, + }, + "title": "RabbitMQConfig", + "type": "object", + }, + "ShoutcastOutput": { + "properties": { + "kind": { + "const": "shoutcast", + "default": "shoutcast", + "title": "Kind", + }, + "enabled": { + "default": False, + "title": "Enabled", + "type": "boolean", + }, + "public_url": { + "anyOf": [ + {"type": "string", "format": "uri"}, + {"type": "null"}, + ], + "default": None, + "title": "Public Url", + }, + "host": { + "default": "localhost", + "title": "Host", + "type": "string", + }, + "port": {"default": 8000, "title": "Port", "type": "integer"}, + "source_user": { + "default": "source", + "title": "Source User", + "type": "string", + }, + "source_password": { + "title": "Source Password", + "type": "string", + }, + "admin_user": { + "default": "admin", + "title": "Admin User", + "type": "string", + }, + "admin_password": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "title": "Admin Password", + }, + "audio": { + "discriminator": { + "mapping": { + "aac": "#/$defs/AudioAAC", + "mp3": "#/$defs/AudioMP3", + }, + "propertyName": "format", + }, + "oneOf": [ + {"$ref": "#/$defs/AudioAAC"}, + {"$ref": "#/$defs/AudioMP3"}, + ], + "title": "Audio", + }, + "name": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "title": "Name", + }, + "website": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "title": "Website", + }, + "genre": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "title": "Genre", + }, + "mobile": { + "default": False, + "title": "Mobile", + "type": "boolean", + }, + }, + "required": ["source_password", "audio"], + "title": "ShoutcastOutput", + "type": "object", + }, + }, + "properties": { + "public_url": {"title": "Public Url", "type": "string", "format": "uri"}, + "api_key": {"title": "Api Key", "type": "string"}, + "allowed_hosts": { + "default": [], + "items": {"type": "string"}, + "title": "Allowed Hosts", + "type": "array", + }, + "database": {"$ref": "#/$defs/DatabaseConfig"}, + "rabbitmq": { + "allOf": [{"$ref": "#/$defs/RabbitMQConfig"}], + "default": { + "host": "localhost", + "port": 5672, + "user": "libretime", + "password": "libretime", + "vhost": "/libretime", + }, + }, + "outputs": { + "items": { + "discriminator": { + "mapping": { + "icecast": "#/$defs/IcecastOutput", + "shoutcast": "#/$defs/ShoutcastOutput", + }, + "propertyName": "kind", + }, + "oneOf": [ + {"$ref": "#/$defs/IcecastOutput"}, + {"$ref": "#/$defs/ShoutcastOutput"}, + ], + }, + "title": "Outputs", + "type": "array", + }, + }, + "required": ["public_url", "api_key", "database", "outputs"], + "title": "FixtureConfig", + "type": "object", +} FIXTURE_CONFIG_RAW = """ @@ -81,6 +423,8 @@ def test_base_config(tmp_path: Path): ): config = FixtureConfig(config_filepath) + assert config.model_json_schema() == FIXTURE_CONFIG_JSON_SCHEMA + assert config.public_url == "http://libretime.example.org" assert config.api_key == "f3bf04fc" assert config.allowed_hosts == ["example.com", "sub.example.com"] diff --git a/shared/tests/config/env_test.py b/shared/tests/config/env_test.py index bd5fc62a3..8c394adbd 100644 --- a/shared/tests/config/env_test.py +++ b/shared/tests/config/env_test.py @@ -89,7 +89,7 @@ class FixtureConfig(BaseConfig): a_list_of_union_obj: List[Union[FirstChildConfig, SecondChildConfig]] -ENV_SCHEMA = FixtureConfig.schema() +ENV_SCHEMA = FixtureConfig.model_json_schema() @pytest.mark.parametrize( diff --git a/shared/tests/config/fields_test.py b/shared/tests/config/fields_test.py new file mode 100644 index 000000000..cdbed1d2d --- /dev/null +++ b/shared/tests/config/fields_test.py @@ -0,0 +1,51 @@ +import pytest +from pydantic import TypeAdapter + +from libretime_shared.config._fields import ( + AnyHttpUrlStr, + StrNoLeadingSlash, + StrNoTrailingSlash, +) + + +@pytest.mark.parametrize( + "data, expected", + [ + ("something/", "something"), + ("something//", "something"), + ("something/keep", "something/keep"), + ("/something/", "/something"), + ], +) +def test_str_no_trailing_slash(data, expected): + found = TypeAdapter(StrNoTrailingSlash).validate_python(data) + assert found == expected + + +@pytest.mark.parametrize( + "data, expected", + [ + ("/something", "something"), + ("//something", "something"), + ("keep/something", "keep/something"), + ("/something/", "something/"), + ], +) +def test_str_no_leading_slash(data, expected): + found = TypeAdapter(StrNoLeadingSlash).validate_python(data) + assert found == expected + + +@pytest.mark.parametrize( + "data, expected", + [ + ("http://localhost:8080", "http://localhost:8080"), + ("http://localhost:8080/path/", "http://localhost:8080/path"), + ("https://example.com/", "https://example.com"), + ("https://example.com/keep", "https://example.com/keep"), + ("https://example.com/keep/", "https://example.com/keep"), + ], +) +def test_any_http_url_str(data, expected): + found = TypeAdapter(AnyHttpUrlStr).validate_python(data) + assert found == expected