Begings for Manager.

This commit is contained in:
Kalzu Rekku
2026-04-17 20:15:24 +03:00
parent 99e0e0208c
commit d0cddec45a
7 changed files with 1033 additions and 4 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
examples
.env
.vscode

View File

@@ -0,0 +1,51 @@
# Kattila Manager Implementation Plan
This document outlines the detailed architecture and implementation steps for the Python-based Kattila Manager.
## Overview
The Manager is a Python/Flask application that maintains a centralized SQLite (WAL mode) database. It provides an HTTP API to receive pushed reports from the agents, securely verifies their HMAC-SHA256 signatures, prevents replay attacks using a nonce sliding window cache, and updates the local network topology and alarm states based on the received data.
Look at the kattila.poc for ideas how to implement ip address anonymization. And tips on how the map chould be drawn.
## Proposed Architecture / Modules
### 1. Database Layer (`db.py`)
- Initializes an `sqlite3` connection with `PRAGMA journal_mode=WAL;`.
- Automatically executes the `CREATE TABLE` and `CREATE INDEX` SQL schemas defined in the DESIGN document on startup.
- Exposes structured data access methods for other modules (e.g., `upsert_agent`, `insert_report`, `update_interfaces`, `update_edges`, `create_alarm`).
### 2. Security Layer (`security.py`)
- **Key Fetching**: A background thread or periodic polling function that utilizes Python's DNS resolver to get the Bootstrap PSK from the TXT record, keeping track of the current PSK and the two previous PSKs.
- **HMAC Verification**: Parses incoming JSON, re-serializes the `data` payload identically, and checks if the provided HMAC matches one of the known PSKs.
- **Nonce Cache**: Maintains a memory-bound cache (e.g., `collections.OrderedDict`) of the last 120 nonces to prevent replay attacks.
- **Time Skew**: Rejects reports whose `timestamp` deviates by more than 10 minutes from the Manager's local clock.
### 3. Data Processor (`processor.py`)
This is the core business logic engine invoked whenever a valid `/status/updates` payload hits the API:
- **Agents**: Upsert the `agent_id` into the `agents` table and update the `last_seen_at` heartbeat.
- **Reports**: Store the raw envelope in `reports` for auditing.
- **Interfaces**: Compare the payload's `interfaces` against `agent_interfaces`. If new interfaces appear or old ones disappear, update the DB and potentially trigger an alarm (e.g., "Interface eth0 went down").
- **Topology Edges**: Iterate over `wg_peers`. For each peer, create or update a link in `topology_edges` specifying `edge_type='wireguard'`.
### 4. API Layer (`api.py` or `app.py`)
- A Flask Blueprint or App defining:
- `POST /status/updates`: Main ingress. Parses JSON -> Verifies HMAC & Nonce -> Calls Processor -> Returns OK. Unwraps `relay_path` envelopes iteratively if needed.
- `POST /status/register`: Allows new agents to announce their generated ID.
- `GET /status/healthcheck`: Returns `{status: ok}`.
- `GET /status/alarms`: JSON list of active alarms.
- `GET /status/agents`: JSON dump of the fleet matrix.
- `POST /status/admin/reset`: Clears specific agent topology state.
## User Review Required
> [!IMPORTANT]
> - Since Python's standard library doesn't organically support fetching DNS TXT records easily, I plan to add `dnspython` to `requirements.txt`. Is this acceptable?
> - The agent successfully generates its own secure hexadecimal `agent_id` locally. Instead of the Manager strictly mandating `/status/register` before everything else, is it acceptable for the Manager to dynamically "auto-register" (upsert) unknown `agent_id`s directly when they push a valid `/status/updates` report? (It simplifies bootstrapping considerably).
> - When generating alarms, should we just log simple messages like "Interface X disappeared" and keep the alarm `active` until a human clears it, or should the alarms auto-dismiss when the issue resolves (e.g., interface comes back)?
## Verification Plan
### Automated testing
- Run basic `pytest` (if available) or dummy scripts pushing forged payloads and ensuring the security layer rejects invalid HMACs and duplicate Nonces.
### Manual Verification
- Start the Flask app.
- Hit `/status/healthcheck` with curl.
- Send a mock successful JSON representation of the `wg_peers` and `interfaces` using exactly the PSK from the test `.env`. Check that `kattila_manager.db` correctly generated the relational graph.

