Begings for Manager.
This commit is contained in:
186
manager/security.py
Normal file
186
manager/security.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
Security layer for Kattila Manager.
|
||||
Handles PSK fetching via DNS TXT, HMAC verification, nonce caching,
|
||||
and timestamp skew checking.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# PSK history: current + 2 previous
|
||||
_psk_lock = threading.Lock()
|
||||
_psk_history: list[str] = [] # most recent first, max 3 entries
|
||||
|
||||
# Nonce sliding window
|
||||
_nonce_lock = threading.Lock()
|
||||
_nonce_cache: OrderedDict[str, float] = OrderedDict()
|
||||
NONCE_CACHE_SIZE = 120
|
||||
|
||||
# Maximum clock skew in seconds
|
||||
MAX_CLOCK_SKEW = 600 # 10 minutes
|
||||
|
||||
|
||||
def _resolve_txt(dns_name: str) -> str | None:
|
||||
"""Resolve DNS TXT record. Uses dnspython if available, falls back to subprocess."""
|
||||
try:
|
||||
import dns.resolver
|
||||
answers = dns.resolver.resolve(dns_name, "TXT")
|
||||
for rdata in answers:
|
||||
txt = rdata.to_text().strip('"')
|
||||
return txt
|
||||
except ImportError:
|
||||
# Fallback: use dig
|
||||
import subprocess
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["dig", "+short", "TXT", dns_name],
|
||||
capture_output=True, text=True, timeout=5
|
||||
)
|
||||
val = result.stdout.strip().strip('"')
|
||||
if val:
|
||||
return val
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error("DNS TXT lookup failed for %s: %s", dns_name, e)
|
||||
return None
|
||||
|
||||
|
||||
def fetch_psk():
|
||||
"""Fetch the PSK from the DNS TXT record and update the history."""
|
||||
dns_name = os.environ.get("DNS", "")
|
||||
if not dns_name:
|
||||
logger.warning("security: No DNS configured for PSK lookup")
|
||||
return
|
||||
|
||||
key = _resolve_txt(dns_name)
|
||||
if not key:
|
||||
logger.warning("security: Could not resolve PSK from DNS")
|
||||
return
|
||||
|
||||
with _psk_lock:
|
||||
if not _psk_history or _psk_history[0] != key:
|
||||
_psk_history.insert(0, key)
|
||||
# Keep at most 3 entries (current + 2 previous)
|
||||
while len(_psk_history) > 3:
|
||||
_psk_history.pop()
|
||||
logger.info("security: PSK updated (history depth: %d)", len(_psk_history))
|
||||
|
||||
|
||||
def start_key_poller():
|
||||
"""Fetch PSK immediately and then every hour in a background thread."""
|
||||
fetch_psk()
|
||||
|
||||
def _poll():
|
||||
while True:
|
||||
time.sleep(3600)
|
||||
fetch_psk()
|
||||
|
||||
t = threading.Thread(target=_poll, daemon=True)
|
||||
t.start()
|
||||
|
||||
|
||||
def get_current_psk() -> str:
|
||||
with _psk_lock:
|
||||
return _psk_history[0] if _psk_history else ""
|
||||
|
||||
|
||||
def get_psk_history() -> list[str]:
|
||||
with _psk_lock:
|
||||
return list(_psk_history)
|
||||
|
||||
|
||||
def fleet_id_for_psk(psk: str) -> str:
|
||||
"""Generate the fleet_id (SHA256 hash) for a given PSK."""
|
||||
return hashlib.sha256(psk.encode()).hexdigest()
|
||||
|
||||
|
||||
def verify_hmac(data_payload, provided_hmac: str) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Verify the HMAC against all known PSKs (current + 2 previous).
|
||||
Returns (valid, matched_psk) or (False, None).
|
||||
"""
|
||||
psks = get_psk_history()
|
||||
if not psks:
|
||||
logger.error("security: No PSKs available for verification")
|
||||
return False, None
|
||||
|
||||
data_bytes = json.dumps(data_payload, separators=(",", ":"),
|
||||
sort_keys=True).encode()
|
||||
|
||||
for psk in psks:
|
||||
expected = hmac.new(psk.encode(), data_bytes, hashlib.sha256).hexdigest()
|
||||
if hmac.compare_digest(expected, provided_hmac):
|
||||
if psk != psks[0]:
|
||||
logger.warning("security: Agent using old PSK (not current)")
|
||||
return True, psk
|
||||
|
||||
return False, None
|
||||
|
||||
|
||||
def check_nonce(nonce: str) -> bool:
|
||||
"""
|
||||
Check if a nonce has been seen before (replay protection).
|
||||
Returns True if the nonce is fresh (not seen), False if it's a replay.
|
||||
"""
|
||||
with _nonce_lock:
|
||||
if nonce in _nonce_cache:
|
||||
return False
|
||||
|
||||
_nonce_cache[nonce] = time.time()
|
||||
|
||||
# Evict oldest entries if over the limit
|
||||
while len(_nonce_cache) > NONCE_CACHE_SIZE:
|
||||
_nonce_cache.popitem(last=False)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_timestamp(timestamp: int) -> bool:
|
||||
"""Check if the timestamp is within the acceptable clock skew."""
|
||||
now = int(time.time())
|
||||
return abs(now - timestamp) <= MAX_CLOCK_SKEW
|
||||
|
||||
|
||||
def validate_report(report: dict) -> tuple[bool, str]:
|
||||
"""
|
||||
Full validation pipeline for an incoming report.
|
||||
Returns (valid, error_message).
|
||||
"""
|
||||
# 1. Timestamp check
|
||||
ts = report.get("timestamp", 0)
|
||||
if not check_timestamp(ts):
|
||||
return False, "timestamp_skew"
|
||||
|
||||
# 2. Nonce replay check
|
||||
nonce = report.get("nonce", "")
|
||||
if not nonce:
|
||||
return False, "missing_nonce"
|
||||
if not check_nonce(nonce):
|
||||
return False, "replay_detected"
|
||||
|
||||
# 3. HMAC verification
|
||||
provided_hmac = report.get("hmac", "")
|
||||
data = report.get("data")
|
||||
if not data or not provided_hmac:
|
||||
return False, "missing_hmac_or_data"
|
||||
|
||||
valid, matched_psk = verify_hmac(data, provided_hmac)
|
||||
if not valid:
|
||||
return False, "hmac_invalid"
|
||||
|
||||
# 4. Fleet ID consistency check
|
||||
if matched_psk:
|
||||
expected_fleet = fleet_id_for_psk(matched_psk)
|
||||
if report.get("fleet_id") != expected_fleet:
|
||||
return False, "fleet_id_mismatch"
|
||||
|
||||
return True, "ok"
|
||||
Reference in New Issue
Block a user