sliding_sqlite/SlidingSqlite.py
2025-03-16 10:30:47 +02:00

721 lines
26 KiB
Python

import sqlite3
import uuid
import threading
import datetime
import queue
import os
import logging
import time
from typing import (
Any,
Dict,
Optional,
Tuple,
List,
Union,
Set,
NamedTuple,
TypeVar,
Generic,
Callable,
)
T = TypeVar("T")
class DatabaseError(Exception):
"""Base exception for all database errors"""
pass
class QueryError(DatabaseError):
"""Exception raised when a query fails"""
pass
class QueryResult(Generic[T]):
"""Class to handle query results with proper error handling"""
def __init__(self, data: Optional[T] = None, error: Optional[Exception] = None):
self.data = data
self.error = error
self.success = error is None and data is not None
def __bool__(self) -> bool:
return self.success
class DatabaseTimeframe(NamedTuple):
"""Represents a database file and its time range"""
db_file: str
start_time: float
end_time: float
class QueryLockManager:
def __init__(self, sliding_sqlite, query_id):
self.sliding_sqlite = sliding_sqlite
self.query_id = query_id
self.lock = sliding_sqlite.query_lock
self.is_active = False
def __enter__(self):
# Acquire the lock and check query status
with self.lock:
self.is_active = (
self.query_id in self.sliding_sqlite.read_queues
and self.query_id in self.sliding_sqlite.active_queries
)
return self.is_active
def __exit__(self, exc_type, exc_val, exc_tb):
# If there was an exception, we don't need to do anything
pass
class SlidingSQLite:
"""
Thread-safe SQLite implementation with automatic time-based database rotation.
This class provides a way to safely use SQLite in a multi-threaded environment
by queuing database operations and processing them in dedicated worker threads.
Databases are created based on the specified rotation interval and old databases
are automatically cleaned up based on the specified retention period.
"""
def __init__(
self,
db_dir: str,
schema: str,
retention_period: int = 604800,
rotation_interval: int = 3600,
cleanup_interval: int = 3600,
auto_delete_old_dbs: bool = True,
) -> None:
"""
Initialize the SlidingSQLite instance.
Args:
db_dir: Directory to store database files
schema: SQL schema to initialize new databases
retention_period: Number of seconds to keep databases before deletion (default: 7 days)
rotation_interval: How often to rotate to a new database file in seconds (default: 1 hour)
cleanup_interval: How often to run the cleanup process in seconds (default: 1 hour)
auto_delete_old_dbs: Whether to automatically delete old databases (default: True)
"""
self.db_dir = db_dir
self.schema = schema
self.retention_period = retention_period # In seconds
self.rotation_interval = rotation_interval # In seconds
self.cleanup_interval = cleanup_interval # In seconds
self.auto_delete_old_dbs = auto_delete_old_dbs # New field
# Queues for operations
self.write_queue: queue.Queue[Tuple[str, Tuple[Any, ...], uuid.UUID]] = (
queue.Queue()
)
self.result_queues: Dict[uuid.UUID, queue.Queue[QueryResult[bool]]] = {}
self.read_queues: Dict[
uuid.UUID, queue.Queue[QueryResult[List[Tuple[Any, ...]]]]
] = {}
# Thread synchronization
self.shutdown_flag = threading.Event()
self.worker_thread = None
# Cache for database connections
self.connections: Dict[str, sqlite3.Connection] = {}
self.conn_lock = threading.Lock()
# Track active query IDs for cleanup
self.active_queries: Set[uuid.UUID] = set()
self.query_lock = threading.Lock()
# Initialize system
self._setup()
def _setup(self) -> None:
"""Setup the database directory and initialize workers"""
try:
os.makedirs(self.db_dir, exist_ok=True)
self._init_metadata()
# Start worker threads
self._start_worker()
self._start_cleanup_worker()
# Register current database
self._register_current_db()
except Exception as e:
logging.error(f"Failed to initialize SlidingSQLite: {e}")
raise DatabaseError(f"Failed to initialize SlidingSQLite: {e}")
def _init_metadata(self) -> None:
"""Initialize the metadata database"""
try:
with self._get_connection(self._get_metadata_db()) as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS metadata (
id INTEGER PRIMARY KEY AUTOINCREMENT,
db_file TEXT NOT NULL UNIQUE,
start_time REAL NOT NULL,
end_time REAL NOT NULL
)
"""
)
conn.commit()
except sqlite3.Error as e:
logging.error(f"Failed to initialize metadata database: {e}")
raise DatabaseError(f"Failed to initialize metadata database: {e}")
def _get_connection(self, db_file: str) -> sqlite3.Connection:
"""
Get a connection to the specified database file.
Reuses existing connections when possible.
Args:
db_file: Path to the database file
Returns:
SQLite connection object
"""
with self.conn_lock:
if db_file not in self.connections or self.connections[db_file] is None:
try:
conn = sqlite3.connect(
db_file, isolation_level=None, check_same_thread=False
)
conn.execute("PRAGMA journal_mode=WAL;")
conn.execute(
"PRAGMA busy_timeout=5000;"
) # Wait up to 5 seconds when database is locked
# Initialize schema if this is a data database (not metadata)
if db_file != self._get_metadata_db():
conn.executescript(self.schema)
self.connections[db_file] = conn
except sqlite3.Error as e:
logging.error(f"Failed to connect to database {db_file}: {e}")
raise DatabaseError(f"Failed to connect to database {db_file}: {e}")
return self.connections[db_file]
def _get_metadata_db(self) -> str:
"""Get the path to the metadata database"""
return os.path.join(self.db_dir, "metadata.db")
def _get_current_db(self) -> str:
"""Get the path to the current time-based database"""
# Generate timestamped DB name based on rotation interval
now = time.time()
interval_timestamp = int(now // self.rotation_interval) * self.rotation_interval
timestamp_str = datetime.datetime.fromtimestamp(interval_timestamp).strftime(
"%Y%m%d_%H%M%S"
)
return os.path.join(self.db_dir, f"data_{timestamp_str}.db")
def _register_current_db(self) -> None:
"""Register the current database in the metadata table"""
current_db = self._get_current_db()
now = time.time()
# Calculate time boundaries for the current database
interval_start = int(now // self.rotation_interval) * self.rotation_interval
interval_end = interval_start + self.rotation_interval
try:
with self._get_connection(self._get_metadata_db()) as conn:
# Check if this database is already registered
existing = conn.execute(
"SELECT id FROM metadata WHERE db_file = ?", (current_db,)
).fetchone()
if not existing:
conn.execute(
"INSERT INTO metadata (db_file, start_time, end_time) VALUES (?, ?, ?)",
(current_db, interval_start, interval_end),
)
conn.commit()
except sqlite3.Error as e:
logging.error(f"Failed to register current database: {e}")
# Continue execution as this is not critical
def _rotate_databases(self) -> None:
"""Delete databases that are older than the retention period"""
if not self.auto_delete_old_dbs:
return # Skip deletion if auto-delete is disabled
cutoff_time = time.time() - self.retention_period
try:
with self._get_connection(self._get_metadata_db()) as conn:
# Find databases older than the cutoff time
old_dbs = conn.execute(
"SELECT db_file FROM metadata WHERE end_time < ?", (cutoff_time,)
).fetchall()
# Delete old database files
for (db_file,) in old_dbs:
self._delete_database_file(db_file)
# Clean up metadata entries
conn.execute("DELETE FROM metadata WHERE end_time < ?", (cutoff_time,))
conn.commit()
except sqlite3.Error as e:
logging.error(f"Database rotation error: {e}")
def _cleanup_stale_queries(self) -> None:
"""Clean up stale query results to prevent memory leaks"""
with self.query_lock:
# Find completed queries to clean up
completed_queries = set()
for query_id in list(self.result_queues.keys()):
if query_id not in self.active_queries:
completed_queries.add(query_id)
for query_id in list(self.read_queues.keys()):
if query_id not in self.active_queries:
completed_queries.add(query_id)
# Remove completed queries from dictionaries
for query_id in completed_queries:
if query_id in self.result_queues:
del self.result_queues[query_id]
if query_id in self.read_queues:
del self.read_queues[query_id]
def _delete_database_file(self, db_file: str) -> bool:
"""
Delete a database file and clean up resources.
Args:
db_file: Path to the database file to delete
Returns:
True if deleted successfully, False otherwise
"""
# Close and remove connection if it exists
with self.conn_lock:
if db_file in self.connections:
try:
self.connections[db_file].close()
except sqlite3.Error:
pass # Ignore errors on close
del self.connections[db_file]
# Delete the file
if os.path.exists(db_file):
try:
os.remove(db_file)
logging.info(f"Deleted database: {db_file}")
return True
except OSError as e:
logging.error(f"Failed to delete database {db_file}: {e}")
return False
return False # File didn't exist
def set_retention_period(self, seconds: int) -> None:
"""
Set the retention period for databases.
Args:
seconds: Number of seconds to keep databases before deletion
"""
self.retention_period = max(0, seconds) # Ensure positive value
def set_auto_delete(self, enabled: bool) -> None:
"""
Enable or disable automatic deletion of old databases.
Args:
enabled: Whether to automatically delete old databases
"""
self.auto_delete_old_dbs = enabled
def delete_databases_before(self, timestamp: float) -> int:
"""
Delete all databases with end_time before the specified timestamp.
Args:
timestamp: Unix timestamp (seconds since epoch)
Returns:
Number of databases deleted
"""
count = 0
try:
with self._get_connection(self._get_metadata_db()) as conn:
# Find databases older than the specified time
old_dbs = conn.execute(
"SELECT db_file FROM metadata WHERE end_time < ?", (timestamp,)
).fetchall()
# Delete old database files
for (db_file,) in old_dbs:
if self._delete_database_file(db_file):
count += 1
# Clean up metadata entries
conn.execute("DELETE FROM metadata WHERE end_time < ?", (timestamp,))
conn.commit()
except sqlite3.Error as e:
logging.error(f"Database deletion error: {e}")
raise DatabaseError(f"Failed to delete databases: {e}")
return count
def delete_databases_in_range(self, start_time: float, end_time: float) -> int:
"""
Delete all databases with time ranges falling within the specified period.
Args:
start_time: Start of time range (unix timestamp)
end_time: End of time range (unix timestamp)
Returns:
Number of databases deleted
"""
count = 0
try:
with self._get_connection(self._get_metadata_db()) as conn:
# Find databases in the specified time range
# A database is in range if its time range overlaps with the specified range
dbs = conn.execute(
"""
SELECT db_file FROM metadata
WHERE (start_time <= ? AND end_time >= ?) OR
(start_time <= ? AND end_time >= ?) OR
(start_time >= ? AND end_time <= ?)
""",
(end_time, start_time, end_time, start_time, start_time, end_time),
).fetchall()
# Delete database files
for (db_file,) in dbs:
if self._delete_database_file(db_file):
count += 1
# Clean up metadata entries
conn.execute(
"""
DELETE FROM metadata
WHERE (start_time <= ? AND end_time >= ?) OR
(start_time <= ? AND end_time >= ?) OR
(start_time >= ? AND end_time <= ?)
""",
(end_time, start_time, end_time, start_time, start_time, end_time),
)
conn.commit()
except sqlite3.Error as e:
logging.error(f"Database deletion error: {e}")
raise DatabaseError(f"Failed to delete databases: {e}")
return count
def get_databases_info(self) -> List[DatabaseTimeframe]:
"""
Get information about all available databases.
Returns:
List of DatabaseTimeframe objects containing database file paths and time ranges
"""
databases = []
try:
with self._get_connection(self._get_metadata_db()) as conn:
rows = conn.execute(
"SELECT db_file, start_time, end_time FROM metadata ORDER BY start_time"
).fetchall()
for db_file, start_time, end_time in rows:
databases.append(DatabaseTimeframe(db_file, start_time, end_time))
except sqlite3.Error as e:
logging.error(f"Error retrieving database info: {e}")
raise DatabaseError(f"Failed to retrieve database info: {e}")
return databases
def execute(self, query: str, params: Tuple[Any, ...] = ()) -> uuid.UUID:
"""
Smart query executor that automatically determines if the query
is a read or write operation and routes accordingly.
Args:
query: SQL query to execute
params: Parameters for the query
Returns:
UUID that can be used to retrieve the result
"""
# look for new database files
self._register_current_db()
query_upper = query.strip().upper()
# Check if the query is a read operation
if (
query_upper.startswith("SELECT")
or query_upper.startswith("PRAGMA")
or query_upper.startswith("EXPLAIN")
):
return self.execute_read(query, params)
else:
return self.execute_write(query, params)
def execute_write(self, query: str, params: Tuple[Any, ...] = ()) -> uuid.UUID:
"""
Execute a write query asynchronously.
Args:
query: SQL query to execute
params: Parameters for the query
Returns:
UUID that can be used to retrieve the result
"""
# look for new database files
self._register_current_db()
query_id = uuid.uuid4()
with self.query_lock:
self.result_queues[query_id] = queue.Queue()
self.active_queries.add(query_id)
self.write_queue.put((query, params, query_id))
return query_id
def execute_write_sync(
self, query: str, params: Tuple[Any, ...] = (), timeout: float = 5.0
) -> QueryResult[bool]:
"""
Execute a write query synchronously.
Args:
query: SQL query to execute
params: Parameters for the query
timeout: Maximum time to wait for the result
Returns:
QueryResult containing success status and any error
"""
query_id = self.execute_write(query, params)
return self.get_result(query_id, timeout)
def execute_read(self, query: str, params: Tuple[Any, ...] = ()) -> uuid.UUID:
"""
Execute a read query asynchronously across all relevant databases.
This provides transparent access to all time-windowed data.
Args:
query: SQL query to execute
params: Parameters for the query
Returns:
UUID that can be used to retrieve the result
"""
# look for new database files
self._register_current_db()
query_id = uuid.uuid4()
with self.query_lock:
self.read_queues[query_id] = queue.Queue()
self.active_queries.add(query_id)
# Start the worker thread that will query across all databases
threading.Thread(
target=self._read_across_all_worker,
args=(query, params, query_id),
daemon=True,
).start()
return query_id
def _read_worker(
self, query: str, params: Tuple[Any, ...], query_id: uuid.UUID
) -> None:
"""Worker thread for processing read queries"""
db_file = self._get_current_db()
try:
with self._get_connection(db_file) as conn:
results = conn.execute(query, params).fetchall()
if query_id in self.read_queues:
self.read_queues[query_id].put(QueryResult(data=results))
except Exception as e:
error_msg = f"Read error: {e}"
logging.error(error_msg)
if query_id in self.read_queues:
self.read_queues[query_id].put(QueryResult(error=QueryError(error_msg)))
def execute_read_sync(
self, query: str, params: Tuple[Any, ...] = (), timeout: float = 5.0
) -> QueryResult[List[Tuple[Any, ...]]]:
"""
Execute a read query synchronously across all relevant databases.
Args:
query: SQL query to execute
params: Parameters for the query
timeout: Maximum time to wait for the result
Returns:
QueryResult containing combined query results and any error
"""
query_id = self.execute_read(query, params)
return self.get_read_result(query_id, timeout)
def _read_across_all_worker(
self, query: str, params: Tuple[Any, ...], query_id: uuid.UUID
) -> None:
"""Worker thread for processing read queries across all databases"""
try:
# Get all available databases from metadata
with self._get_connection(self._get_metadata_db()) as conn:
db_files = conn.execute(
"SELECT db_file FROM metadata ORDER BY end_time DESC"
).fetchall()
all_results: List[Tuple[Any, ...]] = []
for (db_file,) in db_files:
if os.path.exists(db_file):
try:
with self._get_connection(db_file) as conn:
results = conn.execute(query, params).fetchall()
all_results.extend(results)
except sqlite3.Error as e:
logging.warning(f"Error reading from {db_file}: {e}")
# Continue with other databases
# Use the context manager to safely check query status
with QueryLockManager(self, query_id) as is_active:
if is_active:
self.read_queues[query_id].put(QueryResult(data=all_results))
else:
logging.warning(
f"Query ID {query_id} no longer active when trying to return results"
)
except Exception as e:
error_msg = f"Failed to execute query across all databases: {e}"
logging.error(error_msg)
with QueryLockManager(self, query_id) as is_active:
if is_active:
self.read_queues[query_id].put(
QueryResult(error=QueryError(error_msg))
)
def get_result(
self, query_id: uuid.UUID, timeout: float = 5.0
) -> QueryResult[bool]:
"""
Get the result of a write query.
Args:
query_id: UUID returned by execute_write
timeout: Maximum time to wait for the result
Returns:
QueryResult containing success status and any error
"""
if query_id not in self.result_queues:
return QueryResult(error=QueryError("Invalid query ID"))
try:
result = self.result_queues[query_id].get(timeout=timeout)
with self.query_lock:
if query_id in self.active_queries:
self.active_queries.remove(query_id)
return result
except queue.Empty:
return QueryResult(error=QueryError("Query timed out"))
def get_read_result(
self, query_id: uuid.UUID, timeout: float = 5.0
) -> QueryResult[List[Tuple[Any, ...]]]:
"""
Get the result of a read query.
Args:
query_id: UUID returned by execute_read
timeout: Maximum time to wait for the result
Returns:
QueryResult containing query results and any error
"""
# Check if the query ID exists in read_queues
with self.query_lock:
if query_id not in self.read_queues:
return QueryResult(error=QueryError("Invalid query ID"))
if query_id not in self.active_queries:
self.active_queries.add(query_id) # Re-add if it was removed
try:
result = self.read_queues[query_id].get(timeout=timeout)
with self.query_lock:
if query_id in self.active_queries:
self.active_queries.remove(query_id)
return result
except queue.Empty:
return QueryResult(error=QueryError("Query timed out"))
def _start_worker(self) -> None:
"""Start the background worker thread for processing write operations."""
if self.worker_thread and self.worker_thread.is_alive():
return
def worker() -> None:
while not self.shutdown_flag.is_set():
try:
task = self.write_queue.get(timeout=1) # Adjust timeout as needed
if task:
self._process_write_task(task)
except queue.Empty:
continue
except Exception as e:
logging.error(f"Worker thread encountered an error: {e}")
self.worker_thread = threading.Thread(target=worker, daemon=True)
self.worker_thread.start()
def _start_cleanup_worker(self) -> None:
"""Start the cleanup worker thread for database rotation."""
threading.Thread(target=self._cleanup_worker, daemon=True).start()
def _process_write_task(self, task: Tuple[str, Tuple[Any, ...], uuid.UUID]) -> None:
"""Process a single write task from the queue."""
query, params, query_id = task
db_file = self._get_current_db()
try:
with self._get_connection(db_file) as conn:
conn.execute(query, params)
conn.commit()
self.result_queues[query_id].put(QueryResult(data=True))
except Exception as e:
logging.error(f"Write error: {e}")
self.result_queues[query_id].put(QueryResult(error=e))
def _cleanup_worker(self) -> None:
"""Worker thread for handling database rotation and cleanup."""
while not self.shutdown_flag.is_set():
self._rotate_databases()
self._cleanup_stale_queries() # Also clean up stale queries
time.sleep(self.cleanup_interval) # Use the configurable interval
def shutdown(self) -> None:
"""Gracefully shut down the workers and close connections."""
self.shutdown_flag.set()
if self.worker_thread:
self.worker_thread.join()
logging.info("SlidingSQLite shutdown completed.")