import argparse
import asyncio
import base64
import hashlib
import hmac
import json
import multiprocessing
import os
import secrets
import signal
import socket
import sqlite3
import time
import uvicorn
import traceback
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Generator, Any
from fastapi import Depends, FastAPI, Header, HTTPException, Security, status
from fastapi.security import APIKeyHeader
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Field
from twisted.internet import reactor
from twisted.names import dns, server, common
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
# --- Constants ---
SQLITE_TIMEOUT_SECONDS = 10  # Increased timeout for potential concurrent writes
HMAC_KEY_FILE = "minidiscovery.key"
ADMIN_TOKEN_ENV_VAR = "MINIDISCOVERY_ADMIN_TOKEN"
DEFAULT_TOKEN_PERMISSIONS = ["read", "write"]
ADMIN_PERMISSIONS = ["read", "write", "admin"]
DNS_DEFAULT_TTL = 60
DNS_QUERY_SUFFIX = ".laiska.local"  # Define a suffix for DNS lookups
DB_PATH_ENV_VAR = "MINIDISCOVERY_DB_PATH"
# --- Database Schema ---
SQL_CREATE_SERVICES = """
CREATE TABLE IF NOT EXISTS services (
    id TEXT PRIMARY KEY,
    name TEXT NOT NULL,
    address TEXT NOT NULL,
    port INTEGER NOT NULL,
    health TEXT DEFAULT 'passing',
    tags_json TEXT DEFAULT '[]',       -- Store tags as JSON text
    metadata_json TEXT DEFAULT '{}',   -- Store metadata as JSON text
    last_update REAL NOT NULL          -- Timestamp of last update/registration
);
"""
SQL_CREATE_TOKENS = """
CREATE TABLE IF NOT EXISTS tokens (
    token_hash TEXT PRIMARY KEY,       -- Store the HMAC hash of the token
    name TEXT NOT NULL UNIQUE,         -- Token names should be unique
    created_at REAL NOT NULL,
    permissions_json TEXT NOT NULL     -- Store permissions as JSON text
);
"""
SQL_CREATE_META = """
CREATE TABLE IF NOT EXISTS meta (
    key TEXT PRIMARY KEY,
    value TEXT NOT NULL
);
"""  # For storing things like schema version or first-run status
# --- Pydantic Models ---
class ServiceInstance(BaseModel):
    id: str = Field(..., description="Unique identifier for this service instance")
    name: str = Field(
        ..., description="Logical name of the service (e.g., 'web', 'db')"
    )
    address: str = Field(
        ..., description="IP address or hostname where the service listens"
    )
    port: int = Field(..., gt=0, lt=65536, description="Port number for the service")
    tags: List[str] = Field(
        default_factory=list, description="Optional list of tags for filtering"
    )
    health: str = Field(
        default="passing", description="Health status ('passing', 'failing', 'unknown')"
    )
    metadata: Dict[str, str] = Field(
        default_factory=dict, description="Optional key-value metadata"
    )
class TokenInfo(BaseModel):
    """Information about a token (excluding the hash)"""
    name: str
    created_at: float
    permissions: List[str]
class ApiTokenCreateRequest(BaseModel):
    name: str = Field(..., description="A descriptive name for the token")
    permissions: List[str] = Field(
        default=DEFAULT_TOKEN_PERMISSIONS, description="Permissions for the token"
    )
class ApiTokenCreateResponse(BaseModel):
    token: str = Field(..., description="The generated API token (show only once!)")
    name: str
    permissions: List[str]
# --- HMAC Key Management ---
def load_or_generate_hmac_key(key_file: str = HMAC_KEY_FILE) -> bytes:
    """Loads HMAC key from file or generates a new one."""
    if os.path.exists(key_file):
        try:
            with open(key_file, "rb") as f:
                key = f.read()
            if len(key) < 32:  # Basic sanity check
                raise ValueError("HMAC key file seems corrupted or too short.")
            print(f"Loaded HMAC key from {key_file}")
            return key
        except Exception as e:
            print(
                f"Error loading HMAC key from {key_file}: {e}. Check file permissions and content."
            )
            raise SystemExit(1)
    else:
        print(f"HMAC key file ({key_file}) not found. Generating a new one.")
        key = secrets.token_bytes(32)  # 256 bits
        try:
            # Attempt to write with restricted permissions
            fd = os.open(key_file, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o600)
            with os.fdopen(fd, "wb") as f:
                f.write(key)
            print(f"Generated and saved new HMAC key to {key_file}. PROTECT THIS FILE!")
            return key
        except FileExistsError:
            # Race condition: another process created it between check and open
            return load_or_generate_hmac_key(key_file)
        except OSError as e:
            print(f"Error writing HMAC key file {key_file}: {e}")
            print("Please ensure the directory is writable by the process.")
            raise SystemExit(1)
        except Exception as e:
            print(f"Unexpected error generating HMAC key: {e}")
            raise SystemExit(1)
def generate_token_hash(token: str, hmac_key: bytes) -> str:
    """Generates HMAC-SHA256 hash of the token."""
    return hmac.new(hmac_key, token.encode("utf-8"), hashlib.sha256).hexdigest()
