268 lines
10 KiB
Python
268 lines
10 KiB
Python
|
"""
|
||
|
This module mirrors SQLite operations across multiple USB devices.
|
||
|
Manages device connections, disconnections, and data integrity.
|
||
|
"""
|
||
|
|
||
|
import json
|
||
|
import logging
|
||
|
import os
|
||
|
import queue
|
||
|
import sqlite3
|
||
|
import threading
|
||
|
import time
|
||
|
from contextlib import contextmanager
|
||
|
from datetime import datetime
|
||
|
from typing import Dict, Tuple, List, Optional
|
||
|
import pyudev
|
||
|
|
||
|
logging.basicConfig(
|
||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||
|
)
|
||
|
|
||
|
|
||
|
class SQLiteMirrorProxy:
|
||
|
"""
|
||
|
A class to manage multiple SQLite database connections across USB devices.
|
||
|
|
||
|
This class provides functionality to mirror SQLite operations across multiple
|
||
|
connected USB storage devices, handle device connections/disconnections,
|
||
|
and manage data integrity and write operations.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
db_name: str = "data.sqlite",
|
||
|
mount_path: str = "/mnt",
|
||
|
max_retries: int = 3,
|
||
|
) -> None:
|
||
|
"""
|
||
|
Initialize the SQLiteMirrorProxy.
|
||
|
|
||
|
:param db_name: Name of the SQLite database file on each USB device
|
||
|
:param mount_path: Root path where USB devices are mounted
|
||
|
:param max_retries: Maximum number of retry attempts for failed write operations
|
||
|
"""
|
||
|
self.db_name = db_name
|
||
|
self.mount_path = mount_path
|
||
|
self.connections = {}
|
||
|
self.lock = threading.RLock()
|
||
|
self.write_queue = queue.Queue()
|
||
|
self.stop_event = threading.Event()
|
||
|
self.failed_writes = queue.Queue()
|
||
|
self.max_retries = max_retries
|
||
|
|
||
|
def add_storage(self, db_path: str) -> None:
|
||
|
"""Add a new SQLite connection to the database on the specified path."""
|
||
|
if os.path.exists(db_path):
|
||
|
try:
|
||
|
conn = sqlite3.connect(db_path, check_same_thread=False)
|
||
|
with self.lock:
|
||
|
self.connections[db_path] = conn
|
||
|
logging.info("Connected to %s", db_path)
|
||
|
self._log_event(db_path, "connection_added")
|
||
|
self._perform_integrity_check(db_path)
|
||
|
except sqlite3.Error as e:
|
||
|
logging.error("Failed to connect to %s: %s", db_path, e )
|
||
|
|
||
|
def remove_storage(self, db_path: str) -> None:
|
||
|
"""Remove the connection for the specified database path."""
|
||
|
with self.lock:
|
||
|
conn = self.connections.pop(db_path, None)
|
||
|
if conn:
|
||
|
conn.close()
|
||
|
logging.info("Disconnected from %s", db_path)
|
||
|
self._log_event(db_path, "connection_removed")
|
||
|
else:
|
||
|
logging.warning("Connection %s not found.", db_path)
|
||
|
|
||
|
def execute(self, query: str, params: Tuple = ()) -> None:
|
||
|
"""
|
||
|
Execute a query on all connected databases.
|
||
|
|
||
|
:param query: SQL query to execute
|
||
|
:param params: Parameters for the SQL query
|
||
|
"""
|
||
|
self.write_queue.put((query, params, 0)) # 0 is the initial retry count
|
||
|
|
||
|
def _process_write_queue(self) -> None:
|
||
|
"""Process the write queue, executing queries and handling retries."""
|
||
|
while not self.stop_event.is_set():
|
||
|
try:
|
||
|
query, params, retry_count = self.write_queue.get(timeout=1)
|
||
|
success, failures = self._execute_on_all(query, params)
|
||
|
if not success and retry_count < self.max_retries:
|
||
|
self.write_queue.put((query, params, retry_count + 1))
|
||
|
elif not success:
|
||
|
logging.error(
|
||
|
"Write operation failed after %d attempts: %s", self.max_retries, query
|
||
|
)
|
||
|
self._log_failed_write(query, params, failures)
|
||
|
except queue.Empty:
|
||
|
continue
|
||
|
|
||
|
def _execute_on_all(self, query: str, params: Tuple) -> Tuple[bool, List[str]]:
|
||
|
"""Execute a query on all connected databases."""
|
||
|
failures = []
|
||
|
success = False
|
||
|
with self.lock:
|
||
|
for db_path, conn in list(self.connections.items()):
|
||
|
try:
|
||
|
with self._transaction(conn):
|
||
|
conn.execute(query, params)
|
||
|
success = True
|
||
|
self._log_event(db_path, "write_success", {"query": query})
|
||
|
except sqlite3.Error as e:
|
||
|
logging.error(f"Failed to write to {db_path}: {e}")
|
||
|
failures.append(db_path)
|
||
|
self.remove_storage(db_path)
|
||
|
if failures:
|
||
|
logging.error(f"Write failures on: {failures}")
|
||
|
return success, failures
|
||
|
|
||
|
@contextmanager
|
||
|
def _transaction(self, conn: sqlite3.Connection):
|
||
|
"""Context manager for handling transactions."""
|
||
|
try:
|
||
|
yield
|
||
|
conn.commit()
|
||
|
except sqlite3.Error as e:
|
||
|
conn.rollback()
|
||
|
logging.error(f"Transaction failed: {e}")
|
||
|
raise e
|
||
|
|
||
|
def _log_event(
|
||
|
self, db_path: str, event_type: str, details: Optional[Dict] = None
|
||
|
) -> None:
|
||
|
"""
|
||
|
Log an event for a specific database.
|
||
|
|
||
|
:param db_path: Path to the database file
|
||
|
:param event_type: Type of event being logged
|
||
|
:param details: Additional details about the event
|
||
|
"""
|
||
|
log_path = f"{db_path}.log"
|
||
|
event = {
|
||
|
"timestamp": datetime.now().isoformat(),
|
||
|
"event_type": event_type,
|
||
|
"details": details,
|
||
|
}
|
||
|
with open(log_path, "a", encoding="utf-8") as log_file:
|
||
|
json.dump(event, log_file)
|
||
|
log_file.write("\n")
|
||
|
|
||
|
def _perform_integrity_check(self, db_path: str) -> None:
|
||
|
"""Perform an integrity check on the specified database."""
|
||
|
conn = self.connections.get(db_path)
|
||
|
if conn:
|
||
|
try:
|
||
|
cursor = conn.cursor()
|
||
|
cursor.execute("PRAGMA integrity_check")
|
||
|
result = cursor.fetchone()[0]
|
||
|
self._log_event(db_path, "integrity_check", {"result": result})
|
||
|
if result != "ok":
|
||
|
logging.warning(f"Integrity check failed for {db_path}: {result}")
|
||
|
except sqlite3.Error as e:
|
||
|
logging.error(f"Error during integrity check for {db_path}: {e}")
|
||
|
self._log_event(db_path, "integrity_check_error", {"error": str(e)})
|
||
|
|
||
|
def _log_failed_write(self, query: str, params: Tuple, failures: List[str]) -> None:
|
||
|
"""
|
||
|
Log information about failed write operations.
|
||
|
|
||
|
:param query: SQL query that failed
|
||
|
:param params: Parameters for the failed query
|
||
|
:param failures: List of database paths where the write failed
|
||
|
"""
|
||
|
for db_path in failures:
|
||
|
self._log_event(
|
||
|
db_path,
|
||
|
"write_failure",
|
||
|
{
|
||
|
"query": query,
|
||
|
"params": str(params),
|
||
|
"reason": "Max retries reached",
|
||
|
},
|
||
|
)
|
||
|
|
||
|
def monitor_usb(self) -> None:
|
||
|
"""Monitor USB devices and manage database connections accordingly."""
|
||
|
context = pyudev.Context()
|
||
|
monitor = pyudev.Monitor.from_netlink(context)
|
||
|
monitor.filter_by("block")
|
||
|
|
||
|
for device in iter(monitor.poll, None):
|
||
|
if self.stop_event.is_set():
|
||
|
break
|
||
|
if device.action == "add" and device.get("ID_FS_TYPE") == "vfat":
|
||
|
mount_path = os.path.join(self.mount_path, device.get("ID_SERIAL"))
|
||
|
logging.info(f"USB inserted: {device.get('ID_SERIAL')}")
|
||
|
self.add_storage(os.path.join(mount_path, self.db_name))
|
||
|
elif device.action == "remove":
|
||
|
mount_path = os.path.join(self.mount_path, device.get("ID_SERIAL"))
|
||
|
logging.info(f"USB removed: {device.get('ID_SERIAL')}")
|
||
|
self.remove_storage(os.path.join(mount_path, self.db_name))
|
||
|
|
||
|
def check_connections(self) -> None:
|
||
|
"""Check all database connections and attempt to reconnect if necessary."""
|
||
|
with self.lock:
|
||
|
for db_path, conn in list(self.connections.items()):
|
||
|
try:
|
||
|
conn.execute("SELECT 1")
|
||
|
except sqlite3.Error:
|
||
|
logging.warning(
|
||
|
f"Connection to {db_path} lost. Attempting to reconnect..."
|
||
|
)
|
||
|
self.remove_storage(db_path)
|
||
|
self.add_storage(db_path)
|
||
|
|
||
|
def start_monitoring(self) -> None:
|
||
|
"""Start the USB monitoring, write processing, and maintenance threads."""
|
||
|
self.monitor_thread = threading.Thread(target=self.monitor_usb)
|
||
|
self.monitor_thread.daemon = True
|
||
|
self.monitor_thread.start()
|
||
|
|
||
|
self.write_thread = threading.Thread(target=self._process_write_queue)
|
||
|
self.write_thread.daemon = True
|
||
|
self.write_thread.start()
|
||
|
|
||
|
self.maintenance_thread = threading.Thread(target=self._maintenance_loop)
|
||
|
self.maintenance_thread.daemon = True
|
||
|
self.maintenance_thread.start()
|
||
|
|
||
|
def _maintenance_loop(self) -> None:
|
||
|
"""Perform periodic maintenance tasks such as connection checks and integrity checks."""
|
||
|
while not self.stop_event.is_set():
|
||
|
self.check_connections()
|
||
|
for db_path in self.connections:
|
||
|
self._perform_integrity_check(db_path)
|
||
|
time.sleep(3600) # Run maintenance every hour
|
||
|
|
||
|
def shutdown(self) -> None:
|
||
|
"""Gracefully shut down the monitoring thread and close all connections."""
|
||
|
logging.info("Shutting down...")
|
||
|
self.stop_event.set()
|
||
|
self.monitor_thread.join()
|
||
|
self.write_thread.join()
|
||
|
self.maintenance_thread.join()
|
||
|
self.close()
|
||
|
|
||
|
def close(self) -> None:
|
||
|
"""Close all active database connections."""
|
||
|
with self.lock:
|
||
|
for db_path, conn in self.connections.items():
|
||
|
conn.close()
|
||
|
logging.info(f"Closed connection to {db_path}")
|
||
|
|
||
|
def get_connection_status(self) -> Dict[str, str]:
|
||
|
"""
|
||
|
Get the current status of all database connections.
|
||
|
|
||
|
:return: Dictionary mapping database paths to their connection status
|
||
|
"""
|
||
|
with self.lock:
|
||
|
return {db_path: "Connected" for db_path in self.connections}
|
||
|
|
||
|
def get_failed_writes_count(self) -> int:
|
||
|
"""Get the count of failed write operations."""
|
||
|
return self.failed_writes.qsize()
|