import sqlite3 from contextlib import contextmanager import pyudev import os import threading import logging import queue import asyncio logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') class SQLiteMirrorProxy: def __init__(self, db_name='data.sqlite', mount_path='/mnt'): """ Initializes the SQLiteMirrorProxy. :param db_name: The name of the database file on each USB device. :param mount_path: The root path where USB devices are mounted. """ self.db_name = db_name self.mount_path = mount_path self.connections = {} self.lock = threading.RLock() self.write_queue = queue.Queue() self.stop_event = threading.Event() def add_storage(self, db_path): """ Adds a new SQLite connection to the database on the specified path. :param db_path: The full path to the SQLite database file. """ if os.path.exists(db_path): try: conn = sqlite3.connect(db_path, check_same_thread=False) with self.lock: self.connections[db_path] = conn logging.info(f"Connected to {db_path}") except sqlite3.Error as e: logging.error(f"Failed to connect to {db_path}: {e}") def remove_storage(self, db_path): """Removes the connection for the specified database path.""" with self.lock: conn = self.connections.pop(db_path, None) if conn: conn.close() logging.info(f"Disconnected from {db_path}") else: logging.warning(f"Connection {db_path} not found.") def execute(self, query, params=()): """ Executes a query on all connected databases. :param query: The SQL query to execute. :param params: Parameters for the SQL query. """ self.write_queue.put((query, params)) def _process_write_queue(self): while not self.stop_event.is_set(): try: query, params = self.write_queue.get(timeout=1) self._execute_on_all(query, params) except queue.Empty: continue def _execute_on_all(self, query, params): failures = [] with self.lock: for db_path, conn in self.connections.items(): try: with self._transaction(conn): conn.execute(query, params) except sqlite3.Error as e: logging.error(f"Failed to write to {db_path}: {e}") failures.append(db_path) if failures: logging.error(f"Write failures on: {failures}") async def execute_async(self, query, params=()): loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self.execute, query, params) def commit(self): """Commits the current transaction on all databases.""" with self.lock: for db_path, conn in self.connections.items(): try: conn.commit() except sqlite3.Error as e: logging.error(f"Failed to commit on {db_path}: {e}") def close(self): """Closes all active database connections.""" with self.lock: for db_path, conn in self.connections.items(): conn.close() logging.info(f"Closed connection to {db_path}") @contextmanager def _transaction(self, conn): """Context manager for handling transactions.""" try: yield conn.commit() except sqlite3.Error as e: conn.rollback() logging.error(f"Transaction failed: {e}") raise e def monitor_usb(self): """Monitor USB devices and manage database connections accordingly.""" context = pyudev.Context() monitor = pyudev.Monitor.from_netlink(context) monitor.filter_by('block') for device in iter(monitor.poll, None): if self.stop_event.is_set(): break if device.action == 'add' and device.get('ID_FS_TYPE') == 'vfat': # Use dynamic mount point for devices mount_path = os.path.join(self.mount_path, device.get('ID_SERIAL')) logging.info(f"USB inserted: {device.get('ID_SERIAL')}") self.add_storage(os.path.join(mount_path, self.db_name)) elif device.action == 'remove': mount_path = os.path.join(self.mount_path, device.get('ID_SERIAL')) logging.info(f"USB removed: {device.get('ID_SERIAL')}") self.remove_storage(os.path.join(mount_path, self.db_name)) def start_monitoring(self): self.monitor_thread = threading.Thread(target=self.monitor_usb) self.monitor_thread.daemon = True self.monitor_thread.start() self.write_thread = threading.Thread(target=self._process_write_queue) self.write_thread.daemon = True self.write_thread.start() def check_connections(self): with self.lock: for db_path, conn in list(self.connections.items()): try: conn.execute("SELECT 1") except sqlite3.Error: logging.warning(f"Connection to {db_path} lost. Attempting to reconnect...") self.remove_storage(db_path) self.add_storage(db_path) def shutdown(self): """Gracefully shuts down the monitoring thread and closes connections.""" logging.info("Shutting down...") self.stop_event.set() if hasattr(self, 'monitor_thread'): self.monitor_thread.join() if hasattr(self, 'write_thread'): self.write_thread.join() self.close()