# --- Service Registry (SQLite Backend) ---
class ServiceRegistry:
    def __init__(self, db_path: str, hmac_key: bytes):
        self.db_path = db_path
        self._hmac_key = hmac_key
        self._init_db()
    @contextmanager
    def _get_db_conn(self) -> Generator[sqlite3.Connection, None, None]:
        """Provides a context-managed database connection."""
        conn = None
        try:
            conn = sqlite3.connect(self.db_path, timeout=SQLITE_TIMEOUT_SECONDS)
            conn.row_factory = sqlite3.Row  # Access columns by name
            conn.execute("PRAGMA foreign_keys = ON;")  # Enforce foreign keys if needed
            conn.execute(
                "PRAGMA journal_mode = WAL;"
            )  # Write-Ahead Logging for better concurrency
            yield conn
            conn.commit()
        except sqlite3.Error as e:
            print(f"SQLite error: {e} - {self.db_path}")
            if conn:
                conn.rollback()  # Rollback on error
            raise  # Re-raise the exception
        finally:
            if conn:
                conn.close()
    def _init_db(self) -> None:
        """Initializes the database schema if it doesn't exist."""
        try:
            with self._get_db_conn() as conn:
                cursor = conn.cursor()
                cursor.execute(SQL_CREATE_SERVICES)
                cursor.execute(SQL_CREATE_TOKENS)
                cursor.execute(SQL_CREATE_META)
                # Check if initial admin token needs to be created
                cursor.execute("SELECT COUNT(*) FROM tokens")
                token_count = cursor.fetchone()[0]
                cursor.execute("SELECT value FROM meta WHERE key = 'initialized'")
                initialized = cursor.fetchone()
                if token_count == 0 and not initialized:
                    print("No tokens found in the database.")
                    admin_token_plain = os.environ.get(ADMIN_TOKEN_ENV_VAR)
                    if not admin_token_plain:
                        print(
                            f"ERROR: Database is empty and initial admin token not provided."
                        )
                        print(
                            f"Please set the '{ADMIN_TOKEN_ENV_VAR}' environment variable with a secure token."
                        )
                        raise SystemExit(1)  # Critical configuration missing
                    print(
                        f"Creating initial admin token from '{ADMIN_TOKEN_ENV_VAR}'..."
                    )
                    token_hash = generate_token_hash(admin_token_plain, self._hmac_key)
                    cursor.execute(
                        "INSERT INTO tokens (token_hash, name, created_at, permissions_json) VALUES (?, ?, ?, ?)",
                        (
                            token_hash,
                            "admin",
                            time.time(),
                            json.dumps(ADMIN_PERMISSIONS),
                        ),
                    )
                    # Mark DB as initialized to prevent asking for env var again
                    cursor.execute(
                        "INSERT OR REPLACE INTO meta (key, value) VALUES ('initialized', 'true')"
                    )
                    print(
                        "Initial admin token created successfully. You can unset the environment variable now."
                    )
                elif token_count > 0 and not initialized:
                    # Tokens exist, but not marked initialized (e.g., older version) - mark it now
                    cursor.execute(
                        "INSERT OR REPLACE INTO meta (key, value) VALUES ('initialized', 'true')"
                    )
                # Future schema migrations could go here based on a version in 'meta' table
        except sqlite3.OperationalError as e:
            if "database is locked" in str(e):
                print(
                    f"Warning: Database {self.db_path} is locked during initialization. Retrying shortly..."
                )
                time.sleep(
                    0.5
                )  # Short delay before potential retry by caller/process start
                self._init_db()  # Recursive call - be cautious, maybe add max retries
            else:
                print(f"Fatal error initializing database {self.db_path}: {e}")
                raise SystemExit(1)
        except Exception as e:
            print(f"Fatal error during database initialization: {e}")
            raise SystemExit(1)
    def register(self, instance: ServiceInstance) -> bool:
        """Registers or updates a service instance in the database."""
        with self._get_db_conn() as conn:
            cursor = conn.cursor()
            try:
                cursor.execute(
                    """
                    INSERT OR REPLACE INTO services
                    (id, name, address, port, health, tags_json, metadata_json, last_update)
                    VALUES (?, ?, ?, ?, ?, ?, ?, ?)
                    """,
                    (
                        instance.id,
                        instance.name,
                        instance.address,
                        instance.port,
                        instance.health,
                        json.dumps(instance.tags),
                        json.dumps(instance.metadata),
                        time.time(),
                    ),
                )
                return True
            except sqlite3.IntegrityError as e:
                print(f"Error registering service {instance.id}: {e}")
                return False  # Should not happen with INSERT OR REPLACE unless DB issue
    def deregister(self, service_id: str) -> bool:
        """Removes a service instance from the database."""
        with self._get_db_conn() as conn:
            cursor = conn.cursor()
            cursor.execute("DELETE FROM services WHERE id = ?", (service_id,))
            return cursor.rowcount > 0  # Returns true if a row was deleted
    def get_service(
        self, name: str, only_passing: bool = False
    ) -> List[ServiceInstance]:
        """Retrieves all instances for a given service name."""
        instances = []
        sql = "SELECT * FROM services WHERE name = ?"
        params: List[Any] = [name]
        if only_passing:
            sql += " AND health = 'passing'"
        with self._get_db_conn() as conn:
            cursor = conn.cursor()
            cursor.execute(sql, tuple(params))
            rows = cursor.fetchall()
            for row in rows:
                try:
                    instances.append(
                        ServiceInstance(
                            id=row["id"],
                            name=row["name"],
                            address=row["address"],
                            port=row["port"],
                            health=row["health"],
                            tags=json.loads(row["tags_json"]),
                            metadata=json.loads(row["metadata_json"]),
                        )
                    )
                except (json.JSONDecodeError, TypeError, KeyError) as e:
                    print(
                        f"Warning: Could not parse service data for ID {row.get('id', 'N/A')}: {e}"
                    )
                    # Optionally skip or mark as unhealthy
        return instances
    def get_all_services(self) -> Dict[str, List[ServiceInstance]]:
        """Retrieves all registered services, grouped by name."""
        services: Dict[str, List[ServiceInstance]] = {}
        with self._get_db_conn() as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT * FROM services ORDER BY name")
            rows = cursor.fetchall()
            for row in rows:
                try:
                    instance = ServiceInstance(
                        id=row["id"],
                        name=row["name"],
                        address=row["address"],
                        port=row["port"],
                        health=row["health"],
                        tags=json.loads(row["tags_json"]),
                        metadata=json.loads(row["metadata_json"]),
                    )
                    if instance.name not in services:
                        services[instance.name] = []
                    services[instance.name].append(instance)
                except (json.JSONDecodeError, TypeError, KeyError) as e:
                    print(
                        f"Warning: Could not parse service data for ID {row.get('id', 'N/A')}: {e}"
                    )
        return services
    def update_health(self, service_id: str, health: str) -> bool:
        """Updates the health status of a specific service instance."""
        with self._get_db_conn() as conn:
            cursor = conn.cursor()
            cursor.execute(
                "UPDATE services SET health = ?, last_update = ? WHERE id = ?",
                (health, time.time(), service_id),
            )
            return cursor.rowcount > 0
    def create_token(
        self, name: str, permissions: List[str]
    ) -> Tuple[Optional[str], Optional[str]]:
        """Creates a new API token, storing its hash. Returns (token, None) on success, (None, error_message) on failure."""
        token_plain = secrets.token_urlsafe(32)
        token_hash = generate_token_hash(token_plain, self._hmac_key)
        with self._get_db_conn() as conn:
            cursor = conn.cursor()
            try:
                cursor.execute(
                    "INSERT INTO tokens (token_hash, name, created_at, permissions_json) VALUES (?, ?, ?, ?)",
                    (token_hash, name, time.time(), json.dumps(permissions)),
                )
                return token_plain, None  # Return the plain token only on success
            except sqlite3.IntegrityError:
                # This likely means the name is not unique
                return None, f"Token name '{name}' already exists."
            except Exception as e:
                print(f"Error creating token '{name}': {e}")
                return None, "Database error during token creation."
    def revoke_token(self, token_plain: str) -> bool:
        """Revokes an API token by deleting its hash from the database."""
        token_hash = generate_token_hash(token_plain, self._hmac_key)
        with self._get_db_conn() as conn:
            cursor = conn.cursor()
            cursor.execute("DELETE FROM tokens WHERE token_hash = ?", (token_hash,))
            return cursor.rowcount > 0
    def validate_token(self, token_plain: str) -> Optional[TokenInfo]:
        """
        Validates a plain text token against stored hashes and returns its info if valid.
        Returns None if the token is invalid.
        """
        token_hash = generate_token_hash(token_plain, self._hmac_key)
        with self._get_db_conn() as conn:
            cursor = conn.cursor()
            cursor.execute(
                "SELECT name, created_at, permissions_json FROM tokens WHERE token_hash = ?",
                (token_hash,),
            )
            row = cursor.fetchone()
            if row:
                try:
                    permissions = json.loads(row["permissions_json"])
                    if not isinstance(permissions, list):
                        raise TypeError("Permissions format error")
                    return TokenInfo(
                        name=row["name"],
                        created_at=row["created_at"],
                        permissions=permissions,
                    )
                except (json.JSONDecodeError, TypeError) as e:
                    print(
                        f"Error decoding permissions for token hash {token_hash}: {e}"
                    )
                    return None  # Treat malformed permissions as invalid
            else:
                return None  # Token hash not found
