feat(api_client): load config using shared helpers

This commit is contained in:
jo 2022-02-22 18:19:16 +01:00 committed by Kyle Robbertze
parent ba0897a023
commit d42615eb6a
9 changed files with 88 additions and 179 deletions

View File

@ -2,7 +2,7 @@ all: lint test
include ../tools/python.mk include ../tools/python.mk
PIP_INSTALL := --editable . PIP_INSTALL := --editable .[dev]
PYLINT_ARG := libretime_api_client tests || true PYLINT_ARG := libretime_api_client tests || true
MYPY_ARG := libretime_api_client tests || true MYPY_ARG := libretime_api_client tests || true
BANDIT_ARG := libretime_api_client || true BANDIT_ARG := libretime_api_client || true

View File

@ -0,0 +1,5 @@
from libretime_shared.config import BaseConfig, GeneralConfig
class Config(BaseConfig):
general: GeneralConfig

View File

@ -1,26 +1,11 @@
import datetime import datetime
import json
import logging import logging
import socket
from time import sleep from time import sleep
import requests import requests
from requests.auth import AuthBase from requests.auth import AuthBase
def get_protocol(config):
positive_values = ["Yes", "yes", "True", "true", True]
port = config["general"].get("base_port", 80)
force_ssl = config["general"].get("force_ssl", False)
if force_ssl in positive_values:
protocol = "https"
else:
protocol = config["general"].get("protocol")
if not protocol:
protocol = str(("http", "https")[int(port) == 443])
return protocol
class UrlParamDict(dict): class UrlParamDict(dict):
def __missing__(self, key): def __missing__(self, key):
return "{" + key + "}" return "{" + key + "}"
@ -151,38 +136,21 @@ class ApiRequest:
class RequestProvider: class RequestProvider:
"""Creates the available ApiRequest instance that can be read from """
a config file""" Creates the available ApiRequest instance
"""
def __init__(self, cfg, endpoints): def __init__(self, base_url: str, api_key: str, endpoints: dict):
self.config = cfg
self.requests = {} self.requests = {}
if self.config["general"]["base_dir"].startswith("/"): self.url = ApcUrl(base_url + "/{action}")
self.config["general"]["base_dir"] = self.config["general"]["base_dir"][1:]
protocol = get_protocol(self.config)
base_port = self.config["general"]["base_port"]
base_url = self.config["general"]["base_url"]
base_dir = self.config["general"]["base_dir"]
api_base = self.config["api_base"]
api_url = "{protocol}://{base_url}:{base_port}/{base_dir}{api_base}/{action}".format_map(
UrlParamDict(
protocol=protocol,
base_url=base_url,
base_port=base_port,
base_dir=base_dir,
api_base=api_base,
)
)
self.url = ApcUrl(api_url)
# Now we must discover the possible actions # Now we must discover the possible actions
for action_name, action_value in endpoints.items(): for action_name, action_value in endpoints.items():
new_url = self.url.params(action=action_value) new_url = self.url.params(action=action_value)
if "{api_key}" in action_value: if "{api_key}" in action_value:
new_url = new_url.params(api_key=self.config["general"]["api_key"]) new_url = new_url.params(api_key=api_key)
self.requests[action_name] = ApiRequest( self.requests[action_name] = ApiRequest(
action_name, new_url, api_key=self.config["general"]["api_key"] action_name, new_url, api_key=api_key
) )
def available_requests(self): def available_requests(self):

View File

