""" This module mirrors SQLite operations across multiple USB devices. Manages device connections, disconnections, and data integrity. """ import json import logging import os import queue import sqlite3 import threading import time from contextlib import contextmanager from datetime import datetime from typing import Dict, Tuple, List, Optional import pyudev logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) class SQLiteMirrorProxy: """ A class to manage multiple SQLite database connections across USB devices. This class provides functionality to mirror SQLite operations across multiple connected USB storage devices, handle device connections/disconnections, and manage data integrity and write operations. """ def __init__( self, db_name: str = "data.sqlite", mount_path: str = "/mnt", max_retries: int = 3, ) -> None: """ Initialize the SQLiteMirrorProxy. :param db_name: Name of the SQLite database file on each USB device :param mount_path: Root path where USB devices are mounted :param max_retries: Maximum number of retry attempts for failed write operations """ self.db_name = db_name self.mount_path = mount_path self.connections = {} self.lock = threading.RLock() self.write_queue = queue.Queue() self.stop_event = threading.Event() self.failed_writes = queue.Queue() self.max_retries = max_retries def add_storage(self, db_path: str) -> None: """Add a new SQLite connection to the database on the specified path.""" if os.path.exists(db_path): try: conn = sqlite3.connect(db_path, check_same_thread=False) with self.lock: self.connections[db_path] = conn logging.info("Connected to %s", db_path) self._log_event(db_path, "connection_added") self._perform_integrity_check(db_path) except sqlite3.Error as e: logging.error("Failed to connect to %s: %s", db_path, e ) def remove_storage(self, db_path: str) -> None: """Remove the connection for the specified database path.""" with self.lock: conn = self.connections.pop(db_path, None) if conn: conn.close() logging.info("Disconnected from %s", db_path) self._log_event(db_path, "connection_removed") else: logging.warning("Connection %s not found.", db_path) def execute(self, query: str, params: Tuple = ()) -> None: """ Execute a query on all connected databases. :param query: SQL query to execute :param params: Parameters for the SQL query """ self.write_queue.put((query, params, 0)) # 0 is the initial retry count def _process_write_queue(self) -> None: """Process the write queue, executing queries and handling retries.""" while not self.stop_event.is_set(): try: query, params, retry_count = self.write_queue.get(timeout=1) success, failures = self._execute_on_all(query, params) if not success and retry_count < self.max_retries: self.write_queue.put((query, params, retry_count + 1)) elif not success: logging.error( "Write operation failed after %d attempts: %s", self.max_retries, query ) self._log_failed_write(query, params, failures) except queue.Empty: continue def _execute_on_all(self, query: str, params: Tuple) -> Tuple[bool, List[str]]: """Execute a query on all connected databases.""" failures = [] success = False with self.lock: for db_path, conn in list(self.connections.items()): try: with self._transaction(conn): conn.execute(query, params) success = True self._log_event(db_path, "write_success", {"query": query}) except sqlite3.Error as e: logging.error(f"Failed to write to {db_path}: {e}") failures.append(db_path) self.remove_storage(db_path) if failures: logging.error(f"Write failures on: {failures}") return success, failures @contextmanager def _transaction(self, conn: sqlite3.Connection): """Context manager for handling transactions.""" try: yield conn.commit() except sqlite3.Error as e: conn.rollback() logging.error(f"Transaction failed: {e}") raise e def _log_event( self, db_path: str, event_type: str, details: Optional[Dict] = None ) -> None: """ Log an event for a specific database. :param db_path: Path to the database file :param event_type: Type of event being logged :param details: Additional details about the event """ log_path = f"{db_path}.log" event = { "timestamp": datetime.now().isoformat(), "event_type": event_type, "details": details, } with open(log_path, "a", encoding="utf-8") as log_file: json.dump(event, log_file) log_file.write("\n") def _perform_integrity_check(self, db_path: str) -> None: """Perform an integrity check on the specified database.""" conn = self.connections.get(db_path) if conn: try: cursor = conn.cursor() cursor.execute("PRAGMA integrity_check") result = cursor.fetchone()[0] self._log_event(db_path, "integrity_check", {"result": result}) if result != "ok": logging.warning(f"Integrity check failed for {db_path}: {result}") except sqlite3.Error as e: logging.error(f"Error during integrity check for {db_path}: {e}") self._log_event(db_path, "integrity_check_error", {"error": str(e)}) def _log_failed_write(self, query: str, params: Tuple, failures: List[str]) -> None: """ Log information about failed write operations. :param query: SQL query that failed :param params: Parameters for the failed query :param failures: List of database paths where the write failed """ for db_path in failures: self._log_event( db_path, "write_failure", { "query": query, "params": str(params), "reason": "Max retries reached", }, ) def monitor_usb(self) -> None: """Monitor USB devices and manage database connections accordingly.""" context = pyudev.Context() monitor = pyudev.Monitor.from_netlink(context) monitor.filter_by("block") for device in iter(monitor.poll, None): if self.stop_event.is_set(): break if device.action == "add" and device.get("ID_FS_TYPE") == "vfat": mount_path = os.path.join(self.mount_path, device.get("ID_SERIAL")) logging.info(f"USB inserted: {device.get('ID_SERIAL')}") self.add_storage(os.path.join(mount_path, self.db_name)) elif device.action == "remove": mount_path = os.path.join(self.mount_path, device.get("ID_SERIAL")) logging.info(f"USB removed: {device.get('ID_SERIAL')}") self.remove_storage(os.path.join(mount_path, self.db_name)) def check_connections(self) -> None: """Check all database connections and attempt to reconnect if necessary.""" with self.lock: for db_path, conn in list(self.connections.items()): try: conn.execute("SELECT 1") except sqlite3.Error: logging.warning( f"Connection to {db_path} lost. Attempting to reconnect..." ) self.remove_storage(db_path) self.add_storage(db_path) def start_monitoring(self) -> None: """Start the USB monitoring, write processing, and maintenance threads.""" self.monitor_thread = threading.Thread(target=self.monitor_usb) self.monitor_thread.daemon = True self.monitor_thread.start() self.write_thread = threading.Thread(target=self._process_write_queue) self.write_thread.daemon = True self.write_thread.start() self.maintenance_thread = threading.Thread(target=self._maintenance_loop) self.maintenance_thread.daemon = True self.maintenance_thread.start() def _maintenance_loop(self) -> None: """Perform periodic maintenance tasks such as connection checks and integrity checks.""" while not self.stop_event.is_set(): self.check_connections() for db_path in self.connections: self._perform_integrity_check(db_path) time.sleep(3600) # Run maintenance every hour def shutdown(self) -> None: """Gracefully shut down the monitoring thread and close all connections.""" logging.info("Shutting down...") self.stop_event.set() self.monitor_thread.join() self.write_thread.join() self.maintenance_thread.join() self.close() def close(self) -> None: """Close all active database connections.""" with self.lock: for db_path, conn in self.connections.items(): conn.close() logging.info(f"Closed connection to {db_path}") def get_connection_status(self) -> Dict[str, str]: """ Get the current status of all database connections. :return: Dictionary mapping database paths to their connection status """ with self.lock: return {db_path: "Connected" for db_path in self.connections} def get_failed_writes_count(self) -> int: """Get the count of failed write operations.""" return self.failed_writes.qsize()