View File

@@ -1,10 +1,385 @@
from flask import Flask, jsonify """
Kattila Manager — Flask API and Web UI.
Receives agent reports, serves fleet status, and renders the network topology map.
"""
import hashlib
import ipaddress
import json
import logging
import os
import time
from flask import Flask, jsonify, request, render_template_string
import db
import security
import processor
# ── Logging ──────────────────────────────────────────────────────────────────
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger(__name__)
# ── App init ─────────────────────────────────────────────────────────────────
app = Flask(__name__) app = Flask(__name__)
@app.route('/status/healthcheck')
def load_env():
"""Load .env file from parent directory."""
paths = [".env", "../.env"]
for p in paths:
if os.path.isfile(p):
with open(p) as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" in line:
key, val = line.split("=", 1)
os.environ.setdefault(key, val)
break
load_env()
db.init_db()
security.start_key_poller()
# ── IP anonymization (from POC) ─────────────────────────────────────────────
def is_public_ip(ip_str: str) -> bool:
"""Returns True if the IP is a global, public address."""
try:
ip = ipaddress.ip_address(ip_str)
return ip.is_global and not ip.is_private and not ip.is_loopback
except ValueError:
return False
def anonymize_ip(ip: str) -> str:
"""
Private RFC1918 addresses: reveal subnet, hide host octet.
Public addresses: replace with a short stable SHA hash fingerprint.
"""
# Strip CIDR notation if present
ip_bare = ip.split("/")[0]
parts = ip_bare.split(".")
if len(parts) != 4:
return "???"
try:
first = int(parts[0])
except ValueError:
return "???"
if first == 10:
return f"10.{parts[1]}.*.*"
if first == 172 and 16 <= int(parts[1]) <= 31:
return f"172.{parts[1]}.*.*"
if first == 192 and parts[1] == "168":
return "192.168.*.*"
token = hashlib.sha256(ip_bare.encode()).hexdigest()[:6]
return f"[pub:{token}]"
# ── API Endpoints ────────────────────────────────────────────────────────────
@app.route("/status/healthcheck")
def healthcheck(): def healthcheck():
return jsonify({"status": "ok"}) return jsonify({"status": "ok"})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5086) @app.route("/status/updates", methods=["POST"])
def receive_update():
"""Main ingress point for agent reports (direct or relayed)."""
report = request.get_json(silent=True)
if not report:
return jsonify({"error": "invalid_json"}), 400
# Check if this is a relay envelope
if "relay_path" in report and "payload" in report:
inner = report["payload"]
valid, err = security.validate_report(inner)
if not valid:
logger.warning("Security check failed for relayed report: %s "
"(agent: %s, from: %s)",
err, inner.get("agent_id"), request.remote_addr)
return jsonify({"error": err}), 403
processor.process_relay(report)
return jsonify({"status": "ok", "relayed": True})
# Direct report
valid, err = security.validate_report(report)
if not valid:
logger.warning("Security check failed: %s (agent: %s, from: %s)",
err, report.get("agent_id"), request.remote_addr)
status_code = 403
if err == "hmac_invalid":
status_code = 401
return jsonify({"error": err}), status_code
processor.process_report(report)
return jsonify({"status": "ok"})
@app.route("/status/register", methods=["POST"])
def register_agent():
"""
Agent registration endpoint.
Agents auto-register via /status/updates, but this allows explicit
first-contact registration as well.
"""
data = request.get_json(silent=True)
if not data:
return jsonify({"error": "invalid_json"}), 400
agent_id = data.get("agent_id", "")
hostname = data.get("hostname", "unknown")
fleet_id = data.get("fleet_id", "")
if not agent_id:
return jsonify({"error": "missing_agent_id"}), 400
db.upsert_agent(agent_id, hostname, data.get("agent_version", 1),
fleet_id, 0)
logger.info("Registered agent %s (%s)", agent_id, hostname)
return jsonify({"status": "registered", "agent_id": agent_id})
@app.route("/status/alarms")
def get_alarms():
"""Return all active alarms."""
alarms = db.get_active_alarms()
return jsonify(alarms)
@app.route("/status/alarms/<int:alarm_id>/dismiss", methods=["POST"])
def dismiss_alarm(alarm_id):
"""Dismiss a specific alarm."""
db.dismiss_alarm(alarm_id)
return jsonify({"status": "dismissed"})
@app.route("/status/agents")
def list_agents():
"""List all known agents and their status."""
db.mark_stale_agents()
agents = db.get_all_agents()
# Enrich with interface data
for agent in agents:
ifaces = db.get_agent_interfaces(agent["agent_id"])
agent["interfaces"] = ifaces
return jsonify(agents)
@app.route("/status/admin/reset", methods=["POST"])
def admin_reset():
"""Reset a specific agent or the entire fleet."""
data = request.get_json(silent=True) or {}
agent_id = data.get("agent_id")
if agent_id:
db.reset_agent(agent_id)
return jsonify({"status": "reset", "agent_id": agent_id})
else:
db.reset_all()
return jsonify({"status": "full_reset"})
# ── Visualization data endpoint ─────────────────────────────────────────────
@app.route("/status/data")
def graph_data():
"""Return nodes and edges for the vis-network graph."""
db.mark_stale_agents()
agents = db.get_all_agents()
edges = db.get_all_edges()
now = time.time()
nodes = []
for a in agents:
ifaces = db.get_agent_interfaces(a["agent_id"])
# Determine level: 0 = has public IP (hub), 1 = private only (spoke)
level = 1
anon_ips = []
for iface in ifaces:
try:
addrs = json.loads(iface.get("addresses_json", "[]"))
except (json.JSONDecodeError, TypeError):
addrs = []
for addr in addrs:
ip_bare = addr.split("/")[0]
if is_public_ip(ip_bare):
level = 0
anon_ips.append(anonymize_ip(addr))
age = int(now - a["last_seen_at"])
is_alive = age < 300
nodes.append({
"id": a["agent_id"],
"label": f"<b>{a['hostname'].upper()}</b>\n"
f"{', '.join(anon_ips[:3])}",
"level": level,
"color": "#2ecc71" if is_alive else "#e74c3c",
"title": f"Agent: {a['agent_id']}\n"
f"Last seen: {age}s ago\n"
f"Status: {a['status']}",
})
vis_edges = []
seen_pairs: set[tuple] = set()
for e in edges:
pair = tuple(sorted([e["from_agent_id"], e["to_agent_id"]]))
if pair in seen_pairs:
continue
seen_pairs.add(pair)
meta = json.loads(e.get("metadata", "{}"))
handshake = meta.get("latest_handshake", 0)
if handshake == 0:
color = "#95a5a6"
elif (now - handshake) <= 120:
color = "#2ecc71"
elif (now - handshake) <= 86400:
color = "#f1c40f"
else:
color = "#e74c3c"
vis_edges.append({
"from": e["from_agent_id"],
"to": e["to_agent_id"],
"label": e["edge_type"],
"color": color,
"width": 3 if color == "#e74c3c" else 2,
"title": f"{e['edge_type']} link\n"
f"Handshake: {int(now - handshake)}s ago"
if handshake else f"{e['edge_type']} link\nNo handshake",
})
return jsonify({"nodes": nodes, "edges": vis_edges})
# ── Web UI (vis-network topology map) ────────────────────────────────────────
@app.route("/status")
def index():
return render_template_string(HTML_TEMPLATE)
HTML_TEMPLATE = """<!DOCTYPE html>
<html>
<head>
<title>Kattila — Network Map</title>
<script src="https://unpkg.com/vis-network/standalone/umd/vis-network.min.js"></script>
<style>
*, *::before, *::after { box-sizing: border-box; }
body {
font-family: 'Segoe UI', system-ui, sans-serif;
background: #0d1117; color: #e6edf3; margin: 0;
}
#map { width: 100vw; height: 100vh; }
#hud {
position: absolute; top: 12px; left: 12px;
background: rgba(13,17,23,0.88); backdrop-filter: blur(6px);
border: 1px solid #30363d; border-radius: 8px;
padding: 12px 16px; font-size: 12px; z-index: 10; min-width: 160px;
}
#hud b {
display: block; margin-bottom: 6px;
font-size: 13px; color: #58a6ff;
}
.dot {
display: inline-block; width: 9px; height: 9px;
border-radius: 50%; margin-right: 6px;
}
.row { margin: 3px 0; }
#clock {
margin-top: 10px; font-size: 10px; color: #8b949e;
border-top: 1px solid #30363d; padding-top: 6px;
}
#stats {
margin-top: 6px; font-size: 10px; color: #8b949e;
}
</style>
</head>
<body>
<div id="hud">
<b>Kattila Network Map</b>
<div class="row"><span class="dot" style="background:#2ecc71"></span>Active (&lt; 2 min)</div>
<div class="row"><span class="dot" style="background:#f1c40f"></span>Stale (&gt; 2 min)</div>
<div class="row"><span class="dot" style="background:#e74c3c"></span>Broken (&gt; 24 h)</div>
<div class="row"><span class="dot" style="background:#95a5a6"></span>No handshake</div>
<div id="stats"></div>
<div id="clock"></div>
</div>
<div id="map"></div>
<script>
const nodes = new vis.DataSet();
const edges = new vis.DataSet();
const net = new vis.Network(
document.getElementById('map'),
{ nodes, edges },
{
layout: {
hierarchical: {
direction: 'UD',
nodeSpacing: 300,
levelSeparation: 200
}
},
physics: false,
nodes: {
shape: 'box',
font: { multi: 'html', color: '#e6edf3', size: 13 },
margin: 10,
color: { background: '#161b22', border: '#30363d' },
},
edges: {
arrows: 'to',
font: { size: 10, color: '#8b949e', strokeWidth: 0 },
smooth: { type: 'curvedCW', roundness: 0.15 },
},
}
);
async function refresh() {
try {
const res = await fetch('/status/data');
const data = await res.json();
nodes.update(data.nodes);
edges.update(data.edges);
document.getElementById('stats').textContent =
data.nodes.length + ' nodes, ' + data.edges.length + ' links';
document.getElementById('clock').textContent =
'Updated ' + new Date().toLocaleTimeString();
} catch (e) {
document.getElementById('clock').textContent = 'Fetch error';
}
}
setInterval(refresh, 10000);
refresh();
</script>
</body>
</html>"""
# ── Main ─────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5086)

