commit 26e9454ba9542998342c81410877c93489e3bbf1 Author: Kalzu Rekku Date: Sat May 3 09:59:47 2025 +0300 Initial commit. DNS queries are broken. But most of the stuff seems to work. diff --git a/MiniDiscovery.py b/MiniDiscovery.py new file mode 100644 index 0000000..ae6b54c --- /dev/null +++ b/MiniDiscovery.py @@ -0,0 +1,1350 @@ +import argparse +import asyncio +import base64 +import hashlib +import hmac +import json +import multiprocessing +import os +import secrets +import signal +import socket +import sqlite3 +import time +import uvicorn +import traceback +from contextlib import contextmanager +from typing import Dict, List, Optional, Tuple, Generator, Any + +from fastapi import Depends, FastAPI, Header, HTTPException, Security, status +from fastapi.security import APIKeyHeader +from fastapi.responses import HTMLResponse +from pydantic import BaseModel, Field +from twisted.internet import reactor +from twisted.names import dns, server, common +from twisted.internet.defer import Deferred +from twisted.python.failure import Failure + +# --- Constants --- +SQLITE_TIMEOUT_SECONDS = 10 # Increased timeout for potential concurrent writes +HMAC_KEY_FILE = "minidiscovery.key" +ADMIN_TOKEN_ENV_VAR = "MINIDISCOVERY_ADMIN_TOKEN" +DEFAULT_TOKEN_PERMISSIONS = ["read", "write"] +ADMIN_PERMISSIONS = ["read", "write", "admin"] +DNS_DEFAULT_TTL = 60 +DNS_QUERY_SUFFIX = ".laiska.local" # Define a suffix for DNS lookups +DB_PATH_ENV_VAR = "MINIDISCOVERY_DB_PATH" + +# --- Database Schema --- +SQL_CREATE_SERVICES = """ +CREATE TABLE IF NOT EXISTS services ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + address TEXT NOT NULL, + port INTEGER NOT NULL, + health TEXT DEFAULT 'passing', + tags_json TEXT DEFAULT '[]', -- Store tags as JSON text + metadata_json TEXT DEFAULT '{}', -- Store metadata as JSON text + last_update REAL NOT NULL -- Timestamp of last update/registration +); +""" +SQL_CREATE_TOKENS = """ +CREATE TABLE IF NOT EXISTS tokens ( + token_hash TEXT PRIMARY KEY, -- Store the HMAC hash of the token + name TEXT NOT NULL UNIQUE, -- Token names should be unique + created_at REAL NOT NULL, + permissions_json TEXT NOT NULL -- Store permissions as JSON text +); +""" +SQL_CREATE_META = """ +CREATE TABLE IF NOT EXISTS meta ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +); +""" # For storing things like schema version or first-run status + + +# --- Pydantic Models --- +class ServiceInstance(BaseModel): + id: str = Field(..., description="Unique identifier for this service instance") + name: str = Field( + ..., description="Logical name of the service (e.g., 'web', 'db')" + ) + address: str = Field( + ..., description="IP address or hostname where the service listens" + ) + port: int = Field(..., gt=0, lt=65536, description="Port number for the service") + tags: List[str] = Field( + default_factory=list, description="Optional list of tags for filtering" + ) + health: str = Field( + default="passing", description="Health status ('passing', 'failing', 'unknown')" + ) + metadata: Dict[str, str] = Field( + default_factory=dict, description="Optional key-value metadata" + ) + + +class TokenInfo(BaseModel): + """Information about a token (excluding the hash)""" + + name: str + created_at: float + permissions: List[str] + + +class ApiTokenCreateRequest(BaseModel): + name: str = Field(..., description="A descriptive name for the token") + permissions: List[str] = Field( + default=DEFAULT_TOKEN_PERMISSIONS, description="Permissions for the token" + ) + + +class ApiTokenCreateResponse(BaseModel): + token: str = Field(..., description="The generated API token (show only once!)") + name: str + permissions: List[str] + + +# --- HMAC Key Management --- +def load_or_generate_hmac_key(key_file: str = HMAC_KEY_FILE) -> bytes: + """Loads HMAC key from file or generates a new one.""" + if os.path.exists(key_file): + try: + with open(key_file, "rb") as f: + key = f.read() + if len(key) < 32: # Basic sanity check + raise ValueError("HMAC key file seems corrupted or too short.") + print(f"Loaded HMAC key from {key_file}") + return key + except Exception as e: + print( + f"Error loading HMAC key from {key_file}: {e}. Check file permissions and content." + ) + raise SystemExit(1) + else: + print(f"HMAC key file ({key_file}) not found. Generating a new one.") + key = secrets.token_bytes(32) # 256 bits + try: + # Attempt to write with restricted permissions + fd = os.open(key_file, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o600) + with os.fdopen(fd, "wb") as f: + f.write(key) + print(f"Generated and saved new HMAC key to {key_file}. PROTECT THIS FILE!") + return key + except FileExistsError: + # Race condition: another process created it between check and open + return load_or_generate_hmac_key(key_file) + except OSError as e: + print(f"Error writing HMAC key file {key_file}: {e}") + print("Please ensure the directory is writable by the process.") + raise SystemExit(1) + except Exception as e: + print(f"Unexpected error generating HMAC key: {e}") + raise SystemExit(1) + + +def generate_token_hash(token: str, hmac_key: bytes) -> str: + """Generates HMAC-SHA256 hash of the token.""" + return hmac.new(hmac_key, token.encode("utf-8"), hashlib.sha256).hexdigest() + + +# --- Service Registry (SQLite Backend) --- +class ServiceRegistry: + def __init__(self, db_path: str, hmac_key: bytes): + self.db_path = db_path + self._hmac_key = hmac_key + self._init_db() + + @contextmanager + def _get_db_conn(self) -> Generator[sqlite3.Connection, None, None]: + """Provides a context-managed database connection.""" + conn = None + try: + conn = sqlite3.connect(self.db_path, timeout=SQLITE_TIMEOUT_SECONDS) + conn.row_factory = sqlite3.Row # Access columns by name + conn.execute("PRAGMA foreign_keys = ON;") # Enforce foreign keys if needed + conn.execute( + "PRAGMA journal_mode = WAL;" + ) # Write-Ahead Logging for better concurrency + yield conn + conn.commit() + except sqlite3.Error as e: + print(f"SQLite error: {e} - {self.db_path}") + if conn: + conn.rollback() # Rollback on error + raise # Re-raise the exception + finally: + if conn: + conn.close() + + def _init_db(self) -> None: + """Initializes the database schema if it doesn't exist.""" + try: + with self._get_db_conn() as conn: + cursor = conn.cursor() + cursor.execute(SQL_CREATE_SERVICES) + cursor.execute(SQL_CREATE_TOKENS) + cursor.execute(SQL_CREATE_META) + + # Check if initial admin token needs to be created + cursor.execute("SELECT COUNT(*) FROM tokens") + token_count = cursor.fetchone()[0] + + cursor.execute("SELECT value FROM meta WHERE key = 'initialized'") + initialized = cursor.fetchone() + + if token_count == 0 and not initialized: + print("No tokens found in the database.") + admin_token_plain = os.environ.get(ADMIN_TOKEN_ENV_VAR) + if not admin_token_plain: + print( + f"ERROR: Database is empty and initial admin token not provided." + ) + print( + f"Please set the '{ADMIN_TOKEN_ENV_VAR}' environment variable with a secure token." + ) + raise SystemExit(1) # Critical configuration missing + + print( + f"Creating initial admin token from '{ADMIN_TOKEN_ENV_VAR}'..." + ) + token_hash = generate_token_hash(admin_token_plain, self._hmac_key) + cursor.execute( + "INSERT INTO tokens (token_hash, name, created_at, permissions_json) VALUES (?, ?, ?, ?)", + ( + token_hash, + "admin", + time.time(), + json.dumps(ADMIN_PERMISSIONS), + ), + ) + # Mark DB as initialized to prevent asking for env var again + cursor.execute( + "INSERT OR REPLACE INTO meta (key, value) VALUES ('initialized', 'true')" + ) + print( + "Initial admin token created successfully. You can unset the environment variable now." + ) + elif token_count > 0 and not initialized: + # Tokens exist, but not marked initialized (e.g., older version) - mark it now + cursor.execute( + "INSERT OR REPLACE INTO meta (key, value) VALUES ('initialized', 'true')" + ) + + # Future schema migrations could go here based on a version in 'meta' table + except sqlite3.OperationalError as e: + if "database is locked" in str(e): + print( + f"Warning: Database {self.db_path} is locked during initialization. Retrying shortly..." + ) + time.sleep( + 0.5 + ) # Short delay before potential retry by caller/process start + self._init_db() # Recursive call - be cautious, maybe add max retries + else: + print(f"Fatal error initializing database {self.db_path}: {e}") + raise SystemExit(1) + except Exception as e: + print(f"Fatal error during database initialization: {e}") + raise SystemExit(1) + + def register(self, instance: ServiceInstance) -> bool: + """Registers or updates a service instance in the database.""" + with self._get_db_conn() as conn: + cursor = conn.cursor() + try: + cursor.execute( + """ + INSERT OR REPLACE INTO services + (id, name, address, port, health, tags_json, metadata_json, last_update) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + instance.id, + instance.name, + instance.address, + instance.port, + instance.health, + json.dumps(instance.tags), + json.dumps(instance.metadata), + time.time(), + ), + ) + return True + except sqlite3.IntegrityError as e: + print(f"Error registering service {instance.id}: {e}") + return False # Should not happen with INSERT OR REPLACE unless DB issue + + def deregister(self, service_id: str) -> bool: + """Removes a service instance from the database.""" + with self._get_db_conn() as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM services WHERE id = ?", (service_id,)) + return cursor.rowcount > 0 # Returns true if a row was deleted + + def get_service( + self, name: str, only_passing: bool = False + ) -> List[ServiceInstance]: + """Retrieves all instances for a given service name.""" + instances = [] + sql = "SELECT * FROM services WHERE name = ?" + params: List[Any] = [name] + if only_passing: + sql += " AND health = 'passing'" + + with self._get_db_conn() as conn: + cursor = conn.cursor() + cursor.execute(sql, tuple(params)) + rows = cursor.fetchall() + for row in rows: + try: + instances.append( + ServiceInstance( + id=row["id"], + name=row["name"], + address=row["address"], + port=row["port"], + health=row["health"], + tags=json.loads(row["tags_json"]), + metadata=json.loads(row["metadata_json"]), + ) + ) + except (json.JSONDecodeError, TypeError, KeyError) as e: + print( + f"Warning: Could not parse service data for ID {row.get('id', 'N/A')}: {e}" + ) + # Optionally skip or mark as unhealthy + return instances + + def get_all_services(self) -> Dict[str, List[ServiceInstance]]: + """Retrieves all registered services, grouped by name.""" + services: Dict[str, List[ServiceInstance]] = {} + with self._get_db_conn() as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM services ORDER BY name") + rows = cursor.fetchall() + for row in rows: + try: + instance = ServiceInstance( + id=row["id"], + name=row["name"], + address=row["address"], + port=row["port"], + health=row["health"], + tags=json.loads(row["tags_json"]), + metadata=json.loads(row["metadata_json"]), + ) + if instance.name not in services: + services[instance.name] = [] + services[instance.name].append(instance) + except (json.JSONDecodeError, TypeError, KeyError) as e: + print( + f"Warning: Could not parse service data for ID {row.get('id', 'N/A')}: {e}" + ) + return services + + def update_health(self, service_id: str, health: str) -> bool: + """Updates the health status of a specific service instance.""" + with self._get_db_conn() as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE services SET health = ?, last_update = ? WHERE id = ?", + (health, time.time(), service_id), + ) + return cursor.rowcount > 0 + + def create_token( + self, name: str, permissions: List[str] + ) -> Tuple[Optional[str], Optional[str]]: + """Creates a new API token, storing its hash. Returns (token, None) on success, (None, error_message) on failure.""" + token_plain = secrets.token_urlsafe(32) + token_hash = generate_token_hash(token_plain, self._hmac_key) + + with self._get_db_conn() as conn: + cursor = conn.cursor() + try: + cursor.execute( + "INSERT INTO tokens (token_hash, name, created_at, permissions_json) VALUES (?, ?, ?, ?)", + (token_hash, name, time.time(), json.dumps(permissions)), + ) + return token_plain, None # Return the plain token only on success + except sqlite3.IntegrityError: + # This likely means the name is not unique + return None, f"Token name '{name}' already exists." + except Exception as e: + print(f"Error creating token '{name}': {e}") + return None, "Database error during token creation." + + def revoke_token(self, token_plain: str) -> bool: + """Revokes an API token by deleting its hash from the database.""" + token_hash = generate_token_hash(token_plain, self._hmac_key) + with self._get_db_conn() as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM tokens WHERE token_hash = ?", (token_hash,)) + return cursor.rowcount > 0 + + def validate_token(self, token_plain: str) -> Optional[TokenInfo]: + """ + Validates a plain text token against stored hashes and returns its info if valid. + Returns None if the token is invalid. + """ + token_hash = generate_token_hash(token_plain, self._hmac_key) + with self._get_db_conn() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT name, created_at, permissions_json FROM tokens WHERE token_hash = ?", + (token_hash,), + ) + row = cursor.fetchone() + if row: + try: + permissions = json.loads(row["permissions_json"]) + if not isinstance(permissions, list): + raise TypeError("Permissions format error") + return TokenInfo( + name=row["name"], + created_at=row["created_at"], + permissions=permissions, + ) + except (json.JSONDecodeError, TypeError) as e: + print( + f"Error decoding permissions for token hash {token_hash}: {e}" + ) + return None # Treat malformed permissions as invalid + else: + return None # Token hash not found + + +# --- API Key Security --- +api_key_header = APIKeyHeader(name="X-API-Token", auto_error=False) + + +def get_permission_checker(required_permission: str): + """ + Dependency factory to create a dependency that checks for a specific permission. + """ + + async def _check_permission( + registry: ServiceRegistry = Depends( + lambda: get_service_registry() + ), # Get registry instance + api_key: Optional[str] = Security(api_key_header), + ) -> str: # Return API key on success + if api_key is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API token missing in X-API-Token header", + ) + + token_info = registry.validate_token(api_key) + + if token_info is None: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API token" + ) + + if required_permission not in token_info.permissions: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Insufficient permissions. Required: '{required_permission}'", + ) + + return api_key # Return the validated key (or token info if needed later) + + return _check_permission + + +# --- FastAPI Application Factory --- + +# Global variable to hold the registry instance *within this process* +# This is okay because FastAPI runs in a single process (or multiple workers, +# each getting its own instance, which is fine as they all talk to the DB). +_registry_instance: Optional[ServiceRegistry] = None + + +def setup_registry(db_path: str, hmac_key: bytes): + """Initializes the registry instance for the current process.""" + global _registry_instance + if _registry_instance is None: + print(f"[{os.getpid()}] Initializing ServiceRegistry for FastAPI...") + _registry_instance = ServiceRegistry(db_path, hmac_key) + else: + print(f"[{os.getpid()}] ServiceRegistry already initialized for FastAPI.") + + +def get_service_registry() -> ServiceRegistry: + """Dependency to get the ServiceRegistry instance.""" + if _registry_instance is None: + # This should not happen if setup_registry is called correctly before app starts + raise RuntimeError("ServiceRegistry not initialized for this process.") + return _registry_instance + + +def generate_status_page_html(registry: ServiceRegistry) -> str: + """Generates the HTML for the status page.""" + + all_services_map = registry.get_all_services() + nodes = ( + {} + ) # Dictionary to store node info: {address: {"status": "Active|Inactive", "label": "Node N"}} + + # 1. Identify unique "nodes" (addresses) and determine their status + unique_addresses = sorted( + list( + set( + inst.address + for instances in all_services_map.values() + for inst in instances + ) + ) + ) + + node_label_counter = 1 + for addr in unique_addresses: + is_active = False + for service_name, instances in all_services_map.items(): + for instance in instances: + if instance.address == addr and instance.health == "passing": + is_active = True + break # Found one passing service at this address, node is active + if is_active: + break + nodes[addr] = { + "status": "Active" if is_active else "Inactive", + "label": f"Node {node_label_counter}", + } + node_label_counter += 1 + + grid_size = len(nodes) + + # 2. Build the HTML string dynamically + # Use inline style for grid-size as it's dynamic + html_content = f""" + + + + + + + MiniDiscovery Status + + + +

