From 7460ea7718b4ebbb30746bf1da9996109b8cc8fc Mon Sep 17 00:00:00 2001 From: kalzu rekku Date: Mon, 7 Oct 2024 23:59:54 +0300 Subject: [PATCH] Made sqlite mirroring tool with Claude and ChatGPT. --- sqlitemirrorproxy.py | 267 ++++++++++++++++++++++++++++++++++ sqlitemirrorproxy_wo_async.py | 161 ++++++++++++++++++++ 2 files changed, 428 insertions(+) create mode 100644 sqlitemirrorproxy.py create mode 100644 sqlitemirrorproxy_wo_async.py diff --git a/sqlitemirrorproxy.py b/sqlitemirrorproxy.py new file mode 100644 index 0000000..2e76e5d --- /dev/null +++ b/sqlitemirrorproxy.py @@ -0,0 +1,267 @@ +""" +This module mirrors SQLite operations across multiple USB devices. +Manages device connections, disconnections, and data integrity. +""" + +import json +import logging +import os +import queue +import sqlite3 +import threading +import time +from contextlib import contextmanager +from datetime import datetime +from typing import Dict, Tuple, List, Optional +import pyudev + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +class SQLiteMirrorProxy: + """ + A class to manage multiple SQLite database connections across USB devices. + + This class provides functionality to mirror SQLite operations across multiple + connected USB storage devices, handle device connections/disconnections, + and manage data integrity and write operations. + """ + + def __init__( + self, + db_name: str = "data.sqlite", + mount_path: str = "/mnt", + max_retries: int = 3, + ) -> None: + """ + Initialize the SQLiteMirrorProxy. + + :param db_name: Name of the SQLite database file on each USB device + :param mount_path: Root path where USB devices are mounted + :param max_retries: Maximum number of retry attempts for failed write operations + """ + self.db_name = db_name + self.mount_path = mount_path + self.connections = {} + self.lock = threading.RLock() + self.write_queue = queue.Queue() + self.stop_event = threading.Event() + self.failed_writes = queue.Queue() + self.max_retries = max_retries + + def add_storage(self, db_path: str) -> None: + """Add a new SQLite connection to the database on the specified path.""" + if os.path.exists(db_path): + try: + conn = sqlite3.connect(db_path, check_same_thread=False) + with self.lock: + self.connections[db_path] = conn + logging.info("Connected to %s", db_path) + self._log_event(db_path, "connection_added") + self._perform_integrity_check(db_path) + except sqlite3.Error as e: + logging.error("Failed to connect to %s: %s", db_path, e ) + + def remove_storage(self, db_path: str) -> None: + """Remove the connection for the specified database path.""" + with self.lock: + conn = self.connections.pop(db_path, None) + if conn: + conn.close() + logging.info("Disconnected from %s", db_path) + self._log_event(db_path, "connection_removed") + else: + logging.warning("Connection %s not found.", db_path) + + def execute(self, query: str, params: Tuple = ()) -> None: + """ + Execute a query on all connected databases. + + :param query: SQL query to execute + :param params: Parameters for the SQL query + """ + self.write_queue.put((query, params, 0)) # 0 is the initial retry count + + def _process_write_queue(self) -> None: + """Process the write queue, executing queries and handling retries.""" + while not self.stop_event.is_set(): + try: + query, params, retry_count = self.write_queue.get(timeout=1) + success, failures = self._execute_on_all(query, params) + if not success and retry_count < self.max_retries: + self.write_queue.put((query, params, retry_count + 1)) + elif not success: + logging.error( + "Write operation failed after %d attempts: %s", self.max_retries, query + ) + self._log_failed_write(query, params, failures) + except queue.Empty: + continue + + def _execute_on_all(self, query: str, params: Tuple) -> Tuple[bool, List[str]]: + """Execute a query on all connected databases.""" + failures = [] + success = False + with self.lock: + for db_path, conn in list(self.connections.items()): + try: + with self._transaction(conn): + conn.execute(query, params) + success = True + self._log_event(db_path, "write_success", {"query": query}) + except sqlite3.Error as e: + logging.error(f"Failed to write to {db_path}: {e}") + failures.append(db_path) + self.remove_storage(db_path) + if failures: + logging.error(f"Write failures on: {failures}") + return success, failures + + @contextmanager + def _transaction(self, conn: sqlite3.Connection): + """Context manager for handling transactions.""" + try: + yield + conn.commit() + except sqlite3.Error as e: + conn.rollback() + logging.error(f"Transaction failed: {e}") + raise e + + def _log_event( + self, db_path: str, event_type: str, details: Optional[Dict] = None + ) -> None: + """ + Log an event for a specific database. + + :param db_path: Path to the database file + :param event_type: Type of event being logged + :param details: Additional details about the event + """ + log_path = f"{db_path}.log" + event = { + "timestamp": datetime.now().isoformat(), + "event_type": event_type, + "details": details, + } + with open(log_path, "a", encoding="utf-8") as log_file: + json.dump(event, log_file) + log_file.write("\n") + + def _perform_integrity_check(self, db_path: str) -> None: + """Perform an integrity check on the specified database.""" + conn = self.connections.get(db_path) + if conn: + try: + cursor = conn.cursor() + cursor.execute("PRAGMA integrity_check") + result = cursor.fetchone()[0] + self._log_event(db_path, "integrity_check", {"result": result}) + if result != "ok": + logging.warning(f"Integrity check failed for {db_path}: {result}") + except sqlite3.Error as e: + logging.error(f"Error during integrity check for {db_path}: {e}") + self._log_event(db_path, "integrity_check_error", {"error": str(e)}) + + def _log_failed_write(self, query: str, params: Tuple, failures: List[str]) -> None: + """ + Log information about failed write operations. + + :param query: SQL query that failed + :param params: Parameters for the failed query + :param failures: List of database paths where the write failed + """ + for db_path in failures: + self._log_event( + db_path, + "write_failure", + { + "query": query, + "params": str(params), + "reason": "Max retries reached", + }, + ) + + def monitor_usb(self) -> None: + """Monitor USB devices and manage database connections accordingly.""" + context = pyudev.Context() + monitor = pyudev.Monitor.from_netlink(context) + monitor.filter_by("block") + + for device in iter(monitor.poll, None): + if self.stop_event.is_set(): + break + if device.action == "add" and device.get("ID_FS_TYPE") == "vfat": + mount_path = os.path.join(self.mount_path, device.get("ID_SERIAL")) + logging.info(f"USB inserted: {device.get('ID_SERIAL')}") + self.add_storage(os.path.join(mount_path, self.db_name)) + elif device.action == "remove": + mount_path = os.path.join(self.mount_path, device.get("ID_SERIAL")) + logging.info(f"USB removed: {device.get('ID_SERIAL')}") + self.remove_storage(os.path.join(mount_path, self.db_name)) + + def check_connections(self) -> None: + """Check all database connections and attempt to reconnect if necessary.""" + with self.lock: + for db_path, conn in list(self.connections.items()): + try: + conn.execute("SELECT 1") + except sqlite3.Error: + logging.warning( + f"Connection to {db_path} lost. Attempting to reconnect..." + ) + self.remove_storage(db_path) + self.add_storage(db_path) + + def start_monitoring(self) -> None: + """Start the USB monitoring, write processing, and maintenance threads.""" + self.monitor_thread = threading.Thread(target=self.monitor_usb) + self.monitor_thread.daemon = True + self.monitor_thread.start() + + self.write_thread = threading.Thread(target=self._process_write_queue) + self.write_thread.daemon = True + self.write_thread.start() + + self.maintenance_thread = threading.Thread(target=self._maintenance_loop) + self.maintenance_thread.daemon = True + self.maintenance_thread.start() + + def _maintenance_loop(self) -> None: + """Perform periodic maintenance tasks such as connection checks and integrity checks.""" + while not self.stop_event.is_set(): + self.check_connections() + for db_path in self.connections: + self._perform_integrity_check(db_path) + time.sleep(3600) # Run maintenance every hour + + def shutdown(self) -> None: + """Gracefully shut down the monitoring thread and close all connections.""" + logging.info("Shutting down...") + self.stop_event.set() + self.monitor_thread.join() + self.write_thread.join() + self.maintenance_thread.join() + self.close() + + def close(self) -> None: + """Close all active database connections.""" + with self.lock: + for db_path, conn in self.connections.items(): + conn.close() + logging.info(f"Closed connection to {db_path}") + + def get_connection_status(self) -> Dict[str, str]: + """ + Get the current status of all database connections. + + :return: Dictionary mapping database paths to their connection status + """ + with self.lock: + return {db_path: "Connected" for db_path in self.connections} + + def get_failed_writes_count(self) -> int: + """Get the count of failed write operations.""" + return self.failed_writes.qsize() diff --git a/sqlitemirrorproxy_wo_async.py b/sqlitemirrorproxy_wo_async.py new file mode 100644 index 0000000..ad263df --- /dev/null +++ b/sqlitemirrorproxy_wo_async.py @@ -0,0 +1,161 @@ +import sqlite3 +from contextlib import contextmanager +import pyudev +import os +import threading +import logging +import queue +import asyncio + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +class SQLiteMirrorProxy: + def __init__(self, db_name='data.sqlite', mount_path='/mnt'): + """ + Initializes the SQLiteMirrorProxy. + + :param db_name: The name of the database file on each USB device. + :param mount_path: The root path where USB devices are mounted. + """ + self.db_name = db_name + self.mount_path = mount_path + self.connections = {} + self.lock = threading.RLock() + self.write_queue = queue.Queue() + self.stop_event = threading.Event() + + def add_storage(self, db_path): + """ + Adds a new SQLite connection to the database on the specified path. + + :param db_path: The full path to the SQLite database file. + """ + if os.path.exists(db_path): + try: + conn = sqlite3.connect(db_path, check_same_thread=False) + with self.lock: + self.connections[db_path] = conn + logging.info(f"Connected to {db_path}") + except sqlite3.Error as e: + logging.error(f"Failed to connect to {db_path}: {e}") + + def remove_storage(self, db_path): + """Removes the connection for the specified database path.""" + with self.lock: + conn = self.connections.pop(db_path, None) + if conn: + conn.close() + logging.info(f"Disconnected from {db_path}") + else: + logging.warning(f"Connection {db_path} not found.") + + def execute(self, query, params=()): + """ + Executes a query on all connected databases. + + :param query: The SQL query to execute. + :param params: Parameters for the SQL query. + """ + self.write_queue.put((query, params)) + + def _process_write_queue(self): + while not self.stop_event.is_set(): + try: + query, params = self.write_queue.get(timeout=1) + self._execute_on_all(query, params) + except queue.Empty: + continue + + def _execute_on_all(self, query, params): + failures = [] + with self.lock: + for db_path, conn in self.connections.items(): + try: + with self._transaction(conn): + conn.execute(query, params) + except sqlite3.Error as e: + logging.error(f"Failed to write to {db_path}: {e}") + failures.append(db_path) + if failures: + logging.error(f"Write failures on: {failures}") + + async def execute_async(self, query, params=()): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.execute, query, params) + + def commit(self): + """Commits the current transaction on all databases.""" + with self.lock: + for db_path, conn in self.connections.items(): + try: + conn.commit() + except sqlite3.Error as e: + logging.error(f"Failed to commit on {db_path}: {e}") + + def close(self): + """Closes all active database connections.""" + with self.lock: + for db_path, conn in self.connections.items(): + conn.close() + logging.info(f"Closed connection to {db_path}") + + @contextmanager + def _transaction(self, conn): + """Context manager for handling transactions.""" + try: + yield + conn.commit() + except sqlite3.Error as e: + conn.rollback() + logging.error(f"Transaction failed: {e}") + raise e + + def monitor_usb(self): + """Monitor USB devices and manage database connections accordingly.""" + context = pyudev.Context() + monitor = pyudev.Monitor.from_netlink(context) + monitor.filter_by('block') + + for device in iter(monitor.poll, None): + if self.stop_event.is_set(): + break + if device.action == 'add' and device.get('ID_FS_TYPE') == 'vfat': + # Use dynamic mount point for devices + mount_path = os.path.join(self.mount_path, device.get('ID_SERIAL')) + logging.info(f"USB inserted: {device.get('ID_SERIAL')}") + self.add_storage(os.path.join(mount_path, self.db_name)) + elif device.action == 'remove': + mount_path = os.path.join(self.mount_path, device.get('ID_SERIAL')) + logging.info(f"USB removed: {device.get('ID_SERIAL')}") + self.remove_storage(os.path.join(mount_path, self.db_name)) + + def start_monitoring(self): + self.monitor_thread = threading.Thread(target=self.monitor_usb) + self.monitor_thread.daemon = True + self.monitor_thread.start() + + self.write_thread = threading.Thread(target=self._process_write_queue) + self.write_thread.daemon = True + self.write_thread.start() + + def check_connections(self): + with self.lock: + for db_path, conn in list(self.connections.items()): + try: + conn.execute("SELECT 1") + except sqlite3.Error: + logging.warning(f"Connection to {db_path} lost. Attempting to reconnect...") + self.remove_storage(db_path) + self.add_storage(db_path) + + def shutdown(self): + """Gracefully shuts down the monitoring thread and closes connections.""" + logging.info("Shutting down...") + self.stop_event.set() + if hasattr(self, 'monitor_thread'): + self.monitor_thread.join() + if hasattr(self, 'write_thread'): + self.write_thread.join() + self.close() + +