308
manager/db.py Normal file
View File

@@ -0,0 +1,308 @@
"""
Database layer for Kattila Manager.
SQLite with WAL mode. All schemas from DESIGN.md are created on init.
"""
import sqlite3
import json
import time
import threading
import logging
logger = logging.getLogger(__name__)
DB_PATH = "kattila_manager.db"
_local = threading.local()
def get_conn() -> sqlite3.Connection:
"""Return a thread-local SQLite connection."""
if not hasattr(_local, "conn") or _local.conn is None:
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL;")
conn.execute("PRAGMA foreign_keys=ON;")
_local.conn = conn
return _local.conn
def init_db():
"""Create all tables and indexes if they don't exist."""
conn = get_conn()
conn.executescript("""
CREATE TABLE IF NOT EXISTS agents (
agent_id TEXT PRIMARY KEY,
hostname TEXT NOT NULL,
agent_version INTEGER NOT NULL,
fleet_id TEXT NOT NULL,
registered_at INTEGER NOT NULL,
last_seen_at INTEGER NOT NULL,
last_tick INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'online'
);
CREATE INDEX IF NOT EXISTS idx_agents_last_seen ON agents(last_seen_at);
CREATE TABLE IF NOT EXISTS reports (
id INTEGER PRIMARY KEY AUTOINCREMENT,
agent_id TEXT NOT NULL,
tick INTEGER NOT NULL,
timestamp INTEGER NOT NULL,
report_type TEXT NOT NULL,
report_json TEXT NOT NULL,
received_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
FOREIGN KEY (agent_id) REFERENCES agents(agent_id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_reports_agent_tick ON reports(agent_id, tick);
CREATE TABLE IF NOT EXISTS topology_edges (
id INTEGER PRIMARY KEY AUTOINCREMENT,
from_agent_id TEXT NOT NULL,
to_agent_id TEXT NOT NULL,
edge_type TEXT NOT NULL,
metadata TEXT DEFAULT '{}',
last_seen INTEGER NOT NULL,
is_active INTEGER NOT NULL DEFAULT 1
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_edges_pair ON topology_edges(from_agent_id, to_agent_id, edge_type);
CREATE TABLE IF NOT EXISTS agent_interfaces (
id INTEGER PRIMARY KEY AUTOINCREMENT,
agent_id TEXT NOT NULL,
interface_name TEXT NOT NULL,
mac_address TEXT,
addresses_json TEXT,
is_virtual INTEGER NOT NULL DEFAULT 0,
vpn_type TEXT,
last_seen_at INTEGER NOT NULL,
FOREIGN KEY (agent_id) REFERENCES agents(agent_id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_agent_interfaces ON agent_interfaces(agent_id, interface_name);
CREATE TABLE IF NOT EXISTS alarms (
id INTEGER PRIMARY KEY AUTOINCREMENT,
agent_id TEXT NOT NULL,
alarm_type TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'active',
details_json TEXT DEFAULT '{}',
created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
dismissed_at INTEGER,
FOREIGN KEY (agent_id) REFERENCES agents(agent_id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_alarms_agent_status ON alarms(agent_id, status);
""")
conn.commit()
logger.info("Database initialized at %s", DB_PATH)
# ── Agent operations ─────────────────────────────────────────────────────────
def upsert_agent(agent_id: str, hostname: str, agent_version: int,
fleet_id: str, tick: int):
"""Insert or update an agent record (auto-register on first valid report)."""
now = int(time.time())
conn = get_conn()
conn.execute("""
INSERT INTO agents (agent_id, hostname, agent_version, fleet_id,
registered_at, last_seen_at, last_tick, status)
VALUES (?, ?, ?, ?, ?, ?, ?, 'online')
ON CONFLICT(agent_id) DO UPDATE SET
hostname = excluded.hostname,
agent_version = excluded.agent_version,
fleet_id = excluded.fleet_id,
last_seen_at = excluded.last_seen_at,
last_tick = excluded.last_tick,
status = 'online'
""", (agent_id, hostname, agent_version, fleet_id, now, now, tick))
conn.commit()
def get_all_agents() -> list[dict]:
conn = get_conn()
rows = conn.execute("SELECT * FROM agents ORDER BY last_seen_at DESC").fetchall()
return [dict(r) for r in rows]
def get_agent(agent_id: str) -> dict | None:
conn = get_conn()
row = conn.execute("SELECT * FROM agents WHERE agent_id = ?",
(agent_id,)).fetchone()
return dict(row) if row else None
def delete_agent(agent_id: str):
conn = get_conn()
conn.execute("DELETE FROM agents WHERE agent_id = ?", (agent_id,))
conn.commit()
def mark_stale_agents(timeout_seconds: int = 120):
"""Mark agents as offline if not seen within timeout."""
cutoff = int(time.time()) - timeout_seconds
conn = get_conn()
conn.execute("""
UPDATE agents SET status = 'offline'
WHERE last_seen_at < ? AND status != 'offline'
""", (cutoff,))
conn.commit()
# ── Report operations ────────────────────────────────────────────────────────
def insert_report(agent_id: str, tick: int, timestamp: int,
report_type: str, report_json: str):
conn = get_conn()
try:
conn.execute("""
INSERT INTO reports (agent_id, tick, timestamp, report_type, report_json)
VALUES (?, ?, ?, ?, ?)
""", (agent_id, tick, timestamp, report_type, report_json))
conn.commit()
except sqlite3.IntegrityError:
logger.warning("Duplicate report from %s tick %d — skipping", agent_id, tick)
# ── Interface operations ─────────────────────────────────────────────────────
def update_interfaces(agent_id: str, interfaces: list[dict]) -> list[str]:
"""
Update agent_interfaces table. Returns list of change descriptions
(for alarm generation).
"""
now = int(time.time())
conn = get_conn()
changes = []
# Get current known interfaces for this agent
existing = conn.execute(
"SELECT interface_name FROM agent_interfaces WHERE agent_id = ?",
(agent_id,)
).fetchall()
existing_names = {r["interface_name"] for r in existing}
reported_names = set()
for iface in interfaces:
name = iface.get("name", "")
if not name:
continue
reported_names.add(name)
addresses = json.dumps(iface.get("addresses", []))
is_virtual = 1 if iface.get("is_virtual", False) else 0
vpn_type = iface.get("vpn_type")
conn.execute("""
INSERT INTO agent_interfaces
(agent_id, interface_name, mac_address, addresses_json,
is_virtual, vpn_type, last_seen_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(agent_id, interface_name) DO UPDATE SET
mac_address = excluded.mac_address,
addresses_json = excluded.addresses_json,
is_virtual = excluded.is_virtual,
vpn_type = excluded.vpn_type,
last_seen_at = excluded.last_seen_at
""", (agent_id, name, iface.get("mac", ""), addresses,
is_virtual, vpn_type, now))
# Detect disappeared interfaces
disappeared = existing_names - reported_names
for ifname in disappeared:
changes.append(f"Interface '{ifname}' disappeared")
# Detect new interfaces
appeared = reported_names - existing_names
for ifname in appeared:
if existing_names: # only alarm if agent had prior interfaces
changes.append(f"New interface '{ifname}' appeared")
conn.commit()
return changes
def get_agent_interfaces(agent_id: str) -> list[dict]:
conn = get_conn()
rows = conn.execute(
"SELECT * FROM agent_interfaces WHERE agent_id = ?", (agent_id,)
).fetchall()
return [dict(r) for r in rows]
# ── Topology edge operations ────────────────────────────────────────────────
def upsert_edge(from_agent_id: str, to_agent_id: str, edge_type: str,
metadata: dict | None = None):
now = int(time.time())
meta_json = json.dumps(metadata or {})
conn = get_conn()
conn.execute("""
INSERT INTO topology_edges
(from_agent_id, to_agent_id, edge_type, metadata, last_seen, is_active)
VALUES (?, ?, ?, ?, ?, 1)
ON CONFLICT(from_agent_id, to_agent_id, edge_type) DO UPDATE SET
metadata = excluded.metadata,
last_seen = excluded.last_seen,
is_active = 1
""", (from_agent_id, to_agent_id, edge_type, meta_json, now))
conn.commit()
def get_all_edges() -> list[dict]:
conn = get_conn()
rows = conn.execute("SELECT * FROM topology_edges WHERE is_active = 1").fetchall()
return [dict(r) for r in rows]
# ── Alarm operations ────────────────────────────────────────────────────────
def create_alarm(agent_id: str, alarm_type: str, details: dict | None = None):
conn = get_conn()
conn.execute("""
INSERT INTO alarms (agent_id, alarm_type, details_json)
VALUES (?, ?, ?)
""", (agent_id, alarm_type, json.dumps(details or {})))
conn.commit()
logger.info("ALARM [%s] %s: %s", agent_id, alarm_type, details)
def get_active_alarms() -> list[dict]:
conn = get_conn()
rows = conn.execute(
"SELECT * FROM alarms WHERE status = 'active' ORDER BY created_at DESC"
).fetchall()
return [dict(r) for r in rows]
def dismiss_alarm(alarm_id: int):
now = int(time.time())
conn = get_conn()
conn.execute("""
UPDATE alarms SET status = 'dismissed', dismissed_at = ?
WHERE id = ?
""", (now, alarm_id))
conn.commit()
def reset_agent(agent_id: str):
"""Wipe all state for a specific agent."""
conn = get_conn()
conn.execute("DELETE FROM agents WHERE agent_id = ?", (agent_id,))
conn.commit()
logger.info("Reset all state for agent %s", agent_id)
def reset_all():
"""Wipe the entire fleet state."""
conn = get_conn()
conn.execute("DELETE FROM topology_edges")
conn.execute("DELETE FROM agent_interfaces")
conn.execute("DELETE FROM reports")
conn.execute("DELETE FROM alarms")
conn.execute("DELETE FROM agents")
conn.commit()
logger.info("Full fleet reset executed")