# --- API Key Security ---
api_key_header = APIKeyHeader(name="X-API-Token", auto_error=False)
def get_permission_checker(required_permission: str):
    """
    Dependency factory to create a dependency that checks for a specific permission.
    """
    async def _check_permission(
        registry: ServiceRegistry = Depends(
            lambda: get_service_registry()
        ),  # Get registry instance
        api_key: Optional[str] = Security(api_key_header),
    ) -> str:  # Return API key on success
        if api_key is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="API token missing in X-API-Token header",
            )
        token_info = registry.validate_token(api_key)
        if token_info is None:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API token"
            )
        if required_permission not in token_info.permissions:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail=f"Insufficient permissions. Required: '{required_permission}'",
            )
        return api_key  # Return the validated key (or token info if needed later)
    return _check_permission
# --- FastAPI Application Factory ---
# Global variable to hold the registry instance *within this process*
# This is okay because FastAPI runs in a single process (or multiple workers,
# each getting its own instance, which is fine as they all talk to the DB).
_registry_instance: Optional[ServiceRegistry] = None
def setup_registry(db_path: str, hmac_key: bytes):
    """Initializes the registry instance for the current process."""
    global _registry_instance
    if _registry_instance is None:
        print(f"[{os.getpid()}] Initializing ServiceRegistry for FastAPI...")
        _registry_instance = ServiceRegistry(db_path, hmac_key)
    else:
        print(f"[{os.getpid()}] ServiceRegistry already initialized for FastAPI.")
