SqliteMirrorProxy/sqlitemirrorproxy.py

268 lines
10 KiB
Python
Raw Permalink Normal View History

"""
This module mirrors SQLite operations across multiple USB devices.
Manages device connections, disconnections, and data integrity.
"""
import json
import logging
import os
import queue
import sqlite3
import threading
import time
from contextlib import contextmanager
from datetime import datetime
from typing import Dict, Tuple, List, Optional
import pyudev
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class SQLiteMirrorProxy:
"""
A class to manage multiple SQLite database connections across USB devices.
This class provides functionality to mirror SQLite operations across multiple
connected USB storage devices, handle device connections/disconnections,
and manage data integrity and write operations.
"""
def __init__(
self,
db_name: str = "data.sqlite",
mount_path: str = "/mnt",
max_retries: int = 3,
) -> None:
"""
Initialize the SQLiteMirrorProxy.
:param db_name: Name of the SQLite database file on each USB device
:param mount_path: Root path where USB devices are mounted
:param max_retries: Maximum number of retry attempts for failed write operations
"""
self.db_name = db_name
self.mount_path = mount_path
self.connections = {}
self.lock = threading.RLock()
self.write_queue = queue.Queue()
self.stop_event = threading.Event()
self.failed_writes = queue.Queue()
self.max_retries = max_retries
def add_storage(self, db_path: str) -> None:
"""Add a new SQLite connection to the database on the specified path."""
if os.path.exists(db_path):
try:
conn = sqlite3.connect(db_path, check_same_thread=False)
with self.lock:
self.connections[db_path] = conn
logging.info("Connected to %s", db_path)
self._log_event(db_path, "connection_added")
self._perform_integrity_check(db_path)
except sqlite3.Error as e:
logging.error("Failed to connect to %s: %s", db_path, e )
def remove_storage(self, db_path: str) -> None:
"""Remove the connection for the specified database path."""
with self.lock:
conn = self.connections.pop(db_path, None)
if conn:
conn.close()
logging.info("Disconnected from %s", db_path)
self._log_event(db_path, "connection_removed")
else:
logging.warning("Connection %s not found.", db_path)
def execute(self, query: str, params: Tuple = ()) -> None:
"""
Execute a query on all connected databases.
:param query: SQL query to execute
:param params: Parameters for the SQL query
"""
self.write_queue.put((query, params, 0)) # 0 is the initial retry count
def _process_write_queue(self) -> None:
"""Process the write queue, executing queries and handling retries."""
while not self.stop_event.is_set():
try:
query, params, retry_count = self.write_queue.get(timeout=1)
success, failures = self._execute_on_all(query, params)
if not success and retry_count < self.max_retries:
self.write_queue.put((query, params, retry_count + 1))
elif not success:
logging.error(
"Write operation failed after %d attempts: %s", self.max_retries, query
)
self._log_failed_write(query, params, failures)
except queue.Empty:
continue
def _execute_on_all(self, query: str, params: Tuple) -> Tuple[bool, List[str]]:
"""Execute a query on all connected databases."""
failures = []
success = False
with self.lock:
for db_path, conn in list(self.connections.items()):
try:
with self._transaction(conn):
conn.execute(query, params)
success = True
self._log_event(db_path, "write_success", {"query": query})
except sqlite3.Error as e:
logging.error(f"Failed to write to {db_path}: {e}")
failures.append(db_path)
self.remove_storage(db_path)
if failures:
logging.error(f"Write failures on: {failures}")
return success, failures
@contextmanager
def _transaction(self, conn: sqlite3.Connection):
"""Context manager for handling transactions."""
try:
yield
conn.commit()
except sqlite3.Error as e:
conn.rollback()
logging.error(f"Transaction failed: {e}")
raise e
def _log_event(
self, db_path: str, event_type: str, details: Optional[Dict] = None
) -> None:
"""
Log an event for a specific database.
:param db_path: Path to the database file
:param event_type: Type of event being logged
:param details: Additional details about the event
"""
log_path = f"{db_path}.log"
event = {
"timestamp": datetime.now().isoformat(),
"event_type": event_type,
"details": details,
}
with open(log_path, "a", encoding="utf-8") as log_file:
json.dump(event, log_file)
log_file.write("\n")
def _perform_integrity_check(self, db_path: str) -> None:
"""Perform an integrity check on the specified database."""
conn = self.connections.get(db_path)
if conn:
try:
cursor = conn.cursor()
cursor.execute("PRAGMA integrity_check")
result = cursor.fetchone()[0]
self._log_event(db_path, "integrity_check", {"result": result})
if result != "ok":
logging.warning(f"Integrity check failed for {db_path}: {result}")
except sqlite3.Error as e:
logging.error(f"Error during integrity check for {db_path}: {e}")
self._log_event(db_path, "integrity_check_error", {"error": str(e)})
def _log_failed_write(self, query: str, params: Tuple, failures: List[str]) -> None:
"""
Log information about failed write operations.
:param query: SQL query that failed
:param params: Parameters for the failed query
:param failures: List of database paths where the write failed
"""
for db_path in failures:
self._log_event(
db_path,
"write_failure",
{
"query": query,
"params": str(params),
"reason": "Max retries reached",
},
)
def monitor_usb(self) -> None:
"""Monitor USB devices and manage database connections accordingly."""
context = pyudev.Context()
monitor = pyudev.Monitor.from_netlink(context)
monitor.filter_by("block")
for device in iter(monitor.poll, None):
if self.stop_event.is_set():
break
if device.action == "add" and device.get("ID_FS_TYPE") == "vfat":
mount_path = os.path.join(self.mount_path, device.get("ID_SERIAL"))
logging.info(f"USB inserted: {device.get('ID_SERIAL')}")
self.add_storage(os.path.join(mount_path, self.db_name))
elif device.action == "remove":
mount_path = os.path.join(self.mount_path, device.get("ID_SERIAL"))
logging.info(f"USB removed: {device.get('ID_SERIAL')}")
self.remove_storage(os.path.join(mount_path, self.db_name))
def check_connections(self) -> None:
"""Check all database connections and attempt to reconnect if necessary."""
with self.lock:
for db_path, conn in list(self.connections.items()):
try:
conn.execute("SELECT 1")
except sqlite3.Error:
logging.warning(
f"Connection to {db_path} lost. Attempting to reconnect..."
)
self.remove_storage(db_path)
self.add_storage(db_path)
def start_monitoring(self) -> None:
"""Start the USB monitoring, write processing, and maintenance threads."""
self.monitor_thread = threading.Thread(target=self.monitor_usb)
self.monitor_thread.daemon = True
self.monitor_thread.start()
self.write_thread = threading.Thread(target=self._process_write_queue)
self.write_thread.daemon = True
self.write_thread.start()
self.maintenance_thread = threading.Thread(target=self._maintenance_loop)
self.maintenance_thread.daemon = True
self.maintenance_thread.start()
def _maintenance_loop(self) -> None:
"""Perform periodic maintenance tasks such as connection checks and integrity checks."""
while not self.stop_event.is_set():
self.check_connections()
for db_path in self.connections:
self._perform_integrity_check(db_path)
time.sleep(3600) # Run maintenance every hour
def shutdown(self) -> None:
"""Gracefully shut down the monitoring thread and close all connections."""
logging.info("Shutting down...")
self.stop_event.set()
self.monitor_thread.join()
self.write_thread.join()
self.maintenance_thread.join()
self.close()
def close(self) -> None:
"""Close all active database connections."""
with self.lock:
for db_path, conn in self.connections.items():
conn.close()
logging.info(f"Closed connection to {db_path}")
def get_connection_status(self) -> Dict[str, str]:
"""
Get the current status of all database connections.
:return: Dictionary mapping database paths to their connection status
"""
with self.lock:
return {db_path: "Connected" for db_path in self.connections}
def get_failed_writes_count(self) -> int:
"""Get the count of failed write operations."""
return self.failed_writes.qsize()