SqliteMirrorProxy/sqlitemirrorproxy_wo_async.py

162 lines
5.8 KiB
Python

import sqlite3
from contextlib import contextmanager
import pyudev
import os
import threading
import logging
import queue
import asyncio
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class SQLiteMirrorProxy:
def __init__(self, db_name='data.sqlite', mount_path='/mnt'):
"""
Initializes the SQLiteMirrorProxy.
:param db_name: The name of the database file on each USB device.
:param mount_path: The root path where USB devices are mounted.
"""
self.db_name = db_name
self.mount_path = mount_path
self.connections = {}
self.lock = threading.RLock()
self.write_queue = queue.Queue()
self.stop_event = threading.Event()
def add_storage(self, db_path):
"""
Adds a new SQLite connection to the database on the specified path.
:param db_path: The full path to the SQLite database file.
"""
if os.path.exists(db_path):
try:
conn = sqlite3.connect(db_path, check_same_thread=False)
with self.lock:
self.connections[db_path] = conn
logging.info(f"Connected to {db_path}")
except sqlite3.Error as e:
logging.error(f"Failed to connect to {db_path}: {e}")
def remove_storage(self, db_path):
"""Removes the connection for the specified database path."""
with self.lock:
conn = self.connections.pop(db_path, None)
if conn:
conn.close()
logging.info(f"Disconnected from {db_path}")
else:
logging.warning(f"Connection {db_path} not found.")
def execute(self, query, params=()):
"""
Executes a query on all connected databases.
:param query: The SQL query to execute.
:param params: Parameters for the SQL query.
"""
self.write_queue.put((query, params))
def _process_write_queue(self):
while not self.stop_event.is_set():
try:
query, params = self.write_queue.get(timeout=1)
self._execute_on_all(query, params)
except queue.Empty:
continue
def _execute_on_all(self, query, params):
failures = []
with self.lock:
for db_path, conn in self.connections.items():
try:
with self._transaction(conn):
conn.execute(query, params)
except sqlite3.Error as e:
logging.error(f"Failed to write to {db_path}: {e}")
failures.append(db_path)
if failures:
logging.error(f"Write failures on: {failures}")
async def execute_async(self, query, params=()):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.execute, query, params)
def commit(self):
"""Commits the current transaction on all databases."""
with self.lock:
for db_path, conn in self.connections.items():
try:
conn.commit()
except sqlite3.Error as e:
logging.error(f"Failed to commit on {db_path}: {e}")
def close(self):
"""Closes all active database connections."""
with self.lock:
for db_path, conn in self.connections.items():
conn.close()
logging.info(f"Closed connection to {db_path}")
@contextmanager
def _transaction(self, conn):
"""Context manager for handling transactions."""
try:
yield
conn.commit()
except sqlite3.Error as e:
conn.rollback()
logging.error(f"Transaction failed: {e}")
raise e
def monitor_usb(self):
"""Monitor USB devices and manage database connections accordingly."""
context = pyudev.Context()
monitor = pyudev.Monitor.from_netlink(context)
monitor.filter_by('block')
for device in iter(monitor.poll, None):
if self.stop_event.is_set():
break
if device.action == 'add' and device.get('ID_FS_TYPE') == 'vfat':
# Use dynamic mount point for devices
mount_path = os.path.join(self.mount_path, device.get('ID_SERIAL'))
logging.info(f"USB inserted: {device.get('ID_SERIAL')}")
self.add_storage(os.path.join(mount_path, self.db_name))
elif device.action == 'remove':
mount_path = os.path.join(self.mount_path, device.get('ID_SERIAL'))
logging.info(f"USB removed: {device.get('ID_SERIAL')}")
self.remove_storage(os.path.join(mount_path, self.db_name))
def start_monitoring(self):
self.monitor_thread = threading.Thread(target=self.monitor_usb)
self.monitor_thread.daemon = True
self.monitor_thread.start()
self.write_thread = threading.Thread(target=self._process_write_queue)
self.write_thread.daemon = True
self.write_thread.start()
def check_connections(self):
with self.lock:
for db_path, conn in list(self.connections.items()):
try:
conn.execute("SELECT 1")
except sqlite3.Error:
logging.warning(f"Connection to {db_path} lost. Attempting to reconnect...")
self.remove_storage(db_path)
self.add_storage(db_path)
def shutdown(self):
"""Gracefully shuts down the monitoring thread and closes connections."""
logging.info("Shutting down...")
self.stop_event.set()
if hasattr(self, 'monitor_thread'):
self.monitor_thread.join()
if hasattr(self, 'write_thread'):
self.write_thread.join()
self.close()