diff --git a/transip_client/cli.py b/transip_client/cli.py index 2aa1c1b..861f4bc 100644 --- a/transip_client/cli.py +++ b/transip_client/cli.py @@ -3,8 +3,7 @@ import click from transip_client.main import detect -DEFAULT_DNS = "myip.opendns.com" -DEFAULT_DNS_NAME = "@resolver1.opendns.com" +DEFAULT_SERVICE = "https://api.ipify.org" DEFAULT_API_URL = "https://api.transip.nl/v6" @@ -13,11 +12,18 @@ DEFAULT_API_URL = "https://api.transip.nl/v6" @click.option("--token", envvar="TOKEN") @click.option("--login", envvar="LOGIN") @click.option("--private-key-path", envvar="PRIVATE_KEY_PATH") -@click.option("--dns", envvar="DNS", default=DEFAULT_DNS) -@click.option("--dns-name", envvar="DNS_NAME", default=DEFAULT_DNS_NAME) +@click.option("--service", envvar="SERVICE", default=DEFAULT_SERVICE) @click.option("--api-url", envvar="API_URL", default=DEFAULT_API_URL) @click.option("--read-only/--write", envvar="READ_ONLY", default=False) -def run(domains, token, login, private_key_path, dns, dns_name, api_url, read_only): +def run( + domains: list[str], + token: str, + login: str, + private_key_path: str, + service: str, + api_url: str, + read_only: bool, +) -> None: if not domains: raise ValueError("No domain(s) specified") @@ -40,7 +46,7 @@ def run(domains, token, login, private_key_path, dns, dns_name, api_url, read_on detect( domains, - (dns, dns_name), + service, (private_key_path, login), token, api_url, diff --git a/transip_client/main.py b/transip_client/main.py index e4f8d1c..8f1d2ed 100644 --- a/transip_client/main.py +++ b/transip_client/main.py @@ -1,10 +1,10 @@ import base64 import json import logging -import subprocess import time from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Generator import requests @@ -17,19 +17,17 @@ from cryptography.hazmat.primitives.hashes import SHA512 logger = logging.getLogger(__name__) -def _get_ip(resolvers): +def _get_ip(service: str) -> str: try: - output = subprocess.check_output( - ["dig", "+short", *resolvers], - stderr=subprocess.STDOUT, - ) - except subprocess.CalledProcessError as e: - raise OSError("Unable to retrieve current IP") from e + response = requests.get(service, timeout=10) + response.raise_for_status() + except requests.RequestException as e: + raise OSError(f"Unable to retrieve current IP from {service}") from e - return output.decode("utf-8").strip() + return response.text -def _get_token(private_key_path, login, api_url): +def _get_token(private_key_path: str, login: str, api_url: str) -> str: request = requests.Request( "POST", f"{api_url}/auth", @@ -66,13 +64,12 @@ def _get_token(private_key_path, login, api_url): return response_data["token"] -def _get_domain(domain, token, api_url): +def _get_domain(domain: str, token: str, api_url: str) -> requests.Response: headers = {"Authorization": f"Bearer {token}"} - return requests.get(f"{api_url}/domains/{domain}/dns", headers=headers) -def _get_domain_data(domains, token, api_url): +def _get_domain_data(domains: list[str], token: str, api_url: str) -> Generator[dict]: with ThreadPoolExecutor(max_workers=10) as executor: futures = { executor.submit(_get_domain, domain, token, api_url): domain @@ -80,19 +77,21 @@ def _get_domain_data(domains, token, api_url): } for future in as_completed(futures): - response = future.result() domain = futures[future] try: + response = future.result() response.raise_for_status() - except requests.HTTPError as e: + except requests.HTTPError: logger.exception(f"Failed retrieving information for {domain}") continue yield {"domain": domain, **response.json()} -def _update_domain(domain, payload, api_url, token): +def _update_domain( + domain: str, payload: dict, api_url: str, token: str +) -> requests.Response: headers = {"Authorization": f"Bearer {token}"} return requests.put( @@ -100,7 +99,9 @@ def _update_domain(domain, payload, api_url, token): ) -def _update_domains(updated_domains, api_url, token, read_only): +def _update_domains( + updated_domains: dict, api_url: str, token: str, read_only: bool +) -> None: if read_only: return @@ -123,14 +124,21 @@ def _update_domains(updated_domains, api_url, token, read_only): logger.info(f"Updated domain {domain}") -def detect(domains, resolvers, credentials, token, api_url, read_only): - ip = _get_ip(resolvers) - updated_domains = {} +def detect( + domains: list[str], + service: str, + credentials: tuple[str, str], + token: str, + api_url: str, + read_only: bool, +) -> None: + ip = _get_ip(service) if all(credentials): token = _get_token(*credentials, api_url) domain_data = _get_domain_data(domains, token, api_url) + updated_domains = {} for data in domain_data: dns_entries = data["dnsEntries"] diff --git a/transip_client/tests/tests.py b/transip_client/tests/tests.py index 604e471..fc911d8 100644 --- a/transip_client/tests/tests.py +++ b/transip_client/tests/tests.py @@ -1,21 +1,18 @@ import json import os -from unittest import TestCase -from unittest.mock import call, patch +from unittest import TestCase, skip +from unittest.mock import call, patch, Mock from pathlib import Path from click.testing import CliRunner from requests import HTTPError -from transip_client.cli import DEFAULT_API_URL, run +from transip_client.cli import DEFAULT_API_URL, DEFAULT_SERVICE, run class RunTestCase(TestCase): def setUp(self): - patcher = patch("transip_client.main.subprocess.check_output") - self.mocked_dns = patcher.start() - patcher = patch("transip_client.main.requests.get") self.mocked_get = patcher.start() @@ -28,30 +25,36 @@ class RunTestCase(TestCase): self.runner = CliRunner() def test_simple(self): - self.mocked_dns.return_value = b"111.420\n" - self.mocked_get.return_value.json.return_value = { - "dnsEntries": [ - { - "name": "@", - "expire": 60, - "type": "A", - "content": "111.421", + self.mocked_get.side_effect = [ + Mock(text="111.420"), + Mock( + json=lambda: { + "dnsEntries": [ + { + "name": "@", + "expire": 60, + "type": "A", + "content": "111.421", + } + ], + "_links": [ + { + "rel": "self", + "link": "https://api.transip.nl/v6/domains/foobar.com/dns", + }, + { + "rel": "domain", + "link": "https://api.transip.nl/v6/domains/foobar.com", + }, + ], } - ], - "_links": [ - { - "rel": "self", - "link": "https://api.transip.nl/v6/domains/foobar.com/dns", - }, - { - "rel": "domain", - "link": "https://api.transip.nl/v6/domains/foobar.com", - }, - ], - } + ), + ] with self.assertLogs("transip_client.main", level="INFO") as logger: - result = self.runner.invoke(run, ["foobar.com"], env={"TOKEN": "token"}) + result = self.runner.invoke( + run, "foobar.com", env={"TOKEN": "token"}, catch_exceptions=False + ) self.assertEqual( logger.output, ["INFO:transip_client.main:Updated domain foobar.com"] @@ -84,11 +87,12 @@ class RunTestCase(TestCase): ) def test_error_response(self): - self.mocked_dns.return_value = b"111.420\n" - self.mocked_get.return_value.raise_for_status.side_effect = HTTPError + self.mocked_get.side_effect = [Mock(text="111.420"), HTTPError] with self.assertLogs("transip_client.main", level="INFO") as logger: - result = self.runner.invoke(run, ["foobar.com"], env={"TOKEN": "token"}) + result = self.runner.invoke( + run, ["foobar.com"], env={"TOKEN": "token"}, catch_exceptions=False + ) error_log = logger.output[0] @@ -103,27 +107,31 @@ class RunTestCase(TestCase): self.mocked_put.assert_not_called() def test_matching_ip(self): - self.mocked_dns.return_value = b"111.420\n" - self.mocked_get.return_value.json.return_value = { - "dnsEntries": [ - { - "name": "@", - "expire": 60, - "type": "A", - "content": "111.420", + self.mocked_get.side_effect = [ + Mock(text="111.420"), + Mock( + json=lambda: { + "dnsEntries": [ + { + "name": "@", + "expire": 60, + "type": "A", + "content": "111.420", + } + ], + "_links": [ + { + "rel": "self", + "link": "https://api.transip.nl/v6/domains/foobar.com/dns", + }, + { + "rel": "domain", + "link": "https://api.transip.nl/v6/domains/foobar.com", + }, + ], } - ], - "_links": [ - { - "rel": "self", - "link": "https://api.transip.nl/v6/domains/foobar.com/dns", - }, - { - "rel": "domain", - "link": "https://api.transip.nl/v6/domains/foobar.com", - }, - ], - } + ), + ] result = self.runner.invoke(run, ["foobar.com"], env={"TOKEN": "token"}) @@ -137,27 +145,31 @@ class RunTestCase(TestCase): self.mocked_put.assert_not_called() def test_readonly(self): - self.mocked_dns.return_value = b"111.420\n" - self.mocked_get.return_value.json.return_value = { - "dnsEntries": [ - { - "name": "@", - "expire": 60, - "type": "A", - "content": "111.421", + self.mocked_get.side_effect = [ + Mock(text="111.420"), + Mock( + json=lambda: { + "dnsEntries": [ + { + "name": "@", + "expire": 60, + "type": "A", + "content": "111.421", + } + ], + "_links": [ + { + "rel": "self", + "link": "https://api.transip.nl/v6/domains/foobar.com/dns", + }, + { + "rel": "domain", + "link": "https://api.transip.nl/v6/domains/foobar.com", + }, + ], } - ], - "_links": [ - { - "rel": "self", - "link": "https://api.transip.nl/v6/domains/foobar.com/dns", - }, - { - "rel": "domain", - "link": "https://api.transip.nl/v6/domains/foobar.com", - }, - ], - } + ), + ] result = self.runner.invoke( run, ["foobar.com", "--read-only"], env={"TOKEN": "token"} @@ -173,27 +185,31 @@ class RunTestCase(TestCase): self.mocked_put.assert_not_called() def test_different_api_url(self): - self.mocked_dns.return_value = b"111.420\n" - self.mocked_get.return_value.json.return_value = { - "dnsEntries": [ - { - "name": "@", - "expire": 60, - "type": "A", - "content": "111.421", + self.mocked_get.side_effect = [ + Mock(text="111.420"), + Mock( + json=lambda: { + "dnsEntries": [ + { + "name": "@", + "expire": 60, + "type": "A", + "content": "111.421", + } + ], + "_links": [ + { + "rel": "self", + "link": "https://api.transip.nl/v6/domains/foobar.com/dns", + }, + { + "rel": "domain", + "link": "https://api.transip.nl/v6/domains/foobar.com", + }, + ], } - ], - "_links": [ - { - "rel": "self", - "link": "https://api.transip.nl/v6/domains/foobar.com/dns", - }, - { - "rel": "domain", - "link": "https://api.transip.nl/v6/domains/foobar.com", - }, - ], - } + ), + ] with self.assertLogs("transip_client.main", level="INFO") as logger: result = self.runner.invoke( @@ -233,27 +249,31 @@ class RunTestCase(TestCase): ) def test_env_var(self): - self.mocked_dns.return_value = b"111.420\n" - self.mocked_get.return_value.json.return_value = { - "dnsEntries": [ - { - "name": "@", - "expire": 60, - "type": "A", - "content": "111.421", + self.mocked_get.side_effect = [ + Mock(text="111.420"), + Mock( + json=lambda: { + "dnsEntries": [ + { + "name": "@", + "expire": 60, + "type": "A", + "content": "111.421", + } + ], + "_links": [ + { + "rel": "self", + "link": "https://api.transip.nl/v6/domains/foobar.com/dns", + }, + { + "rel": "domain", + "link": "https://api.transip.nl/v6/domains/foobar.com", + }, + ], } - ], - "_links": [ - { - "rel": "self", - "link": "https://api.transip.nl/v6/domains/foobar.com/dns", - }, - { - "rel": "domain", - "link": "https://api.transip.nl/v6/domains/foobar.com", - }, - ], - } + ), + ] with self.assertLogs("transip_client.main", level="INFO") as logger: result = self.runner.invoke( @@ -296,48 +316,52 @@ class RunTestCase(TestCase): ) def test_multi_arg_env_var(self): - self.mocked_dns.return_value = b"111.420\n" - self.mocked_get.return_value.json.side_effect = [ - { - "dnsEntries": [ - { - "name": "@", - "expire": 60, - "type": "A", - "content": "111.421", - } - ], - "_links": [ - { - "rel": "self", - "link": "https://api.transip.nl/v6/domains/foobar.com/dns", - }, - { - "rel": "domain", - "link": "https://api.transip.nl/v6/domains/foobar.com", - }, - ], - }, - { - "dnsEntries": [ - { - "name": "@", - "expire": 60, - "type": "A", - "content": "111.421", - } - ], - "_links": [ - { - "rel": "self", - "link": "https://api.transip.nl/v6/domains/foofoo.com/dns", - }, - { - "rel": "domain", - "link": "https://api.transip.nl/v6/domains/foofoo.com", - }, - ], - }, + self.mocked_get.side_effect = [ + Mock(text="111.420"), + Mock( + json=lambda: { + "dnsEntries": [ + { + "name": "@", + "expire": 60, + "type": "A", + "content": "111.421", + } + ], + "_links": [ + { + "rel": "self", + "link": "https://api.transip.nl/v6/domains/foobar.com/dns", + }, + { + "rel": "domain", + "link": "https://api.transip.nl/v6/domains/foobar.com", + }, + ], + }, + ), + Mock( + json=lambda: { + "dnsEntries": [ + { + "name": "@", + "expire": 60, + "type": "A", + "content": "111.421", + } + ], + "_links": [ + { + "rel": "self", + "link": "https://api.transip.nl/v6/domains/foofoo.com/dns", + }, + { + "rel": "domain", + "link": "https://api.transip.nl/v6/domains/foofoo.com", + }, + ], + } + ), ] with self.assertLogs("transip_client.main", level="INFO") as logger: @@ -357,18 +381,15 @@ class RunTestCase(TestCase): self.assertEqual(result.exit_code, 0) expected_calls = [ + call(DEFAULT_SERVICE, timeout=10), call( f"{DEFAULT_API_URL}/domains/foobar.com/dns", headers={"Authorization": "Bearer token"}, ), - call().raise_for_status(), - call().json(), call( f"{DEFAULT_API_URL}/domains/foofoo.com/dns", headers={"Authorization": "Bearer token"}, ), - call().raise_for_status(), - call().json(), ] # use any_order because of the asynchronous requests @@ -393,13 +414,11 @@ class RunTestCase(TestCase): data=expected_json, headers={"Authorization": "Bearer token"}, ), - call().raise_for_status(), call( f"{DEFAULT_API_URL}/domains/foofoo.com/dns", data=expected_json, headers={"Authorization": "Bearer token"}, ), - call().raise_for_status(), ] self.mocked_put.assert_has_calls(expected_calls, any_order=True) @@ -420,7 +439,7 @@ class RunTestCase(TestCase): self.assertEqual( str(result.exception), "Either a token or a login name with a path to a private key need" - " to be specified" + " to be specified", ) self.mocked_get.assert_not_called() @@ -432,34 +451,38 @@ class RunTestCase(TestCase): self.assertEqual(result.exit_code, 1) self.assertEqual( str(result.exception), - "Both a login name and the path to a private key need to be specified" + "Both a login name and the path to a private key need to be specified", ) self.mocked_get.assert_not_called() self.mocked_put.assert_not_called() def test_login_with_private_key_path(self): - self.mocked_dns.return_value = b"111.420\n" - self.mocked_get.return_value.json.return_value = { - "dnsEntries": [ - { - "name": "@", - "expire": 60, - "type": "A", - "content": "111.421", + self.mocked_get.side_effect = [ + Mock(text="111.420"), + Mock( + json=lambda: { + "dnsEntries": [ + { + "name": "@", + "expire": 60, + "type": "A", + "content": "111.421", + } + ], + "_links": [ + { + "rel": "self", + "link": "https://api.transip.nl/v6/domains/foobar.com/dns", + }, + { + "rel": "domain", + "link": "https://api.transip.nl/v6/domains/foobar.com", + }, + ], } - ], - "_links": [ - { - "rel": "self", - "link": "https://api.transip.nl/v6/domains/foobar.com/dns", - }, - { - "rel": "domain", - "link": "https://api.transip.nl/v6/domains/foobar.com", - }, - ], - } + ), + ] self.mocked_session.return_value.json.return_value = {"token": "FOOBAR"} @@ -471,7 +494,7 @@ class RunTestCase(TestCase): result = self.runner.invoke( run, ["foobar.com"], - env={"LOGIN": "foo", "PRIVATE_KEY_PATH": str(private_key_path)} + env={"LOGIN": "foo", "PRIVATE_KEY_PATH": str(private_key_path)}, ) self.assertEqual(