import argparse import asyncio import base64 import hashlib import hmac import json import multiprocessing import os import secrets import signal import socket import sqlite3 import time import uvicorn import traceback from contextlib import contextmanager from typing import Dict, List, Optional, Tuple, Generator, Any from fastapi import Depends, FastAPI, Header, HTTPException, Security, status from fastapi.security import APIKeyHeader from fastapi.responses import HTMLResponse from pydantic import BaseModel, Field from twisted.internet import reactor from twisted.names import dns, server, common from twisted.internet.defer import Deferred from twisted.python.failure import Failure # --- Constants --- SQLITE_TIMEOUT_SECONDS = 10 # Increased timeout for potential concurrent writes HMAC_KEY_FILE = "minidiscovery.key" ADMIN_TOKEN_ENV_VAR = "MINIDISCOVERY_ADMIN_TOKEN" DEFAULT_TOKEN_PERMISSIONS = ["read", "write"] ADMIN_PERMISSIONS = ["read", "write", "admin"] DNS_DEFAULT_TTL = 60 DNS_QUERY_SUFFIX = ".laiska.local" # Define a suffix for DNS lookups DB_PATH_ENV_VAR = "MINIDISCOVERY_DB_PATH" # --- Database Schema --- SQL_CREATE_SERVICES = """ CREATE TABLE IF NOT EXISTS services ( id TEXT PRIMARY KEY, name TEXT NOT NULL, address TEXT NOT NULL, port INTEGER NOT NULL, health TEXT DEFAULT 'passing', tags_json TEXT DEFAULT '[]', -- Store tags as JSON text metadata_json TEXT DEFAULT '{}', -- Store metadata as JSON text last_update REAL NOT NULL -- Timestamp of last update/registration ); """ SQL_CREATE_TOKENS = """ CREATE TABLE IF NOT EXISTS tokens ( token_hash TEXT PRIMARY KEY, -- Store the HMAC hash of the token name TEXT NOT NULL UNIQUE, -- Token names should be unique created_at REAL NOT NULL, permissions_json TEXT NOT NULL -- Store permissions as JSON text ); """ SQL_CREATE_META = """ CREATE TABLE IF NOT EXISTS meta ( key TEXT PRIMARY KEY, value TEXT NOT NULL ); """ # For storing things like schema version or first-run status # --- Pydantic Models --- class ServiceInstance(BaseModel): id: str = Field(..., description="Unique identifier for this service instance") name: str = Field( ..., description="Logical name of the service (e.g., 'web', 'db')" ) address: str = Field( ..., description="IP address or hostname where the service listens" ) port: int = Field(..., gt=0, lt=65536, description="Port number for the service") tags: List[str] = Field( default_factory=list, description="Optional list of tags for filtering" ) metadata: Dict[str, str] = Field( default_factory=dict, description="Optional key-value metadata" ) health: str = Field( default="passing", description="Health status ('passing', 'failing', 'unknown')" ) last_updated: float = Field(default_factory=time.time) class TokenInfo(BaseModel): """Information about a token (excluding the hash)""" name: str created_at: float permissions: List[str] class ApiTokenCreateRequest(BaseModel): name: str = Field(..., description="A descriptive name for the token") permissions: List[str] = Field( default=DEFAULT_TOKEN_PERMISSIONS, description="Permissions for the token" ) class ApiTokenCreateResponse(BaseModel): token: str = Field(..., description="The generated API token (show only once!)") name: str permissions: List[str] # --- HMAC Key Management --- def load_or_generate_hmac_key(key_file: str = HMAC_KEY_FILE) -> bytes: """Loads HMAC key from file or generates a new one.""" if os.path.exists(key_file): try: with open(key_file, "rb") as f: key = f.read() if len(key) < 32: # Basic sanity check raise ValueError("HMAC key file seems corrupted or too short.") print(f"Loaded HMAC key from {key_file}") return key except Exception as e: print( f"Error loading HMAC key from {key_file}: {e}. Check file permissions and content." ) raise SystemExit(1) else: print(f"HMAC key file ({key_file}) not found. Generating a new one.") key = secrets.token_bytes(32) # 256 bits try: # Attempt to write with restricted permissions fd = os.open(key_file, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o600) with os.fdopen(fd, "wb") as f: f.write(key) print(f"Generated and saved new HMAC key to {key_file}. PROTECT THIS FILE!") return key except FileExistsError: # Race condition: another process created it between check and open return load_or_generate_hmac_key(key_file) except OSError as e: print(f"Error writing HMAC key file {key_file}: {e}") print("Please ensure the directory is writable by the process.") raise SystemExit(1) except Exception as e: print(f"Unexpected error generating HMAC key: {e}") raise SystemExit(1) def generate_token_hash(token: str, hmac_key: bytes) -> str: """Generates HMAC-SHA256 hash of the token.""" return hmac.new(hmac_key, token.encode("utf-8"), hashlib.sha256).hexdigest() # --- Service Registry (SQLite Backend) --- class ServiceRegistry: def __init__(self, db_path: str, hmac_key: bytes): self.db_path = db_path self._hmac_key = hmac_key self._init_db() @contextmanager def _get_db_conn(self) -> Generator[sqlite3.Connection, None, None]: """Provides a context-managed database connection.""" conn = None try: conn = sqlite3.connect(self.db_path, timeout=SQLITE_TIMEOUT_SECONDS) conn.row_factory = sqlite3.Row # Access columns by name conn.execute("PRAGMA foreign_keys = ON;") # Enforce foreign keys if needed conn.execute( "PRAGMA journal_mode = WAL;" ) # Write-Ahead Logging for better concurrency yield conn conn.commit() except sqlite3.Error as e: print(f"SQLite error: {e} - {self.db_path}") if conn: conn.rollback() # Rollback on error raise # Re-raise the exception finally: if conn: conn.close() def _init_db(self) -> None: """Initializes the database schema if it doesn't exist.""" try: with self._get_db_conn() as conn: cursor = conn.cursor() cursor.execute(SQL_CREATE_SERVICES) cursor.execute(SQL_CREATE_TOKENS) cursor.execute(SQL_CREATE_META) # Check if initial admin token needs to be created cursor.execute("SELECT COUNT(*) FROM tokens") token_count = cursor.fetchone()[0] cursor.execute("SELECT value FROM meta WHERE key = 'initialized'") initialized = cursor.fetchone() if token_count == 0 and not initialized: print("No tokens found in the database.") admin_token_plain = os.environ.get(ADMIN_TOKEN_ENV_VAR) if not admin_token_plain: print( f"ERROR: Database is empty and initial admin token not provided." ) print( f"Please set the '{ADMIN_TOKEN_ENV_VAR}' environment variable with a secure token." ) raise SystemExit(1) # Critical configuration missing print( f"Creating initial admin token from '{ADMIN_TOKEN_ENV_VAR}'..." ) token_hash = generate_token_hash(admin_token_plain, self._hmac_key) cursor.execute( "INSERT INTO tokens (token_hash, name, created_at, permissions_json) VALUES (?, ?, ?, ?)", ( token_hash, "admin", time.time(), json.dumps(ADMIN_PERMISSIONS), ), ) # Mark DB as initialized to prevent asking for env var again cursor.execute( "INSERT OR REPLACE INTO meta (key, value) VALUES ('initialized', 'true')" ) print( "Initial admin token created successfully. You can unset the environment variable now." ) elif token_count > 0 and not initialized: # Tokens exist, but not marked initialized (e.g., older version) - mark it now cursor.execute( "INSERT OR REPLACE INTO meta (key, value) VALUES ('initialized', 'true')" ) # Future schema migrations could go here based on a version in 'meta' table except sqlite3.OperationalError as e: if "database is locked" in str(e): print( f"Warning: Database {self.db_path} is locked during initialization. Retrying shortly..." ) time.sleep( 0.5 ) # Short delay before potential retry by caller/process start self._init_db() # Recursive call - be cautious, maybe add max retries else: print(f"Fatal error initializing database {self.db_path}: {e}") raise SystemExit(1) except Exception as e: print(f"Fatal error during database initialization: {e}") raise SystemExit(1) def register(self, instance: ServiceInstance) -> bool: """Registers or updates a service instance in the database.""" with self._get_db_conn() as conn: cursor = conn.cursor() try: cursor.execute( """ INSERT OR REPLACE INTO services (id, name, address, port, health, tags_json, metadata_json, last_update) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, ( instance.id, instance.name, instance.address, instance.port, instance.health, json.dumps(instance.tags), json.dumps(instance.metadata), time.time(), ), ) return True except sqlite3.IntegrityError as e: print(f"Error registering service {instance.id}: {e}") return False # Should not happen with INSERT OR REPLACE unless DB issue def deregister(self, service_id: str) -> bool: """Removes a service instance from the database.""" with self._get_db_conn() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM services WHERE id = ?", (service_id,)) return cursor.rowcount > 0 # Returns true if a row was deleted def get_service( self, name: str, only_passing: bool = False ) -> List[ServiceInstance]: """Retrieves all instances for a given service name.""" instances = [] sql = "SELECT * FROM services WHERE name = ?" params: List[Any] = [name] if only_passing: sql += " AND health = 'passing'" with self._get_db_conn() as conn: cursor = conn.cursor() cursor.execute(sql, tuple(params)) rows = cursor.fetchall() for row in rows: try: instances.append( ServiceInstance( id=row["id"], name=row["name"], address=row["address"], port=row["port"], health=row["health"], tags=json.loads(row["tags_json"]), metadata=json.loads(row["metadata_json"]), ) ) except (json.JSONDecodeError, TypeError, KeyError) as e: print( f"Warning: Could not parse service data for ID {row.get('id', 'N/A')}: {e}" ) # Optionally skip or mark as unhealthy return instances def get_all_services(self) -> Dict[str, List[ServiceInstance]]: """Retrieves all registered services, grouped by name.""" services: Dict[str, List[ServiceInstance]] = {} with self._get_db_conn() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM services ORDER BY name") rows = cursor.fetchall() for row in rows: try: instance = ServiceInstance( id=row["id"], name=row["name"], address=row["address"], port=row["port"], health=row["health"], tags=json.loads(row["tags_json"]), metadata=json.loads(row["metadata_json"]), ) if instance.name not in services: services[instance.name] = [] services[instance.name].append(instance) except (json.JSONDecodeError, TypeError, KeyError) as e: print( f"Warning: Could not parse service data for ID {row.get('id', 'N/A')}: {e}" ) return services def update_health(self, service_id: str, health: str) -> bool: """Updates the health status of a specific service instance.""" with self._get_db_conn() as conn: cursor = conn.cursor() cursor.execute( "UPDATE services SET health = ?, last_update = ? WHERE id = ?", (health, time.time(), service_id), ) return cursor.rowcount > 0 def create_token( self, name: str, permissions: List[str] ) -> Tuple[Optional[str], Optional[str]]: """Creates a new API token, storing its hash. Returns (token, None) on success, (None, error_message) on failure.""" token_plain = secrets.token_urlsafe(32) token_hash = generate_token_hash(token_plain, self._hmac_key) with self._get_db_conn() as conn: cursor = conn.cursor() try: cursor.execute( "INSERT INTO tokens (token_hash, name, created_at, permissions_json) VALUES (?, ?, ?, ?)", (token_hash, name, time.time(), json.dumps(permissions)), ) return token_plain, None # Return the plain token only on success except sqlite3.IntegrityError: # This likely means the name is not unique return None, f"Token name '{name}' already exists." except Exception as e: print(f"Error creating token '{name}': {e}") return None, "Database error during token creation." def revoke_token(self, token_plain: str) -> bool: """Revokes an API token by deleting its hash from the database.""" token_hash = generate_token_hash(token_plain, self._hmac_key) with self._get_db_conn() as conn: cursor = conn.cursor() cursor.execute("DELETE FROM tokens WHERE token_hash = ?", (token_hash,)) return cursor.rowcount > 0 def validate_token(self, token_plain: str) -> Optional[TokenInfo]: """ Validates a plain text token against stored hashes and returns its info if valid. Returns None if the token is invalid. """ token_hash = generate_token_hash(token_plain, self._hmac_key) with self._get_db_conn() as conn: cursor = conn.cursor() cursor.execute( "SELECT name, created_at, permissions_json FROM tokens WHERE token_hash = ?", (token_hash,), ) row = cursor.fetchone() if row: try: permissions = json.loads(row["permissions_json"]) if not isinstance(permissions, list): raise TypeError("Permissions format error") return TokenInfo( name=row["name"], created_at=row["created_at"], permissions=permissions, ) except (json.JSONDecodeError, TypeError) as e: print( f"Error decoding permissions for token hash {token_hash}: {e}" ) return None # Treat malformed permissions as invalid else: return None # Token hash not found # --- API Key Security --- api_key_header = APIKeyHeader(name="X-API-Token", auto_error=False) def get_permission_checker(required_permission: str): """ Dependency factory to create a dependency that checks for a specific permission. """ async def _check_permission( registry: ServiceRegistry = Depends( lambda: get_service_registry() ), # Get registry instance api_key: Optional[str] = Security(api_key_header), ) -> str: # Return API key on success if api_key is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="API token missing in X-API-Token header", ) token_info = registry.validate_token(api_key) if token_info is None: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API token" ) if required_permission not in token_info.permissions: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Insufficient permissions. Required: '{required_permission}'", ) return api_key # Return the validated key (or token info if needed later) return _check_permission # --- FastAPI Application Factory --- # Global variable to hold the registry instance *within this process* # This is okay because FastAPI runs in a single process (or multiple workers, # each getting its own instance, which is fine as they all talk to the DB). _registry_instance: Optional[ServiceRegistry] = None def setup_registry(db_path: str, hmac_key: bytes): """Initializes the registry instance for the current process.""" global _registry_instance if _registry_instance is None: print(f"[{os.getpid()}] Initializing ServiceRegistry for FastAPI...") _registry_instance = ServiceRegistry(db_path, hmac_key) else: print(f"[{os.getpid()}] ServiceRegistry already initialized for FastAPI.") def get_service_registry() -> ServiceRegistry: """Dependency to get the ServiceRegistry instance.""" if _registry_instance is None: # This should not happen if setup_registry is called correctly before app starts raise RuntimeError("ServiceRegistry not initialized for this process.") return _registry_instance def generate_status_page_html(registry: ServiceRegistry) -> str: """Generates the HTML for the status page.""" all_services_map = registry.get_all_services() nodes = ( {} ) # Dictionary to store node info: {address: {"status": "Active|Inactive", "label": "Node N"}} # 1. Identify unique "nodes" (addresses) and determine their status unique_addresses = sorted( list( set( inst.address for instances in all_services_map.values() for inst in instances ) ) ) node_label_counter = 1 for addr in unique_addresses: is_active = False for service_name, instances in all_services_map.items(): for instance in instances: if instance.address == addr and instance.health == "passing": is_active = True break # Found one passing service at this address, node is active if is_active: break nodes[addr] = { "status": "Active" if is_active else "Inactive", "label": f"Node {node_label_counter}", } node_label_counter += 1 grid_size = len(nodes) # 2. Build the HTML string dynamically # Use inline style for grid-size as it's dynamic html_content = f""" MiniDiscovery Status