105
manager/processor.py Normal file
View File

@@ -0,0 +1,105 @@
"""
Data processor for Kattila Manager.
Core business logic: processes validated agent reports and updates
the database state (agents, interfaces, topology, alarms).
"""
import json
import logging
import db
logger = logging.getLogger(__name__)
def process_report(report: dict):
"""
Process a validated agent report.
Called after security checks have passed.
"""
agent_id = report["agent_id"]
tick = report["tick"]
timestamp = report["timestamp"]
agent_version = report.get("agent_version", 0)
fleet_id = report.get("fleet_id", "")
report_type = report.get("type", "report")
data = report.get("data", {})
hostname = data.get("hostname", "unknown")
# 1. Upsert agent (auto-register)
db.upsert_agent(agent_id, hostname, agent_version, fleet_id, tick)
logger.info("Processed report from %s (%s) tick=%d", agent_id, hostname, tick)
# 2. Store raw report for auditing
db.insert_report(agent_id, tick, timestamp, report_type,
json.dumps(report))
# 3. Update interfaces and detect changes
interfaces = data.get("interfaces", [])
if interfaces:
changes = db.update_interfaces(agent_id, interfaces)
for change in changes:
db.create_alarm(agent_id, "interface_change",
{"description": change})
# 4. Update topology edges from wg_peers
wg_peers = data.get("wg_peers", [])
_update_topology(agent_id, wg_peers)
def _update_topology(agent_id: str, wg_peers: list[dict]):
"""
Infer topology edges from wireguard peer data.
Cross-references peer public keys against known agent interfaces
to create edges.
"""
# Build a lookup of pubkey -> agent_id from all known agent interfaces
# This is done per-report for simplicity; could be cached for performance
all_agents = db.get_all_agents()
# For each wg peer, try to match the public key to another known agent
for peer in wg_peers:
pubkey = peer.get("public_key", "")
endpoint = peer.get("endpoint", "")
iface = peer.get("interface", "")
# Store the edge with metadata even if we can't resolve the target
# agent yet — the metadata (pubkey, endpoint) is still valuable
metadata = {
"public_key": pubkey,
"endpoint": endpoint,
"interface": iface,
"transfer_rx": peer.get("transfer_rx", 0),
"transfer_tx": peer.get("transfer_tx", 0),
"latest_handshake": peer.get("latest_handshake", 0),
}
# The target is unknown until we can cross-reference pubkeys
# For now, use the pubkey hash as a placeholder target ID
target_id = f"pubkey:{pubkey[:16]}" if pubkey else "unknown"
db.upsert_edge(agent_id, target_id, "wireguard", metadata)
def process_relay(envelope: dict):
"""
Process a relayed report envelope.
Extracts the inner payload and processes it as a normal report
after recording the relay path.
"""
relay_path = envelope.get("relay_path", [])
payload = envelope.get("payload", {})
if not payload:
logger.warning("Empty relay payload received")
return
logger.info("Processing relayed report via path: %s", " -> ".join(relay_path))
# Process the inner report normally
process_report(payload)
# Also record relay edges in the topology
for i in range(len(relay_path) - 1):
db.upsert_edge(relay_path[i], relay_path[i + 1], "relay")

