276 lines
9.5 KiB
Python
276 lines
9.5 KiB
Python
"""
|
|
kattila Agent
|
|
- Reports node state to manager (direct or via WG peer relay)
|
|
- Runs a relay server for segregated peers
|
|
- Authenticates with a pre-shared key (file, env, or DNS TXT)
|
|
"""
|
|
|
|
import subprocess
|
|
import json
|
|
import requests
|
|
import socket
|
|
import time
|
|
import threading
|
|
import os
|
|
import logging
|
|
|
|
from flask import Flask, request as freq, jsonify
|
|
|
|
# ── CONFIG ────────────────────────────────────────────────────────────────────
|
|
|
|
MANAGER_URL = os.environ.get("MANAGER_URL", "http://10.37.11.2:5086/status/api/report")
|
|
RELAY_PORT = int(os.environ.get("RELAY_PORT", "5087"))
|
|
REPORT_INTERVAL = int(os.environ.get("REPORT_INTERVAL", "60"))
|
|
PSK_FILE = os.environ.get("PSK_FILE", "/etc/kattila.status/psk")
|
|
PSK_DNS_RECORD = os.environ.get("PSK_DNS_RECORD", "") # e.g. "_kattila.homelab.local"
|
|
|
|
|
|
# ── PSK LOADING ───────────────────────────────────────────────────────────────
|
|
|
|
def _load_psk_dns(record: str) -> str | None:
|
|
"""
|
|
Fetch PSK from a DNS TXT record.
|
|
|
|
Store the key in your local/private DNS zone as:
|
|
_kattila.homelab.local. TXT "psk=your-key-here"
|
|
|
|
This lets you rotate the key on all nodes without touching any config file —
|
|
just update the DNS record. Agents re-fetch on each startup (cached to disk
|
|
as fallback below). Only use a private/internal zone — never a public one.
|
|
"""
|
|
try:
|
|
out = subprocess.check_output(
|
|
["dig", "+short", "TXT", record],
|
|
stderr=subprocess.DEVNULL, timeout=5
|
|
).decode()
|
|
for line in out.splitlines():
|
|
clean = line.strip().strip('"')
|
|
if clean.startswith("psk="):
|
|
return clean[4:]
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
def load_psk() -> str:
|
|
key: str | None = None
|
|
|
|
# 1. DNS TXT (if configured) — good for zero-touch key rotation
|
|
if PSK_DNS_RECORD:
|
|
key = _load_psk_dns(PSK_DNS_RECORD)
|
|
if key:
|
|
# Cache to disk for offline starts
|
|
try:
|
|
os.makedirs(os.path.dirname(PSK_FILE), exist_ok=True)
|
|
with open(PSK_FILE, "w") as f:
|
|
f.write(key)
|
|
except Exception:
|
|
pass
|
|
|
|
# 2. File (also DNS cache fallback)
|
|
if not key and os.path.isfile(PSK_FILE):
|
|
with open(PSK_FILE) as f:
|
|
key = f.read().strip()
|
|
|
|
# 3. Environment variable
|
|
if not key:
|
|
key = os.environ.get("kattila_PSK", "")
|
|
|
|
if not key:
|
|
raise RuntimeError(
|
|
"No PSK found. Set kattila_PSK env var, write to "
|
|
f"{PSK_FILE}, or configure PSK_DNS_RECORD."
|
|
)
|
|
return key
|
|
|
|
|
|
PSK = load_psk()
|
|
AUTH_HEADERS = {"X-kattila-PSK": PSK, "Content-Type": "application/json"}
|
|
|
|
|
|
# ── RELAY SERVER ──────────────────────────────────────────────────────────────
|
|
|
|
relay_app = Flask("kattila-relay")
|
|
logging.getLogger("werkzeug").setLevel(logging.ERROR)
|
|
|
|
|
|
@relay_app.route("/relay", methods=["POST"])
|
|
def relay():
|
|
if freq.headers.get("X-kattila-PSK") != PSK:
|
|
return jsonify({"error": "unauthorized"}), 401
|
|
try:
|
|
resp = requests.post(
|
|
MANAGER_URL, json=freq.get_json(), headers=AUTH_HEADERS, timeout=10
|
|
)
|
|
return jsonify(resp.json()), resp.status_code
|
|
except Exception as e:
|
|
return jsonify({"error": str(e)}), 502
|
|
|
|
|
|
@relay_app.route("/health")
|
|
def health():
|
|
return jsonify({"status": "ok"})
|
|
|
|
|
|
def _run_relay():
|
|
relay_app.run(host="0.0.0.0", port=RELAY_PORT, use_reloader=False)
|
|
|
|
|
|
# ── WIREGUARD HELPERS ─────────────────────────────────────────────────────────
|
|
|
|
def get_wg_interface_names() -> set[str]:
|
|
"""Return the names of actual WireGuard interfaces on this host."""
|
|
try:
|
|
out = subprocess.check_output(
|
|
["wg", "show", "interfaces"], stderr=subprocess.DEVNULL
|
|
).decode().strip()
|
|
return set(out.split()) if out else set()
|
|
except Exception:
|
|
return set()
|
|
|
|
|
|
def parse_wg_dump() -> list[dict]:
|
|
"""
|
|
Parse 'wg show all dump'. Peer lines have 9 tab-separated fields:
|
|
ifname pubkey psk endpoint allowed_ips handshake rx tx keepalive
|
|
Interface lines have 5 fields (private_key, public_key, listen_port, fwmark)
|
|
and are skipped.
|
|
"""
|
|
peers = []
|
|
try:
|
|
raw = subprocess.check_output(
|
|
["wg", "show", "all", "dump"], stderr=subprocess.DEVNULL
|
|
).decode().strip()
|
|
for line in raw.splitlines():
|
|
parts = line.split("\t")
|
|
if len(parts) == 9: # peer line
|
|
ifname, pubkey, _psk, endpoint, allowed_ips, handshake_ts, *_ = parts
|
|
handshake_ts = int(handshake_ts)
|
|
idle = 0 if handshake_ts == 0 else int(time.time()) - handshake_ts
|
|
peers.append({
|
|
"ifname": ifname,
|
|
"pubkey": pubkey,
|
|
"allowed_ips": allowed_ips,
|
|
"handshake": idle,
|
|
"status": "ok" if 0 < idle < 120 else "stale",
|
|
})
|
|
except Exception as e:
|
|
print(f"[!] WireGuard dump error: {e}")
|
|
return peers
|
|
|
|
|
|
def get_peer_relay_ips() -> list[str]:
|
|
"""
|
|
Extract the /32 tunnel IPs of WireGuard peers — used to find relay candidates
|
|
when direct manager reporting fails.
|
|
"""
|
|
ips = []
|
|
for peer in parse_wg_dump():
|
|
for cidr in peer["allowed_ips"].split(","):
|
|
cidr = cidr.strip()
|
|
if "." in cidr and cidr.endswith("/32"):
|
|
ips.append(cidr[:-3])
|
|
return ips
|
|
|
|
|
|
# ── DATA COLLECTION ───────────────────────────────────────────────────────────
|
|
|
|
def get_data() -> dict:
|
|
hostname = socket.gethostname()
|
|
wg_ifaces = get_wg_interface_names()
|
|
ifaces = []
|
|
|
|
try:
|
|
ip_data = json.loads(
|
|
subprocess.check_output(
|
|
["ip", "-j", "addr"], stderr=subprocess.DEVNULL
|
|
).decode()
|
|
)
|
|
for item in ip_data:
|
|
name = item["ifname"]
|
|
if name == "lo" or name.startswith("br-") or name.startswith("docker"):
|
|
continue
|
|
addr = next(
|
|
(a["local"] for a in item.get("addr_info", []) if a["family"] == "inet"),
|
|
None,
|
|
)
|
|
if not addr:
|
|
continue
|
|
|
|
entry: dict = {"name": name, "ip": addr}
|
|
|
|
# Only query WireGuard for actual WG interfaces — avoids
|
|
# "Unable to access interface: Operation not supported" spam
|
|
if name in wg_ifaces:
|
|
try:
|
|
pubkey = subprocess.check_output(
|
|
["wg", "show", name, "public-key"], stderr=subprocess.DEVNULL
|
|
).decode().strip()
|
|
entry["public_key"] = pubkey
|
|
except Exception:
|
|
pass
|
|
|
|
ifaces.append(entry)
|
|
|
|
except Exception as e:
|
|
print(f"[!] Interface collection error: {e}")
|
|
|
|
return {
|
|
"hostname": hostname,
|
|
"interfaces": ifaces,
|
|
"wg_peers": parse_wg_dump(),
|
|
}
|
|
|
|
|
|
# ── REPORTING ─────────────────────────────────────────────────────────────────
|
|
|
|
def report(payload: dict) -> bool:
|
|
hostname = payload["hostname"]
|
|
|
|
# 1. Direct report to manager
|
|
try:
|
|
r = requests.post(MANAGER_URL, json=payload, headers=AUTH_HEADERS, timeout=10)
|
|
if r.status_code == 200:
|
|
print(f"[+] {hostname} → manager (direct)")
|
|
return True
|
|
print(f"[!] Manager returned {r.status_code}")
|
|
except Exception as e:
|
|
print(f"[!] Direct report failed: {e}")
|
|
|
|
# 2. Relay through reachable WireGuard peers
|
|
candidates = get_peer_relay_ips()
|
|
if not candidates:
|
|
print("[-] No relay candidates found")
|
|
return False
|
|
|
|
for ip in candidates:
|
|
url = f"http://{ip}:{RELAY_PORT}/relay"
|
|
try:
|
|
r = requests.post(url, json=payload, headers=AUTH_HEADERS, timeout=8)
|
|
if r.status_code == 200:
|
|
print(f"[+] {hostname} → relay {ip}")
|
|
return True
|
|
print(f"[!] Relay {ip} returned {r.status_code}")
|
|
except Exception as e:
|
|
print(f"[!] Relay {ip} unreachable: {e}")
|
|
|
|
print(f"[-] All reporting paths failed for {hostname}")
|
|
return False
|
|
|
|
|
|
# ── MAIN ──────────────────────────────────────────────────────────────────────
|
|
|
|
if __name__ == "__main__":
|
|
t = threading.Thread(target=_run_relay, daemon=True)
|
|
t.start()
|
|
print(f"[*] Relay server listening on :{RELAY_PORT}")
|
|
print(f"[*] Reporting to {MANAGER_URL} every {REPORT_INTERVAL}s")
|
|
|
|
while True:
|
|
try:
|
|
payload = get_data()
|
|
report(payload)
|
|
except Exception as e:
|
|
print(f"[!] Unexpected error: {e}")
|
|
time.sleep(REPORT_INTERVAL)
|