wireguard_peer_manager/wpm.py

374 lines
15 KiB
Python

import http.server
import socketserver
import json
import re
import urllib.parse
import tomllib
import logging
import subprocess
import shutil
from datetime import datetime
from typing import Dict, List, Tuple
from pathlib import Path
import os
import base64
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.exceptions import InvalidSignature
# Default configuration
DEFAULT_CONFIG = {
"server": {
"host": "0.0.0.0",
"port": 8000,
},
"wireguard": {
"config_file": "/etc/wireguard/wg0.conf",
},
"logging": {
"log_file": None,
"debug": False,
}
}
def load_config(config_file: str) -> dict:
try:
with open(config_file, "rb") as f:
config = tomllib.load(f)
# Merge with default config
return {**DEFAULT_CONFIG, **config}
except FileNotFoundError:
logging.warning(f"Config file {config_file} not found. Using default configuration.")
return DEFAULT_CONFIG
except tomllib.TOMLDecodeError as e:
logging.error(f"Error parsing config file: {e}")
exit(1)
def load_client_public_key(config):
client_public_key_pem = config.get('client_keys', {}).get('public_key')
if not client_public_key_pem:
logging.error("Client public key not found in config.toml")
return None
return serialization.load_pem_public_key(client_public_key_pem.encode(), backend=default_backend())
CONFIG = load_config("config.toml")
CLIENT_PUBLIC_KEY = load_client_public_key(CONFIG)
# Set up logging
logging.basicConfig(
level=logging.DEBUG if CONFIG["logging"]["debug"] else logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
filename=CONFIG["logging"]["log_file"],
)
def read_config() -> str:
with open(CONFIG["wireguard"]["config_file"], 'r') as file:
return file.read()
def write_config(content: str) -> None:
with open(CONFIG["wireguard"]["config_file"], 'w') as file:
file.write(content)
def create_backup():
config_file = Path(CONFIG["wireguard"]["config_file"])
backup_dir = config_file.parent / "backups"
backup_dir.mkdir(exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_file = backup_dir / f"{config_file.stem}_{timestamp}.conf"
shutil.copy2(config_file, backup_file)
logging.info(f"Backup created: {backup_file}")
return backup_file
def restore_from_backup():
config_file = Path(CONFIG["wireguard"]["config_file"])
backup_dir = config_file.parent / "backups"
if not backup_dir.exists() or not list(backup_dir.glob("*.conf")):
logging.error("No backups found")
return False
latest_backup = max(backup_dir.glob("*.conf"), key=lambda f: f.stat().st_mtime)
shutil.copy2(latest_backup, config_file)
logging.info(f"Restored from backup: {latest_backup}")
return True
def parse_peers(config: str) -> List[Dict[str, str]]:
peers = []
current_peer = None
for line in config.split('\n'):
line = line.strip()
if line == "[Peer]":
if current_peer:
peers.append(current_peer)
current_peer = {}
elif current_peer is not None and '=' in line:
key, value = line.split('=', 1)
current_peer[key.strip()] = value.strip()
if current_peer:
peers.append(current_peer)
return peers
def peer_to_string(peer: Dict[str, str]) -> str:
return f"[Peer]\n" + "\n".join(f"{k} = {v}" for k, v in peer.items())
def reload_wireguard_service():
interface = Path(CONFIG["wireguard"]["config_file"]).stem
try:
# Check if wg-quick is available
subprocess.run(["which", "wg-quick"], check=True, capture_output=True)
except subprocess.CalledProcessError:
logging.warning("wg-quick not found. WireGuard might not be installed.")
return False, "WireGuard (wg-quick) not found. Please ensure WireGuard is installed."
try:
# Check if the interface is up
result = subprocess.run(["wg", "show", interface], capture_output=True, text=True)
if result.returncode == 0:
# Interface is up, we'll restart it
subprocess.run(["wg-quick", "down", interface], check=True)
subprocess.run(["wg-quick", "up", interface], check=True)
logging.info(f"WireGuard service for interface {interface} restarted successfully")
else:
# Interface is down, we'll just bring it up
subprocess.run(["wg-quick", "up", interface], check=True)
logging.info(f"WireGuard service for interface {interface} started successfully")
return True, f"WireGuard service for interface {interface} reloaded successfully"
except subprocess.CalledProcessError as e:
error_message = f"Failed to reload WireGuard service: {e}"
if e.stderr:
error_message += f"\nError details: {e.stderr.decode('utf-8')}"
logging.error(error_message)
return False, error_message
# Add these new functions for encryption and decryption
def generate_ecc_key_pair():
private_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
public_key = private_key.public_key()
return private_key, public_key
def decrypt_symmetric_key(encrypted_symmetric_key, private_key):
return private_key.decrypt(
encrypted_symmetric_key,
ec.ECIES(hashes.SHA256())
)
def decrypt_data(encrypted_data, symmetric_key):
iv = encrypted_data[:16]
cipher = Cipher(algorithms.AES(symmetric_key), modes.CFB(iv), backend=default_backend())
decryptor = cipher.decryptor()
decrypted_data = decryptor.update(encrypted_data[16:]) + decryptor.finalize()
return decrypted_data.decode('utf-8')
# Generate server keys on startup
SERVER_PRIVATE_KEY, SERVER_PUBLIC_KEY = generate_ecc_key_pair()
class SecureWireGuardHandler(http.server.SimpleHTTPRequestHandler):
def _send_response(self, status_code: int, data: dict) -> None:
self.send_response(status_code)
self.send_header('Content-type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(data).encode())
def _decrypt_and_verify_request(self):
content_length = int(self.headers['Content-Length'])
encrypted_data = self.rfile.read(content_length)
# Extract the JWE token from the Authorization header
auth_header = self.headers.get('Authorization', '')
if not auth_header.startswith('Bearer '):
raise ValueError("Invalid Authorization header")
jwe_token = auth_header.split(' ')[1]
payload = json.loads(base64.b64decode(jwe_token.split('.')[1]))
encrypted_symmetric_key = base64.b64decode(payload['enc_sym_key'])
# Decrypt the symmetric key
symmetric_key = decrypt_symmetric_key(encrypted_symmetric_key, SERVER_PRIVATE_KEY)
# Decrypt the data
decrypted_data = decrypt_data(encrypted_data, symmetric_key)
# Verify the client's signature
signature = base64.b64decode(payload['signature'])
try:
CLIENT_PUBLIC_KEY.verify(
signature,
decrypted_data.encode(),
ec.ECDSA(hashes.SHA256())
)
except InvalidSignature:
raise ValueError("Invalid client signature")
return json.loads(decrypted_data)
# Update all request handling methods (do_POST, do_PUT, do_DELETE, do_GET) to use _decrypt_and_verify_request
# For example:
def do_POST(self):
try:
decrypted_data = self._decrypt_and_verify_request()
action = decrypted_data.get('action')
if action == 'add_peer':
create_backup()
new_peer = decrypted_data.get('peer')
config = read_config()
config += "\n\n" + peer_to_string(new_peer)
write_config(config)
success, message = reload_wireguard_service()
if success:
self._send_response(201, {"message": "Peer added successfully and service reloaded"})
else:
self._send_response(500, {"error": "Peer added but failed to reload service"})
elif action == 'restore':
if restore_from_backup():
if reload_wireguard_service():
self._send_response(200, {"message": "Configuration restored from backup and service reloaded"})
else:
self._send_response(500, {"error": "Configuration restored but failed to reload service"})
else:
self._send_response(500, {"error": "Failed to restore from backup"})
else:
self._send_response(400, {"error": "Invalid action"})
except ValueError as e:
logging.error(f"Error processing request: {str(e)}")
self._send_response(400, {"error": "Invalid request"})
except Exception as e:
logging.error(f"Unexpected error: {str(e)}")
self._send_response(500, {"error": "Internal server error"})
def do_PUT(self):
try:
decrypted_data = self._decrypt_and_verify_request()
action = decrypted_data.get('action')
if action == 'update_peer':
create_backup()
public_key = decrypted_data.get('public_key')
updated_peer = decrypted_data.get('peer')
config = read_config()
peers = parse_peers(config)
peer_found = False
for i, peer in enumerate(peers):
if peer.get('PublicKey') == public_key:
peer_found = True
peers[i] = updated_peer
new_config = re.sub(r'(\[Interface\].*?\n\n)(.*)',
r'\1' + '\n\n'.join(peer_to_string(p) for p in peers),
config,
flags=re.DOTALL)
write_config(new_config)
success, message = reload_wireguard_service()
if success:
self._send_response(200, {"message": "Peer updated successfully and service reloaded"})
else:
self._send_response(500, {"error": f"Peer updated but failed to reload service: {message}"})
break
if not peer_found:
self._send_response(404, {"error": "Peer not found"})
else:
self._send_response(400, {"error": "Invalid action"})
except ValueError as e:
logging.error(f"Error processing request: {str(e)}")
self._send_response(400, {"error": "Invalid request"})
except Exception as e:
logging.error(f"Unexpected error: {str(e)}")
self._send_response(500, {"error": "Internal server error"})
def do_DELETE(self):
try:
decrypted_data = self._decrypt_and_verify_request()
action = decrypted_data.get('action')
if action == 'delete_peer':
create_backup()
public_key = decrypted_data.get('public_key')
config = read_config()
peers = parse_peers(config)
peer_found = False
for peer in peers:
if peer.get('PublicKey') == public_key:
peer_found = True
peers.remove(peer)
new_config = re.sub(r'(\[Interface\].*?\n\n)(.*)',
r'\1' + '\n\n'.join(peer_to_string(p) for p in peers),
config,
flags=re.DOTALL)
write_config(new_config)
success, message = reload_wireguard_service()
if success:
self._send_response(200, {"message": "Peer deleted successfully and service reloaded"})
else:
self._send_response(500, {"error": f"Peer deleted but failed to reload service: {message}"})
break
if not peer_found:
self._send_response(404, {"error": "Peer not found"})
else:
self._send_response(400, {"error": "Invalid action"})
except ValueError as e:
logging.error(f"Error processing request: {str(e)}")
self._send_response(400, {"error": "Invalid request"})
except Exception as e:
logging.error(f"Unexpected error: {str(e)}")
self._send_response(500, {"error": "Internal server error"})
def do_GET(self):
if self.path == '/public_key':
try:
public_key_pem = SERVER_PUBLIC_KEY.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
self._send_response(200, {"public_key": public_key_pem.decode()})
except Exception as e:
logging.error(f"Error sending public key: {str(e)}")
self._send_response(500, {"error": "Internal server error"})
else:
try:
decrypted_data = self._decrypt_and_verify_request()
action = decrypted_data.get('action')
if action == 'get_peers':
config = read_config()
peers = parse_peers(config)
self._send_response(200, peers)
else:
self._send_response(400, {"error": "Invalid action"})
except ValueError as e:
logging.error(f"Error processing request: {str(e)}")
self._send_response(400, {"error": "Invalid request"})
except Exception as e:
logging.error(f"Unexpected error: {str(e)}")
self._send_response(500, {"error": "Internal server error"})
def run_server() -> None:
host = CONFIG["server"]["host"]
port = CONFIG["server"]["port"]
with socketserver.TCPServer((host, port), SecureWireGuardHandler) as httpd:
logging.info(f"Serving on {host}:{port}")
httpd.serve_forever()
if __name__ == "__main__":
run_server()