1541 lines
62 KiB
Python
1541 lines
62 KiB
Python
import argparse
|
|
import asyncio
|
|
import base64
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import multiprocessing
|
|
import os
|
|
import secrets
|
|
import signal
|
|
import socket
|
|
import sqlite3
|
|
import time
|
|
import uvicorn
|
|
import traceback
|
|
from contextlib import contextmanager
|
|
from typing import Dict, List, Optional, Tuple, Generator, Any
|
|
|
|
from fastapi import Depends, FastAPI, Header, HTTPException, Security, status
|
|
from fastapi.security import APIKeyHeader
|
|
from fastapi.responses import HTMLResponse
|
|
from pydantic import BaseModel, Field
|
|
from twisted.internet import reactor
|
|
from twisted.names import dns, server, common
|
|
from twisted.internet.defer import Deferred
|
|
from twisted.python.failure import Failure
|
|
|
|
# --- Constants ---
|
|
SQLITE_TIMEOUT_SECONDS = 10 # Increased timeout for potential concurrent writes
|
|
HMAC_KEY_FILE = "minidiscovery.key"
|
|
ADMIN_TOKEN_ENV_VAR = "MINIDISCOVERY_ADMIN_TOKEN"
|
|
DEFAULT_TOKEN_PERMISSIONS = ["read", "write"]
|
|
ADMIN_PERMISSIONS = ["read", "write", "admin"]
|
|
DNS_DEFAULT_TTL = 60
|
|
DNS_QUERY_SUFFIX = ".laiska.local" # Define a suffix for DNS lookups
|
|
DB_PATH_ENV_VAR = "MINIDISCOVERY_DB_PATH"
|
|
|
|
# --- Database Schema ---
|
|
SQL_CREATE_SERVICES = """
|
|
CREATE TABLE IF NOT EXISTS services (
|
|
id TEXT PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
address TEXT NOT NULL,
|
|
port INTEGER NOT NULL,
|
|
health TEXT DEFAULT 'passing',
|
|
tags_json TEXT DEFAULT '[]', -- Store tags as JSON text
|
|
metadata_json TEXT DEFAULT '{}', -- Store metadata as JSON text
|
|
last_update REAL NOT NULL -- Timestamp of last update/registration
|
|
);
|
|
"""
|
|
SQL_CREATE_TOKENS = """
|
|
CREATE TABLE IF NOT EXISTS tokens (
|
|
token_hash TEXT PRIMARY KEY, -- Store the HMAC hash of the token
|
|
name TEXT NOT NULL UNIQUE, -- Token names should be unique
|
|
created_at REAL NOT NULL,
|
|
permissions_json TEXT NOT NULL -- Store permissions as JSON text
|
|
);
|
|
"""
|
|
SQL_CREATE_META = """
|
|
CREATE TABLE IF NOT EXISTS meta (
|
|
key TEXT PRIMARY KEY,
|
|
value TEXT NOT NULL
|
|
);
|
|
""" # For storing things like schema version or first-run status
|
|
|
|
|
|
# --- Pydantic Models ---
|
|
class ServiceInstance(BaseModel):
|
|
id: str = Field(..., description="Unique identifier for this service instance")
|
|
name: str = Field(
|
|
..., description="Logical name of the service (e.g., 'web', 'db')"
|
|
)
|
|
address: str = Field(
|
|
..., description="IP address or hostname where the service listens"
|
|
)
|
|
port: int = Field(..., gt=0, lt=65536, description="Port number for the service")
|
|
tags: List[str] = Field(
|
|
default_factory=list, description="Optional list of tags for filtering"
|
|
)
|
|
metadata: Dict[str, str] = Field(
|
|
default_factory=dict, description="Optional key-value metadata"
|
|
)
|
|
health: str = Field(
|
|
default="passing", description="Health status ('passing', 'failing', 'unknown')"
|
|
)
|
|
last_updated: float = Field(default_factory=time.time)
|
|
|
|
|
|
class TokenInfo(BaseModel):
|
|
"""Information about a token (excluding the hash)"""
|
|
|
|
name: str
|
|
created_at: float
|
|
permissions: List[str]
|
|
|
|
|
|
class ApiTokenCreateRequest(BaseModel):
|
|
name: str = Field(..., description="A descriptive name for the token")
|
|
permissions: List[str] = Field(
|
|
default=DEFAULT_TOKEN_PERMISSIONS, description="Permissions for the token"
|
|
)
|
|
|
|
|
|
class ApiTokenCreateResponse(BaseModel):
|
|
token: str = Field(..., description="The generated API token (show only once!)")
|
|
name: str
|
|
permissions: List[str]
|
|
|
|
|
|
# --- HMAC Key Management ---
|
|
def load_or_generate_hmac_key(key_file: str = HMAC_KEY_FILE) -> bytes:
|
|
"""Loads HMAC key from file or generates a new one."""
|
|
if os.path.exists(key_file):
|
|
try:
|
|
with open(key_file, "rb") as f:
|
|
key = f.read()
|
|
if len(key) < 32: # Basic sanity check
|
|
raise ValueError("HMAC key file seems corrupted or too short.")
|
|
print(f"Loaded HMAC key from {key_file}")
|
|
return key
|
|
except Exception as e:
|
|
print(
|
|
f"Error loading HMAC key from {key_file}: {e}. Check file permissions and content."
|
|
)
|
|
raise SystemExit(1)
|
|
else:
|
|
print(f"HMAC key file ({key_file}) not found. Generating a new one.")
|
|
key = secrets.token_bytes(32) # 256 bits
|
|
try:
|
|
# Attempt to write with restricted permissions
|
|
fd = os.open(key_file, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o600)
|
|
with os.fdopen(fd, "wb") as f:
|
|
f.write(key)
|
|
print(f"Generated and saved new HMAC key to {key_file}. PROTECT THIS FILE!")
|
|
return key
|
|
except FileExistsError:
|
|
# Race condition: another process created it between check and open
|
|
return load_or_generate_hmac_key(key_file)
|
|
except OSError as e:
|
|
print(f"Error writing HMAC key file {key_file}: {e}")
|
|
print("Please ensure the directory is writable by the process.")
|
|
raise SystemExit(1)
|
|
except Exception as e:
|
|
print(f"Unexpected error generating HMAC key: {e}")
|
|
raise SystemExit(1)
|
|
|
|
|
|
def generate_token_hash(token: str, hmac_key: bytes) -> str:
|
|
"""Generates HMAC-SHA256 hash of the token."""
|
|
return hmac.new(hmac_key, token.encode("utf-8"), hashlib.sha256).hexdigest()
|
|
|
|
|
|
# --- Service Registry (SQLite Backend) ---
|
|
class ServiceRegistry:
|
|
def __init__(self, db_path: str, hmac_key: bytes):
|
|
self.db_path = db_path
|
|
self._hmac_key = hmac_key
|
|
self._init_db()
|
|
|
|
@contextmanager
|
|
def _get_db_conn(self) -> Generator[sqlite3.Connection, None, None]:
|
|
"""Provides a context-managed database connection."""
|
|
conn = None
|
|
try:
|
|
conn = sqlite3.connect(self.db_path, timeout=SQLITE_TIMEOUT_SECONDS)
|
|
conn.row_factory = sqlite3.Row # Access columns by name
|
|
conn.execute("PRAGMA foreign_keys = ON;") # Enforce foreign keys if needed
|
|
conn.execute(
|
|
"PRAGMA journal_mode = WAL;"
|
|
) # Write-Ahead Logging for better concurrency
|
|
yield conn
|
|
conn.commit()
|
|
except sqlite3.Error as e:
|
|
print(f"SQLite error: {e} - {self.db_path}")
|
|
if conn:
|
|
conn.rollback() # Rollback on error
|
|
raise # Re-raise the exception
|
|
finally:
|
|
if conn:
|
|
conn.close()
|
|
|
|
def _init_db(self) -> None:
|
|
"""Initializes the database schema if it doesn't exist."""
|
|
try:
|
|
with self._get_db_conn() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(SQL_CREATE_SERVICES)
|
|
cursor.execute(SQL_CREATE_TOKENS)
|
|
cursor.execute(SQL_CREATE_META)
|
|
|
|
# Check if initial admin token needs to be created
|
|
cursor.execute("SELECT COUNT(*) FROM tokens")
|
|
token_count = cursor.fetchone()[0]
|
|
|
|
cursor.execute("SELECT value FROM meta WHERE key = 'initialized'")
|
|
initialized = cursor.fetchone()
|
|
|
|
if token_count == 0 and not initialized:
|
|
print("No tokens found in the database.")
|
|
admin_token_plain = os.environ.get(ADMIN_TOKEN_ENV_VAR)
|
|
if not admin_token_plain:
|
|
print(
|
|
f"ERROR: Database is empty and initial admin token not provided."
|
|
)
|
|
print(
|
|
f"Please set the '{ADMIN_TOKEN_ENV_VAR}' environment variable with a secure token."
|
|
)
|
|
raise SystemExit(1) # Critical configuration missing
|
|
|
|
print(
|
|
f"Creating initial admin token from '{ADMIN_TOKEN_ENV_VAR}'..."
|
|
)
|
|
token_hash = generate_token_hash(admin_token_plain, self._hmac_key)
|
|
cursor.execute(
|
|
"INSERT INTO tokens (token_hash, name, created_at, permissions_json) VALUES (?, ?, ?, ?)",
|
|
(
|
|
token_hash,
|
|
"admin",
|
|
time.time(),
|
|
json.dumps(ADMIN_PERMISSIONS),
|
|
),
|
|
)
|
|
# Mark DB as initialized to prevent asking for env var again
|
|
cursor.execute(
|
|
"INSERT OR REPLACE INTO meta (key, value) VALUES ('initialized', 'true')"
|
|
)
|
|
print(
|
|
"Initial admin token created successfully. You can unset the environment variable now."
|
|
)
|
|
elif token_count > 0 and not initialized:
|
|
# Tokens exist, but not marked initialized (e.g., older version) - mark it now
|
|
cursor.execute(
|
|
"INSERT OR REPLACE INTO meta (key, value) VALUES ('initialized', 'true')"
|
|
)
|
|
|
|
# Future schema migrations could go here based on a version in 'meta' table
|
|
except sqlite3.OperationalError as e:
|
|
if "database is locked" in str(e):
|
|
print(
|
|
f"Warning: Database {self.db_path} is locked during initialization. Retrying shortly..."
|
|
)
|
|
time.sleep(
|
|
0.5
|
|
) # Short delay before potential retry by caller/process start
|
|
self._init_db() # Recursive call - be cautious, maybe add max retries
|
|
else:
|
|
print(f"Fatal error initializing database {self.db_path}: {e}")
|
|
raise SystemExit(1)
|
|
except Exception as e:
|
|
print(f"Fatal error during database initialization: {e}")
|
|
raise SystemExit(1)
|
|
|
|
def register(self, instance: ServiceInstance) -> bool:
|
|
"""Registers or updates a service instance in the database."""
|
|
with self._get_db_conn() as conn:
|
|
cursor = conn.cursor()
|
|
try:
|
|
cursor.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO services
|
|
(id, name, address, port, health, tags_json, metadata_json, last_update)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
instance.id,
|
|
instance.name,
|
|
instance.address,
|
|
instance.port,
|
|
instance.health,
|
|
json.dumps(instance.tags),
|
|
json.dumps(instance.metadata),
|
|
time.time(),
|
|
),
|
|
)
|
|
return True
|
|
except sqlite3.IntegrityError as e:
|
|
print(f"Error registering service {instance.id}: {e}")
|
|
return False # Should not happen with INSERT OR REPLACE unless DB issue
|
|
|
|
def deregister(self, service_id: str) -> bool:
|
|
"""Removes a service instance from the database."""
|
|
with self._get_db_conn() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("DELETE FROM services WHERE id = ?", (service_id,))
|
|
return cursor.rowcount > 0 # Returns true if a row was deleted
|
|
|
|
def get_service(
|
|
self, name: str, only_passing: bool = False
|
|
) -> List[ServiceInstance]:
|
|
"""Retrieves all instances for a given service name."""
|
|
instances = []
|
|
sql = "SELECT * FROM services WHERE name = ?"
|
|
params: List[Any] = [name]
|
|
if only_passing:
|
|
sql += " AND health = 'passing'"
|
|
|
|
with self._get_db_conn() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(sql, tuple(params))
|
|
rows = cursor.fetchall()
|
|
for row in rows:
|
|
try:
|
|
instances.append(
|
|
ServiceInstance(
|
|
id=row["id"],
|
|
name=row["name"],
|
|
address=row["address"],
|
|
port=row["port"],
|
|
health=row["health"],
|
|
tags=json.loads(row["tags_json"]),
|
|
metadata=json.loads(row["metadata_json"]),
|
|
)
|
|
)
|
|
except (json.JSONDecodeError, TypeError, KeyError) as e:
|
|
print(
|
|
f"Warning: Could not parse service data for ID {row.get('id', 'N/A')}: {e}"
|
|
)
|
|
# Optionally skip or mark as unhealthy
|
|
return instances
|
|
|
|
def get_all_services(self) -> Dict[str, List[ServiceInstance]]:
|
|
"""Retrieves all registered services, grouped by name."""
|
|
services: Dict[str, List[ServiceInstance]] = {}
|
|
with self._get_db_conn() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT * FROM services ORDER BY name")
|
|
rows = cursor.fetchall()
|
|
for row in rows:
|
|
try:
|
|
instance = ServiceInstance(
|
|
id=row["id"],
|
|
name=row["name"],
|
|
address=row["address"],
|
|
port=row["port"],
|
|
health=row["health"],
|
|
tags=json.loads(row["tags_json"]),
|
|
metadata=json.loads(row["metadata_json"]),
|
|
)
|
|
if instance.name not in services:
|
|
services[instance.name] = []
|
|
services[instance.name].append(instance)
|
|
except (json.JSONDecodeError, TypeError, KeyError) as e:
|
|
print(
|
|
f"Warning: Could not parse service data for ID {row.get('id', 'N/A')}: {e}"
|
|
)
|
|
return services
|
|
|
|
def update_health(self, service_id: str, health: str) -> bool:
|
|
"""Updates the health status of a specific service instance."""
|
|
with self._get_db_conn() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"UPDATE services SET health = ?, last_update = ? WHERE id = ?",
|
|
(health, time.time(), service_id),
|
|
)
|
|
return cursor.rowcount > 0
|
|
|
|
def create_token(
|
|
self, name: str, permissions: List[str]
|
|
) -> Tuple[Optional[str], Optional[str]]:
|
|
"""Creates a new API token, storing its hash. Returns (token, None) on success, (None, error_message) on failure."""
|
|
token_plain = secrets.token_urlsafe(32)
|
|
token_hash = generate_token_hash(token_plain, self._hmac_key)
|
|
|
|
with self._get_db_conn() as conn:
|
|
cursor = conn.cursor()
|
|
try:
|
|
cursor.execute(
|
|
"INSERT INTO tokens (token_hash, name, created_at, permissions_json) VALUES (?, ?, ?, ?)",
|
|
(token_hash, name, time.time(), json.dumps(permissions)),
|
|
)
|
|
return token_plain, None # Return the plain token only on success
|
|
except sqlite3.IntegrityError:
|
|
# This likely means the name is not unique
|
|
return None, f"Token name '{name}' already exists."
|
|
except Exception as e:
|
|
print(f"Error creating token '{name}': {e}")
|
|
return None, "Database error during token creation."
|
|
|
|
def revoke_token(self, token_plain: str) -> bool:
|
|
"""Revokes an API token by deleting its hash from the database."""
|
|
token_hash = generate_token_hash(token_plain, self._hmac_key)
|
|
with self._get_db_conn() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("DELETE FROM tokens WHERE token_hash = ?", (token_hash,))
|
|
return cursor.rowcount > 0
|
|
|
|
def validate_token(self, token_plain: str) -> Optional[TokenInfo]:
|
|
"""
|
|
Validates a plain text token against stored hashes and returns its info if valid.
|
|
Returns None if the token is invalid.
|
|
"""
|
|
token_hash = generate_token_hash(token_plain, self._hmac_key)
|
|
with self._get_db_conn() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT name, created_at, permissions_json FROM tokens WHERE token_hash = ?",
|
|
(token_hash,),
|
|
)
|
|
row = cursor.fetchone()
|
|
if row:
|
|
try:
|
|
permissions = json.loads(row["permissions_json"])
|
|
if not isinstance(permissions, list):
|
|
raise TypeError("Permissions format error")
|
|
return TokenInfo(
|
|
name=row["name"],
|
|
created_at=row["created_at"],
|
|
permissions=permissions,
|
|
)
|
|
except (json.JSONDecodeError, TypeError) as e:
|
|
print(
|
|
f"Error decoding permissions for token hash {token_hash}: {e}"
|
|
)
|
|
return None # Treat malformed permissions as invalid
|
|
else:
|
|
return None # Token hash not found
|
|
|
|
|
|
# --- API Key Security ---
|
|
api_key_header = APIKeyHeader(name="X-API-Token", auto_error=False)
|
|
|
|
|
|
def get_permission_checker(required_permission: str):
|
|
"""
|
|
Dependency factory to create a dependency that checks for a specific permission.
|
|
"""
|
|
|
|
async def _check_permission(
|
|
registry: ServiceRegistry = Depends(
|
|
lambda: get_service_registry()
|
|
), # Get registry instance
|
|
api_key: Optional[str] = Security(api_key_header),
|
|
) -> str: # Return API key on success
|
|
if api_key is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="API token missing in X-API-Token header",
|
|
)
|
|
|
|
token_info = registry.validate_token(api_key)
|
|
|
|
if token_info is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API token"
|
|
)
|
|
|
|
if required_permission not in token_info.permissions:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=f"Insufficient permissions. Required: '{required_permission}'",
|
|
)
|
|
|
|
return api_key # Return the validated key (or token info if needed later)
|
|
|
|
return _check_permission
|
|
|
|
|
|
# --- FastAPI Application Factory ---
|
|
|
|
# Global variable to hold the registry instance *within this process*
|
|
# This is okay because FastAPI runs in a single process (or multiple workers,
|
|
# each getting its own instance, which is fine as they all talk to the DB).
|
|
_registry_instance: Optional[ServiceRegistry] = None
|
|
|
|
|
|
def setup_registry(db_path: str, hmac_key: bytes):
|
|
"""Initializes the registry instance for the current process."""
|
|
global _registry_instance
|
|
if _registry_instance is None:
|
|
print(f"[{os.getpid()}] Initializing ServiceRegistry for FastAPI...")
|
|
_registry_instance = ServiceRegistry(db_path, hmac_key)
|
|
else:
|
|
print(f"[{os.getpid()}] ServiceRegistry already initialized for FastAPI.")
|
|
|
|
|
|
def get_service_registry() -> ServiceRegistry:
|
|
"""Dependency to get the ServiceRegistry instance."""
|
|
if _registry_instance is None:
|
|
# This should not happen if setup_registry is called correctly before app starts
|
|
raise RuntimeError("ServiceRegistry not initialized for this process.")
|
|
return _registry_instance
|
|
|
|
|
|
def generate_status_page_html(registry: ServiceRegistry) -> str:
|
|
"""Generates the HTML for the status page."""
|
|
|
|
all_services_map = registry.get_all_services()
|
|
nodes = (
|
|
{}
|
|
) # Dictionary to store node info: {address: {"status": "Active|Inactive", "label": "Node N"}}
|
|
|
|
# 1. Identify unique "nodes" (addresses) and determine their status
|
|
unique_addresses = sorted(
|
|
list(
|
|
set(
|
|
inst.address
|
|
for instances in all_services_map.values()
|
|
for inst in instances
|
|
)
|
|
)
|
|
)
|
|
|
|
node_label_counter = 1
|
|
for addr in unique_addresses:
|
|
is_active = False
|
|
for service_name, instances in all_services_map.items():
|
|
for instance in instances:
|
|
if instance.address == addr and instance.health == "passing":
|
|
is_active = True
|
|
break # Found one passing service at this address, node is active
|
|
if is_active:
|
|
break
|
|
nodes[addr] = {
|
|
"status": "Active" if is_active else "Inactive",
|
|
"label": f"Node {node_label_counter}",
|
|
}
|
|
node_label_counter += 1
|
|
|
|
grid_size = len(nodes)
|
|
|
|
# 2. Build the HTML string dynamically
|
|
# Use inline style for grid-size as it's dynamic
|
|
html_content = f"""
|
|
<!DOCTYPE html>
|
|
<html lang="en">
|
|
<head>
|
|
<meta charset="UTF-8">
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
<meta http-equiv="refresh" content="15"> <!-- Auto-refresh every 15 seconds -->
|
|
<title>MiniDiscovery Status</title>
|
|
<style>
|
|
/* Dark Theme Color Scheme (same as provided) */
|
|
:root {{
|
|
--black: #1C2526; --white: #C7C7C7; --green: #4A5E4D;
|
|
--yellow: #7A6F4F; --red: #6B3F3F; --blue: #3F506B;
|
|
--gray: #5C6366; --orange: #7A5B3F;
|
|
--gap-size: 5px;
|
|
}}
|
|
body {{ background-color: var(--black); color: var(--white); font-family: Arial, sans-serif; margin: 0; padding: 20px; display: flex; flex-direction: column; align-items: center; min-height: 100vh; box-sizing: border-box; }}
|
|
.header h1 {{ color: var(--green); text-align: center; margin-bottom: 3vh; }}
|
|
.grid-container {{
|
|
display: grid;
|
|
grid-template-columns: auto repeat({grid_size if grid_size > 0 else 1}, 1fr); /* Handle empty case */
|
|
grid-template-rows: auto repeat({grid_size if grid_size > 0 else 1}, 1fr);
|
|
gap: var(--gap-size);
|
|
margin: 20px 0;
|
|
width: 100%; max-width: 70vw; max-height: 70vh;
|
|
aspect-ratio: 1 / 1;
|
|
border: 2px solid var(--gray); padding: var(--gap-size);
|
|
background-color: var(--black); border-radius: 8px; box-sizing: border-box;
|
|
}}
|
|
.grid-corner, .grid-header, .grid-label, .grid-cell {{
|
|
display: flex; align-items: center; justify-content: center; text-align: center;
|
|
border-radius: 5px; padding: 5px; box-sizing: border-box; overflow: hidden;
|
|
word-break: break-word; font-size: clamp(10px, 2vmin, 16px);
|
|
border: 1px solid var(--black);
|
|
}}
|
|
.grid-corner {{ background-color: transparent; border: none; }}
|
|
.grid-header, .grid-label {{
|
|
background-color: var(--black); color: var(--green); font-weight: bold;
|
|
border: 1px solid var(--gray); min-width: 60px; min-height: 60px;
|
|
}}
|
|
.grid-header {{ min-height: 40px; }} .grid-label {{ min-width: 40px; }}
|
|
.grid-cell {{ background-color: var(--gray); color: var(--white); transition: background-color 0.3s ease; }}
|
|
/* Status Styles */
|
|
.grid-cell.status-active {{ background-color: var(--green); color: var(--white); font-weight: bold; }}
|
|
.grid-cell.status-inactive {{ background-color: var(--red); color: var(--white); font-weight: bold; }}
|
|
/* Add a style for the 'unknown' off-diagonal state */
|
|
.grid-cell.status-unknown {{ background-color: var(--gray); color: #a0a0a0; }} /* Dimmer text */
|
|
.grid-cell:hover {{ opacity: 0.8; cursor: default; }}
|
|
.status-message {{ margin-top: 20px; color: var(--yellow); }}
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<div class="header"><h1>Node Status</h1></div>
|
|
"""
|
|
|
|
if grid_size == 0:
|
|
html_content += "<div class='status-message'>No active nodes detected.</div>"
|
|
else:
|
|
html_content += f"<div class='grid-container' style='--grid-size: {grid_size};'>" # Set CSS variable size
|
|
# Top-left corner
|
|
html_content += "<div class='grid-corner'></div>"
|
|
# Column headers
|
|
for addr in unique_addresses:
|
|
html_content += f"<div class='grid-header'>{nodes[addr]['label']}</div>"
|
|
|
|
# Grid rows
|
|
for row_addr in unique_addresses:
|
|
# Row label
|
|
html_content += f"<div class='grid-label'>{nodes[row_addr]['label']}</div>"
|
|
# Cells in the row
|
|
for col_addr in unique_addresses:
|
|
if row_addr == col_addr:
|
|
# Diagonal: Show node's own status
|
|
status = nodes[row_addr]["status"]
|
|
status_class = f"status-{status.lower()}"
|
|
html_content += (
|
|
f"<div class='grid-cell {status_class}'>{status}</div>"
|
|
)
|
|
else:
|
|
# Off-diagonal: Show 'Unknown' as we don't track inter-node connectivity
|
|
html_content += (
|
|
"<div class='grid-cell status-unknown'>Unknown</div>"
|
|
)
|
|
# End of row implicitly handled by grid layout
|
|
html_content += "</div>" # End grid-container
|
|
|
|
# Add footer/closing tags
|
|
html_content += """
|
|
<div class="status-message">Status based on registered service health. Page refreshes automatically.</div>
|
|
</body>
|
|
</html>
|
|
"""
|
|
return html_content
|
|
|
|
|
|
def create_fastapi_app() -> FastAPI:
|
|
"""Creates the FastAPI application."""
|
|
app = FastAPI(
|
|
title="MiniDiscovery",
|
|
description="Minimal Service Discovery inspired by Consul",
|
|
version="0.2.0",
|
|
)
|
|
|
|
# --- Status Page Endpoint (Public) ---
|
|
@app.get("/status", response_class=HTMLResponse)
|
|
async def get_status_page(
|
|
registry: ServiceRegistry = Depends(get_service_registry),
|
|
):
|
|
"""Serves the anonymous HTML status page."""
|
|
html_content = generate_status_page_html(registry)
|
|
return HTMLResponse(content=html_content, status_code=200)
|
|
|
|
# --- API Endpoints (Authenticated) ---
|
|
@app.post("/v1/agent/service/register", status_code=status.HTTP_200_OK)
|
|
async def register_service(
|
|
instance: ServiceInstance,
|
|
registry: ServiceRegistry = Depends(get_service_registry),
|
|
token: str = Depends(get_permission_checker("write")), # Require 'write'
|
|
):
|
|
"""Registers a service instance. Supports replace-if-exists based on ID."""
|
|
# Pydantic performs validation based on the model
|
|
if registry.register(instance):
|
|
# Return 200 OK on success
|
|
return {"status": "registered", "service_id": instance.id}
|
|
else:
|
|
# This path might be less likely with INSERT OR REPLACE
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to register service due to a database error.",
|
|
)
|
|
|
|
@app.put(
|
|
"/v1/agent/service/deregister/{service_id}", status_code=status.HTTP_200_OK
|
|
)
|
|
async def deregister_service(
|
|
service_id: str,
|
|
registry: ServiceRegistry = Depends(get_service_registry),
|
|
token: str = Depends(get_permission_checker("write")), # Require 'write'
|
|
):
|
|
"""Deregisters a service instance by its ID."""
|
|
if registry.deregister(service_id):
|
|
# Returns 200 OK on success
|
|
return {"status": "deregistered", "service_id": service_id}
|
|
else:
|
|
# If deregister returns false, it means the service wasn't found
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Service ID '{service_id}' not found.",
|
|
)
|
|
|
|
@app.get("/v1/catalog/services", response_model=Dict[str, List[str]])
|
|
async def get_catalog_services(
|
|
registry: ServiceRegistry = Depends(get_service_registry),
|
|
token: str = Depends(get_permission_checker("read")), # Require 'read'
|
|
):
|
|
"""Lists all known service names and their tags."""
|
|
services_data = registry.get_all_services()
|
|
# Format matches Consul /v1/catalog/services
|
|
catalog = {
|
|
name: list(set(tag for inst in instances for tag in inst.tags))
|
|
for name, instances in services_data.items()
|
|
}
|
|
return catalog
|
|
|
|
@app.get("/v1/catalog/service/{service_name}", response_model=List[ServiceInstance])
|
|
async def get_catalog_service(
|
|
service_name: str,
|
|
tag: Optional[str] = None, # Allow filtering by tag
|
|
registry: ServiceRegistry = Depends(get_service_registry),
|
|
token: str = Depends(get_permission_checker("read")), # Require 'read'
|
|
):
|
|
"""Lists instances for a specific service, optionally filtered by tag."""
|
|
instances = registry.get_service(service_name)
|
|
if not instances:
|
|
# Return empty list if service name doesn't exist, matches Consul
|
|
return []
|
|
|
|
if tag:
|
|
instances = [inst for inst in instances if tag in inst.tags]
|
|
|
|
# Pydantic validation handles the response model structure
|
|
return instances
|
|
|
|
@app.get("/v1/health/service/{service_name}", response_model=List[ServiceInstance])
|
|
async def get_health_service(
|
|
service_name: str,
|
|
tag: Optional[str] = None,
|
|
passing: bool = False, # Filter for only passing instances
|
|
registry: ServiceRegistry = Depends(get_service_registry),
|
|
token: str = Depends(get_permission_checker("read")), # Require 'read'
|
|
):
|
|
"""Lists healthy instances for a specific service."""
|
|
instances = registry.get_service(service_name, only_passing=passing)
|
|
if not instances:
|
|
return []
|
|
|
|
if tag:
|
|
instances = [inst for inst in instances if tag in inst.tags]
|
|
|
|
return instances
|
|
|
|
# --- Token Management Endpoints ---
|
|
@app.post(
|
|
"/v1/acl/token",
|
|
response_model=ApiTokenCreateResponse,
|
|
status_code=status.HTTP_201_CREATED,
|
|
)
|
|
async def create_acl_token(
|
|
token_request: ApiTokenCreateRequest,
|
|
registry: ServiceRegistry = Depends(get_service_registry),
|
|
token: str = Depends(get_permission_checker("admin")), # Require 'admin'
|
|
):
|
|
"""Creates a new API token."""
|
|
plain_token, error = registry.create_token(
|
|
token_request.name, token_request.permissions
|
|
)
|
|
if error:
|
|
# Distinguish between user error (duplicate name) and server error
|
|
if "already exists" in error:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST, detail=error
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=error
|
|
)
|
|
|
|
return ApiTokenCreateResponse(
|
|
token=plain_token, # Only show the token once upon creation!
|
|
name=token_request.name,
|
|
permissions=token_request.permissions,
|
|
)
|
|
|
|
@app.delete("/v1/acl/token/{token_to_revoke}", status_code=status.HTTP_200_OK)
|
|
async def revoke_acl_token(
|
|
token_to_revoke: str, # The actual token value to revoke
|
|
registry: ServiceRegistry = Depends(get_service_registry),
|
|
token: str = Depends(get_permission_checker("admin")), # Require 'admin'
|
|
):
|
|
"""Revokes an existing API token."""
|
|
# Prevent revoking the token currently being used for the request? Maybe not necessary.
|
|
# current_token_info = registry.validate_token(token)
|
|
# if registry.generate_token_hash(token_to_revoke, registry._hmac_key) == registry.generate_token_hash(token, registry._hmac_key):
|
|
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot revoke the token used for this request.")
|
|
|
|
if registry.revoke_token(token_to_revoke):
|
|
return {
|
|
"status": "revoked",
|
|
"token_info": "Token revoked successfully",
|
|
} # Don't echo back the token
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Token not found or already revoked.",
|
|
)
|
|
|
|
# --- DNS-over-HTTPS Endpoint (Simplified, follows RFC 8484 GET format) ---
|
|
@app.get("/dns-query", status_code=status.HTTP_200_OK)
|
|
async def dns_query(
|
|
name: str,
|
|
type: str = "A", # Query param 'type'
|
|
registry: ServiceRegistry = Depends(get_service_registry),
|
|
# No auth required for basic DNS queries for simplicity, add if needed
|
|
):
|
|
"""Handles DNS queries over HTTPS (DoH)."""
|
|
# Basic parsing, assuming name ends with configured suffix
|
|
if not name.lower().endswith(DNS_QUERY_SUFFIX):
|
|
return {
|
|
"Status": 2
|
|
} # NXDOMAIN according to RFC8484 JSON format for DNS Messages
|
|
|
|
service_name = name[: -len(DNS_QUERY_SUFFIX)].split(".")[
|
|
-1
|
|
] # Get the last part before the suffix
|
|
|
|
instances = registry.get_service(service_name, only_passing=True)
|
|
if not instances:
|
|
return {"Status": 2} # NXDOMAIN
|
|
|
|
answers = []
|
|
query_type_upper = type.upper()
|
|
|
|
if query_type_upper == "A":
|
|
type_code = 1 # DNS Type A
|
|
for instance in instances:
|
|
answers.append(
|
|
{
|
|
"name": name,
|
|
"type": type_code,
|
|
"TTL": DNS_DEFAULT_TTL,
|
|
"data": instance.address,
|
|
}
|
|
)
|
|
elif query_type_upper == "SRV":
|
|
type_code = 33 # DNS Type SRV
|
|
for instance in instances:
|
|
# Format: Priority Weight Port Target
|
|
srv_data = f"0 5 {instance.port} {instance.address}{DNS_QUERY_SUFFIX}" # Target should be resolvable ideally
|
|
answers.append(
|
|
{
|
|
"name": name,
|
|
"type": type_code,
|
|
"TTL": DNS_DEFAULT_TTL,
|
|
"data": srv_data,
|
|
}
|
|
)
|
|
# Add more record types (AAAA, TXT for metadata?) if needed
|
|
else:
|
|
# Unsupported type for this simple endpoint
|
|
return {"Status": 4} # NotImp - Not Implemented
|
|
|
|
# Construct response according to RFC 8484 JSON format
|
|
response = {
|
|
"Status": 0, # NOERROR
|
|
"TC": False, # Truncated
|
|
"RD": True, # Recursion Desired (client asked for it) - reflects typical client flag
|
|
"RA": False, # Recursion Available (we are authoritative-like)
|
|
"AD": False, # Authenticated Data
|
|
"CD": False, # Checking Disabled
|
|
"Question": [
|
|
{
|
|
"name": name,
|
|
"type": type_code if "type_code" in locals() else 1,
|
|
} # Reflect question
|
|
],
|
|
"Answer": answers,
|
|
# Authority and Additional sections omitted for simplicity
|
|
}
|
|
return response
|
|
|
|
return app
|
|
|
|
|
|
# --- Twisted DNS Server ---
|
|
class MiniDiscoveryResolver(common.ResolverBase):
|
|
|
|
def __init__(self, registry: ServiceRegistry):
|
|
self.registry = registry
|
|
super().__init__()
|
|
|
|
def _lookup(
|
|
self,
|
|
name: bytes,
|
|
cls: int,
|
|
query_type: int,
|
|
timeout: Optional[Tuple[int]] = None,
|
|
) -> Deferred:
|
|
"""
|
|
Main lookup entry point. Decodes the name, checks the suffix,
|
|
and dispatches to specific handlers based on query_type.
|
|
"""
|
|
d = Deferred()
|
|
name_str_debug = repr(name) # For logging errors
|
|
|
|
try:
|
|
# 1. Decode the incoming query name safely
|
|
try:
|
|
name_str = name.decode("utf-8").lower()
|
|
except UnicodeDecodeError:
|
|
# Log? Respond with format error? For now, treat as non-existent.
|
|
print(f"DNS: Cannot decode query name {name_str_debug} as UTF-8.")
|
|
# This should result in NXDOMAIN if we don't callback anything
|
|
# d.callback(([], [], [])) # Alternatively, explicitly return empty
|
|
return d # Let Twisted handle NXDOMAIN? Best to callback empty.
|
|
# Returning empty is safer for our specific resolver.
|
|
d.callback(([], [], []))
|
|
return d
|
|
|
|
# 2. Check for the expected suffix
|
|
if not name_str.endswith(DNS_QUERY_SUFFIX):
|
|
# Not a query for our domain, return empty.
|
|
d.callback(([], [], []))
|
|
return d
|
|
|
|
# 3. Extract the base name (part before the suffix)
|
|
base_name = name_str[: -len(DNS_QUERY_SUFFIX)]
|
|
if not base_name: # Query was just the suffix itself
|
|
d.callback(([], [], []))
|
|
return d
|
|
|
|
# 4. Dispatch based on query type
|
|
handler = None
|
|
if query_type == dns.A:
|
|
handler = self._handle_a_query
|
|
elif query_type == dns.SRV:
|
|
handler = self._handle_srv_query
|
|
elif query_type == dns.TXT:
|
|
handler = self._handle_txt_query
|
|
# Add elif for AAAA if needed in the future
|
|
# elif query_type == dns.AAAA:
|
|
# handler = self._handle_aaaa_query # Implement this if needed
|
|
else:
|
|
# Unsupported query type for our resolver
|
|
d.callback(([], [], []))
|
|
return d
|
|
|
|
# 5. Execute the handler and set callback/errback
|
|
# The handlers currently don't return Deferreds, so we call directly
|
|
answers, authority, additional = handler(name, base_name, cls)
|
|
d.callback((answers, authority, additional))
|
|
|
|
except Exception as e:
|
|
# Catch-all for unexpected errors during dispatch or handler execution
|
|
print(
|
|
f"!!! Unhandled exception during DNS lookup for {name_str_debug} "
|
|
f"(Type: {query_type}) !!!"
|
|
)
|
|
print(f"Exception Type: {type(e).__name__}")
|
|
print(f"Exception Args: {e.args}")
|
|
print("--- Traceback ---")
|
|
traceback.print_exc()
|
|
print("--- End Traceback ---")
|
|
# Signal DNS server failure (SERVFAIL)
|
|
d.errback(Failure(e))
|
|
|
|
return d
|
|
|
|
# --- Helper Methods for Specific Record Types ---
|
|
|
|
def _parse_srv_query(self, base_name: str) -> Tuple[Optional[str], Optional[str]]:
|
|
"""
|
|
Parses SRV-style queries: _tag._proto.service or _tag._proto
|
|
Returns (tag, service_name) or (None, None) if not SRV-style.
|
|
"""
|
|
parts = base_name.split(".")
|
|
if (
|
|
len(parts) >= 2
|
|
and parts[0].startswith("_")
|
|
and parts[1] in ["_tcp", "_udp"]
|
|
):
|
|
tag = parts[0][1:] # Remove leading '_'
|
|
# Service name is the part *after* _tag._proto, if it exists
|
|
service_name = parts[2] if len(parts) > 2 else None
|
|
# We currently ignore the rest of the parts (like datacenter in consul)
|
|
return tag, service_name
|
|
return None, None # Not an SRV-style query name
|
|
|
|
def _get_instances_for_query(
|
|
self, base_name: str, is_srv_query: bool = False
|
|
) -> List[ServiceInstance]:
|
|
"""Fetches relevant, passing service instances based on the query name."""
|
|
instances = []
|
|
tag_filter = None
|
|
service_name_filter = None
|
|
|
|
if is_srv_query:
|
|
tag_filter, service_name_filter = self._parse_srv_query(base_name)
|
|
if tag_filter is None: # Not a valid _tag._proto... query
|
|
return [] # Return empty, SRV query handler expects specific format
|
|
|
|
else: # A or TXT query: service name is the last part
|
|
service_name_filter = base_name.split(".")[-1]
|
|
|
|
if service_name_filter:
|
|
# Query targets a specific service (with potential tag filter for SRV)
|
|
service_instances = self.registry.get_service(
|
|
service_name_filter, only_passing=True
|
|
)
|
|
if tag_filter: # SRV query with tag and service
|
|
instances = [
|
|
inst for inst in service_instances if tag_filter in inst.tags
|
|
]
|
|
else: # A/TXT query, or SRV query without tag (using service name)
|
|
instances = service_instances # Already filtered for passing
|
|
|
|
elif tag_filter: # SRV query for a tag across all services
|
|
all_services = self.registry.get_all_services()
|
|
for name, service_instances in all_services.items():
|
|
for inst in service_instances:
|
|
if inst.health == "passing" and tag_filter in inst.tags:
|
|
instances.append(inst)
|
|
else:
|
|
# This case shouldn't be reached if initial checks are correct
|
|
# (e.g., A/TXT query needs a service name part)
|
|
print(f"Warning: Could not determine filter for query '{base_name}'")
|
|
|
|
print(
|
|
f"DNS Lookup: base_name='{base_name}', is_srv={is_srv_query}, "
|
|
f"tag='{tag_filter}', service='{service_name_filter}'. Found {len(instances)} instances."
|
|
)
|
|
return instances
|
|
|
|
def _handle_a_query(
|
|
self, name: bytes, base_name: str, cls: int
|
|
) -> Tuple[List, List, List]:
|
|
"""Handles A record lookups."""
|
|
answers = []
|
|
instances = self._get_instances_for_query(base_name, is_srv_query=False)
|
|
|
|
for instance in instances:
|
|
try:
|
|
# Twisted's Record_A expects the IP address string
|
|
payload = dns.Record_A(address=instance.address, ttl=DNS_DEFAULT_TTL)
|
|
rr = dns.RRHeader(
|
|
name=name, # Respond with the original query name
|
|
type=dns.A,
|
|
cls=cls,
|
|
ttl=DNS_DEFAULT_TTL,
|
|
payload=payload,
|
|
)
|
|
answers.append(rr)
|
|
except Exception as e:
|
|
print(
|
|
f"Warning: Error creating A record for instance {instance.id} "
|
|
f"(IP: {instance.address}): {e}. Skipping."
|
|
)
|
|
|
|
return answers, [], [] # No authority or additional records for basic A
|
|
|
|
def _handle_srv_query(
|
|
self, name: bytes, base_name: str, cls: int
|
|
) -> Tuple[List, List, List]:
|
|
"""Handles SRV record lookups (service or tag based)."""
|
|
answers = []
|
|
additional = []
|
|
instances = self._get_instances_for_query(base_name, is_srv_query=True)
|
|
|
|
# If _get_instances_for_query returned empty because parsing failed,
|
|
# we might want to try interpreting the name differently, e.g.,
|
|
# as a direct SRV lookup for a service name like `service.domain.suffix`.
|
|
# For now, we strictly follow the _tag._proto logic defined above.
|
|
# If you want `srvlookup service.domain.suffix`, the logic in
|
|
# _get_instances_for_query needs adjustment or another branch here.
|
|
|
|
for instance in instances:
|
|
try:
|
|
# SRV target points to a name that resolves to the instance's A record.
|
|
# Conventionally: <instance_id>.<service_name>.<domain_suffix>.
|
|
# Let's use: <instance_id>.node.<domain_suffix> for simplicity,
|
|
# or maybe <instance_id>.<service_name>...
|
|
# Using just instance ID + suffix is simple and unique.
|
|
# Ensure the target ends with the suffix too!
|
|
target_name_str = f"{instance.id}{DNS_QUERY_SUFFIX}"
|
|
target_name_bytes = target_name_str.encode("utf-8")
|
|
|
|
srv_payload = dns.Record_SRV(
|
|
priority=0, # Lower is more preferred
|
|
weight=10, # Relative weight for same priority
|
|
port=instance.port,
|
|
target=target_name_bytes, # Must be bytes
|
|
ttl=DNS_DEFAULT_TTL, # TTL for the SRV record itself
|
|
)
|
|
srv_rr = dns.RRHeader(
|
|
name=name, # Respond with the original query name
|
|
type=dns.SRV,
|
|
cls=cls,
|
|
ttl=DNS_DEFAULT_TTL,
|
|
payload=srv_payload,
|
|
)
|
|
answers.append(srv_rr)
|
|
|
|
# Add corresponding A record for the target in the additional section
|
|
a_payload = dns.Record_A(address=instance.address, ttl=DNS_DEFAULT_TTL)
|
|
a_rr = dns.RRHeader(
|
|
name=target_name_bytes, # Name matches SRV target
|
|
type=dns.A,
|
|
cls=cls,
|
|
ttl=DNS_DEFAULT_TTL, # TTL for the additional A record
|
|
payload=a_payload,
|
|
)
|
|
additional.append(a_rr)
|
|
|
|
except Exception as e:
|
|
print(
|
|
f"Warning: Error creating SRV/A record for instance {instance.id} "
|
|
f"(Addr: {instance.address}:{instance.port}): {e}. Skipping."
|
|
)
|
|
|
|
return answers, [], additional
|
|
|
|
def _handle_txt_query(
|
|
self, name: bytes, base_name: str, cls: int
|
|
) -> Tuple[List, List, List]:
|
|
"""Handles TXT record lookups, returning service metadata."""
|
|
answers = []
|
|
instances = self._get_instances_for_query(base_name, is_srv_query=False)
|
|
|
|
for instance in instances:
|
|
# --- Initialize list for this instance ---
|
|
txt_data = []
|
|
instance_id_str = str(instance.id) # Use consistently
|
|
|
|
try:
|
|
print(f"DNS TXT: Processing instance {instance_id_str}") # Log start
|
|
|
|
# --- Process Tags ---
|
|
if isinstance(instance.tags, list):
|
|
for tag in instance.tags:
|
|
try:
|
|
# Ensure tag is string before encoding
|
|
txt_data.append(str(tag).encode("utf-8"))
|
|
except Exception as tag_enc_err:
|
|
print(
|
|
f"ERROR encoding tag '{repr(tag)}' (type: {type(tag)}) for instance {instance_id_str}: {tag_enc_err}"
|
|
)
|
|
else:
|
|
print(
|
|
f"WARNING: Instance {instance_id_str} tags are not a list: {type(instance.tags)}"
|
|
)
|
|
|
|
# --- Process Metadata ---
|
|
if isinstance(instance.metadata, dict):
|
|
for k, v in instance.metadata.items():
|
|
try:
|
|
# Ensure key/value are strings before formatting/encoding
|
|
key_str = str(k)
|
|
val_str = str(v)
|
|
txt_data.append(f"{key_str}={val_str}".encode("utf-8"))
|
|
except Exception as meta_enc_err:
|
|
print(
|
|
f"ERROR encoding metadata item '{repr(k)}':'{repr(v)}' (types: {type(k)}/{type(v)}) for instance {instance_id_str}: {meta_enc_err}"
|
|
)
|
|
else:
|
|
print(
|
|
f"WARNING: Instance {instance_id_str} metadata is not a dict: {type(instance.metadata)}"
|
|
)
|
|
|
|
# --- Process Instance ID ---
|
|
try:
|
|
txt_data.append(f"instance_id={instance_id_str}".encode("utf-8"))
|
|
except Exception as id_enc_err:
|
|
print(
|
|
f"ERROR encoding instance ID for {instance_id_str}: {id_enc_err}"
|
|
)
|
|
|
|
# --- **** THE CRITICAL DEBUGGING STEP **** ---
|
|
print(
|
|
f"DNS TXT DEBUG: Data for instance {instance_id_str} BEFORE Record_TXT:"
|
|
)
|
|
valid_types = True
|
|
if not isinstance(txt_data, list):
|
|
print(f" FATAL: txt_data is NOT a list! Type: {type(txt_data)}")
|
|
valid_types = False
|
|
else:
|
|
for i, item in enumerate(txt_data):
|
|
item_type = type(item)
|
|
print(f" Item {i}: Type={item_type}, Value={repr(item)}")
|
|
if item_type is not bytes:
|
|
print(f" ^^^^^ ERROR: Item {i} is NOT bytes!")
|
|
valid_types = False
|
|
# --- **** END DEBUGGING STEP **** ---
|
|
|
|
if not txt_data:
|
|
print(
|
|
f"DNS TXT: No valid TXT data generated for instance {instance_id_str}, skipping."
|
|
)
|
|
continue
|
|
|
|
# Only proceed if all items were bytes
|
|
if not valid_types:
|
|
print(
|
|
f"DNS TXT ERROR: txt_data for {instance_id_str} contained non-bytes elements. Skipping record creation."
|
|
)
|
|
continue # Skip this instance if data is bad
|
|
|
|
# --- Create Payload and RR Header ---
|
|
# This is where the error occurs if txt_data contains non-bytes
|
|
print(
|
|
f"DNS TXT: Attempting to create Record_TXT for instance {instance_id_str}..."
|
|
)
|
|
payload = dns.Record_TXT(txt_data, ttl=DNS_DEFAULT_TTL)
|
|
print(
|
|
f"DNS TXT: Record_TXT created successfully for {instance_id_str}."
|
|
)
|
|
|
|
rr = dns.RRHeader(
|
|
name=name,
|
|
type=dns.TXT,
|
|
cls=cls,
|
|
ttl=DNS_DEFAULT_TTL,
|
|
payload=payload,
|
|
)
|
|
answers.append(rr)
|
|
print(
|
|
f"DNS TXT: RRHeader created and added for instance {instance_id_str}."
|
|
)
|
|
|
|
# Catch errors specifically during the DNS object creation phase
|
|
except TypeError as te_dns:
|
|
print(
|
|
f"FATAL DNS TypeError creating TXT record for {instance_id_str}: {te_dns}"
|
|
)
|
|
print(
|
|
" This likely means the list passed to Record_TXT contained non-bytes elements."
|
|
)
|
|
traceback.print_exc() # Crucial to see where in Twisted it fails
|
|
except Exception as e_dns:
|
|
print(
|
|
f"ERROR creating TXT DNS objects for instance {instance_id_str}: {e_dns.__class__.__name__}: {e_dns}"
|
|
)
|
|
traceback.print_exc()
|
|
|
|
# Log the final result before returning
|
|
print(
|
|
f"DNS TXT: Finished processing query for '{base_name}'. Found {len(instances)} instances, generated {len(answers)} TXT records."
|
|
)
|
|
return answers, [], []
|
|
|
|
|
|
# --- Health Checker ---
|
|
def check_service_health(instance: ServiceInstance) -> str:
|
|
"""Performs a simple TCP connection check."""
|
|
try:
|
|
with socket.create_connection(
|
|
(instance.address, instance.port), timeout=2
|
|
) as sock:
|
|
# Could add optional check for specific data/response here later
|
|
return "passing"
|
|
except (socket.timeout, ConnectionRefusedError, OSError):
|
|
return "failing"
|
|
except Exception as e:
|
|
print(
|
|
f"Unexpected error checking health for {instance.id} ({instance.address}:{instance.port}): {e}"
|
|
)
|
|
return "failing" # Treat unexpected errors as failure
|
|
|
|
|
|
# --- Process Runner Functions ---
|
|
def run_fastapi_server(db_path: str, hmac_key: bytes, host: str, port: int):
|
|
"""Sets up and runs the FastAPI server."""
|
|
print(f"[{os.getpid()}] Starting FastAPI process...")
|
|
# Initialize registry for this specific process
|
|
setup_registry(db_path, hmac_key)
|
|
app = create_fastapi_app()
|
|
# Setup signal handlers for graceful shutdown within Uvicorn if possible
|
|
# Uvicorn handles SIGINT/SIGTERM by default
|
|
uvicorn.run(app, host=host, port=port)
|
|
print(
|
|
f"[{os.getpid()}] FastAPI process finished."
|
|
) # Should not be reached normally
|
|
|
|
|
|
def run_dns_server(db_path: str, hmac_key: bytes, port: int):
|
|
"""Sets up and runs the Twisted DNS server."""
|
|
print(f"[{os.getpid()}] Starting DNS server process...")
|
|
registry = ServiceRegistry(
|
|
db_path, hmac_key
|
|
) # DNS process needs its own registry instance
|
|
resolver = MiniDiscoveryResolver(registry)
|
|
factory = server.DNSServerFactory(
|
|
clients=[resolver],
|
|
# verbose=2 # Uncomment for very detailed Twisted DNS logging
|
|
)
|
|
protocol = dns.DNSDatagramProtocol(controller=factory)
|
|
|
|
# Listen on UDP and TCP
|
|
try:
|
|
reactor.listenUDP(port, protocol)
|
|
reactor.listenTCP(port, factory)
|
|
print(f"[{os.getpid()}] DNS Server listening on port {port} (UDP/TCP)")
|
|
except Exception as e:
|
|
print(f"[{os.getpid()}] Error starting DNS listeners on port {port}: {e}")
|
|
return # Exit process if cannot bind
|
|
|
|
# Graceful shutdown for Twisted reactor
|
|
def shutdown_dns():
|
|
print(f"[{os.getpid()}] Shutting down DNS server...")
|
|
# reactor.stop() might be called by signal handlers below
|
|
# Add cleanup here if needed (e.g., closing listeners explicitly)
|
|
|
|
reactor.addSystemEventTrigger("before", "shutdown", shutdown_dns)
|
|
|
|
# Handle signals to stop the reactor gracefully
|
|
def signal_handler(signum, frame):
|
|
print(f"[{os.getpid()}] Received signal {signum}, stopping DNS reactor.")
|
|
# Important: Call reactor.stop() from the reactor thread if possible
|
|
# reactor.callFromThread(reactor.stop) is safer if signals handled elsewhere
|
|
# For simple cases, calling directly might be okay, but watch for deadlocks.
|
|
if reactor.running:
|
|
reactor.callLater(
|
|
0, reactor.stop
|
|
) # Schedule stop in the next loop iteration
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
# Run the Twisted reactor (blocking call)
|
|
reactor.run()
|
|
print(f"[{os.getpid()}] DNS server process finished.")
|
|
|
|
|
|
def run_health_checker(db_path: str, hmac_key: bytes, check_interval: int):
|
|
"""Runs the health checking loop."""
|
|
print(
|
|
f"[{os.getpid()}] Starting Health Checker process (interval: {check_interval}s)..."
|
|
)
|
|
registry = ServiceRegistry(
|
|
db_path, hmac_key
|
|
) # Health checker needs its own registry instance
|
|
|
|
running = True
|
|
|
|
def signal_handler(signum, frame):
|
|
nonlocal running
|
|
print(f"[{os.getpid()}] Received signal {signum}, stopping health checker...")
|
|
running = False
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
while running:
|
|
start_time = time.monotonic()
|
|
print(f"[{os.getpid()}] Health Check Cycle Start")
|
|
updated_count = 0
|
|
checked_count = 0
|
|
try:
|
|
# Fetch all services directly from DB for the check cycle
|
|
services_map = registry.get_all_services()
|
|
all_instances = [
|
|
inst for sublist in services_map.values() for inst in sublist
|
|
]
|
|
checked_count = len(all_instances)
|
|
|
|
for instance in all_instances:
|
|
if not running:
|
|
break # Exit early if shutdown signal received
|
|
|
|
current_health = instance.health
|
|
# Perform the actual health check
|
|
new_health = check_service_health(instance)
|
|
|
|
if current_health != new_health:
|
|
print(
|
|
f"[{os.getpid()}] Health change for {instance.id} ({instance.name}): {current_health} -> {new_health}"
|
|
)
|
|
# Update health status in the database
|
|
if registry.update_health(instance.id, new_health):
|
|
updated_count += 1
|
|
else:
|
|
# This might happen if the service was deregistered between get and update
|
|
print(
|
|
f"[{os.getpid()}] Warning: Failed to update health for {instance.id} (maybe deregistered?)"
|
|
)
|
|
|
|
except Exception as e:
|
|
# Catch errors during the check cycle itself (e.g., DB connection)
|
|
print(f"[{os.getpid()}] Error during health check cycle: {e}")
|
|
|
|
if not running:
|
|
break # Check again after the loop body
|
|
|
|
# Calculate sleep time to maintain interval
|
|
elapsed = time.monotonic() - start_time
|
|
sleep_time = max(0, check_interval - elapsed)
|
|
print(
|
|
f"[{os.getpid()}] Health Check Cycle End. Checked: {checked_count}, Updated: {updated_count}. Took {elapsed:.2f}s. Sleeping for {sleep_time:.2f}s."
|
|
)
|
|
|
|
# Sleep interruptibly
|
|
try:
|
|
time.sleep(sleep_time)
|
|
except InterruptedError: # Catch if signal interrupted sleep
|
|
pass
|
|
|
|
print(f"[{os.getpid()}] Health checker process finished.")
|
|
|
|
|
|
# --- Main Execution ---
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="MiniDiscovery: A minimal service discovery tool.",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, # Show defaults
|
|
)
|
|
parser.add_argument(
|
|
"--db-path",
|
|
default="minidiscovery_data.db",
|
|
help="Path to the SQLite database file. Overridden by ${DB_PATH_ENV_VAR} if set.",
|
|
)
|
|
parser.add_argument(
|
|
"--api-host", default="0.0.0.0", help="Host address for the API server."
|
|
)
|
|
parser.add_argument(
|
|
"--api-port", type=int, default=8500, help="Port for the API server."
|
|
)
|
|
parser.add_argument(
|
|
"--dns-port",
|
|
type=int,
|
|
default=10053,
|
|
help="Port for the DNS server (UDP/TCP). Use ports > 1024 for non-root.",
|
|
)
|
|
parser.add_argument(
|
|
"--health-check-interval",
|
|
type=int,
|
|
default=15,
|
|
help="Health check interval in seconds.",
|
|
)
|
|
parser.add_argument(
|
|
"--hmac-key-file",
|
|
default=HMAC_KEY_FILE,
|
|
help="Path to the file storing the HMAC secret key.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Read db_file_path from ENV var
|
|
db_path_from_env = os.environ.get(DB_PATH_ENV_VAR)
|
|
if db_path_from_env:
|
|
args.db_path = db_path_from_env # Override args
|
|
print(
|
|
f"Using database path from environment variable ${DB_PATH_ENV_VAR}: {args.db_path}"
|
|
)
|
|
elif not args.db_path: # If env var not set AND command line arg not provided
|
|
args.db_path = "minidiscovery_data.db" # Apply the default now
|
|
print(f"Database path not specified, using default: {args.db_path}")
|
|
else:
|
|
# Using the path provided via --db-path argument
|
|
print(f"Using database path from --db-path argument: {args.db_path}")
|
|
|
|
# Ensure DB directory exists
|
|
db_dir = os.path.dirname(os.path.abspath(args.db_path))
|
|
if db_dir:
|
|
os.makedirs(db_dir, exist_ok=True)
|
|
|
|
# Load or generate HMAC key *before* starting processes
|
|
try:
|
|
hmac_key = load_or_generate_hmac_key(args.hmac_key_file)
|
|
except SystemExit:
|
|
return # Exit if key generation/loading fails
|
|
|
|
# Check for initial admin token env var *before* initializing registry in main process
|
|
# The check inside ServiceRegistry._init_db is the authoritative one,
|
|
# but this provides an earlier warning.
|
|
conn_check = sqlite3.connect(args.db_path)
|
|
cursor_check = conn_check.cursor()
|
|
try:
|
|
cursor_check.execute("SELECT COUNT(*) FROM tokens")
|
|
token_count_check = cursor_check.fetchone()[0]
|
|
if token_count_check == 0 and not os.environ.get(ADMIN_TOKEN_ENV_VAR):
|
|
print(
|
|
f"WARNING: Database appears empty. Ensure '{ADMIN_TOKEN_ENV_VAR}' is set for the first run."
|
|
)
|
|
except sqlite3.DatabaseError:
|
|
# Table might not exist yet, ServiceRegistry init will handle it
|
|
pass
|
|
finally:
|
|
conn_check.close()
|
|
|
|
# --- Process Management ---
|
|
processes: List[multiprocessing.Process] = []
|
|
stop_event = multiprocessing.Event() # Used to signal shutdown
|
|
|
|
def graceful_shutdown(signum, frame):
|
|
print(f"Main process received signal {signum}. Initiating shutdown...")
|
|
stop_event.set() # Signal processes to stop (though they also have signal handlers)
|
|
# Terminate processes forcefully after a grace period if they don't exit
|
|
time.sleep(2) # Give processes a moment to react to their own signal handlers
|
|
for p in processes:
|
|
if p.is_alive():
|
|
print(f"Terminating process {p.pid} ({p.name})...")
|
|
p.terminate() # Send SIGTERM
|
|
|
|
signal.signal(signal.SIGINT, graceful_shutdown)
|
|
signal.signal(signal.SIGTERM, graceful_shutdown)
|
|
|
|
try:
|
|
# Start API Server Process
|
|
api_process = multiprocessing.Process(
|
|
target=run_fastapi_server,
|
|
args=(args.db_path, hmac_key, args.api_host, args.api_port),
|
|
name="FastAPI Process",
|
|
)
|
|
processes.append(api_process)
|
|
api_process.start()
|
|
|
|
# Start DNS Server Process
|
|
dns_process = multiprocessing.Process(
|
|
target=run_dns_server,
|
|
args=(args.db_path, hmac_key, args.dns_port),
|
|
name="DNS Process",
|
|
)
|
|
processes.append(dns_process)
|
|
dns_process.start()
|
|
|
|
# Start Health Checker Process
|
|
health_process = multiprocessing.Process(
|
|
target=run_health_checker,
|
|
args=(args.db_path, hmac_key, args.health_check_interval),
|
|
name="Health Check Process",
|
|
)
|
|
processes.append(health_process)
|
|
health_process.start()
|
|
|
|
print("-" * 30)
|
|
print(f"MiniDiscovery Started (PID: {os.getpid()})")
|
|
print(f" API Server: http://{args.api_host}:{args.api_port}")
|
|
print(f" DNS Server: Port {args.dns_port} (UDP/TCP)")
|
|
print(f" Database: {args.db_path}")
|
|
print(f" HMAC Key: {args.hmac_key_file}")
|
|
print(f" Health Int: {args.health_check_interval}s")
|
|
print("-" * 30)
|
|
print("Press Ctrl+C to stop.")
|
|
|
|
# Wait for processes to complete or shutdown signal
|
|
while not stop_event.is_set():
|
|
# Check if any process exited unexpectedly
|
|
for p in processes:
|
|
if not p.is_alive() and p.exitcode != 0:
|
|
print(
|
|
f"ERROR: Process {p.name} (PID: {p.pid}) exited unexpectedly with code {p.exitcode}."
|
|
)
|
|
stop_event.set() # Trigger shutdown if a child crashes
|
|
break
|
|
time.sleep(1) # Wait efficiently
|
|
|
|
finally:
|
|
print("Waiting for processes to join...")
|
|
for p in processes:
|
|
p.join(timeout=5) # Wait for clean exit
|
|
if p.is_alive():
|
|
print(f"Process {p.pid} ({p.name}) did not exit cleanly, killing.")
|
|
p.kill() # Force kill if still alive after timeout
|
|
|
|
print("MiniDiscovery shut down complete.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|