import argparse
import asyncio
import base64
import hashlib
import hmac
import json
import multiprocessing
import os
import secrets
import signal
import socket
import sqlite3
import time
import uvicorn
import traceback
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Generator, Any
from fastapi import Depends, FastAPI, Header, HTTPException, Security, status
from fastapi.security import APIKeyHeader
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Field
from twisted.internet import reactor
from twisted.names import dns, server, common
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
# --- Constants ---
SQLITE_TIMEOUT_SECONDS = 10 # Increased timeout for potential concurrent writes
HMAC_KEY_FILE = "minidiscovery.key"
ADMIN_TOKEN_ENV_VAR = "MINIDISCOVERY_ADMIN_TOKEN"
DEFAULT_TOKEN_PERMISSIONS = ["read", "write"]
ADMIN_PERMISSIONS = ["read", "write", "admin"]
DNS_DEFAULT_TTL = 60
DNS_QUERY_SUFFIX = ".laiska.local" # Define a suffix for DNS lookups
DB_PATH_ENV_VAR = "MINIDISCOVERY_DB_PATH"
# --- Database Schema ---
SQL_CREATE_SERVICES = """
CREATE TABLE IF NOT EXISTS services (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
address TEXT NOT NULL,
port INTEGER NOT NULL,
health TEXT DEFAULT 'passing',
tags_json TEXT DEFAULT '[]', -- Store tags as JSON text
metadata_json TEXT DEFAULT '{}', -- Store metadata as JSON text
last_update REAL NOT NULL -- Timestamp of last update/registration
);
"""
SQL_CREATE_TOKENS = """
CREATE TABLE IF NOT EXISTS tokens (
token_hash TEXT PRIMARY KEY, -- Store the HMAC hash of the token
name TEXT NOT NULL UNIQUE, -- Token names should be unique
created_at REAL NOT NULL,
permissions_json TEXT NOT NULL -- Store permissions as JSON text
);
"""
SQL_CREATE_META = """
CREATE TABLE IF NOT EXISTS meta (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
""" # For storing things like schema version or first-run status
# --- Pydantic Models ---
class ServiceInstance(BaseModel):
id: str = Field(..., description="Unique identifier for this service instance")
name: str = Field(
..., description="Logical name of the service (e.g., 'web', 'db')"
)
address: str = Field(
..., description="IP address or hostname where the service listens"
)
port: int = Field(..., gt=0, lt=65536, description="Port number for the service")
tags: List[str] = Field(
default_factory=list, description="Optional list of tags for filtering"
)
health: str = Field(
default="passing", description="Health status ('passing', 'failing', 'unknown')"
)
metadata: Dict[str, str] = Field(
default_factory=dict, description="Optional key-value metadata"
)
class TokenInfo(BaseModel):
"""Information about a token (excluding the hash)"""
name: str
created_at: float
permissions: List[str]
class ApiTokenCreateRequest(BaseModel):
name: str = Field(..., description="A descriptive name for the token")
permissions: List[str] = Field(
default=DEFAULT_TOKEN_PERMISSIONS, description="Permissions for the token"
)
class ApiTokenCreateResponse(BaseModel):
token: str = Field(..., description="The generated API token (show only once!)")
name: str
permissions: List[str]
# --- HMAC Key Management ---
def load_or_generate_hmac_key(key_file: str = HMAC_KEY_FILE) -> bytes:
"""Loads HMAC key from file or generates a new one."""
if os.path.exists(key_file):
try:
with open(key_file, "rb") as f:
key = f.read()
if len(key) < 32: # Basic sanity check
raise ValueError("HMAC key file seems corrupted or too short.")
print(f"Loaded HMAC key from {key_file}")
return key
except Exception as e:
print(
f"Error loading HMAC key from {key_file}: {e}. Check file permissions and content."
)
raise SystemExit(1)
else:
print(f"HMAC key file ({key_file}) not found. Generating a new one.")
key = secrets.token_bytes(32) # 256 bits
try:
# Attempt to write with restricted permissions
fd = os.open(key_file, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o600)
with os.fdopen(fd, "wb") as f:
f.write(key)
print(f"Generated and saved new HMAC key to {key_file}. PROTECT THIS FILE!")
return key
except FileExistsError:
# Race condition: another process created it between check and open
return load_or_generate_hmac_key(key_file)
except OSError as e:
print(f"Error writing HMAC key file {key_file}: {e}")
print("Please ensure the directory is writable by the process.")
raise SystemExit(1)
except Exception as e:
print(f"Unexpected error generating HMAC key: {e}")
raise SystemExit(1)
def generate_token_hash(token: str, hmac_key: bytes) -> str:
"""Generates HMAC-SHA256 hash of the token."""
return hmac.new(hmac_key, token.encode("utf-8"), hashlib.sha256).hexdigest()
# --- Service Registry (SQLite Backend) ---
class ServiceRegistry:
def __init__(self, db_path: str, hmac_key: bytes):
self.db_path = db_path
self._hmac_key = hmac_key
self._init_db()
@contextmanager
def _get_db_conn(self) -> Generator[sqlite3.Connection, None, None]:
"""Provides a context-managed database connection."""
conn = None
try:
conn = sqlite3.connect(self.db_path, timeout=SQLITE_TIMEOUT_SECONDS)
conn.row_factory = sqlite3.Row # Access columns by name
conn.execute("PRAGMA foreign_keys = ON;") # Enforce foreign keys if needed
conn.execute(
"PRAGMA journal_mode = WAL;"
) # Write-Ahead Logging for better concurrency
yield conn
conn.commit()
except sqlite3.Error as e:
print(f"SQLite error: {e} - {self.db_path}")
if conn:
conn.rollback() # Rollback on error
raise # Re-raise the exception
finally:
if conn:
conn.close()
def _init_db(self) -> None:
"""Initializes the database schema if it doesn't exist."""
try:
with self._get_db_conn() as conn:
cursor = conn.cursor()
cursor.execute(SQL_CREATE_SERVICES)
cursor.execute(SQL_CREATE_TOKENS)
cursor.execute(SQL_CREATE_META)
# Check if initial admin token needs to be created
cursor.execute("SELECT COUNT(*) FROM tokens")
token_count = cursor.fetchone()[0]
cursor.execute("SELECT value FROM meta WHERE key = 'initialized'")
initialized = cursor.fetchone()
if token_count == 0 and not initialized:
print("No tokens found in the database.")
admin_token_plain = os.environ.get(ADMIN_TOKEN_ENV_VAR)
if not admin_token_plain:
print(
f"ERROR: Database is empty and initial admin token not provided."
)
print(
f"Please set the '{ADMIN_TOKEN_ENV_VAR}' environment variable with a secure token."
)
raise SystemExit(1) # Critical configuration missing
print(
f"Creating initial admin token from '{ADMIN_TOKEN_ENV_VAR}'..."
)
token_hash = generate_token_hash(admin_token_plain, self._hmac_key)
cursor.execute(
"INSERT INTO tokens (token_hash, name, created_at, permissions_json) VALUES (?, ?, ?, ?)",
(
token_hash,
"admin",
time.time(),
json.dumps(ADMIN_PERMISSIONS),
),
)
# Mark DB as initialized to prevent asking for env var again
cursor.execute(
"INSERT OR REPLACE INTO meta (key, value) VALUES ('initialized', 'true')"
)
print(
"Initial admin token created successfully. You can unset the environment variable now."
)
elif token_count > 0 and not initialized:
# Tokens exist, but not marked initialized (e.g., older version) - mark it now
cursor.execute(
"INSERT OR REPLACE INTO meta (key, value) VALUES ('initialized', 'true')"
)
# Future schema migrations could go here based on a version in 'meta' table
except sqlite3.OperationalError as e:
if "database is locked" in str(e):
print(
f"Warning: Database {self.db_path} is locked during initialization. Retrying shortly..."
)
time.sleep(
0.5
) # Short delay before potential retry by caller/process start
self._init_db() # Recursive call - be cautious, maybe add max retries
else:
print(f"Fatal error initializing database {self.db_path}: {e}")
raise SystemExit(1)
except Exception as e:
print(f"Fatal error during database initialization: {e}")
raise SystemExit(1)
def register(self, instance: ServiceInstance) -> bool:
"""Registers or updates a service instance in the database."""
with self._get_db_conn() as conn:
cursor = conn.cursor()
try:
cursor.execute(
"""
INSERT OR REPLACE INTO services
(id, name, address, port, health, tags_json, metadata_json, last_update)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
instance.id,
instance.name,
instance.address,
instance.port,
instance.health,
json.dumps(instance.tags),
json.dumps(instance.metadata),
time.time(),
),
)
return True
except sqlite3.IntegrityError as e:
print(f"Error registering service {instance.id}: {e}")
return False # Should not happen with INSERT OR REPLACE unless DB issue
def deregister(self, service_id: str) -> bool:
"""Removes a service instance from the database."""
with self._get_db_conn() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM services WHERE id = ?", (service_id,))
return cursor.rowcount > 0 # Returns true if a row was deleted
def get_service(
self, name: str, only_passing: bool = False
) -> List[ServiceInstance]:
"""Retrieves all instances for a given service name."""
instances = []
sql = "SELECT * FROM services WHERE name = ?"
params: List[Any] = [name]
if only_passing:
sql += " AND health = 'passing'"
with self._get_db_conn() as conn:
cursor = conn.cursor()
cursor.execute(sql, tuple(params))
rows = cursor.fetchall()
for row in rows:
try:
instances.append(
ServiceInstance(
id=row["id"],
name=row["name"],
address=row["address"],
port=row["port"],
health=row["health"],
tags=json.loads(row["tags_json"]),
metadata=json.loads(row["metadata_json"]),
)
)
except (json.JSONDecodeError, TypeError, KeyError) as e:
print(
f"Warning: Could not parse service data for ID {row.get('id', 'N/A')}: {e}"
)
# Optionally skip or mark as unhealthy
return instances
def get_all_services(self) -> Dict[str, List[ServiceInstance]]:
"""Retrieves all registered services, grouped by name."""
services: Dict[str, List[ServiceInstance]] = {}
with self._get_db_conn() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM services ORDER BY name")
rows = cursor.fetchall()
for row in rows:
try:
instance = ServiceInstance(
id=row["id"],
name=row["name"],
address=row["address"],
port=row["port"],
health=row["health"],
tags=json.loads(row["tags_json"]),
metadata=json.loads(row["metadata_json"]),
)
if instance.name not in services:
services[instance.name] = []
services[instance.name].append(instance)
except (json.JSONDecodeError, TypeError, KeyError) as e:
print(
f"Warning: Could not parse service data for ID {row.get('id', 'N/A')}: {e}"
)
return services
def update_health(self, service_id: str, health: str) -> bool:
"""Updates the health status of a specific service instance."""
with self._get_db_conn() as conn:
cursor = conn.cursor()
cursor.execute(
"UPDATE services SET health = ?, last_update = ? WHERE id = ?",
(health, time.time(), service_id),
)
return cursor.rowcount > 0
def create_token(
self, name: str, permissions: List[str]
) -> Tuple[Optional[str], Optional[str]]:
"""Creates a new API token, storing its hash. Returns (token, None) on success, (None, error_message) on failure."""
token_plain = secrets.token_urlsafe(32)
token_hash = generate_token_hash(token_plain, self._hmac_key)
with self._get_db_conn() as conn:
cursor = conn.cursor()
try:
cursor.execute(
"INSERT INTO tokens (token_hash, name, created_at, permissions_json) VALUES (?, ?, ?, ?)",
(token_hash, name, time.time(), json.dumps(permissions)),
)
return token_plain, None # Return the plain token only on success
except sqlite3.IntegrityError:
# This likely means the name is not unique
return None, f"Token name '{name}' already exists."
except Exception as e:
print(f"Error creating token '{name}': {e}")
return None, "Database error during token creation."
def revoke_token(self, token_plain: str) -> bool:
"""Revokes an API token by deleting its hash from the database."""
token_hash = generate_token_hash(token_plain, self._hmac_key)
with self._get_db_conn() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM tokens WHERE token_hash = ?", (token_hash,))
return cursor.rowcount > 0
def validate_token(self, token_plain: str) -> Optional[TokenInfo]:
"""
Validates a plain text token against stored hashes and returns its info if valid.
Returns None if the token is invalid.
"""
token_hash = generate_token_hash(token_plain, self._hmac_key)
with self._get_db_conn() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT name, created_at, permissions_json FROM tokens WHERE token_hash = ?",
(token_hash,),
)
row = cursor.fetchone()
if row:
try:
permissions = json.loads(row["permissions_json"])
if not isinstance(permissions, list):
raise TypeError("Permissions format error")
return TokenInfo(
name=row["name"],
created_at=row["created_at"],
permissions=permissions,
)
except (json.JSONDecodeError, TypeError) as e:
print(
f"Error decoding permissions for token hash {token_hash}: {e}"
)
return None # Treat malformed permissions as invalid
else:
return None # Token hash not found
# --- API Key Security ---
api_key_header = APIKeyHeader(name="X-API-Token", auto_error=False)
def get_permission_checker(required_permission: str):
"""
Dependency factory to create a dependency that checks for a specific permission.
"""
async def _check_permission(
registry: ServiceRegistry = Depends(
lambda: get_service_registry()
), # Get registry instance
api_key: Optional[str] = Security(api_key_header),
) -> str: # Return API key on success
if api_key is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API token missing in X-API-Token header",
)
token_info = registry.validate_token(api_key)
if token_info is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API token"
)
if required_permission not in token_info.permissions:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Insufficient permissions. Required: '{required_permission}'",
)
return api_key # Return the validated key (or token info if needed later)
return _check_permission
# --- FastAPI Application Factory ---
# Global variable to hold the registry instance *within this process*
# This is okay because FastAPI runs in a single process (or multiple workers,
# each getting its own instance, which is fine as they all talk to the DB).
_registry_instance: Optional[ServiceRegistry] = None
def setup_registry(db_path: str, hmac_key: bytes):
"""Initializes the registry instance for the current process."""
global _registry_instance
if _registry_instance is None:
print(f"[{os.getpid()}] Initializing ServiceRegistry for FastAPI...")
_registry_instance = ServiceRegistry(db_path, hmac_key)
else:
print(f"[{os.getpid()}] ServiceRegistry already initialized for FastAPI.")
def get_service_registry() -> ServiceRegistry:
"""Dependency to get the ServiceRegistry instance."""
if _registry_instance is None:
# This should not happen if setup_registry is called correctly before app starts
raise RuntimeError("ServiceRegistry not initialized for this process.")
return _registry_instance
def generate_status_page_html(registry: ServiceRegistry) -> str:
"""Generates the HTML for the status page."""
all_services_map = registry.get_all_services()
nodes = (
{}
) # Dictionary to store node info: {address: {"status": "Active|Inactive", "label": "Node N"}}
# 1. Identify unique "nodes" (addresses) and determine their status
unique_addresses = sorted(
list(
set(
inst.address
for instances in all_services_map.values()
for inst in instances
)
)
)
node_label_counter = 1
for addr in unique_addresses:
is_active = False
for service_name, instances in all_services_map.items():
for instance in instances:
if instance.address == addr and instance.health == "passing":
is_active = True
break # Found one passing service at this address, node is active
if is_active:
break
nodes[addr] = {
"status": "Active" if is_active else "Inactive",
"label": f"Node {node_label_counter}",
}
node_label_counter += 1
grid_size = len(nodes)
# 2. Build the HTML string dynamically
# Use inline style for grid-size as it's dynamic
html_content = f"""
MiniDiscovery Status
"""
if grid_size == 0:
html_content += "No active nodes detected.
"
else:
html_content += f"" # Set CSS variable size
# Top-left corner
html_content += "
"
# Column headers
for addr in unique_addresses:
html_content += f""
# Grid rows
for row_addr in unique_addresses:
# Row label
html_content += f"
{nodes[row_addr]['label']}
"
# Cells in the row
for col_addr in unique_addresses:
if row_addr == col_addr:
# Diagonal: Show node's own status
status = nodes[row_addr]["status"]
status_class = f"status-{status.lower()}"
html_content += (
f"
{status}
"
)
else:
# Off-diagonal: Show 'Unknown' as we don't track inter-node connectivity
html_content += (
"
Unknown
"
)
# End of row implicitly handled by grid layout
html_content += "
" # End grid-container
# Add footer/closing tags
html_content += """
Status based on registered service health. Page refreshes automatically.
"""
return html_content
def create_fastapi_app() -> FastAPI:
"""Creates the FastAPI application."""
app = FastAPI(
title="MiniDiscovery",
description="Minimal Service Discovery inspired by Consul",
version="0.2.0",
)
# --- Status Page Endpoint (Public) ---
@app.get("/status", response_class=HTMLResponse)
async def get_status_page(
registry: ServiceRegistry = Depends(get_service_registry),
):
"""Serves the anonymous HTML status page."""
html_content = generate_status_page_html(registry)
return HTMLResponse(content=html_content, status_code=200)
# --- API Endpoints (Authenticated) ---
@app.post("/v1/agent/service/register", status_code=status.HTTP_200_OK)
async def register_service(
instance: ServiceInstance,
registry: ServiceRegistry = Depends(get_service_registry),
token: str = Depends(get_permission_checker("write")), # Require 'write'
):
"""Registers a service instance. Supports replace-if-exists based on ID."""
# Pydantic performs validation based on the model
if registry.register(instance):
# Return 200 OK on success
return {"status": "registered", "service_id": instance.id}
else:
# This path might be less likely with INSERT OR REPLACE
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to register service due to a database error.",
)
@app.put(
"/v1/agent/service/deregister/{service_id}", status_code=status.HTTP_200_OK
)
async def deregister_service(
service_id: str,
registry: ServiceRegistry = Depends(get_service_registry),
token: str = Depends(get_permission_checker("write")), # Require 'write'
):
"""Deregisters a service instance by its ID."""
if registry.deregister(service_id):
# Returns 200 OK on success
return {"status": "deregistered", "service_id": service_id}
else:
# If deregister returns false, it means the service wasn't found
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Service ID '{service_id}' not found.",
)
@app.get("/v1/catalog/services", response_model=Dict[str, List[str]])
async def get_catalog_services(
registry: ServiceRegistry = Depends(get_service_registry),
token: str = Depends(get_permission_checker("read")), # Require 'read'
):
"""Lists all known service names and their tags."""
services_data = registry.get_all_services()
# Format matches Consul /v1/catalog/services
catalog = {
name: list(set(tag for inst in instances for tag in inst.tags))
for name, instances in services_data.items()
}
return catalog
@app.get("/v1/catalog/service/{service_name}", response_model=List[ServiceInstance])
async def get_catalog_service(
service_name: str,
tag: Optional[str] = None, # Allow filtering by tag
registry: ServiceRegistry = Depends(get_service_registry),
token: str = Depends(get_permission_checker("read")), # Require 'read'
):
"""Lists instances for a specific service, optionally filtered by tag."""
instances = registry.get_service(service_name)
if not instances:
# Return empty list if service name doesn't exist, matches Consul
return []
if tag:
instances = [inst for inst in instances if tag in inst.tags]
# Pydantic validation handles the response model structure
return instances
@app.get("/v1/health/service/{service_name}", response_model=List[ServiceInstance])
async def get_health_service(
service_name: str,
tag: Optional[str] = None,
passing: bool = False, # Filter for only passing instances
registry: ServiceRegistry = Depends(get_service_registry),
token: str = Depends(get_permission_checker("read")), # Require 'read'
):
"""Lists healthy instances for a specific service."""
instances = registry.get_service(service_name, only_passing=passing)
if not instances:
return []
if tag:
instances = [inst for inst in instances if tag in inst.tags]
return instances
# --- Token Management Endpoints ---
@app.post(
"/v1/acl/token",
response_model=ApiTokenCreateResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_acl_token(
token_request: ApiTokenCreateRequest,
registry: ServiceRegistry = Depends(get_service_registry),
token: str = Depends(get_permission_checker("admin")), # Require 'admin'
):
"""Creates a new API token."""
plain_token, error = registry.create_token(
token_request.name, token_request.permissions
)
if error:
# Distinguish between user error (duplicate name) and server error
if "already exists" in error:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=error
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=error
)
return ApiTokenCreateResponse(
token=plain_token, # Only show the token once upon creation!
name=token_request.name,
permissions=token_request.permissions,
)
@app.delete("/v1/acl/token/{token_to_revoke}", status_code=status.HTTP_200_OK)
async def revoke_acl_token(
token_to_revoke: str, # The actual token value to revoke
registry: ServiceRegistry = Depends(get_service_registry),
token: str = Depends(get_permission_checker("admin")), # Require 'admin'
):
"""Revokes an existing API token."""
# Prevent revoking the token currently being used for the request? Maybe not necessary.
# current_token_info = registry.validate_token(token)
# if registry.generate_token_hash(token_to_revoke, registry._hmac_key) == registry.generate_token_hash(token, registry._hmac_key):
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot revoke the token used for this request.")
if registry.revoke_token(token_to_revoke):
return {
"status": "revoked",
"token_info": "Token revoked successfully",
} # Don't echo back the token
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Token not found or already revoked.",
)
# --- DNS-over-HTTPS Endpoint (Simplified, follows RFC 8484 GET format) ---
@app.get("/dns-query", status_code=status.HTTP_200_OK)
async def dns_query(
name: str,
type: str = "A", # Query param 'type'
registry: ServiceRegistry = Depends(get_service_registry),
# No auth required for basic DNS queries for simplicity, add if needed
):
"""Handles DNS queries over HTTPS (DoH)."""
# Basic parsing, assuming name ends with configured suffix
if not name.lower().endswith(DNS_QUERY_SUFFIX):
return {
"Status": 2
} # NXDOMAIN according to RFC8484 JSON format for DNS Messages
service_name = name[: -len(DNS_QUERY_SUFFIX)].split(".")[
-1
] # Get the last part before the suffix
instances = registry.get_service(service_name, only_passing=True)
if not instances:
return {"Status": 2} # NXDOMAIN
answers = []
query_type_upper = type.upper()
if query_type_upper == "A":
type_code = 1 # DNS Type A
for instance in instances:
answers.append(
{
"name": name,
"type": type_code,
"TTL": DNS_DEFAULT_TTL,
"data": instance.address,
}
)
elif query_type_upper == "SRV":
type_code = 33 # DNS Type SRV
for instance in instances:
# Format: Priority Weight Port Target
srv_data = f"0 5 {instance.port} {instance.address}{DNS_QUERY_SUFFIX}" # Target should be resolvable ideally
answers.append(
{
"name": name,
"type": type_code,
"TTL": DNS_DEFAULT_TTL,
"data": srv_data,
}
)
# Add more record types (AAAA, TXT for metadata?) if needed
else:
# Unsupported type for this simple endpoint
return {"Status": 4} # NotImp - Not Implemented
# Construct response according to RFC 8484 JSON format
response = {
"Status": 0, # NOERROR
"TC": False, # Truncated
"RD": True, # Recursion Desired (client asked for it) - reflects typical client flag
"RA": False, # Recursion Available (we are authoritative-like)
"AD": False, # Authenticated Data
"CD": False, # Checking Disabled
"Question": [
{
"name": name,
"type": type_code if "type_code" in locals() else 1,
} # Reflect question
],
"Answer": answers,
# Authority and Additional sections omitted for simplicity
}
return response
return app
# --- Twisted DNS Server ---
class MiniDiscoveryResolver(common.ResolverBase):
def __init__(self, registry: ServiceRegistry):
self.registry = registry
super().__init__()
def _lookup(
self,
name: bytes,
cls: int,
query_type: int,
timeout: Optional[Tuple[int]] = None,
) -> Deferred:
d = Deferred()
name_str_debug = repr(name)
try:
# 1. Decode the incoming query name safely
try:
name_str = name.decode("utf-8").lower()
except UnicodeDecodeError as decode_err:
print(
f"DNS lookup error: Cannot decode query name {name_str_debug} as UTF-8: {decode_err}"
)
d.errback(Failure(decode_err))
return d
# 2. Check suffix
if not name_str.endswith(DNS_QUERY_SUFFIX):
d.callback(([], [], []))
return d
# --- SRV Service Type Lookup Logic ---
is_srv_type_query = False
service_tag_to_find = None
if name_str.startswith("_") and (
"_tcp." in name_str or "_udp." in name_str
):
parts = name_str.split(".")
# Example: _ssh._tcp.laiska.local
if (
len(parts) >= 4
and parts[0].startswith("_")
and parts[1] in ["_tcp", "_udp"]
):
# Assume format _service._proto.domain.suffix...
service_tag_to_find = parts[0][1:] # Extract 'ssh' from '_ssh'
# Ensure the query type is actually SRV
if query_type == dns.SRV:
is_srv_type_query = True
else:
# Requesting A/other records for SRV-style name doesn't make sense here
d.callback(([], [], []))
return d
# --- Instance Fetching ---
instances_to_process = []
if is_srv_type_query:
print(
f"DNS: Performing SRV type lookup for tag '{service_tag_to_find}'"
)
all_services = registry.get_all_services()
for service_name, service_instances in all_services.items():
for instance in service_instances:
# Check health AND tag presence
if (
instance.health == "passing"
and service_tag_to_find in instance.tags
):
instances_to_process.append(instance)
if not instances_to_process:
print(
f"DNS: No passing instances found with tag '{service_tag_to_find}'"
)
d.callback(([], [], [])) # No matching instances found
return d
else:
service_name = name_str[: -len(DNS_QUERY_SUFFIX)].split(".")[-1]
if not service_name:
d.callback(([], [], []))
return d
instances_to_process = registry.get_service(
service_name, only_passing=True
)
if not instances_to_process:
d.callback(([], [], []))
return d
# --- Build DNS Records ---
answers = []
authority = []
additional = []
# Check query_type against the records we can generate
if (
query_type == dns.A and not is_srv_type_query
): # Only generate A for direct name query
for instance in instances_to_process:
instance_addr_debug = instance.address
try:
ip_address_string = instance.address
answers.append(
dns.RRHeader(
name=name,
type=dns.A,
cls=cls,
ttl=DNS_DEFAULT_TTL,
# Pass the string, Twisted handles conversion
payload=dns.Record_A(
address=ip_address_string, ttl=DNS_DEFAULT_TTL
),
)
)
except (
Exception
) as record_e: # Catch potential errors during record creation itself
print(
f"Warning: Error creating A record for IP '{instance.address}' (Service ID {instance.id}): {record_e}. Skipping."
)
elif (
query_type == dns.SRV
): # Generate SRV for direct name query OR SRV type query
for instance in instances_to_process:
try:
ip_address_string = instance.address
# SRV target should still be bytes
target_name = f"{instance.id}{DNS_QUERY_SUFFIX}".encode("utf-8")
answers.append(
dns.RRHeader(
name=name,
type=dns.SRV,
cls=cls,
ttl=DNS_DEFAULT_TTL,
payload=dns.Record_SRV(
priority=0,
weight=10,
port=instance.port,
target=target_name,
ttl=DNS_DEFAULT_TTL,
),
)
)
additional.append(
dns.RRHeader(
name=target_name,
type=dns.A,
cls=cls,
ttl=DNS_DEFAULT_TTL,
payload=dns.Record_A(
address=ip_address_string, ttl=DNS_DEFAULT_TTL
),
)
)
except (
Exception
) as record_e: # Catch potential errors during record creation itself
print(
f"Warning: Error creating SRV/additional record for IP '{instance.address}' (Service ID {instance.id}): {record_e}. Skipping."
)
# If we successfully built records (or correctly skipped bad IPs)
d.callback((answers, authority, additional))
# --- Catch ANY OTHER unexpected error during the lookup process ---
except Exception as e:
# Log the errors
print(
f"!!! Unhandled exception during DNS lookup for query name {name_str_debug} !!!"
)
print(f"Exception Type: {type(e).__name__}")
print(f"Exception Args: {e.args}")
print("--- Traceback ---")
traceback.print_exc()
print("--- End Traceback ---")
d.errback(Failure(e)) # Signal SERVFAIL
return d
# --- Health Checker ---
def check_service_health(instance: ServiceInstance) -> str:
"""Performs a simple TCP connection check."""
try:
with socket.create_connection(
(instance.address, instance.port), timeout=2
) as sock:
# Could add optional check for specific data/response here later
return "passing"
except (socket.timeout, ConnectionRefusedError, OSError):
return "failing"
except Exception as e:
print(
f"Unexpected error checking health for {instance.id} ({instance.address}:{instance.port}): {e}"
)
return "failing" # Treat unexpected errors as failure
# --- Process Runner Functions ---
def run_fastapi_server(db_path: str, hmac_key: bytes, host: str, port: int):
"""Sets up and runs the FastAPI server."""
print(f"[{os.getpid()}] Starting FastAPI process...")
# Initialize registry for this specific process
setup_registry(db_path, hmac_key)
app = create_fastapi_app()
# Setup signal handlers for graceful shutdown within Uvicorn if possible
# Uvicorn handles SIGINT/SIGTERM by default
uvicorn.run(app, host=host, port=port)
print(
f"[{os.getpid()}] FastAPI process finished."
) # Should not be reached normally
def run_dns_server(db_path: str, hmac_key: bytes, port: int):
"""Sets up and runs the Twisted DNS server."""
print(f"[{os.getpid()}] Starting DNS server process...")
registry = ServiceRegistry(
db_path, hmac_key
) # DNS process needs its own registry instance
resolver = MiniDiscoveryResolver(registry)
factory = server.DNSServerFactory(clients=[resolver])
protocol = dns.DNSDatagramProtocol(controller=factory)
# Listen on UDP and TCP
try:
reactor.listenUDP(port, protocol)
reactor.listenTCP(port, factory)
print(f"[{os.getpid()}] DNS Server listening on port {port} (UDP/TCP)")
except Exception as e:
print(f"[{os.getpid()}] Error starting DNS listeners on port {port}: {e}")
return # Exit process if cannot bind
# Graceful shutdown for Twisted reactor
def shutdown_dns():
print(f"[{os.getpid()}] Shutting down DNS server...")
# reactor.stop() might be called by signal handlers below
# Add cleanup here if needed (e.g., closing listeners explicitly)
reactor.addSystemEventTrigger("before", "shutdown", shutdown_dns)
# Handle signals to stop the reactor gracefully
def signal_handler(signum, frame):
print(f"[{os.getpid()}] Received signal {signum}, stopping DNS reactor.")
# Important: Call reactor.stop() from the reactor thread if possible
# reactor.callFromThread(reactor.stop) is safer if signals handled elsewhere
# For simple cases, calling directly might be okay, but watch for deadlocks.
if reactor.running:
reactor.callLater(
0, reactor.stop
) # Schedule stop in the next loop iteration
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Run the Twisted reactor (blocking call)
reactor.run()
print(f"[{os.getpid()}] DNS server process finished.")
def run_health_checker(db_path: str, hmac_key: bytes, check_interval: int):
"""Runs the health checking loop."""
print(
f"[{os.getpid()}] Starting Health Checker process (interval: {check_interval}s)..."
)
registry = ServiceRegistry(
db_path, hmac_key
) # Health checker needs its own registry instance
running = True
def signal_handler(signum, frame):
nonlocal running
print(f"[{os.getpid()}] Received signal {signum}, stopping health checker...")
running = False
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
while running:
start_time = time.monotonic()
print(f"[{os.getpid()}] Health Check Cycle Start")
updated_count = 0
checked_count = 0
try:
# Fetch all services directly from DB for the check cycle
services_map = registry.get_all_services()
all_instances = [
inst for sublist in services_map.values() for inst in sublist
]
checked_count = len(all_instances)
for instance in all_instances:
if not running:
break # Exit early if shutdown signal received
current_health = instance.health
# Perform the actual health check
new_health = check_service_health(instance)
if current_health != new_health:
print(
f"[{os.getpid()}] Health change for {instance.id} ({instance.name}): {current_health} -> {new_health}"
)
# Update health status in the database
if registry.update_health(instance.id, new_health):
updated_count += 1
else:
# This might happen if the service was deregistered between get and update
print(
f"[{os.getpid()}] Warning: Failed to update health for {instance.id} (maybe deregistered?)"
)
except Exception as e:
# Catch errors during the check cycle itself (e.g., DB connection)
print(f"[{os.getpid()}] Error during health check cycle: {e}")
if not running:
break # Check again after the loop body
# Calculate sleep time to maintain interval
elapsed = time.monotonic() - start_time
sleep_time = max(0, check_interval - elapsed)
print(
f"[{os.getpid()}] Health Check Cycle End. Checked: {checked_count}, Updated: {updated_count}. Took {elapsed:.2f}s. Sleeping for {sleep_time:.2f}s."
)
# Sleep interruptibly
try:
time.sleep(sleep_time)
except InterruptedError: # Catch if signal interrupted sleep
pass
print(f"[{os.getpid()}] Health checker process finished.")
# --- Main Execution ---
def main():
parser = argparse.ArgumentParser(
description="MiniDiscovery: A minimal service discovery tool.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, # Show defaults
)
parser.add_argument(
"--db-path",
default="minidiscovery_data.db",
help="Path to the SQLite database file. Overridden by ${DB_PATH_ENV_VAR} if set.",
)
parser.add_argument(
"--api-host", default="0.0.0.0", help="Host address for the API server."
)
parser.add_argument(
"--api-port", type=int, default=8500, help="Port for the API server."
)
parser.add_argument(
"--dns-port",
type=int,
default=10053,
help="Port for the DNS server (UDP/TCP). Use ports > 1024 for non-root.",
)
parser.add_argument(
"--health-check-interval",
type=int,
default=15,
help="Health check interval in seconds.",
)
parser.add_argument(
"--hmac-key-file",
default=HMAC_KEY_FILE,
help="Path to the file storing the HMAC secret key.",
)
args = parser.parse_args()
# Read db_file_path from ENV var
db_path_from_env = os.environ.get(DB_PATH_ENV_VAR)
if db_path_from_env:
args.db_path = db_path_from_env # Override args
print(
f"Using database path from environment variable ${DB_PATH_ENV_VAR}: {args.db_path}"
)
elif not args.db_path: # If env var not set AND command line arg not provided
args.db_path = "minidiscovery_data.db" # Apply the default now
print(f"Database path not specified, using default: {args.db_path}")
else:
# Using the path provided via --db-path argument
print(f"Using database path from --db-path argument: {args.db_path}")
# Ensure DB directory exists
db_dir = os.path.dirname(os.path.abspath(args.db_path))
if db_dir:
os.makedirs(db_dir, exist_ok=True)
# Load or generate HMAC key *before* starting processes
try:
hmac_key = load_or_generate_hmac_key(args.hmac_key_file)
except SystemExit:
return # Exit if key generation/loading fails
# Check for initial admin token env var *before* initializing registry in main process
# The check inside ServiceRegistry._init_db is the authoritative one,
# but this provides an earlier warning.
conn_check = sqlite3.connect(args.db_path)
cursor_check = conn_check.cursor()
try:
cursor_check.execute("SELECT COUNT(*) FROM tokens")
token_count_check = cursor_check.fetchone()[0]
if token_count_check == 0 and not os.environ.get(ADMIN_TOKEN_ENV_VAR):
print(
f"WARNING: Database appears empty. Ensure '{ADMIN_TOKEN_ENV_VAR}' is set for the first run."
)
except sqlite3.DatabaseError:
# Table might not exist yet, ServiceRegistry init will handle it
pass
finally:
conn_check.close()
# --- Process Management ---
processes: List[multiprocessing.Process] = []
stop_event = multiprocessing.Event() # Used to signal shutdown
def graceful_shutdown(signum, frame):
print(f"Main process received signal {signum}. Initiating shutdown...")
stop_event.set() # Signal processes to stop (though they also have signal handlers)
# Terminate processes forcefully after a grace period if they don't exit
time.sleep(2) # Give processes a moment to react to their own signal handlers
for p in processes:
if p.is_alive():
print(f"Terminating process {p.pid} ({p.name})...")
p.terminate() # Send SIGTERM
signal.signal(signal.SIGINT, graceful_shutdown)
signal.signal(signal.SIGTERM, graceful_shutdown)
try:
# Start API Server Process
api_process = multiprocessing.Process(
target=run_fastapi_server,
args=(args.db_path, hmac_key, args.api_host, args.api_port),
name="FastAPI Process",
)
processes.append(api_process)
api_process.start()
# Start DNS Server Process
dns_process = multiprocessing.Process(
target=run_dns_server,
args=(args.db_path, hmac_key, args.dns_port),
name="DNS Process",
)
processes.append(dns_process)
dns_process.start()
# Start Health Checker Process
health_process = multiprocessing.Process(
target=run_health_checker,
args=(args.db_path, hmac_key, args.health_check_interval),
name="Health Check Process",
)
processes.append(health_process)
health_process.start()
print("-" * 30)
print(f"MiniDiscovery Started (PID: {os.getpid()})")
print(f" API Server: http://{args.api_host}:{args.api_port}")
print(f" DNS Server: Port {args.dns_port} (UDP/TCP)")
print(f" Database: {args.db_path}")
print(f" HMAC Key: {args.hmac_key_file}")
print(f" Health Int: {args.health_check_interval}s")
print("-" * 30)
print("Press Ctrl+C to stop.")
# Wait for processes to complete or shutdown signal
while not stop_event.is_set():
# Check if any process exited unexpectedly
for p in processes:
if not p.is_alive() and p.exitcode != 0:
print(
f"ERROR: Process {p.name} (PID: {p.pid}) exited unexpectedly with code {p.exitcode}."
)
stop_event.set() # Trigger shutdown if a child crashes
break
time.sleep(1) # Wait efficiently
finally:
print("Waiting for processes to join...")
for p in processes:
p.join(timeout=5) # Wait for clean exit
if p.is_alive():
print(f"Process {p.pid} ({p.name}) did not exit cleanly, killing.")
p.kill() # Force kill if still alive after timeout
print("MiniDiscovery shut down complete.")
if __name__ == "__main__":
main()