From d42615eb6a70ffe45e40849b134590ec292d479a Mon Sep 17 00:00:00 2001 From: jo Date: Tue, 22 Feb 2022 18:19:16 +0100 Subject: [PATCH] feat(api_client): load config using shared helpers --- api_client/Makefile | 2 +- api_client/libretime_api_client/_config.py | 5 ++ api_client/libretime_api_client/utils.py | 46 ++---------- api_client/libretime_api_client/version1.py | 80 +++++++-------------- api_client/libretime_api_client/version2.py | 30 ++++---- api_client/setup.py | 6 +- api_client/tests/requestprovider_test.py | 34 ++++----- api_client/tests/utils_test.py | 32 --------- api_client/tests/version2_test.py | 32 +++++---- 9 files changed, 88 insertions(+), 179 deletions(-) create mode 100644 api_client/libretime_api_client/_config.py diff --git a/api_client/Makefile b/api_client/Makefile index 83b229d64..35f8155fb 100644 --- a/api_client/Makefile +++ b/api_client/Makefile @@ -2,7 +2,7 @@ all: lint test include ../tools/python.mk -PIP_INSTALL := --editable . +PIP_INSTALL := --editable .[dev] PYLINT_ARG := libretime_api_client tests || true MYPY_ARG := libretime_api_client tests || true BANDIT_ARG := libretime_api_client || true diff --git a/api_client/libretime_api_client/_config.py b/api_client/libretime_api_client/_config.py new file mode 100644 index 000000000..b2aeecf01 --- /dev/null +++ b/api_client/libretime_api_client/_config.py @@ -0,0 +1,5 @@ +from libretime_shared.config import BaseConfig, GeneralConfig + + +class Config(BaseConfig): + general: GeneralConfig diff --git a/api_client/libretime_api_client/utils.py b/api_client/libretime_api_client/utils.py index 9659a3416..ad20784aa 100644 --- a/api_client/libretime_api_client/utils.py +++ b/api_client/libretime_api_client/utils.py @@ -1,26 +1,11 @@ import datetime -import json import logging -import socket from time import sleep import requests from requests.auth import AuthBase -def get_protocol(config): - positive_values = ["Yes", "yes", "True", "true", True] - port = config["general"].get("base_port", 80) - force_ssl = config["general"].get("force_ssl", False) - if force_ssl in positive_values: - protocol = "https" - else: - protocol = config["general"].get("protocol") - if not protocol: - protocol = str(("http", "https")[int(port) == 443]) - return protocol - - class UrlParamDict(dict): def __missing__(self, key): return "{" + key + "}" @@ -151,38 +136,21 @@ class ApiRequest: class RequestProvider: - """Creates the available ApiRequest instance that can be read from - a config file""" + """ + Creates the available ApiRequest instance + """ - def __init__(self, cfg, endpoints): - self.config = cfg + def __init__(self, base_url: str, api_key: str, endpoints: dict): self.requests = {} - if self.config["general"]["base_dir"].startswith("/"): - self.config["general"]["base_dir"] = self.config["general"]["base_dir"][1:] - - protocol = get_protocol(self.config) - base_port = self.config["general"]["base_port"] - base_url = self.config["general"]["base_url"] - base_dir = self.config["general"]["base_dir"] - api_base = self.config["api_base"] - api_url = "{protocol}://{base_url}:{base_port}/{base_dir}{api_base}/{action}".format_map( - UrlParamDict( - protocol=protocol, - base_url=base_url, - base_port=base_port, - base_dir=base_dir, - api_base=api_base, - ) - ) - self.url = ApcUrl(api_url) + self.url = ApcUrl(base_url + "/{action}") # Now we must discover the possible actions for action_name, action_value in endpoints.items(): new_url = self.url.params(action=action_value) if "{api_key}" in action_value: - new_url = new_url.params(api_key=self.config["general"]["api_key"]) + new_url = new_url.params(api_key=api_key) self.requests[action_name] = ApiRequest( - action_name, new_url, api_key=self.config["general"]["api_key"] + action_name, new_url, api_key=api_key ) def available_requests(self): diff --git a/api_client/libretime_api_client/version1.py b/api_client/libretime_api_client/version1.py index e113b564c..f277c3f6d 100644 --- a/api_client/libretime_api_client/version1.py +++ b/api_client/libretime_api_client/version1.py @@ -9,20 +9,18 @@ import base64 import json import logging -import sys import time import traceback import urllib.parse import requests -from configobj import ConfigObj -from .utils import ApiRequest, RequestProvider, get_protocol +from ._config import Config +from .utils import ApiRequest, RequestProvider AIRTIME_API_VERSION = "1.1" -api_config = {} api_endpoints = {} # URL to get the version number of the server API @@ -67,8 +65,6 @@ api_endpoints[ # show-recorder api_endpoints["show_schedule_url"] = "recorded-shows/format/json/api_key/{api_key}" api_endpoints["upload_file_url"] = "rest/media" -api_endpoints["upload_retries"] = "3" -api_endpoints["upload_wait"] = "60" # pypo api_endpoints["export_url"] = "schedule/api_key/{api_key}" api_endpoints["get_media_url"] = "get-media/file/{file}/api_key/{api_key}" @@ -119,28 +115,28 @@ api_endpoints[ api_endpoints[ "update_metadata_on_tunein" ] = "update-metadata-on-tunein/api_key/{api_key}" -api_config["api_base"] = "api" -api_config["bin_dir"] = "/usr/lib/airtime/api_clients/" ################################################################################ # Airtime API Version 1 Client ################################################################################ class AirtimeApiClient: - def __init__(self, logger=None, config_path="/etc/airtime/airtime.conf"): - if logger is None: - self.logger = logging - else: - self.logger = logger + API_BASE = "/api" + UPLOAD_RETRIES = 3 + UPLOAD_WAIT = 60 - # loading config file - try: - self.config = ConfigObj(config_path) - self.config.update(api_config) - self.services = RequestProvider(self.config, api_endpoints) - except Exception as e: - self.logger.exception("Error loading config file: %s", config_path) - sys.exit(1) + def __init__(self, logger=None, config_path="/etc/airtime/airtime.conf"): + self.logger = logger or logging + + config = Config(filepath=config_path) + self.base_url = config.general.public_url + self.api_key = config.general.api_key + + self.services = RequestProvider( + base_url=self.base_url + self.API_BASE, + api_key=self.api_key, + endpoints=api_endpoints, + ) def __get_airtime_version(self): try: @@ -213,8 +209,8 @@ class AirtimeApiClient: logger = self.logger response = "" - retries = int(self.config["upload_retries"]) - retries_wait = int(self.config["upload_wait"]) + retries = self.UPLOAD_RETRIES + retries_wait = self.UPLOAD_WAIT url = self.construct_rest_url("upload_file_url") @@ -276,37 +272,13 @@ class AirtimeApiClient: self.logger.exception(e) return {} - def construct_url(self, config_action_key): - """Constructs the base url for every request""" - # TODO : Make other methods in this class use this this method. - if self.config["general"]["base_dir"].startswith("/"): - self.config["general"]["base_dir"] = self.config["general"]["base_dir"][1:] - protocol = get_protocol(self.config) - url = "{}://{}:{}/{}{}/{}".format( - protocol, - self.config["general"]["base_url"], - str(self.config["general"]["base_port"]), - self.config["general"]["base_dir"], - self.config["api_base"], - self.config[config_action_key], - ) - url = url.replace("%%api_key%%", self.config["general"]["api_key"]) - return url - - def construct_rest_url(self, config_action_key): - """Constructs the base url for RESTful requests""" - if self.config["general"]["base_dir"].startswith("/"): - self.config["general"]["base_dir"] = self.config["general"]["base_dir"][1:] - protocol = get_protocol(self.config) - url = "{}://{}:@{}:{}/{}/{}".format( - protocol, - self.config["general"]["api_key"], - self.config["general"]["base_url"], - str(self.config["general"]["base_port"]), - self.config["general"]["base_dir"], - self.config[config_action_key], - ) - return url + def construct_rest_url(self, action_key): + """ + Constructs the base url for RESTful requests + """ + url = urllib.parse.urlsplit(self.base_url) + url.username = self.api_key + return f"{url.geturl()}/{api_endpoints[action_key]}" """ Caller of this method needs to catch any exceptions such as diff --git a/api_client/libretime_api_client/version2.py b/api_client/libretime_api_client/version2.py index e2826338b..12bd702bd 100644 --- a/api_client/libretime_api_client/version2.py +++ b/api_client/libretime_api_client/version2.py @@ -7,17 +7,15 @@ # schedule a playlist one minute from the current time. ############################################################################### import logging -import sys from datetime import datetime, timedelta -from configobj import ConfigObj from dateutil.parser import isoparse +from ._config import Config from .utils import RequestProvider, fromisoformat, time_in_milliseconds, time_in_seconds LIBRETIME_API_VERSION = "2.0" -api_config = {} api_endpoints = {} api_endpoints["version_url"] = "version/" @@ -27,23 +25,23 @@ api_endpoints["show_instance_url"] = "show-instances/{id}/" api_endpoints["show_url"] = "shows/{id}/" api_endpoints["file_url"] = "files/{id}/" api_endpoints["file_download_url"] = "files/{id}/download/" -api_config["api_base"] = "api/v2" class AirtimeApiClient: - def __init__(self, logger=None, config_path="/etc/airtime/airtime.conf"): - if logger is None: - self.logger = logging - else: - self.logger = logger + API_BASE = "/api/v2" - try: - self.config = ConfigObj(config_path) - self.config.update(api_config) - self.services = RequestProvider(self.config, api_endpoints) - except Exception as e: - self.logger.exception("Error loading config file: %s", config_path) - sys.exit(1) + def __init__(self, logger=None, config_path="/etc/airtime/airtime.conf"): + self.logger = logger or logging + + config = Config(filepath=config_path) + self.base_url = config.general.public_url + self.api_key = config.general.api_key + + self.services = RequestProvider( + base_url=self.base_url + self.API_BASE, + api_key=self.api_key, + endpoints=api_endpoints, + ) def get_schedule(self): current_time = datetime.utcnow() diff --git a/api_client/setup.py b/api_client/setup.py index 3cc75c57f..817c443c9 100644 --- a/api_client/setup.py +++ b/api_client/setup.py @@ -22,9 +22,13 @@ setup( packages=["libretime_api_client"], python_requires=">=3.6", install_requires=[ - "configobj", "python-dateutil>=2.7.0", "requests", ], + extras_require={ + "dev": [ + f"libretime-shared @ file://localhost{here.parent / 'shared'}", + ], + }, zip_safe=False, ) diff --git a/api_client/tests/requestprovider_test.py b/api_client/tests/requestprovider_test.py index f8afc9b47..0bcd2fae5 100644 --- a/api_client/tests/requestprovider_test.py +++ b/api_client/tests/requestprovider_test.py @@ -1,34 +1,26 @@ -import pytest - from libretime_api_client.utils import RequestProvider -from libretime_api_client.version1 import api_config -@pytest.fixture() -def config(): - return { - **api_config, - "general": { - "base_dir": "/test", - "base_port": 80, - "base_url": "localhost", - "api_key": "TEST_KEY", - }, - "api_base": "api", - } - - -def test_request_provider_init(config): - request_provider = RequestProvider(config, {}) +def test_request_provider_init(): + request_provider = RequestProvider( + base_url="http://localhost/test", + api_key="test_key", + endpoints={}, + ) assert len(request_provider.available_requests()) == 0 -def test_request_provider_contains(config): +def test_request_provider_contains(): endpoints = { "upload_recorded": "/1/", "update_media_url": "/2/", "list_all_db_files": "/3/", } - request_provider = RequestProvider(config, endpoints) + request_provider = RequestProvider( + base_url="http://localhost/test", + api_key="test_key", + endpoints=endpoints, + ) + for endpoint in endpoints: assert endpoint in request_provider.requests diff --git a/api_client/tests/utils_test.py b/api_client/tests/utils_test.py index 61d4d840e..b8e39931c 100644 --- a/api_client/tests/utils_test.py +++ b/api_client/tests/utils_test.py @@ -1,5 +1,4 @@ import datetime -from configparser import ConfigParser import pytest @@ -16,37 +15,6 @@ def test_time_in_milliseconds(): assert utils.time_in_milliseconds(time) == 500 -@pytest.mark.parametrize( - "payload, expected", - [({}, "http"), ({"base_port": 80}, "http"), ({"base_port": 443}, "https")], -) -@pytest.mark.parametrize( - "use_config", - [False, True], -) -def test_get_protocol(payload, use_config, expected): - config = ConfigParser() if use_config else {} - config["general"] = {**payload} - - assert utils.get_protocol(config) == expected - - -@pytest.mark.parametrize("payload", [{}, {"base_port": 80}]) -@pytest.mark.parametrize("use_config", [False, True]) -@pytest.mark.parametrize( - "values, expected", - [ - (["yes", "Yes", "True", "true", True], "https"), - (["no", "No", "False", "false", False], "http"), - ], -) -def test_get_protocol_force_https(payload, use_config, values, expected): - for value in values: - config = ConfigParser() if use_config else {} - config["general"] = {**payload, "force_ssl": value} - assert utils.get_protocol(config) == expected - - @pytest.mark.parametrize( "payload, expected", [ diff --git a/api_client/tests/version2_test.py b/api_client/tests/version2_test.py index 81162fe48..86de06b8c 100644 --- a/api_client/tests/version2_test.py +++ b/api_client/tests/version2_test.py @@ -1,21 +1,23 @@ +from pathlib import Path + import pytest -from libretime_api_client.utils import RequestProvider -from libretime_api_client.version2 import AirtimeApiClient, api_config +from libretime_api_client.version2 import AirtimeApiClient @pytest.fixture() -def config(): - return { - **api_config, - "general": { - "base_dir": "/test", - "base_port": 80, - "base_url": "localhost", - "api_key": "TEST_KEY", - }, - "api_base": "api", - } +def config_filepath(tmp_path: Path): + filepath = tmp_path / "airtime.conf" + filepath.write_text( + """ +[general] +api_key = TEST_KEY +base_dir = /test +base_port = 80 +base_url = localhost +""" + ) + return filepath class MockRequestProvider: @@ -82,8 +84,8 @@ class MockRequestProvider: } -def test_get_schedule(monkeypatch, config): - client = AirtimeApiClient(None, config) +def test_get_schedule(monkeypatch, config_filepath): + client = AirtimeApiClient(config_path=config_filepath) client.services = MockRequestProvider() schedule = client.get_schedule() assert schedule == {