sliding_sqlite/SlidingSqlite.py
2025-04-05 22:45:18 +03:00

811 lines
30 KiB
Python

import sqlite3
import uuid
import threading
import datetime
import queue
import os
import logging
import time
from typing import (
Any,
Dict,
Optional,
Tuple,
List,
Union,
Set,
NamedTuple,
TypeVar,
Generic,
Callable,
)
T = TypeVar("T")
class DatabaseError(Exception):
"""Base exception for all database errors"""
pass
class QueryError(DatabaseError):
"""Exception raised when a query fails"""
pass
class QueryResult(Generic[T]):
"""Class to handle query results with proper error handling"""
def __init__(self, data: Optional[T] = None, error: Optional[Exception] = None):
self.data = data
self.error = error
self.success = error is None and data is not None
def __bool__(self) -> bool:
return self.success
class DatabaseTimeframe(NamedTuple):
"""Represents a database file and its time range"""
db_file: str
start_time: float
end_time: float
class QueryLockManager:
def __init__(self, sliding_sqlite, query_id):
self.sliding_sqlite = sliding_sqlite
self.query_id = query_id
self.lock = sliding_sqlite.query_lock
self.is_active = False
def __enter__(self):
# Acquire the lock and check query status
with self.lock:
self.is_active = (
self.query_id in self.sliding_sqlite.read_queues
and self.query_id in self.sliding_sqlite.active_queries
)
return self.is_active
def __exit__(self, exc_type, exc_val, exc_tb):
# If there was an exception, we don't need to do anything
pass
class SlidingSQLite:
"""
Thread-safe SQLite implementation with automatic time-based database rotation.
This class provides a way to safely use SQLite in a multi-threaded environment
by queuing database operations and processing them in dedicated worker threads.
Databases are created based on the specified rotation interval and old databases
are automatically cleaned up based on the specified retention period.
"""
def __init__(
self,
db_dir: str,
schema: str,
retention_period: int = 604800,
rotation_interval: int = 3600,
cleanup_interval: int = 3600,
auto_delete_old_dbs: bool = True,
) -> None:
"""
Initialize the SlidingSQLite instance.
Args:
db_dir: Directory to store database files
schema: SQL schema to initialize new databases
retention_period: Number of seconds to keep databases before deletion (default: 7 days)
rotation_interval: How often to rotate to a new database file in seconds (default: 1 hour)
cleanup_interval: How often to run the cleanup process in seconds (default: 1 hour)
auto_delete_old_dbs: Whether to automatically delete old databases (default: True)
"""
self.db_dir = db_dir
self.schema = schema
self.retention_period = retention_period # In seconds
self.rotation_interval = rotation_interval # In seconds
self.cleanup_interval = cleanup_interval # In seconds
self.auto_delete_old_dbs = auto_delete_old_dbs
# Queues for operations
self.write_queue: queue.Queue[Tuple[str, Tuple[Any, ...], uuid.UUID]] = (
queue.Queue()
)
self.result_queues: Dict[uuid.UUID, queue.Queue[QueryResult[bool]]] = {}
self.read_queues: Dict[
uuid.UUID, queue.Queue[QueryResult[List[Tuple[Any, ...]]]]
] = {}
# Thread synchronization
self.shutdown_flag = threading.Event()
self.worker_thread = None
self.cleanup_thread = None
self._init_complete = threading.Event() # Added to delay cleanup worker
# Cache for database connections
self.connections: Dict[str, sqlite3.Connection] = {}
self.conn_lock = threading.Lock()
# Thread-local storage for metadata connections
self._thread_local = threading.local()
# Track active query IDs for cleanup
self.active_queries: Set[uuid.UUID] = set()
self.query_lock = threading.Lock()
# Cache for current database
self._current_db_cache = None
self._current_db_expiry = 0
# Initialize system
self._setup()
self._init_complete.set() # Signal that initialization is complete
def _setup(self) -> None:
"""Setup the database directory and initialize workers"""
try:
print("Creating database directory...")
os.makedirs(self.db_dir, exist_ok=True)
print("Initializing metadata database...")
self._init_metadata()
print("Registering current database...")
self._register_current_db()
print("Starting write worker thread...")
self._start_worker()
print("Starting cleanup worker thread...")
self._start_cleanup_worker()
except Exception as e:
logging.error(f"Failed to initialize SlidingSQLite: {e}")
raise DatabaseError(f"Failed to initialize SlidingSQLite: {e}")
def _init_metadata(self) -> None:
"""Initialize the metadata database"""
try:
with self._get_connection(self._get_metadata_db()) as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS metadata (
id INTEGER PRIMARY KEY AUTOINCREMENT,
db_file TEXT NOT NULL UNIQUE,
start_time REAL NOT NULL,
end_time REAL NOT NULL
)
"""
)
conn.commit()
except sqlite3.Error as e:
logging.error(f"Failed to initialize metadata database: {e}")
raise DatabaseError(f"Failed to initialize metadata database: {e}")
def _get_connection(self, db_file: str) -> sqlite3.Connection:
"""
Get a connection to the specified database file.
Reuses existing connections when possible and applies schema only if needed.
Args:
db_file: Path to the database file
Returns:
SQLite connection object
"""
if db_file == self._get_metadata_db():
# Use thread-local storage for metadata DB to avoid threading issues
if not hasattr(self._thread_local, "metadata_conn"):
conn = sqlite3.connect(
db_file, isolation_level=None, check_same_thread=False
)
conn.execute("PRAGMA journal_mode=WAL;")
conn.execute("PRAGMA busy_timeout=5000;")
self._thread_local.metadata_conn = conn
return self._thread_local.metadata_conn
with self.conn_lock:
if db_file not in self.connections or self.connections[db_file] is None:
try:
conn = sqlite3.connect(
db_file, isolation_level=None, check_same_thread=False
)
conn.execute("PRAGMA journal_mode=WAL;")
conn.execute("PRAGMA busy_timeout=5000;")
# Only apply schema if the database is new
if db_file != self._get_metadata_db() and os.path.exists(db_file):
# Check if schema is already applied by testing for a known table
try:
conn.execute("SELECT 1 FROM node_auth LIMIT 1")
except sqlite3.Error:
# Schema not applied yet, apply it
try:
conn.executescript(self.schema)
conn.commit()
except sqlite3.Error as e:
logging.error(f"Failed to apply schema to {db_file}: {e}")
conn.close()
raise DatabaseError(f"Failed to apply schema to {db_file}: {e}")
elif db_file != self._get_metadata_db():
# New file, apply schema
try:
conn.executescript(self.schema)
conn.commit()
except sqlite3.Error as e:
logging.error(f"Failed to apply schema to new {db_file}: {e}")
conn.close()
raise DatabaseError(f"Failed to apply schema to {db_file}: {e}")
self.connections[db_file] = conn
except sqlite3.Error as e:
logging.error(f"Failed to connect to database {db_file}: {e}")
raise DatabaseError(f"Failed to connect to database {db_file}: {e}")
return self.connections[db_file]
def _get_metadata_db(self) -> str:
"""Get the path to the metadata database"""
return os.path.join(self.db_dir, "metadata.db")
def _get_current_db(self) -> str:
"""Get the path to the current time-based database, using cache"""
now = time.time()
if now >= self._current_db_expiry or not self._current_db_cache:
# Generate timestamped DB name based on rotation interval
interval_timestamp = int(now // self.rotation_interval) * self.rotation_interval
timestamp_str = datetime.datetime.fromtimestamp(interval_timestamp).strftime(
"%Y%m%d_%H%M%S"
)
self._current_db_cache = os.path.join(self.db_dir, f"data_{timestamp_str}.db")
self._current_db_expiry = interval_timestamp + self.rotation_interval
return self._current_db_cache
def _register_current_db(self) -> None:
"""Register the current database in the metadata table"""
current_db = self._get_current_db()
if not isinstance(current_db, str) or not current_db:
logging.error(f"Invalid current_db path: {current_db}")
return
now = time.time()
# Calculate time boundaries for the current database
interval_start = int(now // self.rotation_interval) * self.rotation_interval
interval_end = interval_start + self.rotation_interval
with self.conn_lock: # Synchronize access to prevent race conditions
try:
with self._get_connection(self._get_metadata_db()) as conn:
# Check if this database is already registered
cursor = conn.execute(
"SELECT id FROM metadata WHERE db_file = ?", (current_db,)
)
existing = cursor.fetchone()
if not existing:
conn.execute(
"INSERT OR IGNORE INTO metadata (db_file, start_time, end_time) VALUES (?, ?, ?)",
(current_db, interval_start, interval_end),
)
conn.commit()
logging.debug(f"Registered new database: {current_db}")
else:
logging.debug(f"Database {current_db} already registered")
except sqlite3.Error as e:
logging.error(f"Failed to register current database: {e}")
# Continue execution as this is not critical
def _rotate_databases(self) -> None:
"""Delete databases that are older than the retention period"""
if not self.auto_delete_old_dbs:
return
cutoff_time = time.time() - self.retention_period
try:
with self._get_connection(self._get_metadata_db()) as conn:
# Find databases older than the cutoff time
old_dbs = conn.execute(
"SELECT db_file FROM metadata WHERE end_time < ?", (cutoff_time,)
).fetchall()
# Delete old database files
for (db_file,) in old_dbs:
self._delete_database_file(db_file)
# Clean up metadata entries
conn.execute("DELETE FROM metadata WHERE end_time < ?", (cutoff_time,))
conn.commit()
except sqlite3.Error as e:
logging.error(f"Database rotation error: {e}")
def _cleanup_stale_queries(self) -> None:
"""Clean up stale query results to prevent memory leaks"""
with self.query_lock:
# Find completed queries to clean up
completed_queries = set()
for query_id in list(self.result_queues.keys()):
if query_id not in self.active_queries:
completed_queries.add(query_id)
for query_id in list(self.read_queues.keys()):
if query_id not in self.active_queries:
completed_queries.add(query_id)
# Remove completed queries from dictionaries
for query_id in completed_queries:
if query_id in self.result_queues:
del self.result_queues[query_id]
if query_id in self.read_queues:
del self.read_queues[query_id]
def _delete_database_file(self, db_file: str) -> bool:
"""
Delete a database file and clean up resources.
"""
# Close and remove connection if it exists
with self.conn_lock:
if db_file in self.connections:
try:
self.connections[db_file].close()
except sqlite3.Error:
pass
del self.connections[db_file]
if db_file == self._get_metadata_db() and hasattr(self._thread_local, "metadata_conn"):
try:
self._thread_local.metadata_conn.close()
except sqlite3.Error:
pass
delattr(self._thread_local, "metadata_conn")
if os.path.exists(db_file):
try:
os.remove(db_file)
logging.info(f"Deleted database: {db_file}")
return True
except OSError as e:
logging.error(f"Failed to delete database {db_file}: {e}")
return False
return False
def set_retention_period(self, seconds: int) -> None:
"""
Set the retention period for databases.
Args:
seconds: Number of seconds to keep databases before deletion
"""
self.retention_period = max(0, seconds)
def set_auto_delete(self, enabled: bool) -> None:
"""
Enable or disable automatic deletion of old databases.
Args:
enabled: Whether to automatically delete old databases
"""
self.auto_delete_old_dbs = enabled
def delete_databases_before(self, timestamp: float) -> int:
"""
Delete all databases with end_time before the specified timestamp.
Args:
timestamp: Unix timestamp (seconds since epoch)
Returns:
Number of databases deleted
"""
count = 0
try:
with self._get_connection(self._get_metadata_db()) as conn:
# Find databases older than the specified time
old_dbs = conn.execute(
"SELECT db_file FROM metadata WHERE end_time < ?", (timestamp,)
).fetchall()
# Delete old database files
for (db_file,) in old_dbs:
if self._delete_database_file(db_file):
count += 1
# Clean up metadata entries
conn.execute("DELETE FROM metadata WHERE end_time < ?", (timestamp,))
conn.commit()
except sqlite3.Error as e:
logging.error(f"Database deletion error: {e}")
raise DatabaseError(f"Failed to delete databases: {e}")
return count
def delete_databases_in_range(self, start_time: float, end_time: float) -> int:
"""
Delete all databases with time ranges falling within the specified period.
Args:
start_time: Start of time range (unix timestamp)
end_time: End of time range (unix timestamp)
Returns:
Number of databases deleted
"""
count = 0
try:
with self._get_connection(self._get_metadata_db()) as conn:
# Find databases in the specified time range
# A database is in range if its time range overlaps with the specified range
dbs = conn.execute(
"""
SELECT db_file FROM metadata
WHERE (start_time <= ? AND end_time >= ?) OR
(start_time <= ? AND end_time >= ?) OR
(start_time >= ? AND end_time <= ?)
""",
(end_time, start_time, end_time, start_time, start_time, end_time),
).fetchall()
# Delete database files
for (db_file,) in dbs:
if self._delete_database_file(db_file):
count += 1
# Clean up metadata entries
conn.execute(
"""
DELETE FROM metadata
WHERE (start_time <= ? AND end_time >= ?) OR
(start_time <= ? AND end_time >= ?) OR
(start_time >= ? AND end_time <= ?)
""",
(end_time, start_time, end_time, start_time, start_time, end_time),
)
conn.commit()
except sqlite3.Error as e:
logging.error(f"Database deletion error: {e}")
raise DatabaseError(f"Failed to delete databases: {e}")
return count
def get_databases_info(self) -> List[DatabaseTimeframe]:
"""
Get information about all available databases.
Returns:
List of DatabaseTimeframe objects containing database file paths and time ranges
"""
databases = []
try:
with self._get_connection(self._get_metadata_db()) as conn:
rows = conn.execute(
"SELECT db_file, start_time, end_time FROM metadata ORDER BY start_time"
).fetchall()
for db_file, start_time, end_time in rows:
databases.append(DatabaseTimeframe(db_file, start_time, end_time))
except sqlite3.Error as e:
logging.error(f"Error retrieving database info: {e}")
raise DatabaseError(f"Failed to retrieve database info: {e}")
return databases
def execute(self, query: str, params: Tuple[Any, ...] = ()) -> uuid.UUID:
"""
Smart query executor that automatically determines if the query
is a read or write operation and routes accordingly.
Args:
query: SQL query to execute
params: Parameters for the query
Returns:
UUID that can be used to retrieve the result
"""
if not self._current_db_cache or time.time() >= self._current_db_expiry:
self._register_current_db()
query_upper = query.strip().upper()
# Check if the query is a read operation
if (
query_upper.startswith("SELECT")
or query_upper.startswith("PRAGMA")
or query_upper.startswith("EXPLAIN")
):
return self.execute_read(query, params)
else:
return self.execute_write(query, params)
def execute_write(self, query: str, params: Tuple[Any, ...] = ()) -> uuid.UUID:
"""
Execute a write query asynchronously.
Args:
query: SQL query to execute
params: Parameters for the query
Returns:
UUID that can be used to retrieve the result
"""
if not self._current_db_cache or time.time() >= self._current_db_expiry:
self._register_current_db()
query_id = uuid.uuid4()
with self.query_lock:
self.result_queues[query_id] = queue.Queue()
self.active_queries.add(query_id)
self.write_queue.put((query, params, query_id))
return query_id
def execute_write_sync(
self, query: str, params: Tuple[Any, ...] = (), timeout: float = 5.0
) -> QueryResult[bool]:
"""
Execute a write query synchronously.
Args:
query: SQL query to execute
params: Parameters for the query
timeout: Maximum time to wait for the result
Returns:
QueryResult containing success status and any error
"""
query_id = self.execute_write(query, params)
return self.get_result(query_id, timeout)
def execute_read(self, query: str, params: Tuple[Any, ...] = ()) -> uuid.UUID:
"""
Execute a read query asynchronously across all relevant databases.
This provides transparent access to all time-windowed data.
Args:
query: SQL query to execute
params: Parameters for the query
Returns:
UUID that can be used to retrieve the result
"""
if not self._current_db_cache or time.time() >= self._current_db_expiry:
self._register_current_db()
query_id = uuid.uuid4()
with self.query_lock:
self.read_queues[query_id] = queue.Queue()
self.active_queries.add(query_id)
# Start the worker thread that will query across all databases
threading.Thread(
target=self._read_across_all_worker,
args=(query, params, query_id),
daemon=True,
).start()
return query_id
def _read_worker(
self, query: str, params: Tuple[Any, ...], query_id: uuid.UUID
) -> None:
"""Worker thread for processing read queries"""
db_file = self._get_current_db()
try:
with self._get_connection(db_file) as conn:
results = conn.execute(query, params).fetchall()
if query_id in self.read_queues:
self.read_queues[query_id].put(QueryResult(data=results))
except Exception as e:
error_msg = f"Read error: {e}"
logging.error(error_msg)
if query_id in self.read_queues:
self.read_queues[query_id].put(QueryResult(error=QueryError(error_msg)))
def execute_read_sync(
self, query: str, params: Tuple[Any, ...] = (), timeout: float = 5.0
) -> QueryResult[List[Tuple[Any, ...]]]:
"""
Execute a read query synchronously across all relevant databases.
Args:
query: SQL query to execute
params: Parameters for the query
timeout: Maximum time to wait for the result
Returns:
QueryResult containing combined query results and any error
"""
query_id = self.execute_read(query, params)
return self.get_read_result(query_id, timeout)
def _read_across_all_worker(
self, query: str, params: Tuple[Any, ...], query_id: uuid.UUID
) -> None:
"""Worker thread for processing read queries across all databases"""
try:
# Get all available databases from metadata
with self._get_connection(self._get_metadata_db()) as conn:
db_files = conn.execute(
"SELECT db_file FROM metadata ORDER BY end_time DESC"
).fetchall()
if not db_files:
logging.warning("No database files found in metadata for read operation")
with QueryLockManager(self, query_id) as is_active:
if is_active:
self.read_queues[query_id].put(QueryResult(data=[]))
return
all_results: List[Tuple[Any, ...]] = []
for (db_file,) in db_files:
if not isinstance(db_file, str):
logging.error(f"Invalid db_file in metadata: {db_file}")
continue
if os.path.exists(db_file):
try:
with self._get_connection(db_file) as conn:
results = conn.execute(query, params).fetchall()
all_results.extend(results)
except sqlite3.Error as e:
logging.warning(f"Error reading from {db_file}: {e}")
# Continue with other databases
# Use the context manager to safely check query status
with QueryLockManager(self, query_id) as is_active:
if is_active:
self.read_queues[query_id].put(QueryResult(data=all_results))
else:
logging.warning(
f"Query ID {query_id} no longer active when trying to return results"
)
except Exception as e:
error_msg = f"Failed to execute query across all databases: {e}"
logging.error(error_msg)
with QueryLockManager(self, query_id) as is_active:
if is_active:
self.read_queues[query_id].put(
QueryResult(error=QueryError(error_msg))
)
def get_result(
self, query_id: uuid.UUID, timeout: float = 5.0
) -> QueryResult[bool]:
"""
Get the result of a write query.
Args:
query_id: UUID returned by execute_write
timeout: Maximum time to wait for the result
Returns:
QueryResult containing success status and any error
"""
if query_id not in self.result_queues:
return QueryResult(error=QueryError("Invalid query ID"))
try:
result = self.result_queues[query_id].get(timeout=timeout)
with self.query_lock:
if query_id in self.active_queries:
self.active_queries.remove(query_id)
return result
except queue.Empty:
return QueryResult(error=QueryError("Query timed out"))
def get_read_result(
self, query_id: uuid.UUID, timeout: float = 5.0
) -> QueryResult[List[Tuple[Any, ...]]]:
"""
Get the result of a read query.
Args:
query_id: UUID returned by execute_read
timeout: Maximum time to wait for the result
Returns:
QueryResult containing query results and any error
"""
# Check if the query ID exists in read_queues
with self.query_lock:
if query_id not in self.read_queues:
return QueryResult(error=QueryError("Invalid query ID"))
if query_id not in self.active_queries:
self.active_queries.add(query_id)
try:
result = self.read_queues[query_id].get(timeout=timeout)
with self.query_lock:
if query_id in self.active_queries:
self.active_queries.remove(query_id)
return result
except queue.Empty:
return QueryResult(error=QueryError("Query timed out"))
def _start_worker(self) -> None:
"""Start the background worker thread for processing write operations."""
if self.worker_thread and self.worker_thread.is_alive():
return
def worker() -> None:
while not self.shutdown_flag.is_set():
try:
task = self.write_queue.get(timeout=1)
if task:
self._process_write_task(task)
except queue.Empty:
continue
except Exception as e:
logging.error(f"Worker thread encountered an error: {e}")
self.worker_thread = threading.Thread(target=worker, daemon=True)
self.worker_thread.start()
def _start_cleanup_worker(self) -> None:
"""Start the cleanup worker thread for database rotation."""
if self.cleanup_thread and self.cleanup_thread.is_alive():
return
def cleanup_worker() -> None:
# Wait for initialization to complete before starting cleanup
self._init_complete.wait()
while not self.shutdown_flag.is_set():
self._rotate_databases()
self._cleanup_stale_queries()
time.sleep(self.cleanup_interval)
self.cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
self.cleanup_thread.start()
def _process_write_task(self, task: Tuple[str, Tuple[Any, ...], uuid.UUID]) -> None:
"""Process a single write task from the queue."""
query, params, query_id = task
db_file = self._get_current_db()
try:
with self._get_connection(db_file) as conn:
conn.execute(query, params)
conn.commit()
self.result_queues[query_id].put(QueryResult(data=True))
except Exception as e:
logging.error(f"Write error: {e}")
self.result_queues[query_id].put(QueryResult(error=e))
def _cleanup_worker(self) -> None:
"""Worker thread for handling database rotation and cleanup."""
self._init_complete.wait()
while not self.shutdown_flag.is_set():
self._rotate_databases()
self._cleanup_stale_queries()
time.sleep(self.cleanup_interval)
def shutdown(self) -> None:
"""Gracefully shut down the workers and close connections."""
self.shutdown_flag.set()
if self.worker_thread:
self.worker_thread.join()
if self.cleanup_thread:
self.cleanup_thread.join()
with self.conn_lock:
for db_file, conn in list(self.connections.items()):
try:
conn.close()
except sqlite3.Error:
pass
del self.connections[db_file]
if hasattr(self._thread_local, "metadata_conn"):
try:
self._thread_local.metadata_conn.close()
except sqlite3.Error:
pass
delattr(self._thread_local, "metadata_conn")
logging.info("SlidingSQLite shutdown completed.")