Files
kattila.status/manager/security.py
2026-04-17 20:15:24 +03:00

187 lines
5.2 KiB
Python

"""
Security layer for Kattila Manager.
Handles PSK fetching via DNS TXT, HMAC verification, nonce caching,
and timestamp skew checking.
"""
import hashlib
import hmac
import json
import logging
import os
import threading
import time
from collections import OrderedDict
logger = logging.getLogger(__name__)
# PSK history: current + 2 previous
_psk_lock = threading.Lock()
_psk_history: list[str] = [] # most recent first, max 3 entries
# Nonce sliding window
_nonce_lock = threading.Lock()
_nonce_cache: OrderedDict[str, float] = OrderedDict()
NONCE_CACHE_SIZE = 120
# Maximum clock skew in seconds
MAX_CLOCK_SKEW = 600 # 10 minutes
def _resolve_txt(dns_name: str) -> str | None:
"""Resolve DNS TXT record. Uses dnspython if available, falls back to subprocess."""
try:
import dns.resolver
answers = dns.resolver.resolve(dns_name, "TXT")
for rdata in answers:
txt = rdata.to_text().strip('"')
return txt
except ImportError:
# Fallback: use dig
import subprocess
try:
result = subprocess.run(
["dig", "+short", "TXT", dns_name],
capture_output=True, text=True, timeout=5
)
val = result.stdout.strip().strip('"')
if val:
return val
except (subprocess.TimeoutExpired, FileNotFoundError):
pass
except Exception as e:
logger.error("DNS TXT lookup failed for %s: %s", dns_name, e)
return None
def fetch_psk():
"""Fetch the PSK from the DNS TXT record and update the history."""
dns_name = os.environ.get("DNS", "")
if not dns_name:
logger.warning("security: No DNS configured for PSK lookup")
return
key = _resolve_txt(dns_name)
if not key:
logger.warning("security: Could not resolve PSK from DNS")
return
with _psk_lock:
if not _psk_history or _psk_history[0] != key:
_psk_history.insert(0, key)
# Keep at most 3 entries (current + 2 previous)
while len(_psk_history) > 3:
_psk_history.pop()
logger.info("security: PSK updated (history depth: %d)", len(_psk_history))
def start_key_poller():
"""Fetch PSK immediately and then every hour in a background thread."""
fetch_psk()
def _poll():
while True:
time.sleep(3600)
fetch_psk()
t = threading.Thread(target=_poll, daemon=True)
t.start()
def get_current_psk() -> str:
with _psk_lock:
return _psk_history[0] if _psk_history else ""
def get_psk_history() -> list[str]:
with _psk_lock:
return list(_psk_history)
def fleet_id_for_psk(psk: str) -> str:
"""Generate the fleet_id (SHA256 hash) for a given PSK."""
return hashlib.sha256(psk.encode()).hexdigest()
def verify_hmac(data_payload, provided_hmac: str) -> tuple[bool, str | None]:
"""
Verify the HMAC against all known PSKs (current + 2 previous).
Returns (valid, matched_psk) or (False, None).
"""
psks = get_psk_history()
if not psks:
logger.error("security: No PSKs available for verification")
return False, None
data_bytes = json.dumps(data_payload, separators=(",", ":"),
sort_keys=True).encode()
for psk in psks:
expected = hmac.new(psk.encode(), data_bytes, hashlib.sha256).hexdigest()
if hmac.compare_digest(expected, provided_hmac):
if psk != psks[0]:
logger.warning("security: Agent using old PSK (not current)")
return True, psk
return False, None
def check_nonce(nonce: str) -> bool:
"""
Check if a nonce has been seen before (replay protection).
Returns True if the nonce is fresh (not seen), False if it's a replay.
"""
with _nonce_lock:
if nonce in _nonce_cache:
return False
_nonce_cache[nonce] = time.time()
# Evict oldest entries if over the limit
while len(_nonce_cache) > NONCE_CACHE_SIZE:
_nonce_cache.popitem(last=False)
return True
def check_timestamp(timestamp: int) -> bool:
"""Check if the timestamp is within the acceptable clock skew."""
now = int(time.time())
return abs(now - timestamp) <= MAX_CLOCK_SKEW
def validate_report(report: dict) -> tuple[bool, str]:
"""
Full validation pipeline for an incoming report.
Returns (valid, error_message).
"""
# 1. Timestamp check
ts = report.get("timestamp", 0)
if not check_timestamp(ts):
return False, "timestamp_skew"
# 2. Nonce replay check
nonce = report.get("nonce", "")
if not nonce:
return False, "missing_nonce"
if not check_nonce(nonce):
return False, "replay_detected"
# 3. HMAC verification
provided_hmac = report.get("hmac", "")
data = report.get("data")
if not data or not provided_hmac:
return False, "missing_hmac_or_data"
valid, matched_psk = verify_hmac(data, provided_hmac)
if not valid:
return False, "hmac_invalid"
# 4. Fleet ID consistency check
if matched_psk:
expected_fleet = fleet_id_for_psk(matched_psk)
if report.get("fleet_id") != expected_fleet:
return False, "fleet_id_mismatch"
return True, "ok"