View File

@@ -1 +1,2 @@
Flask>=3.0.0 Flask>=3.0.0
dnspython>=2.4.0

186
manager/security.py Normal file
View File

@@ -0,0 +1,186 @@
"""
Security layer for Kattila Manager.
Handles PSK fetching via DNS TXT, HMAC verification, nonce caching,
and timestamp skew checking.
"""
import hashlib
import hmac
import json
import logging
import os
import threading
import time
from collections import OrderedDict
logger = logging.getLogger(__name__)
# PSK history: current + 2 previous
_psk_lock = threading.Lock()
_psk_history: list[str] = [] # most recent first, max 3 entries
# Nonce sliding window
_nonce_lock = threading.Lock()
_nonce_cache: OrderedDict[str, float] = OrderedDict()
NONCE_CACHE_SIZE = 120
# Maximum clock skew in seconds
MAX_CLOCK_SKEW = 600 # 10 minutes
def _resolve_txt(dns_name: str) -> str | None:
"""Resolve DNS TXT record. Uses dnspython if available, falls back to subprocess."""
try:
import dns.resolver
answers = dns.resolver.resolve(dns_name, "TXT")
for rdata in answers:
txt = rdata.to_text().strip('"')
return txt
except ImportError:
# Fallback: use dig
import subprocess
try:
result = subprocess.run(
["dig", "+short", "TXT", dns_name],
capture_output=True, text=True, timeout=5
)
val = result.stdout.strip().strip('"')
if val:
return val
except (subprocess.TimeoutExpired, FileNotFoundError):
pass
except Exception as e:
logger.error("DNS TXT lookup failed for %s: %s", dns_name, e)
return None
def fetch_psk():
"""Fetch the PSK from the DNS TXT record and update the history."""
dns_name = os.environ.get("DNS", "")
if not dns_name:
logger.warning("security: No DNS configured for PSK lookup")
return
key = _resolve_txt(dns_name)
if not key:
logger.warning("security: Could not resolve PSK from DNS")
return
with _psk_lock:
if not _psk_history or _psk_history[0] != key:
_psk_history.insert(0, key)
# Keep at most 3 entries (current + 2 previous)
while len(_psk_history) > 3:
_psk_history.pop()
logger.info("security: PSK updated (history depth: %d)", len(_psk_history))
def start_key_poller():
"""Fetch PSK immediately and then every hour in a background thread."""
fetch_psk()
def _poll():
while True:
time.sleep(3600)
fetch_psk()
t = threading.Thread(target=_poll, daemon=True)
t.start()
def get_current_psk() -> str:
with _psk_lock:
return _psk_history[0] if _psk_history else ""
def get_psk_history() -> list[str]:
with _psk_lock:
return list(_psk_history)
def fleet_id_for_psk(psk: str) -> str:
"""Generate the fleet_id (SHA256 hash) for a given PSK."""
return hashlib.sha256(psk.encode()).hexdigest()
def verify_hmac(data_payload, provided_hmac: str) -> tuple[bool, str | None]:
"""
Verify the HMAC against all known PSKs (current + 2 previous).
Returns (valid, matched_psk) or (False, None).
"""
psks = get_psk_history()
if not psks:
logger.error("security: No PSKs available for verification")
return False, None
data_bytes = json.dumps(data_payload, separators=(",", ":"),
sort_keys=True).encode()
for psk in psks:
expected = hmac.new(psk.encode(), data_bytes, hashlib.sha256).hexdigest()
if hmac.compare_digest(expected, provided_hmac):
if psk != psks[0]:
logger.warning("security: Agent using old PSK (not current)")
return True, psk
return False, None
def check_nonce(nonce: str) -> bool:
"""
Check if a nonce has been seen before (replay protection).
Returns True if the nonce is fresh (not seen), False if it's a replay.
"""
with _nonce_lock:
if nonce in _nonce_cache:
return False
_nonce_cache[nonce] = time.time()
# Evict oldest entries if over the limit
while len(_nonce_cache) > NONCE_CACHE_SIZE:
_nonce_cache.popitem(last=False)
return True
def check_timestamp(timestamp: int) -> bool:
"""Check if the timestamp is within the acceptable clock skew."""
now = int(time.time())
return abs(now - timestamp) <= MAX_CLOCK_SKEW
def validate_report(report: dict) -> tuple[bool, str]:
"""
Full validation pipeline for an incoming report.
Returns (valid, error_message).
"""
# 1. Timestamp check
ts = report.get("timestamp", 0)
if not check_timestamp(ts):
return False, "timestamp_skew"
# 2. Nonce replay check
nonce = report.get("nonce", "")
if not nonce:
return False, "missing_nonce"
if not check_nonce(nonce):
return False, "replay_detected"
# 3. HMAC verification
provided_hmac = report.get("hmac", "")
data = report.get("data")
if not data or not provided_hmac:
return False, "missing_hmac_or_data"
valid, matched_psk = verify_hmac(data, provided_hmac)
if not valid:
return False, "hmac_invalid"
# 4. Fleet ID consistency check
if matched_psk:
expected_fleet = fleet_id_for_psk(matched_psk)
if report.get("fleet_id") != expected_fleet:
return False, "fleet_id_mismatch"
return True, "ok"