@ -9,20 +9,18 @@
import base64 import base64
import json import json
import logging import logging
import sys
import time import time
import traceback import traceback
import urllib.parse import urllib.parse
import requests import requests
from configobj import ConfigObj
from .utils import ApiRequest, RequestProvider, get_protocol from ._config import Config
from .utils import ApiRequest, RequestProvider
AIRTIME_API_VERSION = "1.1" AIRTIME_API_VERSION = "1.1"
api_config = {}
api_endpoints = {} api_endpoints = {}
# URL to get the version number of the server API # URL to get the version number of the server API
@ -67,8 +65,6 @@ api_endpoints[
# show-recorder # show-recorder
api_endpoints["show_schedule_url"] = "recorded-shows/format/json/api_key/{api_key}" api_endpoints["show_schedule_url"] = "recorded-shows/format/json/api_key/{api_key}"
api_endpoints["upload_file_url"] = "rest/media" api_endpoints["upload_file_url"] = "rest/media"
api_endpoints["upload_retries"] = "3"
api_endpoints["upload_wait"] = "60"
# pypo # pypo
api_endpoints["export_url"] = "schedule/api_key/{api_key}" api_endpoints["export_url"] = "schedule/api_key/{api_key}"
api_endpoints["get_media_url"] = "get-media/file/{file}/api_key/{api_key}" api_endpoints["get_media_url"] = "get-media/file/{file}/api_key/{api_key}"
@ -119,28 +115,28 @@ api_endpoints[
api_endpoints[ api_endpoints[
"update_metadata_on_tunein" "update_metadata_on_tunein"
] = "update-metadata-on-tunein/api_key/{api_key}" ] = "update-metadata-on-tunein/api_key/{api_key}"
api_config["api_base"] = "api"
api_config["bin_dir"] = "/usr/lib/airtime/api_clients/"
################################################################################ ################################################################################
# Airtime API Version 1 Client # Airtime API Version 1 Client
################################################################################ ################################################################################
class AirtimeApiClient: class AirtimeApiClient:
def __init__(self, logger=None, config_path="/etc/airtime/airtime.conf"): API_BASE = "/api"
if logger is None: UPLOAD_RETRIES = 3
self.logger = logging UPLOAD_WAIT = 60
else:
self.logger = logger
# loading config file def __init__(self, logger=None, config_path="/etc/airtime/airtime.conf"):
try: self.logger = logger or logging
self.config = ConfigObj(config_path)
self.config.update(api_config) config = Config(filepath=config_path)
self.services = RequestProvider(self.config, api_endpoints) self.base_url = config.general.public_url
except Exception as e: self.api_key = config.general.api_key
self.logger.exception("Error loading config file: %s", config_path)
sys.exit(1) self.services = RequestProvider(
base_url=self.base_url + self.API_BASE,
api_key=self.api_key,
endpoints=api_endpoints,
)
def __get_airtime_version(self): def __get_airtime_version(self):
try: try:
@ -213,8 +209,8 @@ class AirtimeApiClient:
logger = self.logger logger = self.logger
response = "" response = ""
retries = int(self.config["upload_retries"]) retries = self.UPLOAD_RETRIES
retries_wait = int(self.config["upload_wait"]) retries_wait = self.UPLOAD_WAIT
url = self.construct_rest_url("upload_file_url") url = self.construct_rest_url("upload_file_url")
@ -276,37 +272,13 @@ class AirtimeApiClient:
self.logger.exception(e) self.logger.exception(e)
return {} return {}
def construct_url(self, config_action_key): def construct_rest_url(self, action_key):
"""Constructs the base url for every request""" """
# TODO : Make other methods in this class use this this method. Constructs the base url for RESTful requests
if self.config["general"]["base_dir"].startswith("/"): """
self.config["general"]["base_dir"] = self.config["general"]["base_dir"][1:] url = urllib.parse.urlsplit(self.base_url)
protocol = get_protocol(self.config) url.username = self.api_key
url = "{}://{}:{}/{}{}/{}".format( return f"{url.geturl()}/{api_endpoints[action_key]}"
protocol,
self.config["general"]["base_url"],
str(self.config["general"]["base_port"]),
self.config["general"]["base_dir"],
self.config["api_base"],
self.config[config_action_key],
)
url = url.replace("%%api_key%%", self.config["general"]["api_key"])
return url
def construct_rest_url(self, config_action_key):
"""Constructs the base url for RESTful requests"""
if self.config["general"]["base_dir"].startswith("/"):
self.config["general"]["base_dir"] = self.config["general"]["base_dir"][1:]
protocol = get_protocol(self.config)
url = "{}://{}:@{}:{}/{}/{}".format(
protocol,
self.config["general"]["api_key"],
self.config["general"]["base_url"],
str(self.config["general"]["base_port"]),
self.config["general"]["base_dir"],
self.config[config_action_key],
)
return url
""" """
Caller of this method needs to catch any exceptions such as Caller of this method needs to catch any exceptions such as

View File

@ -7,17 +7,15 @@
# schedule a playlist one minute from the current time. # schedule a playlist one minute from the current time.
############################################################################### ###############################################################################
import logging import logging
import sys
from datetime import datetime, timedelta from datetime import datetime, timedelta
from configobj import ConfigObj
from dateutil.parser import isoparse from dateutil.parser import isoparse
from ._config import Config
from .utils import RequestProvider, fromisoformat, time_in_milliseconds, time_in_seconds from .utils import RequestProvider, fromisoformat, time_in_milliseconds, time_in_seconds
LIBRETIME_API_VERSION = "2.0" LIBRETIME_API_VERSION = "2.0"
api_config = {}
api_endpoints = {} api_endpoints = {}
api_endpoints["version_url"] = "version/" api_endpoints["version_url"] = "version/"
@ -27,23 +25,23 @@ api_endpoints["show_instance_url"] = "show-instances/{id}/"
api_endpoints["show_url"] = "shows/{id}/" api_endpoints["show_url"] = "shows/{id}/"
api_endpoints["file_url"] = "files/{id}/" api_endpoints["file_url"] = "files/{id}/"
api_endpoints["file_download_url"] = "files/{id}/download/" api_endpoints["file_download_url"] = "files/{id}/download/"
api_config["api_base"] = "api/v2"
class AirtimeApiClient: class AirtimeApiClient:
def __init__(self, logger=None, config_path="/etc/airtime/airtime.conf"): API_BASE = "/api/v2"
if logger is None:
self.logger = logging
else:
self.logger = logger
try: def __init__(self, logger=None, config_path="/etc/airtime/airtime.conf"):
self.config = ConfigObj(config_path) self.logger = logger or logging
self.config.update(api_config)
self.services = RequestProvider(self.config, api_endpoints) config = Config(filepath=config_path)
except Exception as e: self.base_url = config.general.public_url
self.logger.exception("Error loading config file: %s", config_path) self.api_key = config.general.api_key
sys.exit(1)
self.services = RequestProvider(
base_url=self.base_url + self.API_BASE,
api_key=self.api_key,
endpoints=api_endpoints,
)
def get_schedule(self): def get_schedule(self):
current_time = datetime.utcnow() current_time = datetime.utcnow()

View File

@ -22,9 +22,13 @@ setup(
packages=["libretime_api_client"], packages=["libretime_api_client"],
python_requires=">=3.6", python_requires=">=3.6",
install_requires=[ install_requires=[
"configobj",
"python-dateutil>=2.7.0", "python-dateutil>=2.7.0",
"requests", "requests",
], ],
extras_require={
"dev": [
f"libretime-shared @ file://localhost{here.parent / 'shared'}",
],
},
zip_safe=False, zip_safe=False,
) )

View File

@ -1,34 +1,26 @@
import pytest
from libretime_api_client.utils import RequestProvider from libretime_api_client.utils import RequestProvider
from libretime_api_client.version1 import api_config
@pytest.fixture() def test_request_provider_init():
def config(): request_provider = RequestProvider(
return { base_url="http://localhost/test",
**api_config, api_key="test_key",
"general": { endpoints={},
"base_dir": "/test", )
"base_port": 80,
"base_url": "localhost",
"api_key": "TEST_KEY",
},
"api_base": "api",
}
def test_request_provider_init(config):
request_provider = RequestProvider(config, {})
assert len(request_provider.available_requests()) == 0 assert len(request_provider.available_requests()) == 0
def test_request_provider_contains(config): def test_request_provider_contains():
endpoints = { endpoints = {
"upload_recorded": "/1/", "upload_recorded": "/1/",
"update_media_url": "/2/", "update_media_url": "/2/",
"list_all_db_files": "/3/", "list_all_db_files": "/3/",
} }
request_provider = RequestProvider(config, endpoints) request_provider = RequestProvider(
base_url="http://localhost/test",
api_key="test_key",
endpoints=endpoints,
)
for endpoint in endpoints: for endpoint in endpoints:
assert endpoint in request_provider.requests assert endpoint in request_provider.requests

View File

@ -1,5 +1,4 @@
import datetime import datetime
from configparser import ConfigParser
import pytest import pytest
@ -16,37 +15,6 @@ def test_time_in_milliseconds():
assert utils.time_in_milliseconds(time) == 500 assert utils.time_in_milliseconds(time) == 500
@pytest.mark.parametrize(
"payload, expected",
[({}, "http"), ({"base_port": 80}, "http"), ({"base_port": 443}, "https")],
)
@pytest.mark.parametrize(
"use_config",
[False, True],
)
def test_get_protocol(payload, use_config, expected):
config = ConfigParser() if use_config else {}
config["general"] = {**payload}
assert utils.get_protocol(config) == expected
@pytest.mark.parametrize("payload", [{}, {"base_port": 80}])
@pytest.mark.parametrize("use_config", [False, True])
@pytest.mark.parametrize(
"values, expected",
[
(["yes", "Yes", "True", "true", True], "https"),
(["no", "No", "False", "false", False], "http"),
],
)
def test_get_protocol_force_https(payload, use_config, values, expected):
for value in values:
config = ConfigParser() if use_config else {}
config["general"] = {**payload, "force_ssl": value}
assert utils.get_protocol(config) == expected
@pytest.mark.parametrize( @pytest.mark.parametrize(
"payload, expected", "payload, expected",
[ [

View File

@ -1,21 +1,23 @@
from pathlib import Path
import pytest import pytest
from libretime_api_client.utils import RequestProvider from libretime_api_client.version2 import AirtimeApiClient
from libretime_api_client.version2 import AirtimeApiClient, api_config
@pytest.fixture() @pytest.fixture()
def config(): def config_filepath(tmp_path: Path):
return { filepath = tmp_path / "airtime.conf"
**api_config, filepath.write_text(
"general": { """
"base_dir": "/test", [general]
"base_port": 80, api_key = TEST_KEY
"base_url": "localhost", base_dir = /test
"api_key": "TEST_KEY", base_port = 80
}, base_url = localhost
"api_base": "api", """
} )
return filepath
class MockRequestProvider: class MockRequestProvider:
@ -82,8 +84,8 @@ class MockRequestProvider:
} }
def test_get_schedule(monkeypatch, config): def test_get_schedule(monkeypatch, config_filepath):
client = AirtimeApiClient(None, config) client = AirtimeApiClient(config_path=config_filepath)
client.services = MockRequestProvider() client.services = MockRequestProvider()
schedule = client.get_schedule() schedule = client.get_schedule()
assert schedule == { assert schedule == {