187 lines
5.2 KiB
Python
187 lines
5.2 KiB
Python
"""
|
|
Security layer for Kattila Manager.
|
|
Handles PSK fetching via DNS TXT, HMAC verification, nonce caching,
|
|
and timestamp skew checking.
|
|
"""
|
|
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import logging
|
|
import os
|
|
import threading
|
|
import time
|
|
from collections import OrderedDict
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# PSK history: current + 2 previous
|
|
_psk_lock = threading.Lock()
|
|
_psk_history: list[str] = [] # most recent first, max 3 entries
|
|
|
|
# Nonce sliding window
|
|
_nonce_lock = threading.Lock()
|
|
_nonce_cache: OrderedDict[str, float] = OrderedDict()
|
|
NONCE_CACHE_SIZE = 120
|
|
|
|
# Maximum clock skew in seconds
|
|
MAX_CLOCK_SKEW = 600 # 10 minutes
|
|
|
|
|
|
def _resolve_txt(dns_name: str) -> str | None:
|
|
"""Resolve DNS TXT record. Uses dnspython if available, falls back to subprocess."""
|
|
try:
|
|
import dns.resolver
|
|
answers = dns.resolver.resolve(dns_name, "TXT")
|
|
for rdata in answers:
|
|
txt = rdata.to_text().strip('"')
|
|
return txt
|
|
except ImportError:
|
|
# Fallback: use dig
|
|
import subprocess
|
|
try:
|
|
result = subprocess.run(
|
|
["dig", "+short", "TXT", dns_name],
|
|
capture_output=True, text=True, timeout=5
|
|
)
|
|
val = result.stdout.strip().strip('"')
|
|
if val:
|
|
return val
|
|
except (subprocess.TimeoutExpired, FileNotFoundError):
|
|
pass
|
|
except Exception as e:
|
|
logger.error("DNS TXT lookup failed for %s: %s", dns_name, e)
|
|
return None
|
|
|
|
|
|
def fetch_psk():
|
|
"""Fetch the PSK from the DNS TXT record and update the history."""
|
|
dns_name = os.environ.get("DNS", "")
|
|
if not dns_name:
|
|
logger.warning("security: No DNS configured for PSK lookup")
|
|
return
|
|
|
|
key = _resolve_txt(dns_name)
|
|
if not key:
|
|
logger.warning("security: Could not resolve PSK from DNS")
|
|
return
|
|
|
|
with _psk_lock:
|
|
if not _psk_history or _psk_history[0] != key:
|
|
_psk_history.insert(0, key)
|
|
# Keep at most 3 entries (current + 2 previous)
|
|
while len(_psk_history) > 3:
|
|
_psk_history.pop()
|
|
logger.info("security: PSK updated (history depth: %d)", len(_psk_history))
|
|
|
|
|
|
def start_key_poller():
|
|
"""Fetch PSK immediately and then every hour in a background thread."""
|
|
fetch_psk()
|
|
|
|
def _poll():
|
|
while True:
|
|
time.sleep(3600)
|
|
fetch_psk()
|
|
|
|
t = threading.Thread(target=_poll, daemon=True)
|
|
t.start()
|
|
|
|
|
|
def get_current_psk() -> str:
|
|
with _psk_lock:
|
|
return _psk_history[0] if _psk_history else ""
|
|
|
|
|
|
def get_psk_history() -> list[str]:
|
|
with _psk_lock:
|
|
return list(_psk_history)
|
|
|
|
|
|
def fleet_id_for_psk(psk: str) -> str:
|
|
"""Generate the fleet_id (SHA256 hash) for a given PSK."""
|
|
return hashlib.sha256(psk.encode()).hexdigest()
|
|
|
|
|
|
def verify_hmac(data_payload, provided_hmac: str) -> tuple[bool, str | None]:
|
|
"""
|
|
Verify the HMAC against all known PSKs (current + 2 previous).
|
|
Returns (valid, matched_psk) or (False, None).
|
|
"""
|
|
psks = get_psk_history()
|
|
if not psks:
|
|
logger.error("security: No PSKs available for verification")
|
|
return False, None
|
|
|
|
data_bytes = json.dumps(data_payload, separators=(",", ":"),
|
|
sort_keys=True).encode()
|
|
|
|
for psk in psks:
|
|
expected = hmac.new(psk.encode(), data_bytes, hashlib.sha256).hexdigest()
|
|
if hmac.compare_digest(expected, provided_hmac):
|
|
if psk != psks[0]:
|
|
logger.warning("security: Agent using old PSK (not current)")
|
|
return True, psk
|
|
|
|
return False, None
|
|
|
|
|
|
def check_nonce(nonce: str) -> bool:
|
|
"""
|
|
Check if a nonce has been seen before (replay protection).
|
|
Returns True if the nonce is fresh (not seen), False if it's a replay.
|
|
"""
|
|
with _nonce_lock:
|
|
if nonce in _nonce_cache:
|
|
return False
|
|
|
|
_nonce_cache[nonce] = time.time()
|
|
|
|
# Evict oldest entries if over the limit
|
|
while len(_nonce_cache) > NONCE_CACHE_SIZE:
|
|
_nonce_cache.popitem(last=False)
|
|
|
|
return True
|
|
|
|
|
|
def check_timestamp(timestamp: int) -> bool:
|
|
"""Check if the timestamp is within the acceptable clock skew."""
|
|
now = int(time.time())
|
|
return abs(now - timestamp) <= MAX_CLOCK_SKEW
|
|
|
|
|
|
def validate_report(report: dict) -> tuple[bool, str]:
|
|
"""
|
|
Full validation pipeline for an incoming report.
|
|
Returns (valid, error_message).
|
|
"""
|
|
# 1. Timestamp check
|
|
ts = report.get("timestamp", 0)
|
|
if not check_timestamp(ts):
|
|
return False, "timestamp_skew"
|
|
|
|
# 2. Nonce replay check
|
|
nonce = report.get("nonce", "")
|
|
if not nonce:
|
|
return False, "missing_nonce"
|
|
if not check_nonce(nonce):
|
|
return False, "replay_detected"
|
|
|
|
# 3. HMAC verification
|
|
provided_hmac = report.get("hmac", "")
|
|
data = report.get("data")
|
|
if not data or not provided_hmac:
|
|
return False, "missing_hmac_or_data"
|
|
|
|
valid, matched_psk = verify_hmac(data, provided_hmac)
|
|
if not valid:
|
|
return False, "hmac_invalid"
|
|
|
|
# 4. Fleet ID consistency check
|
|
if matched_psk:
|
|
expected_fleet = fleet_id_for_psk(matched_psk)
|
|
if report.get("fleet_id") != expected_fleet:
|
|
return False, "fleet_id_mismatch"
|
|
|
|
return True, "ok"
|