Replace dig usage

This commit is contained in:
sonny 2025-01-05 17:08:28 +01:00
parent 400792922b
commit 780dd5a624
3 changed files with 246 additions and 207 deletions

View file

@ -1,10 +1,10 @@
import base64
import json
import logging
import subprocess
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Generator
import requests
@ -17,19 +17,17 @@ from cryptography.hazmat.primitives.hashes import SHA512
logger = logging.getLogger(__name__)
def _get_ip(resolvers):
def _get_ip(service: str) -> str:
try:
output = subprocess.check_output(
["dig", "+short", *resolvers],
stderr=subprocess.STDOUT,
)
except subprocess.CalledProcessError as e:
raise OSError("Unable to retrieve current IP") from e
response = requests.get(service, timeout=10)
response.raise_for_status()
except requests.RequestException as e:
raise OSError(f"Unable to retrieve current IP from {service}") from e
return output.decode("utf-8").strip()
return response.text
def _get_token(private_key_path, login, api_url):
def _get_token(private_key_path: str, login: str, api_url: str) -> str:
request = requests.Request(
"POST",
f"{api_url}/auth",
@ -66,13 +64,14 @@ def _get_token(private_key_path, login, api_url):
return response_data["token"]
def _get_domain(domain, token, api_url):
def _get_domain(domain: str, token: str, api_url: str) -> requests.Response:
headers = {"Authorization": f"Bearer {token}"}
return requests.get(f"{api_url}/domains/{domain}/dns", headers=headers)
def _get_domain_data(domains, token, api_url):
def _get_domain_data(
domains: list[str], token: str, api_url: str
) -> Generator[dict, None, None]:
with ThreadPoolExecutor(max_workers=10) as executor:
futures = {
executor.submit(_get_domain, domain, token, api_url): domain
@ -80,19 +79,21 @@ def _get_domain_data(domains, token, api_url):
}
for future in as_completed(futures):
response = future.result()
domain = futures[future]
try:
response = future.result()
response.raise_for_status()
except requests.HTTPError as e:
except requests.HTTPError:
logger.exception(f"Failed retrieving information for {domain}")
continue
yield {"domain": domain, **response.json()}
def _update_domain(domain, payload, api_url, token):
def _update_domain(
domain: str, payload: dict, api_url: str, token: str
) -> requests.Response:
headers = {"Authorization": f"Bearer {token}"}
return requests.put(
@ -100,7 +101,9 @@ def _update_domain(domain, payload, api_url, token):
)
def _update_domains(updated_domains, api_url, token, read_only):
def _update_domains(
updated_domains: dict, api_url: str, token: str, read_only: bool
) -> None:
if read_only:
return
@ -123,14 +126,21 @@ def _update_domains(updated_domains, api_url, token, read_only):
logger.info(f"Updated domain {domain}")
def detect(domains, resolvers, credentials, token, api_url, read_only):
ip = _get_ip(resolvers)
updated_domains = {}
def detect(
domains: list[str],
service: str,
credentials: tuple[str, str],
token: str,
api_url: str,
read_only: bool,
) -> None:
ip = _get_ip(service)
if all(credentials):
token = _get_token(*credentials, api_url)
domain_data = _get_domain_data(domains, token, api_url)
updated_domains = {}
for data in domain_data:
dns_entries = data["dnsEntries"]