Node Status

+ """ + + if grid_size == 0: + html_content += "
No active nodes detected.
" + else: + html_content += f"
" # Set CSS variable size + # Top-left corner + html_content += "
" + # Column headers + for addr in unique_addresses: + html_content += f"
{nodes[addr]['label']}
" + + # Grid rows + for row_addr in unique_addresses: + # Row label + html_content += f"
{nodes[row_addr]['label']}
" + # Cells in the row + for col_addr in unique_addresses: + if row_addr == col_addr: + # Diagonal: Show node's own status + status = nodes[row_addr]["status"] + status_class = f"status-{status.lower()}" + html_content += ( + f"
{status}
" + ) + else: + # Off-diagonal: Show 'Unknown' as we don't track inter-node connectivity + html_content += ( + "
Unknown
" + ) + # End of row implicitly handled by grid layout + html_content += "
" # End grid-container + + # Add footer/closing tags + html_content += """ +
Status based on registered service health. Page refreshes automatically.
+ + + """ + return html_content + + +def create_fastapi_app() -> FastAPI: + """Creates the FastAPI application.""" + app = FastAPI( + title="MiniDiscovery", + description="Minimal Service Discovery inspired by Consul", + version="0.2.0", + ) + + # --- Status Page Endpoint (Public) --- + @app.get("/status", response_class=HTMLResponse) + async def get_status_page( + registry: ServiceRegistry = Depends(get_service_registry), + ): + """Serves the anonymous HTML status page.""" + html_content = generate_status_page_html(registry) + return HTMLResponse(content=html_content, status_code=200) + + # --- API Endpoints (Authenticated) --- + @app.post("/v1/agent/service/register", status_code=status.HTTP_200_OK) + async def register_service( + instance: ServiceInstance, + registry: ServiceRegistry = Depends(get_service_registry), + token: str = Depends(get_permission_checker("write")), # Require 'write' + ): + """Registers a service instance. Supports replace-if-exists based on ID.""" + # Pydantic performs validation based on the model + if registry.register(instance): + # Return 200 OK on success + return {"status": "registered", "service_id": instance.id} + else: + # This path might be less likely with INSERT OR REPLACE + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to register service due to a database error.", + ) + + @app.put( + "/v1/agent/service/deregister/{service_id}", status_code=status.HTTP_200_OK + ) + async def deregister_service( + service_id: str, + registry: ServiceRegistry = Depends(get_service_registry), + token: str = Depends(get_permission_checker("write")), # Require 'write' + ): + """Deregisters a service instance by its ID.""" + if registry.deregister(service_id): + # Returns 200 OK on success + return {"status": "deregistered", "service_id": service_id} + else: + # If deregister returns false, it means the service wasn't found + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Service ID '{service_id}' not found.", + ) + + @app.get("/v1/catalog/services", response_model=Dict[str, List[str]]) + async def get_catalog_services( + registry: ServiceRegistry = Depends(get_service_registry), + token: str = Depends(get_permission_checker("read")), # Require 'read' + ): + """Lists all known service names and their tags.""" + services_data = registry.get_all_services() + # Format matches Consul /v1/catalog/services + catalog = { + name: list(set(tag for inst in instances for tag in inst.tags)) + for name, instances in services_data.items() + } + return catalog + + @app.get("/v1/catalog/service/{service_name}", response_model=List[ServiceInstance]) + async def get_catalog_service( + service_name: str, + tag: Optional[str] = None, # Allow filtering by tag + registry: ServiceRegistry = Depends(get_service_registry), + token: str = Depends(get_permission_checker("read")), # Require 'read' + ): + """Lists instances for a specific service, optionally filtered by tag.""" + instances = registry.get_service(service_name) + if not instances: + # Return empty list if service name doesn't exist, matches Consul + return [] + + if tag: + instances = [inst for inst in instances if tag in inst.tags] + + # Pydantic validation handles the response model structure + return instances + + @app.get("/v1/health/service/{service_name}", response_model=List[ServiceInstance]) + async def get_health_service( + service_name: str, + tag: Optional[str] = None, + passing: bool = False, # Filter for only passing instances + registry: ServiceRegistry = Depends(get_service_registry), + token: str = Depends(get_permission_checker("read")), # Require 'read' + ): + """Lists healthy instances for a specific service.""" + instances = registry.get_service(service_name, only_passing=passing) + if not instances: + return [] + + if tag: + instances = [inst for inst in instances if tag in inst.tags] + + return instances + + # --- Token Management Endpoints --- + @app.post( + "/v1/acl/token", + response_model=ApiTokenCreateResponse, + status_code=status.HTTP_201_CREATED, + ) + async def create_acl_token( + token_request: ApiTokenCreateRequest, + registry: ServiceRegistry = Depends(get_service_registry), + token: str = Depends(get_permission_checker("admin")), # Require 'admin' + ): + """Creates a new API token.""" + plain_token, error = registry.create_token( + token_request.name, token_request.permissions + ) + if error: + # Distinguish between user error (duplicate name) and server error + if "already exists" in error: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=error + ) + else: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=error + ) + + return ApiTokenCreateResponse( + token=plain_token, # Only show the token once upon creation! + name=token_request.name, + permissions=token_request.permissions, + ) + + @app.delete("/v1/acl/token/{token_to_revoke}", status_code=status.HTTP_200_OK) + async def revoke_acl_token( + token_to_revoke: str, # The actual token value to revoke + registry: ServiceRegistry = Depends(get_service_registry), + token: str = Depends(get_permission_checker("admin")), # Require 'admin' + ): + """Revokes an existing API token.""" + # Prevent revoking the token currently being used for the request? Maybe not necessary. + # current_token_info = registry.validate_token(token) + # if registry.generate_token_hash(token_to_revoke, registry._hmac_key) == registry.generate_token_hash(token, registry._hmac_key): + # raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot revoke the token used for this request.") + + if registry.revoke_token(token_to_revoke): + return { + "status": "revoked", + "token_info": "Token revoked successfully", + } # Don't echo back the token + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Token not found or already revoked.", + ) + + # --- DNS-over-HTTPS Endpoint (Simplified, follows RFC 8484 GET format) --- + @app.get("/dns-query", status_code=status.HTTP_200_OK) + async def dns_query( + name: str, + type: str = "A", # Query param 'type' + registry: ServiceRegistry = Depends(get_service_registry), + # No auth required for basic DNS queries for simplicity, add if needed + ): + """Handles DNS queries over HTTPS (DoH).""" + # Basic parsing, assuming name ends with configured suffix + if not name.lower().endswith(DNS_QUERY_SUFFIX): + return { + "Status": 2 + } # NXDOMAIN according to RFC8484 JSON format for DNS Messages + + service_name = name[: -len(DNS_QUERY_SUFFIX)].split(".")[ + -1 + ] # Get the last part before the suffix + + instances = registry.get_service(service_name, only_passing=True) + if not instances: + return {"Status": 2} # NXDOMAIN + + answers = [] + query_type_upper = type.upper() + + if query_type_upper == "A": + type_code = 1 # DNS Type A + for instance in instances: + answers.append( + { + "name": name, + "type": type_code, + "TTL": DNS_DEFAULT_TTL, + "data": instance.address, + } + ) + elif query_type_upper == "SRV": + type_code = 33 # DNS Type SRV + for instance in instances: + # Format: Priority Weight Port Target + srv_data = f"0 5 {instance.port} {instance.address}{DNS_QUERY_SUFFIX}" # Target should be resolvable ideally + answers.append( + { + "name": name, + "type": type_code, + "TTL": DNS_DEFAULT_TTL, + "data": srv_data, + } + ) + # Add more record types (AAAA, TXT for metadata?) if needed + else: + # Unsupported type for this simple endpoint + return {"Status": 4} # NotImp - Not Implemented + + # Construct response according to RFC 8484 JSON format + response = { + "Status": 0, # NOERROR + "TC": False, # Truncated + "RD": True, # Recursion Desired (client asked for it) - reflects typical client flag + "RA": False, # Recursion Available (we are authoritative-like) + "AD": False, # Authenticated Data + "CD": False, # Checking Disabled + "Question": [ + { + "name": name, + "type": type_code if "type_code" in locals() else 1, + } # Reflect question + ], + "Answer": answers, + # Authority and Additional sections omitted for simplicity + } + return response + + return app + + +# --- Twisted DNS Server --- +class MiniDiscoveryResolver(common.ResolverBase): + + def __init__(self, registry: ServiceRegistry): + self.registry = registry + super().__init__() + + def _lookup( + self, + name: bytes, + cls: int, + query_type: int, + timeout: Optional[Tuple[int]] = None, + ) -> Deferred: + d = Deferred() + name_str_debug = repr(name) + try: + # 1. Decode the incoming query name safely + try: + name_str = name.decode("utf-8").lower() + except UnicodeDecodeError as decode_err: + print( + f"DNS lookup error: Cannot decode query name {name_str_debug} as UTF-8: {decode_err}" + ) + d.errback(Failure(decode_err)) + return d + + # 2. Check suffix + if not name_str.endswith(DNS_QUERY_SUFFIX): + d.callback(([], [], [])) + return d + + # --- SRV Service Type Lookup Logic --- + is_srv_type_query = False + service_tag_to_find = None + if name_str.startswith("_") and ( + "_tcp." in name_str or "_udp." in name_str + ): + parts = name_str.split(".") + # Example: _ssh._tcp.laiska.local + if ( + len(parts) >= 4 + and parts[0].startswith("_") + and parts[1] in ["_tcp", "_udp"] + ): + # Assume format _service._proto.domain.suffix... + service_tag_to_find = parts[0][1:] # Extract 'ssh' from '_ssh' + # Ensure the query type is actually SRV + if query_type == dns.SRV: + is_srv_type_query = True + else: + # Requesting A/other records for SRV-style name doesn't make sense here + d.callback(([], [], [])) + return d + + # --- Instance Fetching --- + instances_to_process = [] + if is_srv_type_query: + print( + f"DNS: Performing SRV type lookup for tag '{service_tag_to_find}'" + ) + all_services = registry.get_all_services() + for service_name, service_instances in all_services.items(): + for instance in service_instances: + # Check health AND tag presence + if ( + instance.health == "passing" + and service_tag_to_find in instance.tags + ): + instances_to_process.append(instance) + if not instances_to_process: + print( + f"DNS: No passing instances found with tag '{service_tag_to_find}'" + ) + d.callback(([], [], [])) # No matching instances found + return d + else: + service_name = name_str[: -len(DNS_QUERY_SUFFIX)].split(".")[-1] + if not service_name: + d.callback(([], [], [])) + return d + instances_to_process = registry.get_service( + service_name, only_passing=True + ) + if not instances_to_process: + d.callback(([], [], [])) + return d + + # --- Build DNS Records --- + answers = [] + authority = [] + additional = [] + + # Check query_type against the records we can generate + if ( + query_type == dns.A and not is_srv_type_query + ): # Only generate A for direct name query + for instance in instances_to_process: + instance_addr_debug = instance.address + try: + ip_address_string = instance.address + answers.append( + dns.RRHeader( + name=name, + type=dns.A, + cls=cls, + ttl=DNS_DEFAULT_TTL, + # Pass the string, Twisted handles conversion + payload=dns.Record_A( + address=ip_address_string, ttl=DNS_DEFAULT_TTL + ), + ) + ) + except ( + Exception + ) as record_e: # Catch potential errors during record creation itself + print( + f"Warning: Error creating A record for IP '{instance.address}' (Service ID {instance.id}): {record_e}. Skipping." + ) + + elif ( + query_type == dns.SRV + ): # Generate SRV for direct name query OR SRV type query + for instance in instances_to_process: + try: + ip_address_string = instance.address + # SRV target should still be bytes + target_name = f"{instance.id}{DNS_QUERY_SUFFIX}".encode("utf-8") + answers.append( + dns.RRHeader( + name=name, + type=dns.SRV, + cls=cls, + ttl=DNS_DEFAULT_TTL, + payload=dns.Record_SRV( + priority=0, + weight=10, + port=instance.port, + target=target_name, + ttl=DNS_DEFAULT_TTL, + ), + ) + ) + additional.append( + dns.RRHeader( + name=target_name, + type=dns.A, + cls=cls, + ttl=DNS_DEFAULT_TTL, + payload=dns.Record_A( + address=ip_address_string, ttl=DNS_DEFAULT_TTL + ), + ) + ) + except ( + Exception + ) as record_e: # Catch potential errors during record creation itself + print( + f"Warning: Error creating SRV/additional record for IP '{instance.address}' (Service ID {instance.id}): {record_e}. Skipping." + ) + + # If we successfully built records (or correctly skipped bad IPs) + d.callback((answers, authority, additional)) + + # --- Catch ANY OTHER unexpected error during the lookup process --- + except Exception as e: + # Log the errors + print( + f"!!! Unhandled exception during DNS lookup for query name {name_str_debug} !!!" + ) + print(f"Exception Type: {type(e).__name__}") + print(f"Exception Args: {e.args}") + print("--- Traceback ---") + traceback.print_exc() + print("--- End Traceback ---") + d.errback(Failure(e)) # Signal SERVFAIL + + return d + + +# --- Health Checker --- +def check_service_health(instance: ServiceInstance) -> str: + """Performs a simple TCP connection check.""" + try: + with socket.create_connection( + (instance.address, instance.port), timeout=2 + ) as sock: + # Could add optional check for specific data/response here later + return "passing" + except (socket.timeout, ConnectionRefusedError, OSError): + return "failing" + except Exception as e: + print( + f"Unexpected error checking health for {instance.id} ({instance.address}:{instance.port}): {e}" + ) + return "failing" # Treat unexpected errors as failure + + +# --- Process Runner Functions --- +def run_fastapi_server(db_path: str, hmac_key: bytes, host: str, port: int): + """Sets up and runs the FastAPI server.""" + print(f"[{os.getpid()}] Starting FastAPI process...") + # Initialize registry for this specific process + setup_registry(db_path, hmac_key) + app = create_fastapi_app() + # Setup signal handlers for graceful shutdown within Uvicorn if possible + # Uvicorn handles SIGINT/SIGTERM by default + uvicorn.run(app, host=host, port=port) + print( + f"[{os.getpid()}] FastAPI process finished." + ) # Should not be reached normally + + +def run_dns_server(db_path: str, hmac_key: bytes, port: int): + """Sets up and runs the Twisted DNS server.""" + print(f"[{os.getpid()}] Starting DNS server process...") + registry = ServiceRegistry( + db_path, hmac_key + ) # DNS process needs its own registry instance + resolver = MiniDiscoveryResolver(registry) + factory = server.DNSServerFactory(clients=[resolver]) + protocol = dns.DNSDatagramProtocol(controller=factory) + + # Listen on UDP and TCP + try: + reactor.listenUDP(port, protocol) + reactor.listenTCP(port, factory) + print(f"[{os.getpid()}] DNS Server listening on port {port} (UDP/TCP)") + except Exception as e: + print(f"[{os.getpid()}] Error starting DNS listeners on port {port}: {e}") + return # Exit process if cannot bind + + # Graceful shutdown for Twisted reactor + def shutdown_dns(): + print(f"[{os.getpid()}] Shutting down DNS server...") + # reactor.stop() might be called by signal handlers below + # Add cleanup here if needed (e.g., closing listeners explicitly) + + reactor.addSystemEventTrigger("before", "shutdown", shutdown_dns) + + # Handle signals to stop the reactor gracefully + def signal_handler(signum, frame): + print(f"[{os.getpid()}] Received signal {signum}, stopping DNS reactor.") + # Important: Call reactor.stop() from the reactor thread if possible + # reactor.callFromThread(reactor.stop) is safer if signals handled elsewhere + # For simple cases, calling directly might be okay, but watch for deadlocks. + if reactor.running: + reactor.callLater( + 0, reactor.stop + ) # Schedule stop in the next loop iteration + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + # Run the Twisted reactor (blocking call) + reactor.run() + print(f"[{os.getpid()}] DNS server process finished.") + + +def run_health_checker(db_path: str, hmac_key: bytes, check_interval: int): + """Runs the health checking loop.""" + print( + f"[{os.getpid()}] Starting Health Checker process (interval: {check_interval}s)..." + ) + registry = ServiceRegistry( + db_path, hmac_key + ) # Health checker needs its own registry instance + + running = True + + def signal_handler(signum, frame): + nonlocal running + print(f"[{os.getpid()}] Received signal {signum}, stopping health checker...") + running = False + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + while running: + start_time = time.monotonic() + print(f"[{os.getpid()}] Health Check Cycle Start") + updated_count = 0 + checked_count = 0 + try: + # Fetch all services directly from DB for the check cycle + services_map = registry.get_all_services() + all_instances = [ + inst for sublist in services_map.values() for inst in sublist + ] + checked_count = len(all_instances) + + for instance in all_instances: + if not running: + break # Exit early if shutdown signal received + + current_health = instance.health + # Perform the actual health check + new_health = check_service_health(instance) + + if current_health != new_health: + print( + f"[{os.getpid()}] Health change for {instance.id} ({instance.name}): {current_health} -> {new_health}" + ) + # Update health status in the database + if registry.update_health(instance.id, new_health): + updated_count += 1 + else: + # This might happen if the service was deregistered between get and update + print( + f"[{os.getpid()}] Warning: Failed to update health for {instance.id} (maybe deregistered?)" + ) + + except Exception as e: + # Catch errors during the check cycle itself (e.g., DB connection) + print(f"[{os.getpid()}] Error during health check cycle: {e}") + + if not running: + break # Check again after the loop body + + # Calculate sleep time to maintain interval + elapsed = time.monotonic() - start_time + sleep_time = max(0, check_interval - elapsed) + print( + f"[{os.getpid()}] Health Check Cycle End. Checked: {checked_count}, Updated: {updated_count}. Took {elapsed:.2f}s. Sleeping for {sleep_time:.2f}s." + ) + + # Sleep interruptibly + try: + time.sleep(sleep_time) + except InterruptedError: # Catch if signal interrupted sleep + pass + + print(f"[{os.getpid()}] Health checker process finished.") + + +# --- Main Execution --- +def main(): + parser = argparse.ArgumentParser( + description="MiniDiscovery: A minimal service discovery tool.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, # Show defaults + ) + parser.add_argument( + "--db-path", + default="minidiscovery_data.db", + help="Path to the SQLite database file. Overridden by ${DB_PATH_ENV_VAR} if set.", + ) + parser.add_argument( + "--api-host", default="0.0.0.0", help="Host address for the API server." + ) + parser.add_argument( + "--api-port", type=int, default=8500, help="Port for the API server." + ) + parser.add_argument( + "--dns-port", + type=int, + default=10053, + help="Port for the DNS server (UDP/TCP). Use ports > 1024 for non-root.", + ) + parser.add_argument( + "--health-check-interval", + type=int, + default=15, + help="Health check interval in seconds.", + ) + parser.add_argument( + "--hmac-key-file", + default=HMAC_KEY_FILE, + help="Path to the file storing the HMAC secret key.", + ) + + args = parser.parse_args() + + # Read db_file_path from ENV var + db_path_from_env = os.environ.get(DB_PATH_ENV_VAR) + if db_path_from_env: + args.db_path = db_path_from_env # Override args + print( + f"Using database path from environment variable ${DB_PATH_ENV_VAR}: {args.db_path}" + ) + elif not args.db_path: # If env var not set AND command line arg not provided + args.db_path = "minidiscovery_data.db" # Apply the default now + print(f"Database path not specified, using default: {args.db_path}") + else: + # Using the path provided via --db-path argument + print(f"Using database path from --db-path argument: {args.db_path}") + + # Ensure DB directory exists + db_dir = os.path.dirname(os.path.abspath(args.db_path)) + if db_dir: + os.makedirs(db_dir, exist_ok=True) + + # Load or generate HMAC key *before* starting processes + try: + hmac_key = load_or_generate_hmac_key(args.hmac_key_file) + except SystemExit: + return # Exit if key generation/loading fails + + # Check for initial admin token env var *before* initializing registry in main process + # The check inside ServiceRegistry._init_db is the authoritative one, + # but this provides an earlier warning. + conn_check = sqlite3.connect(args.db_path) + cursor_check = conn_check.cursor() + try: + cursor_check.execute("SELECT COUNT(*) FROM tokens") + token_count_check = cursor_check.fetchone()[0] + if token_count_check == 0 and not os.environ.get(ADMIN_TOKEN_ENV_VAR): + print( + f"WARNING: Database appears empty. Ensure '{ADMIN_TOKEN_ENV_VAR}' is set for the first run." + ) + except sqlite3.DatabaseError: + # Table might not exist yet, ServiceRegistry init will handle it + pass + finally: + conn_check.close() + + # --- Process Management --- + processes: List[multiprocessing.Process] = [] + stop_event = multiprocessing.Event() # Used to signal shutdown + + def graceful_shutdown(signum, frame): + print(f"Main process received signal {signum}. Initiating shutdown...") + stop_event.set() # Signal processes to stop (though they also have signal handlers) + # Terminate processes forcefully after a grace period if they don't exit + time.sleep(2) # Give processes a moment to react to their own signal handlers + for p in processes: + if p.is_alive(): + print(f"Terminating process {p.pid} ({p.name})...") + p.terminate() # Send SIGTERM + + signal.signal(signal.SIGINT, graceful_shutdown) + signal.signal(signal.SIGTERM, graceful_shutdown) + + try: + # Start API Server Process + api_process = multiprocessing.Process( + target=run_fastapi_server, + args=(args.db_path, hmac_key, args.api_host, args.api_port), + name="FastAPI Process", + ) + processes.append(api_process) + api_process.start() + + # Start DNS Server Process + dns_process = multiprocessing.Process( + target=run_dns_server, + args=(args.db_path, hmac_key, args.dns_port), + name="DNS Process", + ) + processes.append(dns_process) + dns_process.start() + + # Start Health Checker Process + health_process = multiprocessing.Process( + target=run_health_checker, + args=(args.db_path, hmac_key, args.health_check_interval), + name="Health Check Process", + ) + processes.append(health_process) + health_process.start() + + print("-" * 30) + print(f"MiniDiscovery Started (PID: {os.getpid()})") + print(f" API Server: http://{args.api_host}:{args.api_port}") + print(f" DNS Server: Port {args.dns_port} (UDP/TCP)") + print(f" Database: {args.db_path}") + print(f" HMAC Key: {args.hmac_key_file}") + print(f" Health Int: {args.health_check_interval}s") + print("-" * 30) + print("Press Ctrl+C to stop.") + + # Wait for processes to complete or shutdown signal + while not stop_event.is_set(): + # Check if any process exited unexpectedly + for p in processes: + if not p.is_alive() and p.exitcode != 0: + print( + f"ERROR: Process {p.name} (PID: {p.pid}) exited unexpectedly with code {p.exitcode}." + ) + stop_event.set() # Trigger shutdown if a child crashes + break + time.sleep(1) # Wait efficiently + + finally: + print("Waiting for processes to join...") + for p in processes: + p.join(timeout=5) # Wait for clean exit + if p.is_alive(): + print(f"Process {p.pid} ({p.name}) did not exit cleanly, killing.") + p.kill() # Force kill if still alive after timeout + + print("MiniDiscovery shut down complete.") + + +if __name__ == "__main__": + main() diff --git a/Pipfile b/Pipfile new file mode 100644 index 0000000..11fa86f --- /dev/null +++ b/Pipfile @@ -0,0 +1,19 @@ +[[source]] +url = "https://pypi.org/simple" +verify_ssl = true +name = "pypi" + +[packages] +fastapi = "*" +twisted = "*" +pytest = "*" +pydantic = "*" +uvicorn = "*" +black = "*" +pytest-mock = "*" +requests = "*" + +[dev-packages] + +[requires] +python_version = "3.13" diff --git a/demo_client.py b/demo_client.py new file mode 100644 index 0000000..da0403d --- /dev/null +++ b/demo_client.py @@ -0,0 +1,213 @@ +import requests +import json +import uuid +import argparse +import sys + +# --- Configuration --- +DEFAULT_BASE_URL = "http://localhost:8500" # Adjust if running elsewhere + +# --- Helper Functions --- + + +def make_request( + method: str, + endpoint: str, + api_token: str, + base_url: str = DEFAULT_BASE_URL, + params: dict = None, + json_data: dict = None, +): + """Helper function to make authenticated requests.""" + headers = { + "X-API-Token": api_token, + "Accept": "application/json", # Indicate we prefer JSON responses + } + if json_data: + headers["Content-Type"] = "application/json" + + url = f"{base_url}{endpoint}" + + try: + response = requests.request( + method, + url, + headers=headers, + params=params, + json=json_data, + timeout=10, # Add a timeout + ) + response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) + # Try to parse JSON, but handle cases where response might be empty (like 200 OK on deregister) + try: + return response.json() + except json.JSONDecodeError: + return { + "status_code": response.status_code, + "content": response.text, + } # Return status/text if no JSON body + except requests.exceptions.RequestException as e: + print(f"Error making request to {method} {url}: {e}", file=sys.stderr) + return None + except Exception as e: + print(f"An unexpected error occurred: {e}", file=sys.stderr) + return None + + +def register_service( + api_token: str, + service_name: str, + address: str, + port: int, + tags: list = None, + metadata: dict = None, + service_id: str = None, +): + """Registers a service instance.""" + if service_id is None: + service_id = f"{service_name}-{uuid.uuid4()}" # Generate a unique ID + + instance_data = { + "id": service_id, + "name": service_name, + "address": address, + "port": port, + "tags": tags if tags is not None else [], + "metadata": metadata if metadata is not None else {}, + # Health defaults to 'passing' on the server side if not provided + } + + print(f"[*] Registering service: {json.dumps(instance_data)}") + result = make_request( + "POST", + "/v1/agent/service/register", + api_token=api_token, + json_data=instance_data, + ) + if result: + print(f"[+] Registration successful: {result}") + return service_id # Return the ID used + else: + print(f"[-] Registration failed.") + return None + + +def deregister_service(api_token: str, service_id: str): + """Deregisters a service instance.""" + print(f"[*] Deregistering service ID: {service_id}") + endpoint = f"/v1/agent/service/deregister/{service_id}" + result = make_request( + "PUT", endpoint, api_token=api_token # Note: Consul uses PUT for deregister + ) + if result: + # Check status code as PUT might return 200 OK with no body on success + if isinstance(result, dict) and result.get("status_code") == 200: + print(f"[+] Deregistration successful (Status Code 200).") + return True + else: + print(f"[+] Deregistration response: {result}") + return True # Assume success if no exception raised + else: + print(f"[-] Deregistration failed.") + return False + + +def get_service_instances( + api_token: str, service_name: str, passing_only: bool = False +): + """Gets instances for a specific service.""" + print(f"[*] Querying service '{service_name}' (Passing only: {passing_only})") + if passing_only: + endpoint = f"/v1/health/service/{service_name}" + params = {"passing": "true"} + else: + endpoint = f"/v1/catalog/service/{service_name}" + params = None + + result = make_request("GET", endpoint, api_token=api_token, params=params) + + if result is not None: # Result could be an empty list [] which is valid + print(f"[+] Found {len(result)} instance(s):") + print(json.dumps(result, indent=2)) + return result + else: + print(f"[-] Failed to query service '{service_name}'.") + return None + + +def list_all_services(api_token: str): + """Lists all known service names.""" + print("[*] Listing all known service names...") + endpoint = "/v1/catalog/services" + result = make_request("GET", endpoint, api_token=api_token) + if result: + print("[+] Known services:") + print(json.dumps(result, indent=2)) + return result + else: + print("[-] Failed to list services.") + return None + + +# --- Main Execution --- +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="MiniDiscovery Client Example") + parser.add_argument( + "-t", + "--token", + required=True, + help="API Token for MiniDiscovery authentication", + ) + parser.add_argument( + "--base-url", + default=DEFAULT_BASE_URL, + help=f"Base URL of the MiniDiscovery API (default: {DEFAULT_BASE_URL})", + ) + args = parser.parse_args() + + print(f"--- MiniDiscovery Client Demo (Using API at {args.base_url}) ---") + + # 1. Register a couple of services + print("\n--- Step 1: Register Services ---") + web_id_1 = register_service( + args.token, + "web", + "192.168.1.10", + 8080, + tags=["frontend", "prod"], + metadata={"version": "1.2"}, + ) + web_id_2 = register_service( + args.token, + "web", + "192.168.1.11", + 8080, + tags=["frontend", "prod"], + metadata={"version": "1.2"}, + ) + db_id_1 = register_service( + args.token, "database", "10.0.0.5", 5432, tags=["backend", "prod-db"] + ) + + # 2. List all service names + print("\n--- Step 2: List All Service Names ---") + list_all_services(args.token) + + # 3. Query specific services + print("\n--- Step 3: Query Services ---") + get_service_instances(args.token, "web") + get_service_instances( + args.token, "database", passing_only=True + ) # Assume health check runs + get_service_instances(args.token, "nonexistent-service") + + # 4. Deregister one instance + print("\n--- Step 4: Deregister an Instance ---") + if web_id_1: + deregister_service(args.token, web_id_1) + + # 5. Query again to see the change + print("\n--- Step 5: Query 'web' Service Again ---") + get_service_instances(args.token, "web") + + print("\n--- Demo Complete ---") diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000..be5d14c --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,71 @@ +# How to Use These Scripts + +1. Register SSH Server + +The first script registers an SSH server with the MiniDiscovery service: + +bash + +chmod +x register_ssh_server.sh +./register_ssh_server.sh -t YOUR_API_TOKEN + +Options: + + -t, --token TOKEN - API token for authentication (required) + -u, --url URL - Base URL of the API (default: http://localhost:8500) + -n, --name NAME - Server name (default: ssh-server) + -a, --address ADDRESS - Server address (default: 127.0.0.1) + -p, --port PORT - SSH port (default: 22) + --tags TAGS - Comma-separated tags (default: ssh,secure) + +Example with custom values: + +bash + +./register_ssh_server.sh -t YOUR_API_TOKEN -n prod-ssh -a 192.168.1.50 -p 2222 --tags "ssh,secure,production" + +2. List Services + +The second script lists all services or instances of a specific service: + +bash + +chmod +x list_services.sh +./list_services.sh -t YOUR_API_TOKEN + +Options: + + -t, --token TOKEN - API token for authentication (required) + -u, --url URL - Base URL of the API + -s, --service NAME - List instances of a specific service + -p, --passing - Only show instances passing health checks + +Example to list specific service: + +bash + +./list_services.sh -t YOUR_API_TOKEN -s ssh-server + +3. Deregister Service + +The third script removes a service from the discovery service: + +bash + +chmod +x deregister_service.sh +./deregister_service.sh -t YOUR_API_TOKEN -i SERVICE_ID + +Options: + + -t, --token TOKEN - API token for authentication (required) + -u, --url URL - Base URL of the API + -i, --id SERVICE_ID - Service ID to deregister + --last - Deregister the last registered service (uses stored ID) + +Example using the last registered service: + +bash + +./deregister_service.sh -t YOUR_API_TOKEN --last + +Note: The registration script saves the service ID to a file called .last_registered_service_id, which the deregistration script can use with the --last option. diff --git a/scripts/deregister_service.sh b/scripts/deregister_service.sh new file mode 100644 index 0000000..281c104 --- /dev/null +++ b/scripts/deregister_service.sh @@ -0,0 +1,88 @@ +#!/bin/bash +# Script to deregister a service from MiniDiscovery service + +# Default values +BASE_URL="http://localhost:8500" +SERVICE_ID="" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + -t|--token) + API_TOKEN="$2" + shift 2 + ;; + -u|--url) + BASE_URL="$2" + shift 2 + ;; + -i|--id) + SERVICE_ID="$2" + shift 2 + ;; + --last) + # Try to read the service ID from the file created during registration + if [ -f .last_registered_service_id ]; then + SERVICE_ID=$(cat .last_registered_service_id) + else + echo "Error: No previously registered service ID found." + exit 1 + fi + shift + ;; + -h|--help) + echo "Usage: $0 -t API_TOKEN (-i SERVICE_ID | --last) [options]" + echo "Options:" + echo " -t, --token TOKEN API token for authentication (required)" + echo " -u, --url URL Base URL of the API (default: $BASE_URL)" + echo " -i, --id SERVICE_ID Service ID to deregister (required unless --last is used)" + echo " --last Deregister the last registered service" + echo " -h, --help Show this help message" + exit 0 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Check if API token is provided +if [ -z "$API_TOKEN" ]; then + echo "Error: API token is required" + echo "Use -t or --token to provide the API token" + exit 1 +fi + +# Check if service ID is provided +if [ -z "$SERVICE_ID" ]; then + echo "Error: Service ID is required" + echo "Use -i or --id to provide the service ID, or --last to use the last registered service" + exit 1 +fi + +# Make API request to deregister the service +echo "Deregistering service ID: $SERVICE_ID" + +RESPONSE=$(curl -s -X PUT \ + -H "X-API-Token: $API_TOKEN" \ + -H "Accept: application/json" \ + "${BASE_URL}/v1/agent/service/deregister/${SERVICE_ID}") + +# Check if the request was successful +HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" -X PUT \ + -H "X-API-Token: $API_TOKEN" \ + -H "Accept: application/json" \ + "${BASE_URL}/v1/agent/service/deregister/${SERVICE_ID}") + +if [ "$HTTP_CODE" -eq 200 ]; then + echo "Deregistration successful!" + # Remove the stored service ID if --last was used + if [ -f .last_registered_service_id ] && [ "$(cat .last_registered_service_id)" = "$SERVICE_ID" ]; then + rm .last_registered_service_id + fi +else + echo "Deregistration failed. HTTP code: $HTTP_CODE" + echo "Response: $RESPONSE" + exit 1 +fi diff --git a/scripts/list_services.sh b/scripts/list_services.sh new file mode 100755 index 0000000..5cd83ef --- /dev/null +++ b/scripts/list_services.sh @@ -0,0 +1,91 @@ +#!/bin/bash +# Script to list all services from MiniDiscovery service + +# Default values +BASE_URL="http://localhost:8500" +SERVICE_NAME="" +PASSING_ONLY=false + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + -t|--token) + API_TOKEN="$2" + shift 2 + ;; + -u|--url) + BASE_URL="$2" + shift 2 + ;; + -s|--service) + SERVICE_NAME="$2" + shift 2 + ;; + -p|--passing) + PASSING_ONLY=true + shift + ;; + -h|--help) + echo "Usage: $0 -t API_TOKEN [options]" + echo "Options:" + echo " -t, --token TOKEN API token for authentication (required)" + echo " -u, --url URL Base URL of the API (default: $BASE_URL)" + echo " -s, --service NAME List instances of a specific service" + echo " -p, --passing Only show instances passing health checks" + echo " -h, --help Show this help message" + exit 0 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Check if API token is provided +if [ -z "$API_TOKEN" ]; then + echo "Error: API token is required" + echo "Use -t or --token to provide the API token" + exit 1 +fi + +# Function to format JSON output +format_json() { + # Check if jq is available + if command -v jq &> /dev/null; then + echo "$1" | jq . + else + echo "$1" | python3 -m json.tool 2>/dev/null || echo "$1" + fi +} + +# If service name is provided, list instances of that service +if [ -n "$SERVICE_NAME" ]; then + echo "Listing instances of service: $SERVICE_NAME" + + if [ "$PASSING_ONLY" = true ]; then + echo "Showing only instances passing health checks..." + ENDPOINT="/v1/health/service/${SERVICE_NAME}?passing=true" + else + ENDPOINT="/v1/catalog/service/${SERVICE_NAME}" + fi + + RESPONSE=$(curl -s -X GET \ + -H "X-API-Token: $API_TOKEN" \ + -H "Accept: application/json" \ + "${BASE_URL}${ENDPOINT}") + + echo "Found $(echo "$RESPONSE" | grep -o '"ID"' | wc -l) instance(s):" + format_json "$RESPONSE" +else + # List all services + echo "Listing all services..." + + RESPONSE=$(curl -s -X GET \ + -H "X-API-Token: $API_TOKEN" \ + -H "Accept: application/json" \ + "${BASE_URL}/v1/catalog/services") + + echo "Services found:" + format_json "$RESPONSE" +fi diff --git a/scripts/register_ssh_server.sh b/scripts/register_ssh_server.sh new file mode 100755 index 0000000..f15f455 --- /dev/null +++ b/scripts/register_ssh_server.sh @@ -0,0 +1,110 @@ +#!/bin/bash +# Script to register an SSH server with MiniDiscovery service + +# Default values +BASE_URL="http://localhost:8500" +SERVER_NAME="ssh-server" +SERVER_ADDRESS="127.0.0.1" +SERVER_PORT=22 +TAGS="ssh,secure" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + -t|--token) + API_TOKEN="$2" + shift 2 + ;; + -u|--url) + BASE_URL="$2" + shift 2 + ;; + -n|--name) + SERVER_NAME="$2" + shift 2 + ;; + -a|--address) + SERVER_ADDRESS="$2" + shift 2 + ;; + -p|--port) + SERVER_PORT="$2" + shift 2 + ;; + --tags) + TAGS="$2" + shift 2 + ;; + -h|--help) + echo "Usage: $0 -t API_TOKEN [options]" + echo "Options:" + echo " -t, --token TOKEN API token for authentication (required)" + echo " -u, --url URL Base URL of the API (default: $BASE_URL)" + echo " -n, --name NAME Server name (default: $SERVER_NAME)" + echo " -a, --address ADDRESS Server address (default: $SERVER_ADDRESS)" + echo " -p, --port PORT SSH port (default: $SERVER_PORT)" + echo " --tags TAGS Comma-separated tags (default: $TAGS)" + echo " -h, --help Show this help message" + exit 0 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Check if API token is provided +if [ -z "$API_TOKEN" ]; then + echo "Error: API token is required" + echo "Use -t or --token to provide the API token" + exit 1 +fi + +# Generate a unique service ID +SERVICE_ID="${SERVER_NAME}-$(uuidgen || cat /proc/sys/kernel/random/uuid)" + +# Convert comma-separated tags to JSON array +IFS=',' read -ra TAG_ARRAY <<< "$TAGS" +TAGS_JSON=$(printf '"%s",' "${TAG_ARRAY[@]}" | sed 's/,$//') +TAGS_JSON="[$TAGS_JSON]" + +# Prepare JSON payload +JSON_PAYLOAD=$(cat < .last_registered_service_id +else + echo "Registration failed." + echo "Response: $RESPONSE" + exit 1 +fi