def get_service_registry() -> ServiceRegistry:
    """Dependency to get the ServiceRegistry instance."""
    if _registry_instance is None:
        # This should not happen if setup_registry is called correctly before app starts
        raise RuntimeError("ServiceRegistry not initialized for this process.")
    return _registry_instance
def generate_status_page_html(registry: ServiceRegistry) -> str:
    """Generates the HTML for the status page."""
    all_services_map = registry.get_all_services()
    nodes = (
        {}
    )  # Dictionary to store node info: {address: {"status": "Active|Inactive", "label": "Node N"}}
    # 1. Identify unique "nodes" (addresses) and determine their status
    unique_addresses = sorted(
        list(
            set(
                inst.address
                for instances in all_services_map.values()
                for inst in instances
            )
        )
    )
    node_label_counter = 1
    for addr in unique_addresses:
        is_active = False
        for service_name, instances in all_services_map.items():
            for instance in instances:
                if instance.address == addr and instance.health == "passing":
                    is_active = True
                    break  # Found one passing service at this address, node is active
            if is_active:
                break
        nodes[addr] = {
            "status": "Active" if is_active else "Inactive",
            "label": f"Node {node_label_counter}",
        }
        node_label_counter += 1
    grid_size = len(nodes)
    # 2. Build the HTML string dynamically
    # Use inline style for grid-size as it's dynamic
    html_content = f"""
  
  
   
  MiniDiscovery Status
  
  
    """
    if grid_size == 0:
        html_content += "No active nodes detected.
"
    else:
        html_content += f""  # Set CSS variable size
        # Top-left corner
        html_content += "
"
        # Column headers
        for addr in unique_addresses:
            html_content += f""
        # Grid rows
        for row_addr in unique_addresses:
            # Row label
            html_content += f"
{nodes[row_addr]['label']}
"
            # Cells in the row
            for col_addr in unique_addresses:
                if row_addr == col_addr:
                    # Diagonal: Show node's own status
                    status = nodes[row_addr]["status"]
                    status_class = f"status-{status.lower()}"
                    html_content += (
                        f"
{status}
"
                    )
                else:
                    # Off-diagonal: Show 'Unknown' as we don't track inter-node connectivity
                    html_content += (
                        "
Unknown
"
                    )
            # End of row implicitly handled by grid layout
        html_content += "