Node Status

""" if grid_size == 0: html_content += "
No active nodes detected.
" else: html_content += f"
" # Set CSS variable size # Top-left corner html_content += "
" # Column headers for addr in unique_addresses: html_content += f"
{nodes[addr]['label']}
" # Grid rows for row_addr in unique_addresses: # Row label html_content += f"
{nodes[row_addr]['label']}
" # Cells in the row for col_addr in unique_addresses: if row_addr == col_addr: # Diagonal: Show node's own status status = nodes[row_addr]["status"] status_class = f"status-{status.lower()}" html_content += ( f"
{status}
" ) else: # Off-diagonal: Show 'Unknown' as we don't track inter-node connectivity html_content += ( "
Unknown
" ) # End of row implicitly handled by grid layout html_content += "
" # End grid-container # Add footer/closing tags html_content += """
Status based on registered service health. Page refreshes automatically.
""" return html_content def create_fastapi_app() -> FastAPI: """Creates the FastAPI application.""" app = FastAPI( title="MiniDiscovery", description="Minimal Service Discovery inspired by Consul", version="0.2.0", ) # --- Status Page Endpoint (Public) --- @app.get("/status", response_class=HTMLResponse) async def get_status_page( registry: ServiceRegistry = Depends(get_service_registry), ): """Serves the anonymous HTML status page.""" html_content = generate_status_page_html(registry) return HTMLResponse(content=html_content, status_code=200) # --- API Endpoints (Authenticated) --- @app.post("/v1/agent/service/register", status_code=status.HTTP_200_OK) async def register_service( instance: ServiceInstance, registry: ServiceRegistry = Depends(get_service_registry), token: str = Depends(get_permission_checker("write")), # Require 'write' ): """Registers a service instance. Supports replace-if-exists based on ID.""" # Pydantic performs validation based on the model if registry.register(instance): # Return 200 OK on success return {"status": "registered", "service_id": instance.id} else: # This path might be less likely with INSERT OR REPLACE raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to register service due to a database error.", ) @app.put( "/v1/agent/service/deregister/{service_id}", status_code=status.HTTP_200_OK ) async def deregister_service( service_id: str, registry: ServiceRegistry = Depends(get_service_registry), token: str = Depends(get_permission_checker("write")), # Require 'write' ): """Deregisters a service instance by its ID.""" if registry.deregister(service_id): # Returns 200 OK on success return {"status": "deregistered", "service_id": service_id} else: # If deregister returns false, it means the service wasn't found raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Service ID '{service_id}' not found.", ) @app.get("/v1/catalog/services", response_model=Dict[str, List[str]]) async def get_catalog_services( registry: ServiceRegistry = Depends(get_service_registry), token: str = Depends(get_permission_checker("read")), # Require 'read' ): """Lists all known service names and their tags.""" services_data = registry.get_all_services() # Format matches Consul /v1/catalog/services catalog = { name: list(set(tag for inst in instances for tag in inst.tags)) for name, instances in services_data.items() } return catalog @app.get("/v1/catalog/service/{service_name}", response_model=List[ServiceInstance]) async def get_catalog_service( service_name: str, tag: Optional[str] = None, # Allow filtering by tag registry: ServiceRegistry = Depends(get_service_registry), token: str = Depends(get_permission_checker("read")), # Require 'read' ): """Lists instances for a specific service, optionally filtered by tag.""" instances = registry.get_service(service_name) if not instances: # Return empty list if service name doesn't exist, matches Consul return [] if tag: instances = [inst for inst in instances if tag in inst.tags] # Pydantic validation handles the response model structure return instances @app.get("/v1/health/service/{service_name}", response_model=List[ServiceInstance]) async def get_health_service( service_name: str, tag: Optional[str] = None, passing: bool = False, # Filter for only passing instances registry: ServiceRegistry = Depends(get_service_registry), token: str = Depends(get_permission_checker("read")), # Require 'read' ): """Lists healthy instances for a specific service.""" instances = registry.get_service(service_name, only_passing=passing) if not instances: return [] if tag: instances = [inst for inst in instances if tag in inst.tags] return instances # --- Token Management Endpoints --- @app.post( "/v1/acl/token", response_model=ApiTokenCreateResponse, status_code=status.HTTP_201_CREATED, ) async def create_acl_token( token_request: ApiTokenCreateRequest, registry: ServiceRegistry = Depends(get_service_registry), token: str = Depends(get_permission_checker("admin")), # Require 'admin' ): """Creates a new API token.""" plain_token, error = registry.create_token( token_request.name, token_request.permissions ) if error: # Distinguish between user error (duplicate name) and server error if "already exists" in error: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=error ) else: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=error ) return ApiTokenCreateResponse( token=plain_token, # Only show the token once upon creation! name=token_request.name, permissions=token_request.permissions, ) @app.delete("/v1/acl/token/{token_to_revoke}", status_code=status.HTTP_200_OK) async def revoke_acl_token( token_to_revoke: str, # The actual token value to revoke registry: ServiceRegistry = Depends(get_service_registry), token: str = Depends(get_permission_checker("admin")), # Require 'admin' ): """Revokes an existing API token.""" # Prevent revoking the token currently being used for the request? Maybe not necessary. # current_token_info = registry.validate_token(token) # if registry.generate_token_hash(token_to_revoke, registry._hmac_key) == registry.generate_token_hash(token, registry._hmac_key): # raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot revoke the token used for this request.") if registry.revoke_token(token_to_revoke): return { "status": "revoked", "token_info": "Token revoked successfully", } # Don't echo back the token else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Token not found or already revoked.", ) # --- DNS-over-HTTPS Endpoint (Simplified, follows RFC 8484 GET format) --- @app.get("/dns-query", status_code=status.HTTP_200_OK) async def dns_query( name: str, type: str = "A", # Query param 'type' registry: ServiceRegistry = Depends(get_service_registry), # No auth required for basic DNS queries for simplicity, add if needed ): """Handles DNS queries over HTTPS (DoH).""" # Basic parsing, assuming name ends with configured suffix if not name.lower().endswith(DNS_QUERY_SUFFIX): return { "Status": 2 } # NXDOMAIN according to RFC8484 JSON format for DNS Messages service_name = name[: -len(DNS_QUERY_SUFFIX)].split(".")[ -1 ] # Get the last part before the suffix instances = registry.get_service(service_name, only_passing=True) if not instances: return {"Status": 2} # NXDOMAIN answers = [] query_type_upper = type.upper() if query_type_upper == "A": type_code = 1 # DNS Type A for instance in instances: answers.append( { "name": name, "type": type_code, "TTL": DNS_DEFAULT_TTL, "data": instance.address, } ) elif query_type_upper == "SRV": type_code = 33 # DNS Type SRV for instance in instances: # Format: Priority Weight Port Target srv_data = f"0 5 {instance.port} {instance.address}{DNS_QUERY_SUFFIX}" # Target should be resolvable ideally answers.append( { "name": name, "type": type_code, "TTL": DNS_DEFAULT_TTL, "data": srv_data, } ) # Add more record types (AAAA, TXT for metadata?) if needed else: # Unsupported type for this simple endpoint return {"Status": 4} # NotImp - Not Implemented # Construct response according to RFC 8484 JSON format response = { "Status": 0, # NOERROR "TC": False, # Truncated "RD": True, # Recursion Desired (client asked for it) - reflects typical client flag "RA": False, # Recursion Available (we are authoritative-like) "AD": False, # Authenticated Data "CD": False, # Checking Disabled "Question": [ { "name": name, "type": type_code if "type_code" in locals() else 1, } # Reflect question ], "Answer": answers, # Authority and Additional sections omitted for simplicity } return response return app # --- Twisted DNS Server --- class MiniDiscoveryResolver(common.ResolverBase): def __init__(self, registry: ServiceRegistry): self.registry = registry super().__init__() def _lookup( self, name: bytes, cls: int, query_type: int, timeout: Optional[Tuple[int]] = None, ) -> Deferred: """ Main lookup entry point. Decodes the name, checks the suffix, and dispatches to specific handlers based on query_type. """ d = Deferred() name_str_debug = repr(name) # For logging errors try: # 1. Decode the incoming query name safely try: name_str = name.decode("utf-8").lower() except UnicodeDecodeError: # Log? Respond with format error? For now, treat as non-existent. print(f"DNS: Cannot decode query name {name_str_debug} as UTF-8.") # This should result in NXDOMAIN if we don't callback anything # d.callback(([], [], [])) # Alternatively, explicitly return empty return d # Let Twisted handle NXDOMAIN? Best to callback empty. # Returning empty is safer for our specific resolver. d.callback(([], [], [])) return d # 2. Check for the expected suffix if not name_str.endswith(DNS_QUERY_SUFFIX): # Not a query for our domain, return empty. d.callback(([], [], [])) return d # 3. Extract the base name (part before the suffix) base_name = name_str[: -len(DNS_QUERY_SUFFIX)] if not base_name: # Query was just the suffix itself d.callback(([], [], [])) return d # 4. Dispatch based on query type handler = None if query_type == dns.A: handler = self._handle_a_query elif query_type == dns.SRV: handler = self._handle_srv_query elif query_type == dns.TXT: handler = self._handle_txt_query # Add elif for AAAA if needed in the future # elif query_type == dns.AAAA: # handler = self._handle_aaaa_query # Implement this if needed else: # Unsupported query type for our resolver d.callback(([], [], [])) return d # 5. Execute the handler and set callback/errback # The handlers currently don't return Deferreds, so we call directly answers, authority, additional = handler(name, base_name, cls) d.callback((answers, authority, additional)) except Exception as e: # Catch-all for unexpected errors during dispatch or handler execution print( f"!!! Unhandled exception during DNS lookup for {name_str_debug} " f"(Type: {query_type}) !!!" ) print(f"Exception Type: {type(e).__name__}") print(f"Exception Args: {e.args}") print("--- Traceback ---") traceback.print_exc() print("--- End Traceback ---") # Signal DNS server failure (SERVFAIL) d.errback(Failure(e)) return d # --- Helper Methods for Specific Record Types --- def _parse_srv_query(self, base_name: str) -> Tuple[Optional[str], Optional[str]]: """ Parses SRV-style queries: _tag._proto.service or _tag._proto Returns (tag, service_name) or (None, None) if not SRV-style. """ parts = base_name.split(".") if ( len(parts) >= 2 and parts[0].startswith("_") and parts[1] in ["_tcp", "_udp"] ): tag = parts[0][1:] # Remove leading '_' # Service name is the part *after* _tag._proto, if it exists service_name = parts[2] if len(parts) > 2 else None # We currently ignore the rest of the parts (like datacenter in consul) return tag, service_name return None, None # Not an SRV-style query name def _get_instances_for_query( self, base_name: str, is_srv_query: bool = False ) -> List[ServiceInstance]: """Fetches relevant, passing service instances based on the query name.""" instances = [] tag_filter = None service_name_filter = None if is_srv_query: tag_filter, service_name_filter = self._parse_srv_query(base_name) if tag_filter is None: # Not a valid _tag._proto... query return [] # Return empty, SRV query handler expects specific format else: # A or TXT query: service name is the last part service_name_filter = base_name.split(".")[-1] if service_name_filter: # Query targets a specific service (with potential tag filter for SRV) service_instances = self.registry.get_service( service_name_filter, only_passing=True ) if tag_filter: # SRV query with tag and service instances = [ inst for inst in service_instances if tag_filter in inst.tags ] else: # A/TXT query, or SRV query without tag (using service name) instances = service_instances # Already filtered for passing elif tag_filter: # SRV query for a tag across all services all_services = self.registry.get_all_services() for name, service_instances in all_services.items(): for inst in service_instances: if inst.health == "passing" and tag_filter in inst.tags: instances.append(inst) else: # This case shouldn't be reached if initial checks are correct # (e.g., A/TXT query needs a service name part) print(f"Warning: Could not determine filter for query '{base_name}'") print( f"DNS Lookup: base_name='{base_name}', is_srv={is_srv_query}, " f"tag='{tag_filter}', service='{service_name_filter}'. Found {len(instances)} instances." ) return instances def _handle_a_query( self, name: bytes, base_name: str, cls: int ) -> Tuple[List, List, List]: """Handles A record lookups.""" answers = [] instances = self._get_instances_for_query(base_name, is_srv_query=False) for instance in instances: try: # Twisted's Record_A expects the IP address string payload = dns.Record_A(address=instance.address, ttl=DNS_DEFAULT_TTL) rr = dns.RRHeader( name=name, # Respond with the original query name type=dns.A, cls=cls, ttl=DNS_DEFAULT_TTL, payload=payload, ) answers.append(rr) except Exception as e: print( f"Warning: Error creating A record for instance {instance.id} " f"(IP: {instance.address}): {e}. Skipping." ) return answers, [], [] # No authority or additional records for basic A def _handle_srv_query( self, name: bytes, base_name: str, cls: int ) -> Tuple[List, List, List]: """Handles SRV record lookups (service or tag based).""" answers = [] additional = [] instances = self._get_instances_for_query(base_name, is_srv_query=True) # If _get_instances_for_query returned empty because parsing failed, # we might want to try interpreting the name differently, e.g., # as a direct SRV lookup for a service name like `service.domain.suffix`. # For now, we strictly follow the _tag._proto logic defined above. # If you want `srvlookup service.domain.suffix`, the logic in # _get_instances_for_query needs adjustment or another branch here. for instance in instances: try: # SRV target points to a name that resolves to the instance's A record. # Conventionally: ... # Let's use: .node. for simplicity, # or maybe .... # Using just instance ID + suffix is simple and unique. # Ensure the target ends with the suffix too! target_name_str = f"{instance.id}{DNS_QUERY_SUFFIX}" target_name_bytes = target_name_str.encode("utf-8") srv_payload = dns.Record_SRV( priority=0, # Lower is more preferred weight=10, # Relative weight for same priority port=instance.port, target=target_name_bytes, # Must be bytes ttl=DNS_DEFAULT_TTL, # TTL for the SRV record itself ) srv_rr = dns.RRHeader( name=name, # Respond with the original query name type=dns.SRV, cls=cls, ttl=DNS_DEFAULT_TTL, payload=srv_payload, ) answers.append(srv_rr) # Add corresponding A record for the target in the additional section a_payload = dns.Record_A(address=instance.address, ttl=DNS_DEFAULT_TTL) a_rr = dns.RRHeader( name=target_name_bytes, # Name matches SRV target type=dns.A, cls=cls, ttl=DNS_DEFAULT_TTL, # TTL for the additional A record payload=a_payload, ) additional.append(a_rr) except Exception as e: print( f"Warning: Error creating SRV/A record for instance {instance.id} " f"(Addr: {instance.address}:{instance.port}): {e}. Skipping." ) return answers, [], additional def _handle_txt_query( self, name: bytes, base_name: str, cls: int ) -> Tuple[List, List, List]: """Handles TXT record lookups, returning each piece of information as a separate TXT record.""" answers = [] instances = self._get_instances_for_query(base_name, is_srv_query=False) for instance in instances: instance_id_str = str(instance.id) print(f"DNS TXT: Processing instance {instance_id_str}") # Step 1: Collect all individual strings txt_strings = [] # Add tags if isinstance(instance.tags, list): txt_strings.extend([str(tag) for tag in instance.tags]) else: print( f"WARNING: Instance {instance_id_str} tags are not a list: {type(instance.tags)}" ) # Add metadata as key=value pairs if isinstance(instance.metadata, dict): txt_strings.extend( [f"{str(k)}={str(v)}" for k, v in instance.metadata.items()] ) else: print( f"WARNING: Instance {instance_id_str} metadata is not a dict: {type(instance.metadata)}" ) # Add instance ID txt_strings.append(f"instance_id={instance_id_str}") # Step 2 & 3: Create a separate TXT record for each string for txt_string in txt_strings: try: # Encode the string to bytes encoded_string = txt_string.encode("utf-8") # Check length (TXT strings must be <= 255 bytes) if len(encoded_string) > 255: print( f"WARNING: TXT string too long, skipping: {txt_string[:50]}..." ) continue # Create a TXT record with a single string txt_record = dns.Record_TXT( encoded_string, ttl=60 ) # Adjust TTL as needed rr_header = dns.RRHeader( name=name, type=dns.TXT, cls=cls, ttl=60, # Adjust TTL as needed payload=txt_record, ) answers.append(rr_header) print(f"DNS TXT: Added record: {txt_string}") except Exception as e: print(f"ERROR: Failed to create TXT record for {txt_string}: {e}") print(f"DNS TXT: Generated {len(answers)} TXT records for '{base_name}'") return answers, [], [] # --- Health Checker --- def check_service_health(instance: ServiceInstance) -> str: """Performs a simple TCP connection check.""" try: with socket.create_connection( (instance.address, instance.port), timeout=2 ) as sock: # Could add optional check for specific data/response here later return "passing" except (socket.timeout, ConnectionRefusedError, OSError): return "failing" except Exception as e: print( f"Unexpected error checking health for {instance.id} ({instance.address}:{instance.port}): {e}" ) return "failing" # Treat unexpected errors as failure # --- Process Runner Functions --- def run_fastapi_server(db_path: str, hmac_key: bytes, host: str, port: int): """Sets up and runs the FastAPI server.""" print(f"[{os.getpid()}] Starting FastAPI process...") # Initialize registry for this specific process setup_registry(db_path, hmac_key) app = create_fastapi_app() # Setup signal handlers for graceful shutdown within Uvicorn if possible # Uvicorn handles SIGINT/SIGTERM by default uvicorn.run(app, host=host, port=port) print( f"[{os.getpid()}] FastAPI process finished." ) # Should not be reached normally def run_dns_server(db_path: str, hmac_key: bytes, port: int): """Sets up and runs the Twisted DNS server.""" print(f"[{os.getpid()}] Starting DNS server process...") registry = ServiceRegistry( db_path, hmac_key ) # DNS process needs its own registry instance resolver = MiniDiscoveryResolver(registry) factory = server.DNSServerFactory( clients=[resolver], # verbose=2 # Uncomment for very detailed Twisted DNS logging ) protocol = dns.DNSDatagramProtocol(controller=factory) # Listen on UDP and TCP try: reactor.listenUDP(port, protocol) reactor.listenTCP(port, factory) print(f"[{os.getpid()}] DNS Server listening on port {port} (UDP/TCP)") except Exception as e: print(f"[{os.getpid()}] Error starting DNS listeners on port {port}: {e}") return # Exit process if cannot bind # Graceful shutdown for Twisted reactor def shutdown_dns(): print(f"[{os.getpid()}] Shutting down DNS server...") # reactor.stop() might be called by signal handlers below # Add cleanup here if needed (e.g., closing listeners explicitly) reactor.addSystemEventTrigger("before", "shutdown", shutdown_dns) # Handle signals to stop the reactor gracefully def signal_handler(signum, frame): print(f"[{os.getpid()}] Received signal {signum}, stopping DNS reactor.") # Important: Call reactor.stop() from the reactor thread if possible # reactor.callFromThread(reactor.stop) is safer if signals handled elsewhere # For simple cases, calling directly might be okay, but watch for deadlocks. if reactor.running: reactor.callLater( 0, reactor.stop ) # Schedule stop in the next loop iteration signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) # Run the Twisted reactor (blocking call) reactor.run() print(f"[{os.getpid()}] DNS server process finished.") def run_health_checker(db_path: str, hmac_key: bytes, check_interval: int): """Runs the health checking loop.""" print( f"[{os.getpid()}] Starting Health Checker process (interval: {check_interval}s)..." ) registry = ServiceRegistry( db_path, hmac_key ) # Health checker needs its own registry instance running = True def signal_handler(signum, frame): nonlocal running print(f"[{os.getpid()}] Received signal {signum}, stopping health checker...") running = False signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) while running: start_time = time.monotonic() print(f"[{os.getpid()}] Health Check Cycle Start") updated_count = 0 checked_count = 0 try: # Fetch all services directly from DB for the check cycle services_map = registry.get_all_services() all_instances = [ inst for sublist in services_map.values() for inst in sublist ] checked_count = len(all_instances) for instance in all_instances: if not running: break # Exit early if shutdown signal received current_health = instance.health # Perform the actual health check new_health = check_service_health(instance) if current_health != new_health: print( f"[{os.getpid()}] Health change for {instance.id} ({instance.name}): {current_health} -> {new_health}" ) # Update health status in the database if registry.update_health(instance.id, new_health): updated_count += 1 else: # This might happen if the service was deregistered between get and update print( f"[{os.getpid()}] Warning: Failed to update health for {instance.id} (maybe deregistered?)" ) except Exception as e: # Catch errors during the check cycle itself (e.g., DB connection) print(f"[{os.getpid()}] Error during health check cycle: {e}") if not running: break # Check again after the loop body # Calculate sleep time to maintain interval elapsed = time.monotonic() - start_time sleep_time = max(0, check_interval - elapsed) print( f"[{os.getpid()}] Health Check Cycle End. Checked: {checked_count}, Updated: {updated_count}. Took {elapsed:.2f}s. Sleeping for {sleep_time:.2f}s." ) # Sleep interruptibly try: time.sleep(sleep_time) except InterruptedError: # Catch if signal interrupted sleep pass print(f"[{os.getpid()}] Health checker process finished.") # --- Main Execution --- def main(): parser = argparse.ArgumentParser( description="MiniDiscovery: A minimal service discovery tool.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, # Show defaults ) parser.add_argument( "--db-path", default="minidiscovery_data.db", help="Path to the SQLite database file. Overridden by ${DB_PATH_ENV_VAR} if set.", ) parser.add_argument( "--api-host", default="0.0.0.0", help="Host address for the API server." ) parser.add_argument( "--api-port", type=int, default=8500, help="Port for the API server." ) parser.add_argument( "--dns-port", type=int, default=10053, help="Port for the DNS server (UDP/TCP). Use ports > 1024 for non-root.", ) parser.add_argument( "--health-check-interval", type=int, default=15, help="Health check interval in seconds.", ) parser.add_argument( "--hmac-key-file", default=HMAC_KEY_FILE, help="Path to the file storing the HMAC secret key.", ) args = parser.parse_args() # Read db_file_path from ENV var db_path_from_env = os.environ.get(DB_PATH_ENV_VAR) if db_path_from_env: args.db_path = db_path_from_env # Override args print( f"Using database path from environment variable ${DB_PATH_ENV_VAR}: {args.db_path}" ) elif not args.db_path: # If env var not set AND command line arg not provided args.db_path = "minidiscovery_data.db" # Apply the default now print(f"Database path not specified, using default: {args.db_path}") else: # Using the path provided via --db-path argument print(f"Using database path from --db-path argument: {args.db_path}") # Ensure DB directory exists db_dir = os.path.dirname(os.path.abspath(args.db_path)) if db_dir: os.makedirs(db_dir, exist_ok=True) # Load or generate HMAC key *before* starting processes try: hmac_key = load_or_generate_hmac_key(args.hmac_key_file) except SystemExit: return # Exit if key generation/loading fails # Check for initial admin token env var *before* initializing registry in main process # The check inside ServiceRegistry._init_db is the authoritative one, # but this provides an earlier warning. conn_check = sqlite3.connect(args.db_path) cursor_check = conn_check.cursor() try: cursor_check.execute("SELECT COUNT(*) FROM tokens") token_count_check = cursor_check.fetchone()[0] if token_count_check == 0 and not os.environ.get(ADMIN_TOKEN_ENV_VAR): print( f"WARNING: Database appears empty. Ensure '{ADMIN_TOKEN_ENV_VAR}' is set for the first run." ) except sqlite3.DatabaseError: # Table might not exist yet, ServiceRegistry init will handle it pass finally: conn_check.close() # --- Process Management --- processes: List[multiprocessing.Process] = [] stop_event = multiprocessing.Event() # Used to signal shutdown def graceful_shutdown(signum, frame): print(f"Main process received signal {signum}. Initiating shutdown...") stop_event.set() # Signal processes to stop (though they also have signal handlers) # Terminate processes forcefully after a grace period if they don't exit time.sleep(2) # Give processes a moment to react to their own signal handlers for p in processes: if p.is_alive(): print(f"Terminating process {p.pid} ({p.name})...") p.terminate() # Send SIGTERM signal.signal(signal.SIGINT, graceful_shutdown) signal.signal(signal.SIGTERM, graceful_shutdown) try: # Start API Server Process api_process = multiprocessing.Process( target=run_fastapi_server, args=(args.db_path, hmac_key, args.api_host, args.api_port), name="FastAPI Process", ) processes.append(api_process) api_process.start() # Start DNS Server Process dns_process = multiprocessing.Process( target=run_dns_server, args=(args.db_path, hmac_key, args.dns_port), name="DNS Process", ) processes.append(dns_process) dns_process.start() # Start Health Checker Process health_process = multiprocessing.Process( target=run_health_checker, args=(args.db_path, hmac_key, args.health_check_interval), name="Health Check Process", ) processes.append(health_process) health_process.start() print("-" * 30) print(f"MiniDiscovery Started (PID: {os.getpid()})") print(f" API Server: http://{args.api_host}:{args.api_port}") print(f" DNS Server: Port {args.dns_port} (UDP/TCP)") print(f" Database: {args.db_path}") print(f" HMAC Key: {args.hmac_key_file}") print(f" Health Int: {args.health_check_interval}s") print("-" * 30) print("Press Ctrl+C to stop.") # Wait for processes to complete or shutdown signal while not stop_event.is_set(): # Check if any process exited unexpectedly for p in processes: if not p.is_alive() and p.exitcode != 0: print( f"ERROR: Process {p.name} (PID: {p.pid}) exited unexpectedly with code {p.exitcode}." ) stop_event.set() # Trigger shutdown if a child crashes break time.sleep(1) # Wait efficiently finally: print("Waiting for processes to join...") for p in processes: p.join(timeout=5) # Wait for clean exit if p.is_alive(): print(f"Process {p.pid} ({p.name}) did not exit cleanly, killing.") p.kill() # Force kill if still alive after timeout print("MiniDiscovery shut down complete.") if __name__ == "__main__": main()