""" Security layer for Kattila Manager. Handles PSK fetching via DNS TXT, HMAC verification, nonce caching, and timestamp skew checking. """ import hashlib import hmac import json import logging import os import threading import time from collections import OrderedDict logger = logging.getLogger(__name__) # PSK history: current + 2 previous _psk_lock = threading.Lock() _psk_history: list[str] = [] # most recent first, max 3 entries # Nonce sliding window _nonce_lock = threading.Lock() _nonce_cache: OrderedDict[str, float] = OrderedDict() NONCE_CACHE_SIZE = 120 # Maximum clock skew in seconds MAX_CLOCK_SKEW = 600 # 10 minutes def _resolve_txt(dns_name: str) -> str | None: """Resolve DNS TXT record. Uses dnspython if available, falls back to subprocess.""" try: import dns.resolver answers = dns.resolver.resolve(dns_name, "TXT") for rdata in answers: txt = rdata.to_text().strip('"') return txt except ImportError: # Fallback: use dig import subprocess try: result = subprocess.run( ["dig", "+short", "TXT", dns_name], capture_output=True, text=True, timeout=5 ) val = result.stdout.strip().strip('"') if val: return val except (subprocess.TimeoutExpired, FileNotFoundError): pass except Exception as e: logger.error("DNS TXT lookup failed for %s: %s", dns_name, e) return None def fetch_psk(): """Fetch the PSK from the DNS TXT record and update the history.""" dns_name = os.environ.get("DNS", "") if not dns_name: logger.warning("security: No DNS configured for PSK lookup") return key = _resolve_txt(dns_name) if not key: logger.warning("security: Could not resolve PSK from DNS") return with _psk_lock: if not _psk_history or _psk_history[0] != key: _psk_history.insert(0, key) # Keep at most 3 entries (current + 2 previous) while len(_psk_history) > 3: _psk_history.pop() logger.info("security: PSK updated (history depth: %d)", len(_psk_history)) def start_key_poller(): """Fetch PSK immediately and then every hour in a background thread.""" fetch_psk() def _poll(): while True: time.sleep(3600) fetch_psk() t = threading.Thread(target=_poll, daemon=True) t.start() def get_current_psk() -> str: with _psk_lock: return _psk_history[0] if _psk_history else "" def get_psk_history() -> list[str]: with _psk_lock: return list(_psk_history) def fleet_id_for_psk(psk: str) -> str: """Generate the fleet_id (SHA256 hash) for a given PSK.""" return hashlib.sha256(psk.encode()).hexdigest() def verify_hmac(data_payload, provided_hmac: str) -> tuple[bool, str | None]: """ Verify the HMAC against all known PSKs (current + 2 previous). Returns (valid, matched_psk) or (False, None). """ psks = get_psk_history() if not psks: logger.error("security: No PSKs available for verification") return False, None data_bytes = json.dumps(data_payload, separators=(",", ":"), sort_keys=True).encode() for psk in psks: expected = hmac.new(psk.encode(), data_bytes, hashlib.sha256).hexdigest() if hmac.compare_digest(expected, provided_hmac): if psk != psks[0]: logger.warning("security: Agent using old PSK (not current)") return True, psk return False, None def check_nonce(nonce: str) -> bool: """ Check if a nonce has been seen before (replay protection). Returns True if the nonce is fresh (not seen), False if it's a replay. """ with _nonce_lock: if nonce in _nonce_cache: return False _nonce_cache[nonce] = time.time() # Evict oldest entries if over the limit while len(_nonce_cache) > NONCE_CACHE_SIZE: _nonce_cache.popitem(last=False) return True def check_timestamp(timestamp: int) -> bool: """Check if the timestamp is within the acceptable clock skew.""" now = int(time.time()) return abs(now - timestamp) <= MAX_CLOCK_SKEW def validate_report(report: dict) -> tuple[bool, str]: """ Full validation pipeline for an incoming report. Returns (valid, error_message). """ # 1. Timestamp check ts = report.get("timestamp", 0) if not check_timestamp(ts): return False, "timestamp_skew" # 2. Nonce replay check nonce = report.get("nonce", "") if not nonce: return False, "missing_nonce" if not check_nonce(nonce): return False, "replay_detected" # 3. HMAC verification provided_hmac = report.get("hmac", "") data = report.get("data") if not data or not provided_hmac: return False, "missing_hmac_or_data" valid, matched_psk = verify_hmac(data, provided_hmac) if not valid: return False, "hmac_invalid" # 4. Fleet ID consistency check if matched_psk: expected_fleet = fleet_id_for_psk(matched_psk) if report.get("fleet_id") != expected_fleet: return False, "fleet_id_mismatch" return True, "ok"