import argparse
import asyncio
import base64
import hashlib
import hmac
import json
import multiprocessing
import os
import secrets
import signal
import socket
import sqlite3
import time
import uvicorn
import traceback
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Generator, Any
from fastapi import Depends, FastAPI, Header, HTTPException, Security, status
from fastapi.security import APIKeyHeader
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Field
from twisted.internet import reactor
from twisted.names import dns, server, common
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
# --- Constants ---
SQLITE_TIMEOUT_SECONDS = 10 # Increased timeout for potential concurrent writes
HMAC_KEY_FILE = "minidiscovery.key"
ADMIN_TOKEN_ENV_VAR = "MINIDISCOVERY_ADMIN_TOKEN"
DEFAULT_TOKEN_PERMISSIONS = ["read", "write"]
ADMIN_PERMISSIONS = ["read", "write", "admin"]
DNS_DEFAULT_TTL = 60
DNS_QUERY_SUFFIX = ".laiska.local" # Define a suffix for DNS lookups
DB_PATH_ENV_VAR = "MINIDISCOVERY_DB_PATH"
# --- Database Schema ---
SQL_CREATE_SERVICES = """
CREATE TABLE IF NOT EXISTS services (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
address TEXT NOT NULL,
port INTEGER NOT NULL,
health TEXT DEFAULT 'passing',
tags_json TEXT DEFAULT '[]', -- Store tags as JSON text
metadata_json TEXT DEFAULT '{}', -- Store metadata as JSON text
last_update REAL NOT NULL -- Timestamp of last update/registration
);
"""
SQL_CREATE_TOKENS = """
CREATE TABLE IF NOT EXISTS tokens (
token_hash TEXT PRIMARY KEY, -- Store the HMAC hash of the token
name TEXT NOT NULL UNIQUE, -- Token names should be unique
created_at REAL NOT NULL,
permissions_json TEXT NOT NULL -- Store permissions as JSON text
);
"""
SQL_CREATE_META = """
CREATE TABLE IF NOT EXISTS meta (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
""" # For storing things like schema version or first-run status
# --- Pydantic Models ---
class ServiceInstance(BaseModel):
id: str = Field(..., description="Unique identifier for this service instance")
name: str = Field(
..., description="Logical name of the service (e.g., 'web', 'db')"
)
address: str = Field(
..., description="IP address or hostname where the service listens"
)
port: int = Field(..., gt=0, lt=65536, description="Port number for the service")
tags: List[str] = Field(
default_factory=list, description="Optional list of tags for filtering"
)
metadata: Dict[str, str] = Field(
default_factory=dict, description="Optional key-value metadata"
)
health: str = Field(
default="passing", description="Health status ('passing', 'failing', 'unknown')"
)
last_updated: float = Field(default_factory=time.time)
class TokenInfo(BaseModel):
"""Information about a token (excluding the hash)"""
name: str
created_at: float
permissions: List[str]
class ApiTokenCreateRequest(BaseModel):
name: str = Field(..., description="A descriptive name for the token")
permissions: List[str] = Field(
default=DEFAULT_TOKEN_PERMISSIONS, description="Permissions for the token"
)
class ApiTokenCreateResponse(BaseModel):
token: str = Field(..., description="The generated API token (show only once!)")
name: str
permissions: List[str]
# --- HMAC Key Management ---
def load_or_generate_hmac_key(key_file: str = HMAC_KEY_FILE) -> bytes:
"""Loads HMAC key from file or generates a new one."""
if os.path.exists(key_file):
try:
with open(key_file, "rb") as f:
key = f.read()
if len(key) < 32: # Basic sanity check
raise ValueError("HMAC key file seems corrupted or too short.")
print(f"Loaded HMAC key from {key_file}")
return key
except Exception as e:
print(
f"Error loading HMAC key from {key_file}: {e}. Check file permissions and content."
)
raise SystemExit(1)
else:
print(f"HMAC key file ({key_file}) not found. Generating a new one.")
key = secrets.token_bytes(32) # 256 bits
try:
# Attempt to write with restricted permissions
fd = os.open(key_file, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o600)
with os.fdopen(fd, "wb") as f:
f.write(key)
print(f"Generated and saved new HMAC key to {key_file}. PROTECT THIS FILE!")
return key
except FileExistsError:
# Race condition: another process created it between check and open
return load_or_generate_hmac_key(key_file)
except OSError as e:
print(f"Error writing HMAC key file {key_file}: {e}")
print("Please ensure the directory is writable by the process.")
raise SystemExit(1)
except Exception as e:
print(f"Unexpected error generating HMAC key: {e}")
raise SystemExit(1)
def generate_token_hash(token: str, hmac_key: bytes) -> str:
"""Generates HMAC-SHA256 hash of the token."""
return hmac.new(hmac_key, token.encode("utf-8"), hashlib.sha256).hexdigest()
# --- Service Registry (SQLite Backend) ---
class ServiceRegistry:
def __init__(self, db_path: str, hmac_key: bytes):
self.db_path = db_path
self._hmac_key = hmac_key
self._init_db()
@contextmanager
def _get_db_conn(self) -> Generator[sqlite3.Connection, None, None]:
"""Provides a context-managed database connection."""
conn = None
try:
conn = sqlite3.connect(self.db_path, timeout=SQLITE_TIMEOUT_SECONDS)
conn.row_factory = sqlite3.Row # Access columns by name
conn.execute("PRAGMA foreign_keys = ON;") # Enforce foreign keys if needed
conn.execute(
"PRAGMA journal_mode = WAL;"
) # Write-Ahead Logging for better concurrency
yield conn
conn.commit()
except sqlite3.Error as e:
print(f"SQLite error: {e} - {self.db_path}")
if conn:
conn.rollback() # Rollback on error
raise # Re-raise the exception
finally:
if conn:
conn.close()
def _init_db(self) -> None:
"""Initializes the database schema if it doesn't exist."""
try:
with self._get_db_conn() as conn:
cursor = conn.cursor()
cursor.execute(SQL_CREATE_SERVICES)
cursor.execute(SQL_CREATE_TOKENS)
cursor.execute(SQL_CREATE_META)
# Check if initial admin token needs to be created
cursor.execute("SELECT COUNT(*) FROM tokens")
token_count = cursor.fetchone()[0]
cursor.execute("SELECT value FROM meta WHERE key = 'initialized'")
initialized = cursor.fetchone()
if token_count == 0 and not initialized:
print("No tokens found in the database.")
admin_token_plain = os.environ.get(ADMIN_TOKEN_ENV_VAR)
if not admin_token_plain:
print(
f"ERROR: Database is empty and initial admin token not provided."
)
print(
f"Please set the '{ADMIN_TOKEN_ENV_VAR}' environment variable with a secure token."
)
raise SystemExit(1) # Critical configuration missing
print(
f"Creating initial admin token from '{ADMIN_TOKEN_ENV_VAR}'..."
)
token_hash = generate_token_hash(admin_token_plain, self._hmac_key)
cursor.execute(
"INSERT INTO tokens (token_hash, name, created_at, permissions_json) VALUES (?, ?, ?, ?)",
(
token_hash,
"admin",
time.time(),
json.dumps(ADMIN_PERMISSIONS),
),
)
# Mark DB as initialized to prevent asking for env var again
cursor.execute(
"INSERT OR REPLACE INTO meta (key, value) VALUES ('initialized', 'true')"
)
print(
"Initial admin token created successfully. You can unset the environment variable now."
)
elif token_count > 0 and not initialized:
# Tokens exist, but not marked initialized (e.g., older version) - mark it now
cursor.execute(
"INSERT OR REPLACE INTO meta (key, value) VALUES ('initialized', 'true')"
)
# Future schema migrations could go here based on a version in 'meta' table
except sqlite3.OperationalError as e:
if "database is locked" in str(e):
print(
f"Warning: Database {self.db_path} is locked during initialization. Retrying shortly..."
)
time.sleep(
0.5
) # Short delay before potential retry by caller/process start
self._init_db() # Recursive call - be cautious, maybe add max retries
else:
print(f"Fatal error initializing database {self.db_path}: {e}")
raise SystemExit(1)
except Exception as e:
print(f"Fatal error during database initialization: {e}")
raise SystemExit(1)
def register(self, instance: ServiceInstance) -> bool:
"""Registers or updates a service instance in the database."""
with self._get_db_conn() as conn:
cursor = conn.cursor()
try:
cursor.execute(
"""
INSERT OR REPLACE INTO services
(id, name, address, port, health, tags_json, metadata_json, last_update)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
instance.id,
instance.name,
instance.address,
instance.port,
instance.health,
json.dumps(instance.tags),
json.dumps(instance.metadata),
time.time(),
),
)
return True
except sqlite3.IntegrityError as e:
print(f"Error registering service {instance.id}: {e}")
return False # Should not happen with INSERT OR REPLACE unless DB issue
def deregister(self, service_id: str) -> bool:
"""Removes a service instance from the database."""
with self._get_db_conn() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM services WHERE id = ?", (service_id,))
return cursor.rowcount > 0 # Returns true if a row was deleted
def get_service(
self, name: str, only_passing: bool = False
) -> List[ServiceInstance]:
"""Retrieves all instances for a given service name."""
instances = []
sql = "SELECT * FROM services WHERE name = ?"
params: List[Any] = [name]
if only_passing:
sql += " AND health = 'passing'"
with self._get_db_conn() as conn:
cursor = conn.cursor()
cursor.execute(sql, tuple(params))
rows = cursor.fetchall()
for row in rows:
try:
instances.append(
ServiceInstance(
id=row["id"],
name=row["name"],
address=row["address"],
port=row["port"],
health=row["health"],
tags=json.loads(row["tags_json"]),
metadata=json.loads(row["metadata_json"]),
)
)
except (json.JSONDecodeError, TypeError, KeyError) as e:
print(
f"Warning: Could not parse service data for ID {row.get('id', 'N/A')}: {e}"
)
# Optionally skip or mark as unhealthy
return instances
def get_all_services(self) -> Dict[str, List[ServiceInstance]]:
"""Retrieves all registered services, grouped by name."""
services: Dict[str, List[ServiceInstance]] = {}
with self._get_db_conn() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM services ORDER BY name")
rows = cursor.fetchall()
for row in rows:
try:
instance = ServiceInstance(
id=row["id"],
name=row["name"],
address=row["address"],
port=row["port"],
health=row["health"],
tags=json.loads(row["tags_json"]),
metadata=json.loads(row["metadata_json"]),
)
if instance.name not in services:
services[instance.name] = []
services[instance.name].append(instance)
except (json.JSONDecodeError, TypeError, KeyError) as e:
print(
f"Warning: Could not parse service data for ID {row.get('id', 'N/A')}: {e}"
)
return services
def update_health(self, service_id: str, health: str) -> bool:
"""Updates the health status of a specific service instance."""
with self._get_db_conn() as conn:
cursor = conn.cursor()
cursor.execute(
"UPDATE services SET health = ?, last_update = ? WHERE id = ?",
(health, time.time(), service_id),
)
return cursor.rowcount > 0
def create_token(
self, name: str, permissions: List[str]
) -> Tuple[Optional[str], Optional[str]]:
"""Creates a new API token, storing its hash. Returns (token, None) on success, (None, error_message) on failure."""
token_plain = secrets.token_urlsafe(32)
token_hash = generate_token_hash(token_plain, self._hmac_key)
with self._get_db_conn() as conn:
cursor = conn.cursor()
try:
cursor.execute(
"INSERT INTO tokens (token_hash, name, created_at, permissions_json) VALUES (?, ?, ?, ?)",
(token_hash, name, time.time(), json.dumps(permissions)),
)
return token_plain, None # Return the plain token only on success
except sqlite3.IntegrityError:
# This likely means the name is not unique
return None, f"Token name '{name}' already exists."
except Exception as e:
print(f"Error creating token '{name}': {e}")
return None, "Database error during token creation."
def revoke_token(self, token_plain: str) -> bool:
"""Revokes an API token by deleting its hash from the database."""
token_hash = generate_token_hash(token_plain, self._hmac_key)
with self._get_db_conn() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM tokens WHERE token_hash = ?", (token_hash,))
return cursor.rowcount > 0
def validate_token(self, token_plain: str) -> Optional[TokenInfo]:
"""
Validates a plain text token against stored hashes and returns its info if valid.
Returns None if the token is invalid.
"""
token_hash = generate_token_hash(token_plain, self._hmac_key)
with self._get_db_conn() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT name, created_at, permissions_json FROM tokens WHERE token_hash = ?",
(token_hash,),
)
row = cursor.fetchone()
if row:
try:
permissions = json.loads(row["permissions_json"])
if not isinstance(permissions, list):
raise TypeError("Permissions format error")
return TokenInfo(
name=row["name"],
created_at=row["created_at"],
permissions=permissions,
)
except (json.JSONDecodeError, TypeError) as e:
print(
f"Error decoding permissions for token hash {token_hash}: {e}"
)
return None # Treat malformed permissions as invalid
else:
return None # Token hash not found
# --- API Key Security ---
api_key_header = APIKeyHeader(name="X-API-Token", auto_error=False)
def get_permission_checker(required_permission: str):
"""
Dependency factory to create a dependency that checks for a specific permission.
"""
async def _check_permission(
registry: ServiceRegistry = Depends(
lambda: get_service_registry()
), # Get registry instance
api_key: Optional[str] = Security(api_key_header),
) -> str: # Return API key on success
if api_key is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API token missing in X-API-Token header",
)
token_info = registry.validate_token(api_key)
if token_info is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API token"
)
if required_permission not in token_info.permissions:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Insufficient permissions. Required: '{required_permission}'",
)
return api_key # Return the validated key (or token info if needed later)
return _check_permission
# --- FastAPI Application Factory ---
# Global variable to hold the registry instance *within this process*
# This is okay because FastAPI runs in a single process (or multiple workers,
# each getting its own instance, which is fine as they all talk to the DB).
_registry_instance: Optional[ServiceRegistry] = None
def setup_registry(db_path: str, hmac_key: bytes):
"""Initializes the registry instance for the current process."""
global _registry_instance
if _registry_instance is None:
print(f"[{os.getpid()}] Initializing ServiceRegistry for FastAPI...")
_registry_instance = ServiceRegistry(db_path, hmac_key)
else:
print(f"[{os.getpid()}] ServiceRegistry already initialized for FastAPI.")
def get_service_registry() -> ServiceRegistry:
"""Dependency to get the ServiceRegistry instance."""
if _registry_instance is None:
# This should not happen if setup_registry is called correctly before app starts
raise RuntimeError("ServiceRegistry not initialized for this process.")
return _registry_instance
def generate_status_page_html(registry: ServiceRegistry) -> str:
"""Generates the HTML for the status page."""
all_services_map = registry.get_all_services()
nodes = (
{}
) # Dictionary to store node info: {address: {"status": "Active|Inactive", "label": "Node N"}}
# 1. Identify unique "nodes" (addresses) and determine their status
unique_addresses = sorted(
list(
set(
inst.address
for instances in all_services_map.values()
for inst in instances
)
)
)
node_label_counter = 1
for addr in unique_addresses:
is_active = False
for service_name, instances in all_services_map.items():
for instance in instances:
if instance.address == addr and instance.health == "passing":
is_active = True
break # Found one passing service at this address, node is active
if is_active:
break
nodes[addr] = {
"status": "Active" if is_active else "Inactive",
"label": f"Node {node_label_counter}",
}
node_label_counter += 1
grid_size = len(nodes)
# 2. Build the HTML string dynamically
# Use inline style for grid-size as it's dynamic
html_content = f"""
MiniDiscovery Status
"""
if grid_size == 0:
html_content += "No active nodes detected.
"
else:
html_content += f"" # Set CSS variable size
# Top-left corner
html_content += "
"
# Column headers
for addr in unique_addresses:
html_content += f""
# Grid rows
for row_addr in unique_addresses:
# Row label
html_content += f"
{nodes[row_addr]['label']}
"
# Cells in the row
for col_addr in unique_addresses:
if row_addr == col_addr:
# Diagonal: Show node's own status
status = nodes[row_addr]["status"]
status_class = f"status-{status.lower()}"
html_content += (
f"
{status}
"
)
else:
# Off-diagonal: Show 'Unknown' as we don't track inter-node connectivity
html_content += (
"
Unknown
"
)
# End of row implicitly handled by grid layout
html_content += "
" # End grid-container
# Add footer/closing tags
html_content += """
Status based on registered service health. Page refreshes automatically.
"""
return html_content
def create_fastapi_app() -> FastAPI:
"""Creates the FastAPI application."""
app = FastAPI(
title="MiniDiscovery",
description="Minimal Service Discovery inspired by Consul",
version="0.2.0",
)
# --- Status Page Endpoint (Public) ---
@app.get("/status", response_class=HTMLResponse)
async def get_status_page(
registry: ServiceRegistry = Depends(get_service_registry),
):
"""Serves the anonymous HTML status page."""
html_content = generate_status_page_html(registry)
return HTMLResponse(content=html_content, status_code=200)
# --- API Endpoints (Authenticated) ---
@app.post("/v1/agent/service/register", status_code=status.HTTP_200_OK)
async def register_service(
instance: ServiceInstance,
registry: ServiceRegistry = Depends(get_service_registry),
token: str = Depends(get_permission_checker("write")), # Require 'write'
):
"""Registers a service instance. Supports replace-if-exists based on ID."""
# Pydantic performs validation based on the model
if registry.register(instance):
# Return 200 OK on success
return {"status": "registered", "service_id": instance.id}
else:
# This path might be less likely with INSERT OR REPLACE
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to register service due to a database error.",
)
@app.put(
"/v1/agent/service/deregister/{service_id}", status_code=status.HTTP_200_OK
)
async def deregister_service(
service_id: str,
registry: ServiceRegistry = Depends(get_service_registry),
token: str = Depends(get_permission_checker("write")), # Require 'write'
):
"""Deregisters a service instance by its ID."""
if registry.deregister(service_id):
# Returns 200 OK on success
return {"status": "deregistered", "service_id": service_id}
else:
# If deregister returns false, it means the service wasn't found
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Service ID '{service_id}' not found.",
)
@app.get("/v1/catalog/services", response_model=Dict[str, List[str]])
async def get_catalog_services(
registry: ServiceRegistry = Depends(get_service_registry),
token: str = Depends(get_permission_checker("read")), # Require 'read'
):
"""Lists all known service names and their tags."""
services_data = registry.get_all_services()
# Format matches Consul /v1/catalog/services
catalog = {
name: list(set(tag for inst in instances for tag in inst.tags))
for name, instances in services_data.items()
}
return catalog
@app.get("/v1/catalog/service/{service_name}", response_model=List[ServiceInstance])
async def get_catalog_service(
service_name: str,
tag: Optional[str] = None, # Allow filtering by tag
registry: ServiceRegistry = Depends(get_service_registry),
token: str = Depends(get_permission_checker("read")), # Require 'read'
):
"""Lists instances for a specific service, optionally filtered by tag."""
instances = registry.get_service(service_name)
if not instances:
# Return empty list if service name doesn't exist, matches Consul
return []
if tag:
instances = [inst for inst in instances if tag in inst.tags]
# Pydantic validation handles the response model structure
return instances
@app.get("/v1/health/service/{service_name}", response_model=List[ServiceInstance])
async def get_health_service(
service_name: str,
tag: Optional[str] = None,
passing: bool = False, # Filter for only passing instances
registry: ServiceRegistry = Depends(get_service_registry),
token: str = Depends(get_permission_checker("read")), # Require 'read'
):
"""Lists healthy instances for a specific service."""
instances = registry.get_service(service_name, only_passing=passing)
if not instances:
return []
if tag:
instances = [inst for inst in instances if tag in inst.tags]
return instances
# --- Token Management Endpoints ---
@app.post(
"/v1/acl/token",
response_model=ApiTokenCreateResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_acl_token(
token_request: ApiTokenCreateRequest,
registry: ServiceRegistry = Depends(get_service_registry),
token: str = Depends(get_permission_checker("admin")), # Require 'admin'
):
"""Creates a new API token."""
plain_token, error = registry.create_token(
token_request.name, token_request.permissions
)
if error:
# Distinguish between user error (duplicate name) and server error
if "already exists" in error:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=error
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=error
)
return ApiTokenCreateResponse(
token=plain_token, # Only show the token once upon creation!
name=token_request.name,
permissions=token_request.permissions,
)
@app.delete("/v1/acl/token/{token_to_revoke}", status_code=status.HTTP_200_OK)
async def revoke_acl_token(
token_to_revoke: str, # The actual token value to revoke
registry: ServiceRegistry = Depends(get_service_registry),
token: str = Depends(get_permission_checker("admin")), # Require 'admin'
):
"""Revokes an existing API token."""
# Prevent revoking the token currently being used for the request? Maybe not necessary.
# current_token_info = registry.validate_token(token)
# if registry.generate_token_hash(token_to_revoke, registry._hmac_key) == registry.generate_token_hash(token, registry._hmac_key):
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot revoke the token used for this request.")
if registry.revoke_token(token_to_revoke):
return {
"status": "revoked",
"token_info": "Token revoked successfully",
} # Don't echo back the token
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Token not found or already revoked.",
)
# --- DNS-over-HTTPS Endpoint (Simplified, follows RFC 8484 GET format) ---
@app.get("/dns-query", status_code=status.HTTP_200_OK)
async def dns_query(
name: str,
type: str = "A", # Query param 'type'
registry: ServiceRegistry = Depends(get_service_registry),
# No auth required for basic DNS queries for simplicity, add if needed
):
"""Handles DNS queries over HTTPS (DoH)."""
# Basic parsing, assuming name ends with configured suffix
if not name.lower().endswith(DNS_QUERY_SUFFIX):
return {
"Status": 2
} # NXDOMAIN according to RFC8484 JSON format for DNS Messages
service_name = name[: -len(DNS_QUERY_SUFFIX)].split(".")[
-1
] # Get the last part before the suffix
instances = registry.get_service(service_name, only_passing=True)
if not instances:
return {"Status": 2} # NXDOMAIN
answers = []
query_type_upper = type.upper()
if query_type_upper == "A":
type_code = 1 # DNS Type A
for instance in instances:
answers.append(
{
"name": name,
"type": type_code,
"TTL": DNS_DEFAULT_TTL,
"data": instance.address,
}
)
elif query_type_upper == "SRV":
type_code = 33 # DNS Type SRV
for instance in instances:
# Format: Priority Weight Port Target
srv_data = f"0 5 {instance.port} {instance.address}{DNS_QUERY_SUFFIX}" # Target should be resolvable ideally
answers.append(
{
"name": name,
"type": type_code,
"TTL": DNS_DEFAULT_TTL,
"data": srv_data,
}
)
# Add more record types (AAAA, TXT for metadata?) if needed
else:
# Unsupported type for this simple endpoint
return {"Status": 4} # NotImp - Not Implemented
# Construct response according to RFC 8484 JSON format
response = {
"Status": 0, # NOERROR
"TC": False, # Truncated
"RD": True, # Recursion Desired (client asked for it) - reflects typical client flag
"RA": False, # Recursion Available (we are authoritative-like)
"AD": False, # Authenticated Data
"CD": False, # Checking Disabled
"Question": [
{
"name": name,
"type": type_code if "type_code" in locals() else 1,
} # Reflect question
],
"Answer": answers,
# Authority and Additional sections omitted for simplicity
}
return response
return app
# --- Twisted DNS Server ---
class MiniDiscoveryResolver(common.ResolverBase):
def __init__(self, registry: ServiceRegistry):
self.registry = registry
super().__init__()
def _lookup(
self,
name: bytes,
cls: int,
query_type: int,
timeout: Optional[Tuple[int]] = None,
) -> Deferred:
"""
Main lookup entry point. Decodes the name, checks the suffix,
and dispatches to specific handlers based on query_type.
"""
d = Deferred()
name_str_debug = repr(name) # For logging errors
try:
# 1. Decode the incoming query name safely
try:
name_str = name.decode("utf-8").lower()
except UnicodeDecodeError:
# Log? Respond with format error? For now, treat as non-existent.
print(f"DNS: Cannot decode query name {name_str_debug} as UTF-8.")
# This should result in NXDOMAIN if we don't callback anything
# d.callback(([], [], [])) # Alternatively, explicitly return empty
return d # Let Twisted handle NXDOMAIN? Best to callback empty.
# Returning empty is safer for our specific resolver.
d.callback(([], [], []))
return d
# 2. Check for the expected suffix
if not name_str.endswith(DNS_QUERY_SUFFIX):
# Not a query for our domain, return empty.
d.callback(([], [], []))
return d
# 3. Extract the base name (part before the suffix)
base_name = name_str[: -len(DNS_QUERY_SUFFIX)]
if not base_name: # Query was just the suffix itself
d.callback(([], [], []))
return d
# 4. Dispatch based on query type
handler = None
if query_type == dns.A:
handler = self._handle_a_query
elif query_type == dns.SRV:
handler = self._handle_srv_query
elif query_type == dns.TXT:
handler = self._handle_txt_query
# Add elif for AAAA if needed in the future
# elif query_type == dns.AAAA:
# handler = self._handle_aaaa_query # Implement this if needed
else:
# Unsupported query type for our resolver
d.callback(([], [], []))
return d
# 5. Execute the handler and set callback/errback
# The handlers currently don't return Deferreds, so we call directly
answers, authority, additional = handler(name, base_name, cls)
d.callback((answers, authority, additional))
except Exception as e:
# Catch-all for unexpected errors during dispatch or handler execution
print(
f"!!! Unhandled exception during DNS lookup for {name_str_debug} "
f"(Type: {query_type}) !!!"
)
print(f"Exception Type: {type(e).__name__}")
print(f"Exception Args: {e.args}")
print("--- Traceback ---")
traceback.print_exc()
print("--- End Traceback ---")
# Signal DNS server failure (SERVFAIL)
d.errback(Failure(e))
return d
# --- Helper Methods for Specific Record Types ---
def _parse_srv_query(self, base_name: str) -> Tuple[Optional[str], Optional[str]]:
"""
Parses SRV-style queries: _tag._proto.service or _tag._proto
Returns (tag, service_name) or (None, None) if not SRV-style.
"""
parts = base_name.split(".")
if (
len(parts) >= 2
and parts[0].startswith("_")
and parts[1] in ["_tcp", "_udp"]
):
tag = parts[0][1:] # Remove leading '_'
# Service name is the part *after* _tag._proto, if it exists
service_name = parts[2] if len(parts) > 2 else None
# We currently ignore the rest of the parts (like datacenter in consul)
return tag, service_name
return None, None # Not an SRV-style query name
def _get_instances_for_query(
self, base_name: str, is_srv_query: bool = False
) -> List[ServiceInstance]:
"""Fetches relevant, passing service instances based on the query name."""
instances = []
tag_filter = None
service_name_filter = None
if is_srv_query:
tag_filter, service_name_filter = self._parse_srv_query(base_name)
if tag_filter is None: # Not a valid _tag._proto... query
return [] # Return empty, SRV query handler expects specific format
else: # A or TXT query: service name is the last part
service_name_filter = base_name.split(".")[-1]
if service_name_filter:
# Query targets a specific service (with potential tag filter for SRV)
service_instances = self.registry.get_service(
service_name_filter, only_passing=True
)
if tag_filter: # SRV query with tag and service
instances = [
inst for inst in service_instances if tag_filter in inst.tags
]
else: # A/TXT query, or SRV query without tag (using service name)
instances = service_instances # Already filtered for passing
elif tag_filter: # SRV query for a tag across all services
all_services = self.registry.get_all_services()
for name, service_instances in all_services.items():
for inst in service_instances:
if inst.health == "passing" and tag_filter in inst.tags:
instances.append(inst)
else:
# This case shouldn't be reached if initial checks are correct
# (e.g., A/TXT query needs a service name part)
print(f"Warning: Could not determine filter for query '{base_name}'")
print(
f"DNS Lookup: base_name='{base_name}', is_srv={is_srv_query}, "
f"tag='{tag_filter}', service='{service_name_filter}'. Found {len(instances)} instances."
)
return instances
def _handle_a_query(
self, name: bytes, base_name: str, cls: int
) -> Tuple[List, List, List]:
"""Handles A record lookups."""
answers = []
instances = self._get_instances_for_query(base_name, is_srv_query=False)
for instance in instances:
try:
# Twisted's Record_A expects the IP address string
payload = dns.Record_A(address=instance.address, ttl=DNS_DEFAULT_TTL)
rr = dns.RRHeader(
name=name, # Respond with the original query name
type=dns.A,
cls=cls,
ttl=DNS_DEFAULT_TTL,
payload=payload,
)
answers.append(rr)
except Exception as e:
print(
f"Warning: Error creating A record for instance {instance.id} "
f"(IP: {instance.address}): {e}. Skipping."
)
return answers, [], [] # No authority or additional records for basic A
def _handle_srv_query(
self, name: bytes, base_name: str, cls: int
) -> Tuple[List, List, List]:
"""Handles SRV record lookups (service or tag based)."""
answers = []
additional = []
instances = self._get_instances_for_query(base_name, is_srv_query=True)
# If _get_instances_for_query returned empty because parsing failed,
# we might want to try interpreting the name differently, e.g.,
# as a direct SRV lookup for a service name like `service.domain.suffix`.
# For now, we strictly follow the _tag._proto logic defined above.
# If you want `srvlookup service.domain.suffix`, the logic in
# _get_instances_for_query needs adjustment or another branch here.
for instance in instances:
try:
# SRV target points to a name that resolves to the instance's A record.
# Conventionally: ...
# Let's use: .node. for simplicity,
# or maybe ....
# Using just instance ID + suffix is simple and unique.
# Ensure the target ends with the suffix too!
target_name_str = f"{instance.id}{DNS_QUERY_SUFFIX}"
target_name_bytes = target_name_str.encode("utf-8")
srv_payload = dns.Record_SRV(
priority=0, # Lower is more preferred
weight=10, # Relative weight for same priority
port=instance.port,
target=target_name_bytes, # Must be bytes
ttl=DNS_DEFAULT_TTL, # TTL for the SRV record itself
)
srv_rr = dns.RRHeader(
name=name, # Respond with the original query name
type=dns.SRV,
cls=cls,
ttl=DNS_DEFAULT_TTL,
payload=srv_payload,
)
answers.append(srv_rr)
# Add corresponding A record for the target in the additional section
a_payload = dns.Record_A(address=instance.address, ttl=DNS_DEFAULT_TTL)
a_rr = dns.RRHeader(
name=target_name_bytes, # Name matches SRV target
type=dns.A,
cls=cls,
ttl=DNS_DEFAULT_TTL, # TTL for the additional A record
payload=a_payload,
)
additional.append(a_rr)
except Exception as e:
print(
f"Warning: Error creating SRV/A record for instance {instance.id} "
f"(Addr: {instance.address}:{instance.port}): {e}. Skipping."
)
return answers, [], additional
def _handle_txt_query(
self, name: bytes, base_name: str, cls: int
) -> Tuple[List, List, List]:
"""Handles TXT record lookups, returning service metadata, flattening single-segment payloads."""
answers = []
instances = self._get_instances_for_query(base_name, is_srv_query=False)
MAX_TXT_STRING_LEN = 255
for instance in instances:
# --- Initialize list for the final payload segments ---
final_txt_payload_segments = []
instance_id_str = str(instance.id)
try:
print(f"DNS TXT: Processing instance {instance_id_str}")
# --- 1. Gather all logical strings first ---
logical_strings_to_encode = []
if isinstance(instance.tags, list):
logical_strings_to_encode.extend(
[str(tag) for tag in instance.tags]
)
else:
print(
f"WARNING: Instance {instance_id_str} tags are not a list: {type(instance.tags)}"
)
if isinstance(instance.metadata, dict):
logical_strings_to_encode.extend(
[f"{str(k)}={str(v)}" for k, v in instance.metadata.items()]
)
else:
print(
f"WARNING: Instance {instance_id_str} metadata is not a dict: {type(instance.metadata)}"
)
logical_strings_to_encode.append(f"instance_id={instance_id_str}")
# --- 2. Encode each logical string and split if > 255 bytes ---
for logical_string in logical_strings_to_encode:
try:
encoded_bytes = logical_string.encode("utf-8")
# Split the encoded bytes into chunks of MAX_TXT_STRING_LEN
for i in range(0, len(encoded_bytes), MAX_TXT_STRING_LEN):
chunk = encoded_bytes[i : i + MAX_TXT_STRING_LEN]
# Append each chunk as a separate item for the TXT payload
final_txt_payload_segments.append(chunk)
except Exception as enc_err:
# Handle potential errors during encoding or processing a specific string
print(
f"ERROR encoding/splitting item '{logical_string}' for {instance_id_str}: {enc_err}. Skipping this item."
)
# --- 3. Debugging the final list of segments ---
# (Optional: Keep the debugging print statements from previous versions if needed)
print(
f"DNS TXT DEBUG: FINAL payload segments count for {instance_id_str}: {len(final_txt_payload_segments)}"
)
# ... add detailed segment logging back here if required ...
valid_payload_structure = True # Assume valid unless checks fail below
# Basic check if it's a list and contains bytes
if not isinstance(final_txt_payload_segments, list):
print(f" ERROR: final_txt_payload_segments is not a list!")
valid_payload_structure = False
elif final_txt_payload_segments and not all(
isinstance(s, bytes) for s in final_txt_payload_segments
):
print(
f" ERROR: Not all items in final_txt_payload_segments are bytes!"
)
valid_payload_structure = False
# --- 4. Create Record_TXT, FLATTENING if only one segment ---
if valid_payload_structure and final_txt_payload_segments:
num_segments = len(final_txt_payload_segments)
print(
f"DNS TXT: Attempting to create Record_TXT for instance {instance_id_str} with {num_segments} segments..."
)
# **** THE KEY WORKAROUND ****
if num_segments == 1:
# If only one segment, pass the bytes object directly
payload_data = final_txt_payload_segments[0]
print(" (Payload is single segment, passing bytes directly)")
else:
# If multiple segments, pass the list (MUST use list/tuple here)
payload_data = final_txt_payload_segments # Pass the list
print(
f" (Payload has {num_segments} segments, passing sequence)"
)
# **** END WORKAROUND ****
# Instantiate Record_TXT with the correctly structured data
payload = dns.Record_TXT(payload_data, ttl=DNS_DEFAULT_TTL)
print(
f"DNS TXT: Record_TXT created successfully for {instance_id_str}."
)
rr = dns.RRHeader(
name=name,
type=dns.TXT,
cls=cls,
ttl=DNS_DEFAULT_TTL,
payload=payload,
)
answers.append(rr)
print(
f"DNS TXT: RRHeader created and added for instance {instance_id_str}."
)
elif not final_txt_payload_segments:
print(
f"DNS TXT: Skipping record creation for {instance_id_str} due to empty payload."
)
else: # valid_payload_structure must be False
print(
f"DNS TXT ERROR: Skipping record creation for {instance_id_str} due to invalid payload structure."
)
# --- Error Handling (Catch errors during the DNS object creation itself) ---
except TypeError as te_dns:
# This might still catch errors if the multi-segment pathway also fails
print(
f"FATAL DNS TypeError creating TXT record objects for {instance_id_str}: {te_dns}"
)
print(
" This could indicate an issue even with multi-segment lists, or the flattened single segment."
)
traceback.print_exc()
except Exception as e_dns:
print(
f"ERROR creating TXT DNS objects for {instance_id_str}: {e_dns.__class__.__name__}: {e_dns}"
)
traceback.print_exc()
# Log the final result before returning
print(
f"DNS TXT: Finished processing query for '{base_name}'. Found {len(instances)} instances, generated {len(answers)} TXT records."
)
return answers, [], []
# --- Health Checker ---
def check_service_health(instance: ServiceInstance) -> str:
"""Performs a simple TCP connection check."""
try:
with socket.create_connection(
(instance.address, instance.port), timeout=2
) as sock:
# Could add optional check for specific data/response here later
return "passing"
except (socket.timeout, ConnectionRefusedError, OSError):
return "failing"
except Exception as e:
print(
f"Unexpected error checking health for {instance.id} ({instance.address}:{instance.port}): {e}"
)
return "failing" # Treat unexpected errors as failure
# --- Process Runner Functions ---
def run_fastapi_server(db_path: str, hmac_key: bytes, host: str, port: int):
"""Sets up and runs the FastAPI server."""
print(f"[{os.getpid()}] Starting FastAPI process...")
# Initialize registry for this specific process
setup_registry(db_path, hmac_key)
app = create_fastapi_app()
# Setup signal handlers for graceful shutdown within Uvicorn if possible
# Uvicorn handles SIGINT/SIGTERM by default
uvicorn.run(app, host=host, port=port)
print(
f"[{os.getpid()}] FastAPI process finished."
) # Should not be reached normally
def run_dns_server(db_path: str, hmac_key: bytes, port: int):
"""Sets up and runs the Twisted DNS server."""
print(f"[{os.getpid()}] Starting DNS server process...")
registry = ServiceRegistry(
db_path, hmac_key
) # DNS process needs its own registry instance
resolver = MiniDiscoveryResolver(registry)
factory = server.DNSServerFactory(
clients=[resolver],
# verbose=2 # Uncomment for very detailed Twisted DNS logging
)
protocol = dns.DNSDatagramProtocol(controller=factory)
# Listen on UDP and TCP
try:
reactor.listenUDP(port, protocol)
reactor.listenTCP(port, factory)
print(f"[{os.getpid()}] DNS Server listening on port {port} (UDP/TCP)")
except Exception as e:
print(f"[{os.getpid()}] Error starting DNS listeners on port {port}: {e}")
return # Exit process if cannot bind
# Graceful shutdown for Twisted reactor
def shutdown_dns():
print(f"[{os.getpid()}] Shutting down DNS server...")
# reactor.stop() might be called by signal handlers below
# Add cleanup here if needed (e.g., closing listeners explicitly)
reactor.addSystemEventTrigger("before", "shutdown", shutdown_dns)
# Handle signals to stop the reactor gracefully
def signal_handler(signum, frame):
print(f"[{os.getpid()}] Received signal {signum}, stopping DNS reactor.")
# Important: Call reactor.stop() from the reactor thread if possible
# reactor.callFromThread(reactor.stop) is safer if signals handled elsewhere
# For simple cases, calling directly might be okay, but watch for deadlocks.
if reactor.running:
reactor.callLater(
0, reactor.stop
) # Schedule stop in the next loop iteration
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Run the Twisted reactor (blocking call)
reactor.run()
print(f"[{os.getpid()}] DNS server process finished.")
def run_health_checker(db_path: str, hmac_key: bytes, check_interval: int):
"""Runs the health checking loop."""
print(
f"[{os.getpid()}] Starting Health Checker process (interval: {check_interval}s)..."
)
registry = ServiceRegistry(
db_path, hmac_key
) # Health checker needs its own registry instance
running = True
def signal_handler(signum, frame):
nonlocal running
print(f"[{os.getpid()}] Received signal {signum}, stopping health checker...")
running = False
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
while running:
start_time = time.monotonic()
print(f"[{os.getpid()}] Health Check Cycle Start")
updated_count = 0
checked_count = 0
try:
# Fetch all services directly from DB for the check cycle
services_map = registry.get_all_services()
all_instances = [
inst for sublist in services_map.values() for inst in sublist
]
checked_count = len(all_instances)
for instance in all_instances:
if not running:
break # Exit early if shutdown signal received
current_health = instance.health
# Perform the actual health check
new_health = check_service_health(instance)
if current_health != new_health:
print(
f"[{os.getpid()}] Health change for {instance.id} ({instance.name}): {current_health} -> {new_health}"
)
# Update health status in the database
if registry.update_health(instance.id, new_health):
updated_count += 1
else:
# This might happen if the service was deregistered between get and update
print(
f"[{os.getpid()}] Warning: Failed to update health for {instance.id} (maybe deregistered?)"
)
except Exception as e:
# Catch errors during the check cycle itself (e.g., DB connection)
print(f"[{os.getpid()}] Error during health check cycle: {e}")
if not running:
break # Check again after the loop body
# Calculate sleep time to maintain interval
elapsed = time.monotonic() - start_time
sleep_time = max(0, check_interval - elapsed)
print(
f"[{os.getpid()}] Health Check Cycle End. Checked: {checked_count}, Updated: {updated_count}. Took {elapsed:.2f}s. Sleeping for {sleep_time:.2f}s."
)
# Sleep interruptibly
try:
time.sleep(sleep_time)
except InterruptedError: # Catch if signal interrupted sleep
pass
print(f"[{os.getpid()}] Health checker process finished.")
# --- Main Execution ---
def main():
parser = argparse.ArgumentParser(
description="MiniDiscovery: A minimal service discovery tool.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, # Show defaults
)
parser.add_argument(
"--db-path",
default="minidiscovery_data.db",
help="Path to the SQLite database file. Overridden by ${DB_PATH_ENV_VAR} if set.",
)
parser.add_argument(
"--api-host", default="0.0.0.0", help="Host address for the API server."
)
parser.add_argument(
"--api-port", type=int, default=8500, help="Port for the API server."
)
parser.add_argument(
"--dns-port",
type=int,
default=10053,
help="Port for the DNS server (UDP/TCP). Use ports > 1024 for non-root.",
)
parser.add_argument(
"--health-check-interval",
type=int,
default=15,
help="Health check interval in seconds.",
)
parser.add_argument(
"--hmac-key-file",
default=HMAC_KEY_FILE,
help="Path to the file storing the HMAC secret key.",
)
args = parser.parse_args()
# Read db_file_path from ENV var
db_path_from_env = os.environ.get(DB_PATH_ENV_VAR)
if db_path_from_env:
args.db_path = db_path_from_env # Override args
print(
f"Using database path from environment variable ${DB_PATH_ENV_VAR}: {args.db_path}"
)
elif not args.db_path: # If env var not set AND command line arg not provided
args.db_path = "minidiscovery_data.db" # Apply the default now
print(f"Database path not specified, using default: {args.db_path}")
else:
# Using the path provided via --db-path argument
print(f"Using database path from --db-path argument: {args.db_path}")
# Ensure DB directory exists
db_dir = os.path.dirname(os.path.abspath(args.db_path))
if db_dir:
os.makedirs(db_dir, exist_ok=True)
# Load or generate HMAC key *before* starting processes
try:
hmac_key = load_or_generate_hmac_key(args.hmac_key_file)
except SystemExit:
return # Exit if key generation/loading fails
# Check for initial admin token env var *before* initializing registry in main process
# The check inside ServiceRegistry._init_db is the authoritative one,
# but this provides an earlier warning.
conn_check = sqlite3.connect(args.db_path)
cursor_check = conn_check.cursor()
try:
cursor_check.execute("SELECT COUNT(*) FROM tokens")
token_count_check = cursor_check.fetchone()[0]
if token_count_check == 0 and not os.environ.get(ADMIN_TOKEN_ENV_VAR):
print(
f"WARNING: Database appears empty. Ensure '{ADMIN_TOKEN_ENV_VAR}' is set for the first run."
)
except sqlite3.DatabaseError:
# Table might not exist yet, ServiceRegistry init will handle it
pass
finally:
conn_check.close()
# --- Process Management ---
processes: List[multiprocessing.Process] = []
stop_event = multiprocessing.Event() # Used to signal shutdown
def graceful_shutdown(signum, frame):
print(f"Main process received signal {signum}. Initiating shutdown...")
stop_event.set() # Signal processes to stop (though they also have signal handlers)
# Terminate processes forcefully after a grace period if they don't exit
time.sleep(2) # Give processes a moment to react to their own signal handlers
for p in processes:
if p.is_alive():
print(f"Terminating process {p.pid} ({p.name})...")
p.terminate() # Send SIGTERM
signal.signal(signal.SIGINT, graceful_shutdown)
signal.signal(signal.SIGTERM, graceful_shutdown)
try:
# Start API Server Process
api_process = multiprocessing.Process(
target=run_fastapi_server,
args=(args.db_path, hmac_key, args.api_host, args.api_port),
name="FastAPI Process",
)
processes.append(api_process)
api_process.start()
# Start DNS Server Process
dns_process = multiprocessing.Process(
target=run_dns_server,
args=(args.db_path, hmac_key, args.dns_port),
name="DNS Process",
)
processes.append(dns_process)
dns_process.start()
# Start Health Checker Process
health_process = multiprocessing.Process(
target=run_health_checker,
args=(args.db_path, hmac_key, args.health_check_interval),
name="Health Check Process",
)
processes.append(health_process)
health_process.start()
print("-" * 30)
print(f"MiniDiscovery Started (PID: {os.getpid()})")
print(f" API Server: http://{args.api_host}:{args.api_port}")
print(f" DNS Server: Port {args.dns_port} (UDP/TCP)")
print(f" Database: {args.db_path}")
print(f" HMAC Key: {args.hmac_key_file}")
print(f" Health Int: {args.health_check_interval}s")
print("-" * 30)
print("Press Ctrl+C to stop.")
# Wait for processes to complete or shutdown signal
while not stop_event.is_set():
# Check if any process exited unexpectedly
for p in processes:
if not p.is_alive() and p.exitcode != 0:
print(
f"ERROR: Process {p.name} (PID: {p.pid}) exited unexpectedly with code {p.exitcode}."
)
stop_event.set() # Trigger shutdown if a child crashes
break
time.sleep(1) # Wait efficiently
finally:
print("Waiting for processes to join...")
for p in processes:
p.join(timeout=5) # Wait for clean exit
if p.is_alive():
print(f"Process {p.pid} ({p.name}) did not exit cleanly, killing.")
p.kill() # Force kill if still alive after timeout
print("MiniDiscovery shut down complete.")
if __name__ == "__main__":
main()