feat(shared): create stream config models
This commit is contained in:
parent
12d2d4b15a
commit
d9920a1196
|
@ -1,8 +1,16 @@
|
|||
from ._base import DEFAULT_CONFIG_FILEPATH, DEFAULT_ENV_PREFIX, BaseConfig
|
||||
from ._models import (
|
||||
AudioChannels,
|
||||
AudioFormat,
|
||||
DatabaseConfig,
|
||||
GeneralConfig,
|
||||
HarborInput,
|
||||
IcecastOutput,
|
||||
RabbitMQConfig,
|
||||
ShoutcastOutput,
|
||||
StorageConfig,
|
||||
StreamConfig,
|
||||
SystemOutput,
|
||||
no_leading_slash_validator,
|
||||
no_trailing_slash_validator,
|
||||
)
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
from typing import TYPE_CHECKING, Any
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
from pydantic import AnyHttpUrl, BaseModel, validator
|
||||
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, validator
|
||||
from typing_extensions import Annotated, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.typing import AnyClassMethod
|
||||
|
@ -17,6 +19,20 @@ def no_trailing_slash_validator(key: str) -> "AnyClassMethod":
|
|||
return validator(key, pre=True, allow_reuse=True)(strip_trailing_slash)
|
||||
|
||||
|
||||
def no_leading_slash_validator(key: str) -> "AnyClassMethod":
|
||||
# pylint: disable=unused-argument
|
||||
def strip_leading_slash(cls: Any, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return value.lstrip("/")
|
||||
return value
|
||||
|
||||
return validator(key, pre=True, allow_reuse=True)(strip_leading_slash)
|
||||
|
||||
|
||||
# GeneralConfig
|
||||
########################################################################################
|
||||
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class GeneralConfig(BaseModel):
|
||||
public_url: AnyHttpUrl
|
||||
|
@ -26,6 +42,9 @@ class GeneralConfig(BaseModel):
|
|||
_public_url_no_trailing_slash = no_trailing_slash_validator("public_url")
|
||||
|
||||
|
||||
# StorageConfig
|
||||
########################################################################################
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class StorageConfig(BaseModel):
|
||||
path: str = "/srv/libretime"
|
||||
|
@ -34,6 +53,9 @@ class StorageConfig(BaseModel):
|
|||
_path_no_trailing_slash = no_trailing_slash_validator("path")
|
||||
|
||||
|
||||
# DatabaseConfig
|
||||
########################################################################################
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class DatabaseConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
|
@ -50,6 +72,9 @@ class DatabaseConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
# RabbitMQConfig
|
||||
########################################################################################
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class RabbitMQConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
|
@ -64,3 +89,163 @@ class RabbitMQConfig(BaseModel):
|
|||
f"amqp://{self.user}:{self.password}"
|
||||
f"@{self.host}:{self.port}/{self.vhost}"
|
||||
)
|
||||
|
||||
|
||||
# StreamConfig
|
||||
########################################################################################
|
||||
|
||||
|
||||
class BaseInput(BaseModel):
|
||||
enabled: bool = True
|
||||
public_url: Optional[AnyUrl] = None
|
||||
|
||||
|
||||
class InputKind(str, Enum):
|
||||
HARBOR = "harbor"
|
||||
|
||||
|
||||
class HarborInput(BaseInput):
|
||||
kind: Literal[InputKind.HARBOR] = InputKind.HARBOR
|
||||
mount: str
|
||||
port: int
|
||||
|
||||
_mount_no_leading_slash = no_leading_slash_validator("mount")
|
||||
|
||||
|
||||
class MainHarborInput(HarborInput):
|
||||
mount: str = "main"
|
||||
port: int = 8001
|
||||
|
||||
|
||||
class ShowHarborInput(HarborInput):
|
||||
mount: str = "show"
|
||||
port: int = 8002
|
||||
|
||||
|
||||
class Inputs(BaseModel):
|
||||
main: HarborInput = MainHarborInput()
|
||||
show: HarborInput = ShowHarborInput()
|
||||
|
||||
|
||||
class AudioChannels(str, Enum):
|
||||
STEREO = "stereo"
|
||||
MONO = "mono"
|
||||
|
||||
|
||||
class BaseAudio(BaseModel):
|
||||
channels: AudioChannels = AudioChannels.STEREO
|
||||
bitrate: int
|
||||
|
||||
@validator("bitrate")
|
||||
@classmethod
|
||||
def _validate_bitrate(cls, value: int) -> int:
|
||||
# Once the liquidsoap script generation supports it, fine tune
|
||||
# the bitrate validation for each format
|
||||
bitrates = (32, 48, 64, 96, 128, 160, 192, 224, 256, 320)
|
||||
if value not in bitrates:
|
||||
raise ValueError(f"invalid bitrate {value}, must be one of {bitrates}")
|
||||
return value
|
||||
|
||||
|
||||
class AudioFormat(str, Enum):
|
||||
AAC = "aac"
|
||||
MP3 = "mp3"
|
||||
OGG = "ogg"
|
||||
OPUS = "opus"
|
||||
|
||||
|
||||
class AudioAAC(BaseAudio):
|
||||
format: Literal[AudioFormat.AAC] = AudioFormat.AAC
|
||||
|
||||
|
||||
class AudioMP3(BaseAudio):
|
||||
format: Literal[AudioFormat.MP3] = AudioFormat.MP3
|
||||
|
||||
|
||||
class AudioOGG(BaseAudio):
|
||||
format: Literal[AudioFormat.OGG] = AudioFormat.OGG
|
||||
enable_metadata: Optional[bool] = False
|
||||
|
||||
|
||||
class AudioOpus(BaseAudio):
|
||||
format: Literal[AudioFormat.OPUS] = AudioFormat.OPUS
|
||||
|
||||
|
||||
class IcecastOutput(BaseModel):
|
||||
kind: Literal["icecast"] = "icecast"
|
||||
enabled: bool = False
|
||||
public_url: Optional[AnyUrl] = None
|
||||
|
||||
host: str = "localhost"
|
||||
port: int = 8000
|
||||
mount: str
|
||||
source_user: str = "source"
|
||||
source_password: str
|
||||
admin_user: str = "admin"
|
||||
admin_password: Optional[str] = None
|
||||
|
||||
audio: Annotated[
|
||||
Union[AudioAAC, AudioMP3, AudioOGG, AudioOpus],
|
||||
Field(discriminator="format"),
|
||||
]
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
genre: Optional[str] = None
|
||||
|
||||
_mount_no_leading_slash = no_leading_slash_validator("mount")
|
||||
|
||||
|
||||
class ShoutcastOutput(BaseModel):
|
||||
kind: Literal["shoutcast"] = "shoutcast"
|
||||
enabled: bool = False
|
||||
public_url: Optional[AnyUrl] = None
|
||||
|
||||
host: str = "localhost"
|
||||
port: int = 8000
|
||||
source_user: str = "source"
|
||||
source_password: str
|
||||
admin_user: str = "admin"
|
||||
admin_password: Optional[str] = None
|
||||
|
||||
audio: Annotated[
|
||||
Union[AudioAAC, AudioMP3],
|
||||
Field(discriminator="format"),
|
||||
]
|
||||
|
||||
name: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
genre: Optional[str] = None
|
||||
|
||||
|
||||
class SystemOutputKind(str, Enum):
|
||||
ALSA = "alsa"
|
||||
AO = "ao"
|
||||
OSS = "oss"
|
||||
PORTAUDIO = "portaudio"
|
||||
PULSEAUDIO = "pulseaudio"
|
||||
|
||||
|
||||
class SystemOutput(BaseModel):
|
||||
enabled: bool = False
|
||||
kind: SystemOutputKind = SystemOutputKind.ALSA
|
||||
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class Outputs(BaseModel):
|
||||
icecast: List[IcecastOutput] = Field([], max_items=3)
|
||||
shoutcast: List[ShoutcastOutput] = Field([], max_items=1)
|
||||
system: List[SystemOutput] = Field([], max_items=1)
|
||||
|
||||
@property
|
||||
def merged(self) -> Sequence[Union[IcecastOutput, ShoutcastOutput]]:
|
||||
return self.icecast + self.shoutcast # type: ignore
|
||||
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class StreamConfig(BaseModel):
|
||||
"""Stream configuration model."""
|
||||
|
||||
inputs: Inputs = Inputs()
|
||||
outputs: Outputs = Outputs()
|
||||
|
|
|
@ -1,18 +1,25 @@
|
|||
from os import environ
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
from unittest import mock
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
from pytest import mark, raises
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from libretime_shared.config import (
|
||||
BaseConfig,
|
||||
DatabaseConfig,
|
||||
IcecastOutput,
|
||||
RabbitMQConfig,
|
||||
ShoutcastOutput,
|
||||
no_trailing_slash_validator,
|
||||
)
|
||||
|
||||
AnyOutput = Annotated[
|
||||
Union[IcecastOutput, ShoutcastOutput],
|
||||
Field(discriminator="kind"),
|
||||
]
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class FixtureConfig(BaseConfig):
|
||||
|
@ -21,6 +28,7 @@ class FixtureConfig(BaseConfig):
|
|||
allowed_hosts: List[str] = []
|
||||
database: DatabaseConfig
|
||||
rabbitmq: RabbitMQConfig = RabbitMQConfig()
|
||||
outputs: List[AnyOutput]
|
||||
|
||||
# Validators
|
||||
_public_url_no_trailing_slash = no_trailing_slash_validator("public_url")
|
||||
|
@ -39,6 +47,17 @@ database:
|
|||
port: 5432
|
||||
|
||||
ignored: "ignored"
|
||||
|
||||
outputs:
|
||||
- enabled: true
|
||||
kind: icecast
|
||||
host: localhost
|
||||
port: 8000
|
||||
mount: main.ogg
|
||||
source_password: hackme
|
||||
audio:
|
||||
format: ogg
|
||||
bitrate: 256
|
||||
"""
|
||||
|
||||
|
||||
|
@ -54,6 +73,8 @@ def test_base_config(tmp_path: Path):
|
|||
"LIBRETIME_DATABASE": "invalid",
|
||||
"LIBRETIME_RABBITMQ": "invalid",
|
||||
"LIBRETIME_RABBITMQ_HOST": "changed",
|
||||
"LIBRETIME_OUTPUTS_0_ENABLED": "false",
|
||||
"LIBRETIME_OUTPUTS_0_HOST": "changed",
|
||||
"WRONGPREFIX_API_KEY": "invalid",
|
||||
},
|
||||
):
|
||||
|
@ -66,6 +87,10 @@ def test_base_config(tmp_path: Path):
|
|||
assert config.database.port == 8888
|
||||
assert config.rabbitmq.host == "changed"
|
||||
assert config.rabbitmq.port == 5672
|
||||
assert config.outputs[0].enabled is False
|
||||
assert config.outputs[0].kind == "icecast"
|
||||
assert config.outputs[0].host == "changed"
|
||||
assert config.outputs[0].audio.format == "ogg"
|
||||
|
||||
# Optional model: loading default values (rabbitmq)
|
||||
with mock.patch.dict(environ, {}):
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from libretime_shared.config._models import (
|
||||
AudioAAC,
|
||||
AudioMP3,
|
||||
AudioOGG,
|
||||
AudioOpus,
|
||||
StreamConfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"audio",
|
||||
[
|
||||
(AudioAAC),
|
||||
(AudioMP3),
|
||||
(AudioOGG),
|
||||
(AudioOpus),
|
||||
],
|
||||
)
|
||||
def test_audio(audio):
|
||||
audio(bitrate=32)
|
||||
audio(bitrate=320)
|
||||
with pytest.raises(ValidationError):
|
||||
audio(bitrate=11)
|
||||
with pytest.raises(ValidationError):
|
||||
audio(bitrate=321)
|
||||
|
||||
|
||||
def test_stream_config():
|
||||
icecast_output = {
|
||||
"mount": "mount",
|
||||
"source_password": "hackme",
|
||||
"audio": {"format": "ogg", "bitrate": 256},
|
||||
}
|
||||
assert StreamConfig(outputs={"icecast": [icecast_output] * 3})
|
||||
with pytest.raises(ValidationError):
|
||||
StreamConfig(outputs={"icecast": [icecast_output] * 4})
|
||||
|
||||
shoutcast_output = {
|
||||
"source_password": "hackme",
|
||||
"audio": {"format": "mp3", "bitrate": 256},
|
||||
}
|
||||
assert StreamConfig(outputs={"shoutcast": [shoutcast_output]})
|
||||
with pytest.raises(ValidationError):
|
||||
StreamConfig(outputs={"shoutcast": [shoutcast_output] * 2})
|
||||
|
||||
system_output = {
|
||||
"kind": "alsa",
|
||||
}
|
||||
assert StreamConfig(outputs={"system": [system_output]})
|
||||
with pytest.raises(ValidationError):
|
||||
StreamConfig(outputs={"system": [system_output] * 2})
|
||||
|
||||
config = StreamConfig(
|
||||
outputs={
|
||||
"icecast": [icecast_output],
|
||||
"shoutcast": [shoutcast_output],
|
||||
"system": [system_output],
|
||||
}
|
||||
)
|
||||
assert len(config.outputs.icecast) == 1
|
||||
assert len(config.outputs.shoutcast) == 1
|
||||
assert len(config.outputs.system) == 1
|
Loading…
Reference in New Issue