From d0cddec45a93bc70ed3d81449491eaccc84d34d3 Mon Sep 17 00:00:00 2001 From: Kalzu Rekku Date: Fri, 17 Apr 2026 20:15:24 +0300 Subject: [PATCH] Begings for Manager. --- .gitignore | 3 + .../# Kattila Manager Implementation Plan.md | 51 +++ manager/app.py | 383 +++++++++++++++++- manager/db.py | 308 ++++++++++++++ manager/processor.py | 105 +++++ manager/requirements.txt | 1 + manager/security.py | 186 +++++++++ 7 files changed, 1033 insertions(+), 4 deletions(-) create mode 100644 .gitignore create mode 100644 manager/# Kattila Manager Implementation Plan.md create mode 100644 manager/db.py create mode 100644 manager/processor.py create mode 100644 manager/security.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..99a9735 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +examples +.env +.vscode diff --git a/manager/# Kattila Manager Implementation Plan.md b/manager/# Kattila Manager Implementation Plan.md new file mode 100644 index 0000000..04ad02d --- /dev/null +++ b/manager/# Kattila Manager Implementation Plan.md @@ -0,0 +1,51 @@ +# Kattila Manager Implementation Plan + +This document outlines the detailed architecture and implementation steps for the Python-based Kattila Manager. + +## Overview +The Manager is a Python/Flask application that maintains a centralized SQLite (WAL mode) database. It provides an HTTP API to receive pushed reports from the agents, securely verifies their HMAC-SHA256 signatures, prevents replay attacks using a nonce sliding window cache, and updates the local network topology and alarm states based on the received data. + +Look at the kattila.poc for ideas how to implement ip address anonymization. And tips on how the map chould be drawn. + +## Proposed Architecture / Modules + +### 1. Database Layer (`db.py`) +- Initializes an `sqlite3` connection with `PRAGMA journal_mode=WAL;`. +- Automatically executes the `CREATE TABLE` and `CREATE INDEX` SQL schemas defined in the DESIGN document on startup. +- Exposes structured data access methods for other modules (e.g., `upsert_agent`, `insert_report`, `update_interfaces`, `update_edges`, `create_alarm`). + +### 2. Security Layer (`security.py`) +- **Key Fetching**: A background thread or periodic polling function that utilizes Python's DNS resolver to get the Bootstrap PSK from the TXT record, keeping track of the current PSK and the two previous PSKs. +- **HMAC Verification**: Parses incoming JSON, re-serializes the `data` payload identically, and checks if the provided HMAC matches one of the known PSKs. +- **Nonce Cache**: Maintains a memory-bound cache (e.g., `collections.OrderedDict`) of the last 120 nonces to prevent replay attacks. +- **Time Skew**: Rejects reports whose `timestamp` deviates by more than 10 minutes from the Manager's local clock. + +### 3. Data Processor (`processor.py`) +This is the core business logic engine invoked whenever a valid `/status/updates` payload hits the API: +- **Agents**: Upsert the `agent_id` into the `agents` table and update the `last_seen_at` heartbeat. +- **Reports**: Store the raw envelope in `reports` for auditing. +- **Interfaces**: Compare the payload's `interfaces` against `agent_interfaces`. If new interfaces appear or old ones disappear, update the DB and potentially trigger an alarm (e.g., "Interface eth0 went down"). +- **Topology Edges**: Iterate over `wg_peers`. For each peer, create or update a link in `topology_edges` specifying `edge_type='wireguard'`. + +### 4. API Layer (`api.py` or `app.py`) +- A Flask Blueprint or App defining: + - `POST /status/updates`: Main ingress. Parses JSON -> Verifies HMAC & Nonce -> Calls Processor -> Returns OK. Unwraps `relay_path` envelopes iteratively if needed. + - `POST /status/register`: Allows new agents to announce their generated ID. + - `GET /status/healthcheck`: Returns `{status: ok}`. + - `GET /status/alarms`: JSON list of active alarms. + - `GET /status/agents`: JSON dump of the fleet matrix. + - `POST /status/admin/reset`: Clears specific agent topology state. + +## User Review Required +> [!IMPORTANT] +> - Since Python's standard library doesn't organically support fetching DNS TXT records easily, I plan to add `dnspython` to `requirements.txt`. Is this acceptable? +> - The agent successfully generates its own secure hexadecimal `agent_id` locally. Instead of the Manager strictly mandating `/status/register` before everything else, is it acceptable for the Manager to dynamically "auto-register" (upsert) unknown `agent_id`s directly when they push a valid `/status/updates` report? (It simplifies bootstrapping considerably). +> - When generating alarms, should we just log simple messages like "Interface X disappeared" and keep the alarm `active` until a human clears it, or should the alarms auto-dismiss when the issue resolves (e.g., interface comes back)? + +## Verification Plan +### Automated testing +- Run basic `pytest` (if available) or dummy scripts pushing forged payloads and ensuring the security layer rejects invalid HMACs and duplicate Nonces. +### Manual Verification +- Start the Flask app. +- Hit `/status/healthcheck` with curl. +- Send a mock successful JSON representation of the `wg_peers` and `interfaces` using exactly the PSK from the test `.env`. Check that `kattila_manager.db` correctly generated the relational graph. diff --git a/manager/app.py b/manager/app.py index 68a1359..2d8ae1e 100644 --- a/manager/app.py +++ b/manager/app.py @@ -1,10 +1,385 @@ -from flask import Flask, jsonify +""" +Kattila Manager — Flask API and Web UI. +Receives agent reports, serves fleet status, and renders the network topology map. +""" + +import hashlib +import ipaddress +import json +import logging +import os +import time + +from flask import Flask, jsonify, request, render_template_string + +import db +import security +import processor + +# ── Logging ────────────────────────────────────────────────────────────────── + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" +) +logger = logging.getLogger(__name__) + +# ── App init ───────────────────────────────────────────────────────────────── app = Flask(__name__) -@app.route('/status/healthcheck') + +def load_env(): + """Load .env file from parent directory.""" + paths = [".env", "../.env"] + for p in paths: + if os.path.isfile(p): + with open(p) as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + if "=" in line: + key, val = line.split("=", 1) + os.environ.setdefault(key, val) + break + + +load_env() +db.init_db() +security.start_key_poller() + + +# ── IP anonymization (from POC) ───────────────────────────────────────────── + + +def is_public_ip(ip_str: str) -> bool: + """Returns True if the IP is a global, public address.""" + try: + ip = ipaddress.ip_address(ip_str) + return ip.is_global and not ip.is_private and not ip.is_loopback + except ValueError: + return False + + +def anonymize_ip(ip: str) -> str: + """ + Private RFC1918 addresses: reveal subnet, hide host octet. + Public addresses: replace with a short stable SHA hash fingerprint. + """ + # Strip CIDR notation if present + ip_bare = ip.split("/")[0] + parts = ip_bare.split(".") + if len(parts) != 4: + return "???" + + try: + first = int(parts[0]) + except ValueError: + return "???" + + if first == 10: + return f"10.{parts[1]}.*.*" + if first == 172 and 16 <= int(parts[1]) <= 31: + return f"172.{parts[1]}.*.*" + if first == 192 and parts[1] == "168": + return "192.168.*.*" + + token = hashlib.sha256(ip_bare.encode()).hexdigest()[:6] + return f"[pub:{token}]" + + +# ── API Endpoints ──────────────────────────────────────────────────────────── + + +@app.route("/status/healthcheck") def healthcheck(): return jsonify({"status": "ok"}) -if __name__ == '__main__': - app.run(host='0.0.0.0', port=5086) + +@app.route("/status/updates", methods=["POST"]) +def receive_update(): + """Main ingress point for agent reports (direct or relayed).""" + report = request.get_json(silent=True) + if not report: + return jsonify({"error": "invalid_json"}), 400 + + # Check if this is a relay envelope + if "relay_path" in report and "payload" in report: + inner = report["payload"] + valid, err = security.validate_report(inner) + if not valid: + logger.warning("Security check failed for relayed report: %s " + "(agent: %s, from: %s)", + err, inner.get("agent_id"), request.remote_addr) + return jsonify({"error": err}), 403 + processor.process_relay(report) + return jsonify({"status": "ok", "relayed": True}) + + # Direct report + valid, err = security.validate_report(report) + if not valid: + logger.warning("Security check failed: %s (agent: %s, from: %s)", + err, report.get("agent_id"), request.remote_addr) + status_code = 403 + if err == "hmac_invalid": + status_code = 401 + return jsonify({"error": err}), status_code + + processor.process_report(report) + return jsonify({"status": "ok"}) + + +@app.route("/status/register", methods=["POST"]) +def register_agent(): + """ + Agent registration endpoint. + Agents auto-register via /status/updates, but this allows explicit + first-contact registration as well. + """ + data = request.get_json(silent=True) + if not data: + return jsonify({"error": "invalid_json"}), 400 + + agent_id = data.get("agent_id", "") + hostname = data.get("hostname", "unknown") + fleet_id = data.get("fleet_id", "") + + if not agent_id: + return jsonify({"error": "missing_agent_id"}), 400 + + db.upsert_agent(agent_id, hostname, data.get("agent_version", 1), + fleet_id, 0) + logger.info("Registered agent %s (%s)", agent_id, hostname) + return jsonify({"status": "registered", "agent_id": agent_id}) + + +@app.route("/status/alarms") +def get_alarms(): + """Return all active alarms.""" + alarms = db.get_active_alarms() + return jsonify(alarms) + + +@app.route("/status/alarms//dismiss", methods=["POST"]) +def dismiss_alarm(alarm_id): + """Dismiss a specific alarm.""" + db.dismiss_alarm(alarm_id) + return jsonify({"status": "dismissed"}) + + +@app.route("/status/agents") +def list_agents(): + """List all known agents and their status.""" + db.mark_stale_agents() + agents = db.get_all_agents() + + # Enrich with interface data + for agent in agents: + ifaces = db.get_agent_interfaces(agent["agent_id"]) + agent["interfaces"] = ifaces + + return jsonify(agents) + + +@app.route("/status/admin/reset", methods=["POST"]) +def admin_reset(): + """Reset a specific agent or the entire fleet.""" + data = request.get_json(silent=True) or {} + agent_id = data.get("agent_id") + + if agent_id: + db.reset_agent(agent_id) + return jsonify({"status": "reset", "agent_id": agent_id}) + else: + db.reset_all() + return jsonify({"status": "full_reset"}) + + +# ── Visualization data endpoint ───────────────────────────────────────────── + + +@app.route("/status/data") +def graph_data(): + """Return nodes and edges for the vis-network graph.""" + db.mark_stale_agents() + agents = db.get_all_agents() + edges = db.get_all_edges() + now = time.time() + + nodes = [] + for a in agents: + ifaces = db.get_agent_interfaces(a["agent_id"]) + + # Determine level: 0 = has public IP (hub), 1 = private only (spoke) + level = 1 + anon_ips = [] + for iface in ifaces: + try: + addrs = json.loads(iface.get("addresses_json", "[]")) + except (json.JSONDecodeError, TypeError): + addrs = [] + for addr in addrs: + ip_bare = addr.split("/")[0] + if is_public_ip(ip_bare): + level = 0 + anon_ips.append(anonymize_ip(addr)) + + age = int(now - a["last_seen_at"]) + is_alive = age < 300 + + nodes.append({ + "id": a["agent_id"], + "label": f"{a['hostname'].upper()}\n" + f"{', '.join(anon_ips[:3])}", + "level": level, + "color": "#2ecc71" if is_alive else "#e74c3c", + "title": f"Agent: {a['agent_id']}\n" + f"Last seen: {age}s ago\n" + f"Status: {a['status']}", + }) + + vis_edges = [] + seen_pairs: set[tuple] = set() + for e in edges: + pair = tuple(sorted([e["from_agent_id"], e["to_agent_id"]])) + if pair in seen_pairs: + continue + seen_pairs.add(pair) + + meta = json.loads(e.get("metadata", "{}")) + handshake = meta.get("latest_handshake", 0) + + if handshake == 0: + color = "#95a5a6" + elif (now - handshake) <= 120: + color = "#2ecc71" + elif (now - handshake) <= 86400: + color = "#f1c40f" + else: + color = "#e74c3c" + + vis_edges.append({ + "from": e["from_agent_id"], + "to": e["to_agent_id"], + "label": e["edge_type"], + "color": color, + "width": 3 if color == "#e74c3c" else 2, + "title": f"{e['edge_type']} link\n" + f"Handshake: {int(now - handshake)}s ago" + if handshake else f"{e['edge_type']} link\nNo handshake", + }) + + return jsonify({"nodes": nodes, "edges": vis_edges}) + + +# ── Web UI (vis-network topology map) ──────────────────────────────────────── + + +@app.route("/status") +def index(): + return render_template_string(HTML_TEMPLATE) + + +HTML_TEMPLATE = """ + + + Kattila — Network Map + + + + +
+ Kattila Network Map +
Active (< 2 min)
+
Stale (> 2 min)
+
Broken (> 24 h)
+
No handshake
+
+
+
+
+ + +""" + + +# ── Main ───────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + app.run(host="0.0.0.0", port=5086) diff --git a/manager/db.py b/manager/db.py new file mode 100644 index 0000000..5ca777e --- /dev/null +++ b/manager/db.py @@ -0,0 +1,308 @@ +""" +Database layer for Kattila Manager. +SQLite with WAL mode. All schemas from DESIGN.md are created on init. +""" + +import sqlite3 +import json +import time +import threading +import logging + +logger = logging.getLogger(__name__) + +DB_PATH = "kattila_manager.db" + +_local = threading.local() + + +def get_conn() -> sqlite3.Connection: + """Return a thread-local SQLite connection.""" + if not hasattr(_local, "conn") or _local.conn is None: + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL;") + conn.execute("PRAGMA foreign_keys=ON;") + _local.conn = conn + return _local.conn + + +def init_db(): + """Create all tables and indexes if they don't exist.""" + conn = get_conn() + conn.executescript(""" + CREATE TABLE IF NOT EXISTS agents ( + agent_id TEXT PRIMARY KEY, + hostname TEXT NOT NULL, + agent_version INTEGER NOT NULL, + fleet_id TEXT NOT NULL, + registered_at INTEGER NOT NULL, + last_seen_at INTEGER NOT NULL, + last_tick INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'online' + ); + CREATE INDEX IF NOT EXISTS idx_agents_last_seen ON agents(last_seen_at); + + CREATE TABLE IF NOT EXISTS reports ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + agent_id TEXT NOT NULL, + tick INTEGER NOT NULL, + timestamp INTEGER NOT NULL, + report_type TEXT NOT NULL, + report_json TEXT NOT NULL, + received_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + FOREIGN KEY (agent_id) REFERENCES agents(agent_id) ON DELETE CASCADE + ); + CREATE UNIQUE INDEX IF NOT EXISTS idx_reports_agent_tick ON reports(agent_id, tick); + + CREATE TABLE IF NOT EXISTS topology_edges ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + from_agent_id TEXT NOT NULL, + to_agent_id TEXT NOT NULL, + edge_type TEXT NOT NULL, + metadata TEXT DEFAULT '{}', + last_seen INTEGER NOT NULL, + is_active INTEGER NOT NULL DEFAULT 1 + ); + CREATE UNIQUE INDEX IF NOT EXISTS idx_edges_pair ON topology_edges(from_agent_id, to_agent_id, edge_type); + + CREATE TABLE IF NOT EXISTS agent_interfaces ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + agent_id TEXT NOT NULL, + interface_name TEXT NOT NULL, + mac_address TEXT, + addresses_json TEXT, + is_virtual INTEGER NOT NULL DEFAULT 0, + vpn_type TEXT, + last_seen_at INTEGER NOT NULL, + FOREIGN KEY (agent_id) REFERENCES agents(agent_id) ON DELETE CASCADE + ); + CREATE UNIQUE INDEX IF NOT EXISTS idx_agent_interfaces ON agent_interfaces(agent_id, interface_name); + + CREATE TABLE IF NOT EXISTS alarms ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + agent_id TEXT NOT NULL, + alarm_type TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'active', + details_json TEXT DEFAULT '{}', + created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + dismissed_at INTEGER, + FOREIGN KEY (agent_id) REFERENCES agents(agent_id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_alarms_agent_status ON alarms(agent_id, status); + """) + conn.commit() + logger.info("Database initialized at %s", DB_PATH) + + +# ── Agent operations ───────────────────────────────────────────────────────── + + +def upsert_agent(agent_id: str, hostname: str, agent_version: int, + fleet_id: str, tick: int): + """Insert or update an agent record (auto-register on first valid report).""" + now = int(time.time()) + conn = get_conn() + conn.execute(""" + INSERT INTO agents (agent_id, hostname, agent_version, fleet_id, + registered_at, last_seen_at, last_tick, status) + VALUES (?, ?, ?, ?, ?, ?, ?, 'online') + ON CONFLICT(agent_id) DO UPDATE SET + hostname = excluded.hostname, + agent_version = excluded.agent_version, + fleet_id = excluded.fleet_id, + last_seen_at = excluded.last_seen_at, + last_tick = excluded.last_tick, + status = 'online' + """, (agent_id, hostname, agent_version, fleet_id, now, now, tick)) + conn.commit() + + +def get_all_agents() -> list[dict]: + conn = get_conn() + rows = conn.execute("SELECT * FROM agents ORDER BY last_seen_at DESC").fetchall() + return [dict(r) for r in rows] + + +def get_agent(agent_id: str) -> dict | None: + conn = get_conn() + row = conn.execute("SELECT * FROM agents WHERE agent_id = ?", + (agent_id,)).fetchone() + return dict(row) if row else None + + +def delete_agent(agent_id: str): + conn = get_conn() + conn.execute("DELETE FROM agents WHERE agent_id = ?", (agent_id,)) + conn.commit() + + +def mark_stale_agents(timeout_seconds: int = 120): + """Mark agents as offline if not seen within timeout.""" + cutoff = int(time.time()) - timeout_seconds + conn = get_conn() + conn.execute(""" + UPDATE agents SET status = 'offline' + WHERE last_seen_at < ? AND status != 'offline' + """, (cutoff,)) + conn.commit() + + +# ── Report operations ──────────────────────────────────────────────────────── + + +def insert_report(agent_id: str, tick: int, timestamp: int, + report_type: str, report_json: str): + conn = get_conn() + try: + conn.execute(""" + INSERT INTO reports (agent_id, tick, timestamp, report_type, report_json) + VALUES (?, ?, ?, ?, ?) + """, (agent_id, tick, timestamp, report_type, report_json)) + conn.commit() + except sqlite3.IntegrityError: + logger.warning("Duplicate report from %s tick %d — skipping", agent_id, tick) + + +# ── Interface operations ───────────────────────────────────────────────────── + + +def update_interfaces(agent_id: str, interfaces: list[dict]) -> list[str]: + """ + Update agent_interfaces table. Returns list of change descriptions + (for alarm generation). + """ + now = int(time.time()) + conn = get_conn() + changes = [] + + # Get current known interfaces for this agent + existing = conn.execute( + "SELECT interface_name FROM agent_interfaces WHERE agent_id = ?", + (agent_id,) + ).fetchall() + existing_names = {r["interface_name"] for r in existing} + + reported_names = set() + for iface in interfaces: + name = iface.get("name", "") + if not name: + continue + reported_names.add(name) + + addresses = json.dumps(iface.get("addresses", [])) + is_virtual = 1 if iface.get("is_virtual", False) else 0 + vpn_type = iface.get("vpn_type") + + conn.execute(""" + INSERT INTO agent_interfaces + (agent_id, interface_name, mac_address, addresses_json, + is_virtual, vpn_type, last_seen_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(agent_id, interface_name) DO UPDATE SET + mac_address = excluded.mac_address, + addresses_json = excluded.addresses_json, + is_virtual = excluded.is_virtual, + vpn_type = excluded.vpn_type, + last_seen_at = excluded.last_seen_at + """, (agent_id, name, iface.get("mac", ""), addresses, + is_virtual, vpn_type, now)) + + # Detect disappeared interfaces + disappeared = existing_names - reported_names + for ifname in disappeared: + changes.append(f"Interface '{ifname}' disappeared") + + # Detect new interfaces + appeared = reported_names - existing_names + for ifname in appeared: + if existing_names: # only alarm if agent had prior interfaces + changes.append(f"New interface '{ifname}' appeared") + + conn.commit() + return changes + + +def get_agent_interfaces(agent_id: str) -> list[dict]: + conn = get_conn() + rows = conn.execute( + "SELECT * FROM agent_interfaces WHERE agent_id = ?", (agent_id,) + ).fetchall() + return [dict(r) for r in rows] + + +# ── Topology edge operations ──────────────────────────────────────────────── + + +def upsert_edge(from_agent_id: str, to_agent_id: str, edge_type: str, + metadata: dict | None = None): + now = int(time.time()) + meta_json = json.dumps(metadata or {}) + conn = get_conn() + conn.execute(""" + INSERT INTO topology_edges + (from_agent_id, to_agent_id, edge_type, metadata, last_seen, is_active) + VALUES (?, ?, ?, ?, ?, 1) + ON CONFLICT(from_agent_id, to_agent_id, edge_type) DO UPDATE SET + metadata = excluded.metadata, + last_seen = excluded.last_seen, + is_active = 1 + """, (from_agent_id, to_agent_id, edge_type, meta_json, now)) + conn.commit() + + +def get_all_edges() -> list[dict]: + conn = get_conn() + rows = conn.execute("SELECT * FROM topology_edges WHERE is_active = 1").fetchall() + return [dict(r) for r in rows] + + +# ── Alarm operations ──────────────────────────────────────────────────────── + + +def create_alarm(agent_id: str, alarm_type: str, details: dict | None = None): + conn = get_conn() + conn.execute(""" + INSERT INTO alarms (agent_id, alarm_type, details_json) + VALUES (?, ?, ?) + """, (agent_id, alarm_type, json.dumps(details or {}))) + conn.commit() + logger.info("ALARM [%s] %s: %s", agent_id, alarm_type, details) + + +def get_active_alarms() -> list[dict]: + conn = get_conn() + rows = conn.execute( + "SELECT * FROM alarms WHERE status = 'active' ORDER BY created_at DESC" + ).fetchall() + return [dict(r) for r in rows] + + +def dismiss_alarm(alarm_id: int): + now = int(time.time()) + conn = get_conn() + conn.execute(""" + UPDATE alarms SET status = 'dismissed', dismissed_at = ? + WHERE id = ? + """, (now, alarm_id)) + conn.commit() + + +def reset_agent(agent_id: str): + """Wipe all state for a specific agent.""" + conn = get_conn() + conn.execute("DELETE FROM agents WHERE agent_id = ?", (agent_id,)) + conn.commit() + logger.info("Reset all state for agent %s", agent_id) + + +def reset_all(): + """Wipe the entire fleet state.""" + conn = get_conn() + conn.execute("DELETE FROM topology_edges") + conn.execute("DELETE FROM agent_interfaces") + conn.execute("DELETE FROM reports") + conn.execute("DELETE FROM alarms") + conn.execute("DELETE FROM agents") + conn.commit() + logger.info("Full fleet reset executed") diff --git a/manager/processor.py b/manager/processor.py new file mode 100644 index 0000000..77ba465 --- /dev/null +++ b/manager/processor.py @@ -0,0 +1,105 @@ +""" +Data processor for Kattila Manager. +Core business logic: processes validated agent reports and updates +the database state (agents, interfaces, topology, alarms). +""" + +import json +import logging + +import db + +logger = logging.getLogger(__name__) + + +def process_report(report: dict): + """ + Process a validated agent report. + Called after security checks have passed. + """ + agent_id = report["agent_id"] + tick = report["tick"] + timestamp = report["timestamp"] + agent_version = report.get("agent_version", 0) + fleet_id = report.get("fleet_id", "") + report_type = report.get("type", "report") + data = report.get("data", {}) + + hostname = data.get("hostname", "unknown") + + # 1. Upsert agent (auto-register) + db.upsert_agent(agent_id, hostname, agent_version, fleet_id, tick) + logger.info("Processed report from %s (%s) tick=%d", agent_id, hostname, tick) + + # 2. Store raw report for auditing + db.insert_report(agent_id, tick, timestamp, report_type, + json.dumps(report)) + + # 3. Update interfaces and detect changes + interfaces = data.get("interfaces", []) + if interfaces: + changes = db.update_interfaces(agent_id, interfaces) + for change in changes: + db.create_alarm(agent_id, "interface_change", + {"description": change}) + + # 4. Update topology edges from wg_peers + wg_peers = data.get("wg_peers", []) + _update_topology(agent_id, wg_peers) + + +def _update_topology(agent_id: str, wg_peers: list[dict]): + """ + Infer topology edges from wireguard peer data. + Cross-references peer public keys against known agent interfaces + to create edges. + """ + # Build a lookup of pubkey -> agent_id from all known agent interfaces + # This is done per-report for simplicity; could be cached for performance + all_agents = db.get_all_agents() + + # For each wg peer, try to match the public key to another known agent + for peer in wg_peers: + pubkey = peer.get("public_key", "") + endpoint = peer.get("endpoint", "") + iface = peer.get("interface", "") + + # Store the edge with metadata even if we can't resolve the target + # agent yet — the metadata (pubkey, endpoint) is still valuable + metadata = { + "public_key": pubkey, + "endpoint": endpoint, + "interface": iface, + "transfer_rx": peer.get("transfer_rx", 0), + "transfer_tx": peer.get("transfer_tx", 0), + "latest_handshake": peer.get("latest_handshake", 0), + } + + # The target is unknown until we can cross-reference pubkeys + # For now, use the pubkey hash as a placeholder target ID + target_id = f"pubkey:{pubkey[:16]}" if pubkey else "unknown" + + db.upsert_edge(agent_id, target_id, "wireguard", metadata) + + +def process_relay(envelope: dict): + """ + Process a relayed report envelope. + Extracts the inner payload and processes it as a normal report + after recording the relay path. + """ + relay_path = envelope.get("relay_path", []) + payload = envelope.get("payload", {}) + + if not payload: + logger.warning("Empty relay payload received") + return + + logger.info("Processing relayed report via path: %s", " -> ".join(relay_path)) + + # Process the inner report normally + process_report(payload) + + # Also record relay edges in the topology + for i in range(len(relay_path) - 1): + db.upsert_edge(relay_path[i], relay_path[i + 1], "relay") diff --git a/manager/requirements.txt b/manager/requirements.txt index 11955a8..bc4bf33 100644 --- a/manager/requirements.txt +++ b/manager/requirements.txt @@ -1 +1,2 @@ Flask>=3.0.0 +dnspython>=2.4.0 diff --git a/manager/security.py b/manager/security.py new file mode 100644 index 0000000..660e768 --- /dev/null +++ b/manager/security.py @@ -0,0 +1,186 @@ +""" +Security layer for Kattila Manager. +Handles PSK fetching via DNS TXT, HMAC verification, nonce caching, +and timestamp skew checking. +""" + +import hashlib +import hmac +import json +import logging +import os +import threading +import time +from collections import OrderedDict + +logger = logging.getLogger(__name__) + +# PSK history: current + 2 previous +_psk_lock = threading.Lock() +_psk_history: list[str] = [] # most recent first, max 3 entries + +# Nonce sliding window +_nonce_lock = threading.Lock() +_nonce_cache: OrderedDict[str, float] = OrderedDict() +NONCE_CACHE_SIZE = 120 + +# Maximum clock skew in seconds +MAX_CLOCK_SKEW = 600 # 10 minutes + + +def _resolve_txt(dns_name: str) -> str | None: + """Resolve DNS TXT record. Uses dnspython if available, falls back to subprocess.""" + try: + import dns.resolver + answers = dns.resolver.resolve(dns_name, "TXT") + for rdata in answers: + txt = rdata.to_text().strip('"') + return txt + except ImportError: + # Fallback: use dig + import subprocess + try: + result = subprocess.run( + ["dig", "+short", "TXT", dns_name], + capture_output=True, text=True, timeout=5 + ) + val = result.stdout.strip().strip('"') + if val: + return val + except (subprocess.TimeoutExpired, FileNotFoundError): + pass + except Exception as e: + logger.error("DNS TXT lookup failed for %s: %s", dns_name, e) + return None + + +def fetch_psk(): + """Fetch the PSK from the DNS TXT record and update the history.""" + dns_name = os.environ.get("DNS", "") + if not dns_name: + logger.warning("security: No DNS configured for PSK lookup") + return + + key = _resolve_txt(dns_name) + if not key: + logger.warning("security: Could not resolve PSK from DNS") + return + + with _psk_lock: + if not _psk_history or _psk_history[0] != key: + _psk_history.insert(0, key) + # Keep at most 3 entries (current + 2 previous) + while len(_psk_history) > 3: + _psk_history.pop() + logger.info("security: PSK updated (history depth: %d)", len(_psk_history)) + + +def start_key_poller(): + """Fetch PSK immediately and then every hour in a background thread.""" + fetch_psk() + + def _poll(): + while True: + time.sleep(3600) + fetch_psk() + + t = threading.Thread(target=_poll, daemon=True) + t.start() + + +def get_current_psk() -> str: + with _psk_lock: + return _psk_history[0] if _psk_history else "" + + +def get_psk_history() -> list[str]: + with _psk_lock: + return list(_psk_history) + + +def fleet_id_for_psk(psk: str) -> str: + """Generate the fleet_id (SHA256 hash) for a given PSK.""" + return hashlib.sha256(psk.encode()).hexdigest() + + +def verify_hmac(data_payload, provided_hmac: str) -> tuple[bool, str | None]: + """ + Verify the HMAC against all known PSKs (current + 2 previous). + Returns (valid, matched_psk) or (False, None). + """ + psks = get_psk_history() + if not psks: + logger.error("security: No PSKs available for verification") + return False, None + + data_bytes = json.dumps(data_payload, separators=(",", ":"), + sort_keys=True).encode() + + for psk in psks: + expected = hmac.new(psk.encode(), data_bytes, hashlib.sha256).hexdigest() + if hmac.compare_digest(expected, provided_hmac): + if psk != psks[0]: + logger.warning("security: Agent using old PSK (not current)") + return True, psk + + return False, None + + +def check_nonce(nonce: str) -> bool: + """ + Check if a nonce has been seen before (replay protection). + Returns True if the nonce is fresh (not seen), False if it's a replay. + """ + with _nonce_lock: + if nonce in _nonce_cache: + return False + + _nonce_cache[nonce] = time.time() + + # Evict oldest entries if over the limit + while len(_nonce_cache) > NONCE_CACHE_SIZE: + _nonce_cache.popitem(last=False) + + return True + + +def check_timestamp(timestamp: int) -> bool: + """Check if the timestamp is within the acceptable clock skew.""" + now = int(time.time()) + return abs(now - timestamp) <= MAX_CLOCK_SKEW + + +def validate_report(report: dict) -> tuple[bool, str]: + """ + Full validation pipeline for an incoming report. + Returns (valid, error_message). + """ + # 1. Timestamp check + ts = report.get("timestamp", 0) + if not check_timestamp(ts): + return False, "timestamp_skew" + + # 2. Nonce replay check + nonce = report.get("nonce", "") + if not nonce: + return False, "missing_nonce" + if not check_nonce(nonce): + return False, "replay_detected" + + # 3. HMAC verification + provided_hmac = report.get("hmac", "") + data = report.get("data") + if not data or not provided_hmac: + return False, "missing_hmac_or_data" + + valid, matched_psk = verify_hmac(data, provided_hmac) + if not valid: + return False, "hmac_invalid" + + # 4. Fleet ID consistency check + if matched_psk: + expected_fleet = fleet_id_for_psk(matched_psk) + if report.get("fleet_id") != expected_fleet: + return False, "fleet_id_mismatch" + + return True, "ok"