Status based on registered service health. Page refreshes automatically.
    """
    return html_content
def create_fastapi_app() -> FastAPI:
    """Creates the FastAPI application."""
    app = FastAPI(
        title="MiniDiscovery",
        description="Minimal Service Discovery inspired by Consul",
        version="0.2.0",
    )
    # --- Status Page Endpoint (Public) ---
    @app.get("/status", response_class=HTMLResponse)
    async def get_status_page(
        registry: ServiceRegistry = Depends(get_service_registry),
    ):
        """Serves the anonymous HTML status page."""
        html_content = generate_status_page_html(registry)
        return HTMLResponse(content=html_content, status_code=200)
    # --- API Endpoints (Authenticated) ---
    @app.post("/v1/agent/service/register", status_code=status.HTTP_200_OK)
    async def register_service(
        instance: ServiceInstance,
        registry: ServiceRegistry = Depends(get_service_registry),
        token: str = Depends(get_permission_checker("write")),  # Require 'write'
    ):
        """Registers a service instance. Supports replace-if-exists based on ID."""
        # Pydantic performs validation based on the model
        if registry.register(instance):
            # Return 200 OK on success
            return {"status": "registered", "service_id": instance.id}
        else:
            # This path might be less likely with INSERT OR REPLACE
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail="Failed to register service due to a database error.",
            )
    @app.put(
        "/v1/agent/service/deregister/{service_id}", status_code=status.HTTP_200_OK
    )
    async def deregister_service(
        service_id: str,
        registry: ServiceRegistry = Depends(get_service_registry),
        token: str = Depends(get_permission_checker("write")),  # Require 'write'
    ):
        """Deregisters a service instance by its ID."""
        if registry.deregister(service_id):
            # Returns 200 OK on success
            return {"status": "deregistered", "service_id": service_id}
        else:
            # If deregister returns false, it means the service wasn't found
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail=f"Service ID '{service_id}' not found.",
            )
    @app.get("/v1/catalog/services", response_model=Dict[str, List[str]])
    async def get_catalog_services(
        registry: ServiceRegistry = Depends(get_service_registry),
        token: str = Depends(get_permission_checker("read")),  # Require 'read'
    ):
        """Lists all known service names and their tags."""
        services_data = registry.get_all_services()
        # Format matches Consul /v1/catalog/services
        catalog = {
            name: list(set(tag for inst in instances for tag in inst.tags))
            for name, instances in services_data.items()
        }
        return catalog
    @app.get("/v1/catalog/service/{service_name}", response_model=List[ServiceInstance])
    async def get_catalog_service(
        service_name: str,
        tag: Optional[str] = None,  # Allow filtering by tag
        registry: ServiceRegistry = Depends(get_service_registry),
        token: str = Depends(get_permission_checker("read")),  # Require 'read'
    ):
        """Lists instances for a specific service, optionally filtered by tag."""
        instances = registry.get_service(service_name)
        if not instances:
            # Return empty list if service name doesn't exist, matches Consul
            return []
        if tag:
            instances = [inst for inst in instances if tag in inst.tags]
        # Pydantic validation handles the response model structure
        return instances
    @app.get("/v1/health/service/{service_name}", response_model=List[ServiceInstance])
    async def get_health_service(
        service_name: str,
        tag: Optional[str] = None,
        passing: bool = False,  # Filter for only passing instances
        registry: ServiceRegistry = Depends(get_service_registry),
        token: str = Depends(get_permission_checker("read")),  # Require 'read'
    ):
        """Lists healthy instances for a specific service."""
        instances = registry.get_service(service_name, only_passing=passing)
        if not instances:
            return []
        if tag:
            instances = [inst for inst in instances if tag in inst.tags]
        return instances
    # --- Token Management Endpoints ---
    @app.post(
        "/v1/acl/token",
        response_model=ApiTokenCreateResponse,
        status_code=status.HTTP_201_CREATED,
    )
    async def create_acl_token(
        token_request: ApiTokenCreateRequest,
        registry: ServiceRegistry = Depends(get_service_registry),
        token: str = Depends(get_permission_checker("admin")),  # Require 'admin'
    ):
        """Creates a new API token."""
        plain_token, error = registry.create_token(
            token_request.name, token_request.permissions
        )
        if error:
            # Distinguish between user error (duplicate name) and server error
            if "already exists" in error:
                raise HTTPException(
                    status_code=status.HTTP_400_BAD_REQUEST, detail=error
                )
            else:
                raise HTTPException(
                    status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=error
                )
        return ApiTokenCreateResponse(
            token=plain_token,  # Only show the token once upon creation!
            name=token_request.name,
            permissions=token_request.permissions,
        )
    @app.delete("/v1/acl/token/{token_to_revoke}", status_code=status.HTTP_200_OK)
    async def revoke_acl_token(
        token_to_revoke: str,  # The actual token value to revoke
        registry: ServiceRegistry = Depends(get_service_registry),
        token: str = Depends(get_permission_checker("admin")),  # Require 'admin'
    ):
        """Revokes an existing API token."""
        # Prevent revoking the token currently being used for the request? Maybe not necessary.
        # current_token_info = registry.validate_token(token)
        # if registry.generate_token_hash(token_to_revoke, registry._hmac_key) == registry.generate_token_hash(token, registry._hmac_key):
        #    raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot revoke the token used for this request.")
        if registry.revoke_token(token_to_revoke):
            return {
                "status": "revoked",
                "token_info": "Token revoked successfully",
            }  # Don't echo back the token
        else:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="Token not found or already revoked.",
            )
    # --- DNS-over-HTTPS Endpoint (Simplified, follows RFC 8484 GET format) ---
    @app.get("/dns-query", status_code=status.HTTP_200_OK)
    async def dns_query(
        name: str,
        type: str = "A",  # Query param 'type'
        registry: ServiceRegistry = Depends(get_service_registry),
        # No auth required for basic DNS queries for simplicity, add if needed
    ):
        """Handles DNS queries over HTTPS (DoH)."""
        # Basic parsing, assuming name ends with configured suffix
        if not name.lower().endswith(DNS_QUERY_SUFFIX):
            return {
                "Status": 2
            }  # NXDOMAIN according to RFC8484 JSON format for DNS Messages
        service_name = name[: -len(DNS_QUERY_SUFFIX)].split(".")[
            -1
        ]  # Get the last part before the suffix
        instances = registry.get_service(service_name, only_passing=True)
        if not instances:
            return {"Status": 2}  # NXDOMAIN
        answers = []
        query_type_upper = type.upper()
        if query_type_upper == "A":
            type_code = 1  # DNS Type A
            for instance in instances:
                answers.append(
                    {
                        "name": name,
                        "type": type_code,
                        "TTL": DNS_DEFAULT_TTL,
                        "data": instance.address,
                    }
                )
        elif query_type_upper == "SRV":
            type_code = 33  # DNS Type SRV
            for instance in instances:
                # Format: Priority Weight Port Target
                srv_data = f"0 5 {instance.port} {instance.address}{DNS_QUERY_SUFFIX}"  # Target should be resolvable ideally
                answers.append(
                    {
                        "name": name,
                        "type": type_code,
                        "TTL": DNS_DEFAULT_TTL,
                        "data": srv_data,
                    }
                )
        # Add more record types (AAAA, TXT for metadata?) if needed
        else:
            # Unsupported type for this simple endpoint
            return {"Status": 4}  # NotImp - Not Implemented
        # Construct response according to RFC 8484 JSON format
        response = {
            "Status": 0,  # NOERROR
            "TC": False,  # Truncated
            "RD": True,  # Recursion Desired (client asked for it) - reflects typical client flag
            "RA": False,  # Recursion Available (we are authoritative-like)
            "AD": False,  # Authenticated Data
            "CD": False,  # Checking Disabled
            "Question": [
                {
                    "name": name,
                    "type": type_code if "type_code" in locals() else 1,
                }  # Reflect question
            ],
            "Answer": answers,
            # Authority and Additional sections omitted for simplicity
        }
        return response
    return app
# --- Twisted DNS Server ---
class MiniDiscoveryResolver(common.ResolverBase):
    def __init__(self, registry: ServiceRegistry):
        self.registry = registry
        super().__init__()
    def _lookup(
        self,
        name: bytes,
        cls: int,
        query_type: int,
        timeout: Optional[Tuple[int]] = None,
    ) -> Deferred:
        d = Deferred()
        name_str_debug = repr(name)
        try:
            # 1. Decode the incoming query name safely
            try:
                name_str = name.decode("utf-8").lower()
            except UnicodeDecodeError as decode_err:
                print(
                    f"DNS lookup error: Cannot decode query name {name_str_debug} as UTF-8: {decode_err}"
                )
                d.errback(Failure(decode_err))
                return d
            # 2. Check suffix
            if not name_str.endswith(DNS_QUERY_SUFFIX):
                d.callback(([], [], []))
                return d
            # --- SRV Service Type Lookup Logic ---
            is_srv_type_query = False
            service_tag_to_find = None
            if name_str.startswith("_") and (
                "_tcp." in name_str or "_udp." in name_str
            ):
                parts = name_str.split(".")
                # Example: _ssh._tcp.laiska.local
                if (
                    len(parts) >= 4
                    and parts[0].startswith("_")
                    and parts[1] in ["_tcp", "_udp"]
                ):
                    # Assume format _service._proto.domain.suffix...
                    service_tag_to_find = parts[0][1:]  # Extract 'ssh' from '_ssh'
                    # Ensure the query type is actually SRV
                    if query_type == dns.SRV:
                        is_srv_type_query = True
                    else:
                        # Requesting A/other records for SRV-style name doesn't make sense here
                        d.callback(([], [], []))
                        return d
            # --- Instance Fetching ---
            instances_to_process = []
            if is_srv_type_query:
                print(
                    f"DNS: Performing SRV type lookup for tag '{service_tag_to_find}'"
                )
                all_services = registry.get_all_services()
                for service_name, service_instances in all_services.items():
                    for instance in service_instances:
                        # Check health AND tag presence
                        if (
                            instance.health == "passing"
                            and service_tag_to_find in instance.tags
                        ):
                            instances_to_process.append(instance)
            if not instances_to_process:
                print(
                    f"DNS: No passing instances found with tag '{service_tag_to_find}'"
                )
                d.callback(([], [], []))  # No matching instances found
                return d
            else:
                service_name = name_str[: -len(DNS_QUERY_SUFFIX)].split(".")[-1]
                if not service_name:
                    d.callback(([], [], []))
                    return d
                instances_to_process = registry.get_service(
                    service_name, only_passing=True
                )
                if not instances_to_process:
                    d.callback(([], [], []))
                    return d
            # --- Build DNS Records ---
            answers = []
            authority = []
            additional = []
            # Check query_type against the records we can generate
            if (
                query_type == dns.A and not is_srv_type_query
            ):  # Only generate A for direct name query
                for instance in instances_to_process:
                    instance_addr_debug = instance.address
                    try:
                        ip_address_string = instance.address
                        answers.append(
                            dns.RRHeader(
                                name=name,
                                type=dns.A,
                                cls=cls,
                                ttl=DNS_DEFAULT_TTL,
                                # Pass the string, Twisted handles conversion
                                payload=dns.Record_A(
                                    address=ip_address_string, ttl=DNS_DEFAULT_TTL
                                ),
                            )
                        )
                    except (
                        Exception
                    ) as record_e:  # Catch potential errors during record creation itself
                        print(
                            f"Warning: Error creating A record for IP '{instance.address}' (Service ID {instance.id}): {record_e}. Skipping."
                        )
            elif (
                query_type == dns.SRV
            ):  # Generate SRV for direct name query OR SRV type query
                for instance in instances_to_process:
                    try:
                        ip_address_string = instance.address
                        # SRV target should still be bytes
                        target_name = f"{instance.id}{DNS_QUERY_SUFFIX}".encode("utf-8")
                        answers.append(
                            dns.RRHeader(
                                name=name,
                                type=dns.SRV,
                                cls=cls,
                                ttl=DNS_DEFAULT_TTL,
                                payload=dns.Record_SRV(
                                    priority=0,
                                    weight=10,
                                    port=instance.port,
                                    target=target_name,
                                    ttl=DNS_DEFAULT_TTL,
                                ),
                            )
                        )
                        additional.append(
                            dns.RRHeader(
                                name=target_name,
                                type=dns.A,
                                cls=cls,
                                ttl=DNS_DEFAULT_TTL,
                                payload=dns.Record_A(
                                    address=ip_address_string, ttl=DNS_DEFAULT_TTL
                                ),
                            )
                        )
                    except (
                        Exception
                    ) as record_e:  # Catch potential errors during record creation itself
                        print(
                            f"Warning: Error creating SRV/additional record for IP '{instance.address}' (Service ID {instance.id}): {record_e}. Skipping."
                        )
            # If we successfully built records (or correctly skipped bad IPs)
            d.callback((answers, authority, additional))
        # --- Catch ANY OTHER unexpected error during the lookup process ---
        except Exception as e:
            # Log the errors
            print(
                f"!!! Unhandled exception during DNS lookup for query name {name_str_debug} !!!"
            )
            print(f"Exception Type: {type(e).__name__}")
            print(f"Exception Args: {e.args}")
            print("--- Traceback ---")
            traceback.print_exc()
            print("--- End Traceback ---")
            d.errback(Failure(e))  # Signal SERVFAIL
        return d
# --- Health Checker ---
def check_service_health(instance: ServiceInstance) -> str:
    """Performs a simple TCP connection check."""
    try:
        with socket.create_connection(
            (instance.address, instance.port), timeout=2
        ) as sock:
            # Could add optional check for specific data/response here later
            return "passing"
    except (socket.timeout, ConnectionRefusedError, OSError):
        return "failing"
    except Exception as e:
        print(
            f"Unexpected error checking health for {instance.id} ({instance.address}:{instance.port}): {e}"
        )
        return "failing"  # Treat unexpected errors as failure
# --- Process Runner Functions ---
def run_fastapi_server(db_path: str, hmac_key: bytes, host: str, port: int):
    """Sets up and runs the FastAPI server."""
    print(f"[{os.getpid()}] Starting FastAPI process...")
    # Initialize registry for this specific process
    setup_registry(db_path, hmac_key)
    app = create_fastapi_app()
    # Setup signal handlers for graceful shutdown within Uvicorn if possible
    # Uvicorn handles SIGINT/SIGTERM by default
    uvicorn.run(app, host=host, port=port)
    print(
        f"[{os.getpid()}] FastAPI process finished."
    )  # Should not be reached normally
def run_dns_server(db_path: str, hmac_key: bytes, port: int):
    """Sets up and runs the Twisted DNS server."""
    print(f"[{os.getpid()}] Starting DNS server process...")
    registry = ServiceRegistry(
        db_path, hmac_key
    )  # DNS process needs its own registry instance
    resolver = MiniDiscoveryResolver(registry)
    factory = server.DNSServerFactory(clients=[resolver])
    protocol = dns.DNSDatagramProtocol(controller=factory)
    # Listen on UDP and TCP
    try:
        reactor.listenUDP(port, protocol)
        reactor.listenTCP(port, factory)
        print(f"[{os.getpid()}] DNS Server listening on port {port} (UDP/TCP)")
    except Exception as e:
        print(f"[{os.getpid()}] Error starting DNS listeners on port {port}: {e}")
        return  # Exit process if cannot bind
    # Graceful shutdown for Twisted reactor
    def shutdown_dns():
        print(f"[{os.getpid()}] Shutting down DNS server...")
        # reactor.stop() might be called by signal handlers below
        # Add cleanup here if needed (e.g., closing listeners explicitly)
    reactor.addSystemEventTrigger("before", "shutdown", shutdown_dns)
    # Handle signals to stop the reactor gracefully
    def signal_handler(signum, frame):
        print(f"[{os.getpid()}] Received signal {signum}, stopping DNS reactor.")
        # Important: Call reactor.stop() from the reactor thread if possible
        # reactor.callFromThread(reactor.stop) is safer if signals handled elsewhere
        # For simple cases, calling directly might be okay, but watch for deadlocks.
        if reactor.running:
            reactor.callLater(
                0, reactor.stop
            )  # Schedule stop in the next loop iteration
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)
    # Run the Twisted reactor (blocking call)
    reactor.run()
    print(f"[{os.getpid()}] DNS server process finished.")
def run_health_checker(db_path: str, hmac_key: bytes, check_interval: int):
    """Runs the health checking loop."""
    print(
        f"[{os.getpid()}] Starting Health Checker process (interval: {check_interval}s)..."
    )
    registry = ServiceRegistry(
        db_path, hmac_key
    )  # Health checker needs its own registry instance
    running = True
    def signal_handler(signum, frame):
        nonlocal running
        print(f"[{os.getpid()}] Received signal {signum}, stopping health checker...")
        running = False
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)
    while running:
        start_time = time.monotonic()
        print(f"[{os.getpid()}] Health Check Cycle Start")
        updated_count = 0
        checked_count = 0
        try:
            # Fetch all services directly from DB for the check cycle
            services_map = registry.get_all_services()
            all_instances = [
                inst for sublist in services_map.values() for inst in sublist
            ]
            checked_count = len(all_instances)
            for instance in all_instances:
                if not running:
                    break  # Exit early if shutdown signal received
                current_health = instance.health
                # Perform the actual health check
                new_health = check_service_health(instance)
                if current_health != new_health:
                    print(
                        f"[{os.getpid()}] Health change for {instance.id} ({instance.name}): {current_health} -> {new_health}"
                    )
                    # Update health status in the database
                    if registry.update_health(instance.id, new_health):
                        updated_count += 1
                    else:
                        # This might happen if the service was deregistered between get and update
                        print(
                            f"[{os.getpid()}] Warning: Failed to update health for {instance.id} (maybe deregistered?)"
                        )
        except Exception as e:
            # Catch errors during the check cycle itself (e.g., DB connection)
            print(f"[{os.getpid()}] Error during health check cycle: {e}")
        if not running:
            break  # Check again after the loop body
        # Calculate sleep time to maintain interval
        elapsed = time.monotonic() - start_time
        sleep_time = max(0, check_interval - elapsed)
        print(
            f"[{os.getpid()}] Health Check Cycle End. Checked: {checked_count}, Updated: {updated_count}. Took {elapsed:.2f}s. Sleeping for {sleep_time:.2f}s."
        )
        # Sleep interruptibly
        try:
            time.sleep(sleep_time)
        except InterruptedError:  # Catch if signal interrupted sleep
            pass
    print(f"[{os.getpid()}] Health checker process finished.")
# --- Main Execution ---
def main():
    parser = argparse.ArgumentParser(
        description="MiniDiscovery: A minimal service discovery tool.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,  # Show defaults
    )
    parser.add_argument(
        "--db-path",
        default="minidiscovery_data.db",
        help="Path to the SQLite database file. Overridden by ${DB_PATH_ENV_VAR} if set.",
    )
    parser.add_argument(
        "--api-host", default="0.0.0.0", help="Host address for the API server."
    )
    parser.add_argument(
        "--api-port", type=int, default=8500, help="Port for the API server."
    )
    parser.add_argument(
        "--dns-port",
        type=int,
        default=10053,
        help="Port for the DNS server (UDP/TCP). Use ports > 1024 for non-root.",
    )
    parser.add_argument(
        "--health-check-interval",
        type=int,
        default=15,
        help="Health check interval in seconds.",
    )
    parser.add_argument(
        "--hmac-key-file",
        default=HMAC_KEY_FILE,
        help="Path to the file storing the HMAC secret key.",
    )
    args = parser.parse_args()
    # Read db_file_path from ENV var
    db_path_from_env = os.environ.get(DB_PATH_ENV_VAR)
    if db_path_from_env:
        args.db_path = db_path_from_env  # Override args
        print(
            f"Using database path from environment variable ${DB_PATH_ENV_VAR}: {args.db_path}"
        )
    elif not args.db_path:  # If env var not set AND command line arg not provided
        args.db_path = "minidiscovery_data.db"  # Apply the default now
        print(f"Database path not specified, using default: {args.db_path}")
    else:
        # Using the path provided via --db-path argument
        print(f"Using database path from --db-path argument: {args.db_path}")
    # Ensure DB directory exists
    db_dir = os.path.dirname(os.path.abspath(args.db_path))
    if db_dir:
        os.makedirs(db_dir, exist_ok=True)
    # Load or generate HMAC key *before* starting processes
    try:
        hmac_key = load_or_generate_hmac_key(args.hmac_key_file)
    except SystemExit:
        return  # Exit if key generation/loading fails
    # Check for initial admin token env var *before* initializing registry in main process
    # The check inside ServiceRegistry._init_db is the authoritative one,
    # but this provides an earlier warning.
    conn_check = sqlite3.connect(args.db_path)
    cursor_check = conn_check.cursor()
    try:
        cursor_check.execute("SELECT COUNT(*) FROM tokens")
        token_count_check = cursor_check.fetchone()[0]
        if token_count_check == 0 and not os.environ.get(ADMIN_TOKEN_ENV_VAR):
            print(
                f"WARNING: Database appears empty. Ensure '{ADMIN_TOKEN_ENV_VAR}' is set for the first run."
            )
    except sqlite3.DatabaseError:
        # Table might not exist yet, ServiceRegistry init will handle it
        pass
    finally:
        conn_check.close()
    # --- Process Management ---
    processes: List[multiprocessing.Process] = []
    stop_event = multiprocessing.Event()  # Used to signal shutdown
    def graceful_shutdown(signum, frame):
        print(f"Main process received signal {signum}. Initiating shutdown...")
        stop_event.set()  # Signal processes to stop (though they also have signal handlers)
        # Terminate processes forcefully after a grace period if they don't exit
        time.sleep(2)  # Give processes a moment to react to their own signal handlers
        for p in processes:
            if p.is_alive():
                print(f"Terminating process {p.pid} ({p.name})...")
                p.terminate()  # Send SIGTERM
    signal.signal(signal.SIGINT, graceful_shutdown)
    signal.signal(signal.SIGTERM, graceful_shutdown)
    try:
        # Start API Server Process
        api_process = multiprocessing.Process(
            target=run_fastapi_server,
            args=(args.db_path, hmac_key, args.api_host, args.api_port),
            name="FastAPI Process",
        )
        processes.append(api_process)
        api_process.start()
        # Start DNS Server Process
        dns_process = multiprocessing.Process(
            target=run_dns_server,
            args=(args.db_path, hmac_key, args.dns_port),
            name="DNS Process",
        )
        processes.append(dns_process)
        dns_process.start()
        # Start Health Checker Process
        health_process = multiprocessing.Process(
            target=run_health_checker,
            args=(args.db_path, hmac_key, args.health_check_interval),
            name="Health Check Process",
        )
        processes.append(health_process)
        health_process.start()
        print("-" * 30)
        print(f"MiniDiscovery Started (PID: {os.getpid()})")
        print(f"  API Server: http://{args.api_host}:{args.api_port}")
        print(f"  DNS Server: Port {args.dns_port} (UDP/TCP)")
        print(f"  Database:   {args.db_path}")
        print(f"  HMAC Key:   {args.hmac_key_file}")
        print(f"  Health Int: {args.health_check_interval}s")
        print("-" * 30)
        print("Press Ctrl+C to stop.")
        # Wait for processes to complete or shutdown signal
        while not stop_event.is_set():
            # Check if any process exited unexpectedly
            for p in processes:
                if not p.is_alive() and p.exitcode != 0:
                    print(
                        f"ERROR: Process {p.name} (PID: {p.pid}) exited unexpectedly with code {p.exitcode}."
                    )
                    stop_event.set()  # Trigger shutdown if a child crashes
                    break
            time.sleep(1)  # Wait efficiently
    finally:
        print("Waiting for processes to join...")
        for p in processes:
            p.join(timeout=5)  # Wait for clean exit
            if p.is_alive():
                print(f"Process {p.pid} ({p.name}) did not exit cleanly, killing.")
                p.kill()  # Force kill if still alive after timeout
        print("MiniDiscovery shut down complete.")
if __name__ == "__main__":
    main()