162 lines
5.8 KiB
Python
162 lines
5.8 KiB
Python
import sqlite3
|
|
from contextlib import contextmanager
|
|
import pyudev
|
|
import os
|
|
import threading
|
|
import logging
|
|
import queue
|
|
import asyncio
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
class SQLiteMirrorProxy:
|
|
def __init__(self, db_name='data.sqlite', mount_path='/mnt'):
|
|
"""
|
|
Initializes the SQLiteMirrorProxy.
|
|
|
|
:param db_name: The name of the database file on each USB device.
|
|
:param mount_path: The root path where USB devices are mounted.
|
|
"""
|
|
self.db_name = db_name
|
|
self.mount_path = mount_path
|
|
self.connections = {}
|
|
self.lock = threading.RLock()
|
|
self.write_queue = queue.Queue()
|
|
self.stop_event = threading.Event()
|
|
|
|
def add_storage(self, db_path):
|
|
"""
|
|
Adds a new SQLite connection to the database on the specified path.
|
|
|
|
:param db_path: The full path to the SQLite database file.
|
|
"""
|
|
if os.path.exists(db_path):
|
|
try:
|
|
conn = sqlite3.connect(db_path, check_same_thread=False)
|
|
with self.lock:
|
|
self.connections[db_path] = conn
|
|
logging.info(f"Connected to {db_path}")
|
|
except sqlite3.Error as e:
|
|
logging.error(f"Failed to connect to {db_path}: {e}")
|
|
|
|
def remove_storage(self, db_path):
|
|
"""Removes the connection for the specified database path."""
|
|
with self.lock:
|
|
conn = self.connections.pop(db_path, None)
|
|
if conn:
|
|
conn.close()
|
|
logging.info(f"Disconnected from {db_path}")
|
|
else:
|
|
logging.warning(f"Connection {db_path} not found.")
|
|
|
|
def execute(self, query, params=()):
|
|
"""
|
|
Executes a query on all connected databases.
|
|
|
|
:param query: The SQL query to execute.
|
|
:param params: Parameters for the SQL query.
|
|
"""
|
|
self.write_queue.put((query, params))
|
|
|
|
def _process_write_queue(self):
|
|
while not self.stop_event.is_set():
|
|
try:
|
|
query, params = self.write_queue.get(timeout=1)
|
|
self._execute_on_all(query, params)
|
|
except queue.Empty:
|
|
continue
|
|
|
|
def _execute_on_all(self, query, params):
|
|
failures = []
|
|
with self.lock:
|
|
for db_path, conn in self.connections.items():
|
|
try:
|
|
with self._transaction(conn):
|
|
conn.execute(query, params)
|
|
except sqlite3.Error as e:
|
|
logging.error(f"Failed to write to {db_path}: {e}")
|
|
failures.append(db_path)
|
|
if failures:
|
|
logging.error(f"Write failures on: {failures}")
|
|
|
|
async def execute_async(self, query, params=()):
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, self.execute, query, params)
|
|
|
|
def commit(self):
|
|
"""Commits the current transaction on all databases."""
|
|
with self.lock:
|
|
for db_path, conn in self.connections.items():
|
|
try:
|
|
conn.commit()
|
|
except sqlite3.Error as e:
|
|
logging.error(f"Failed to commit on {db_path}: {e}")
|
|
|
|
def close(self):
|
|
"""Closes all active database connections."""
|
|
with self.lock:
|
|
for db_path, conn in self.connections.items():
|
|
conn.close()
|
|
logging.info(f"Closed connection to {db_path}")
|
|
|
|
@contextmanager
|
|
def _transaction(self, conn):
|
|
"""Context manager for handling transactions."""
|
|
try:
|
|
yield
|
|
conn.commit()
|
|
except sqlite3.Error as e:
|
|
conn.rollback()
|
|
logging.error(f"Transaction failed: {e}")
|
|
raise e
|
|
|
|
def monitor_usb(self):
|
|
"""Monitor USB devices and manage database connections accordingly."""
|
|
context = pyudev.Context()
|
|
monitor = pyudev.Monitor.from_netlink(context)
|
|
monitor.filter_by('block')
|
|
|
|
for device in iter(monitor.poll, None):
|
|
if self.stop_event.is_set():
|
|
break
|
|
if device.action == 'add' and device.get('ID_FS_TYPE') == 'vfat':
|
|
# Use dynamic mount point for devices
|
|
mount_path = os.path.join(self.mount_path, device.get('ID_SERIAL'))
|
|
logging.info(f"USB inserted: {device.get('ID_SERIAL')}")
|
|
self.add_storage(os.path.join(mount_path, self.db_name))
|
|
elif device.action == 'remove':
|
|
mount_path = os.path.join(self.mount_path, device.get('ID_SERIAL'))
|
|
logging.info(f"USB removed: {device.get('ID_SERIAL')}")
|
|
self.remove_storage(os.path.join(mount_path, self.db_name))
|
|
|
|
def start_monitoring(self):
|
|
self.monitor_thread = threading.Thread(target=self.monitor_usb)
|
|
self.monitor_thread.daemon = True
|
|
self.monitor_thread.start()
|
|
|
|
self.write_thread = threading.Thread(target=self._process_write_queue)
|
|
self.write_thread.daemon = True
|
|
self.write_thread.start()
|
|
|
|
def check_connections(self):
|
|
with self.lock:
|
|
for db_path, conn in list(self.connections.items()):
|
|
try:
|
|
conn.execute("SELECT 1")
|
|
except sqlite3.Error:
|
|
logging.warning(f"Connection to {db_path} lost. Attempting to reconnect...")
|
|
self.remove_storage(db_path)
|
|
self.add_storage(db_path)
|
|
|
|
def shutdown(self):
|
|
"""Gracefully shuts down the monitoring thread and closes connections."""
|
|
logging.info("Shutting down...")
|
|
self.stop_event.set()
|
|
if hasattr(self, 'monitor_thread'):
|
|
self.monitor_thread.join()
|
|
if hasattr(self, 'write_thread'):
|
|
self.write_thread.join()
|
|
self.close()
|
|
|
|
|