import http.server import socketserver import json import re import urllib.parse import tomllib import logging import subprocess import shutil from datetime import datetime from typing import Dict, List, Tuple from pathlib import Path import os import base64 from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.exceptions import InvalidSignature # Default configuration DEFAULT_CONFIG = { "server": { "host": "0.0.0.0", "port": 8000, }, "wireguard": { "config_file": "/etc/wireguard/wg0.conf", }, "logging": { "log_file": None, "debug": False, } } def load_config(config_file: str) -> dict: try: with open(config_file, "rb") as f: config = tomllib.load(f) # Merge with default config return {**DEFAULT_CONFIG, **config} except FileNotFoundError: logging.warning(f"Config file {config_file} not found. Using default configuration.") return DEFAULT_CONFIG except tomllib.TOMLDecodeError as e: logging.error(f"Error parsing config file: {e}") exit(1) def load_client_public_key(config): client_public_key_pem = config.get('client_keys', {}).get('public_key') if not client_public_key_pem: logging.error("Client public key not found in config.toml") return None return serialization.load_pem_public_key(client_public_key_pem.encode(), backend=default_backend()) CONFIG = load_config("config.toml") CLIENT_PUBLIC_KEY = load_client_public_key(CONFIG) # Set up logging logging.basicConfig( level=logging.DEBUG if CONFIG["logging"]["debug"] else logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", filename=CONFIG["logging"]["log_file"], ) def read_config() -> str: with open(CONFIG["wireguard"]["config_file"], 'r') as file: return file.read() def write_config(content: str) -> None: with open(CONFIG["wireguard"]["config_file"], 'w') as file: file.write(content) def create_backup(): config_file = Path(CONFIG["wireguard"]["config_file"]) backup_dir = config_file.parent / "backups" backup_dir.mkdir(exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") backup_file = backup_dir / f"{config_file.stem}_{timestamp}.conf" shutil.copy2(config_file, backup_file) logging.info(f"Backup created: {backup_file}") return backup_file def restore_from_backup(): config_file = Path(CONFIG["wireguard"]["config_file"]) backup_dir = config_file.parent / "backups" if not backup_dir.exists() or not list(backup_dir.glob("*.conf")): logging.error("No backups found") return False latest_backup = max(backup_dir.glob("*.conf"), key=lambda f: f.stat().st_mtime) shutil.copy2(latest_backup, config_file) logging.info(f"Restored from backup: {latest_backup}") return True def parse_peers(config: str) -> List[Dict[str, str]]: peers = [] current_peer = None for line in config.split('\n'): line = line.strip() if line == "[Peer]": if current_peer: peers.append(current_peer) current_peer = {} elif current_peer is not None and '=' in line: key, value = line.split('=', 1) current_peer[key.strip()] = value.strip() if current_peer: peers.append(current_peer) return peers def peer_to_string(peer: Dict[str, str]) -> str: return f"[Peer]\n" + "\n".join(f"{k} = {v}" for k, v in peer.items()) def reload_wireguard_service(): interface = Path(CONFIG["wireguard"]["config_file"]).stem try: # Check if wg-quick is available subprocess.run(["which", "wg-quick"], check=True, capture_output=True) except subprocess.CalledProcessError: logging.warning("wg-quick not found. WireGuard might not be installed.") return False, "WireGuard (wg-quick) not found. Please ensure WireGuard is installed." try: # Check if the interface is up result = subprocess.run(["wg", "show", interface], capture_output=True, text=True) if result.returncode == 0: # Interface is up, we'll restart it subprocess.run(["wg-quick", "down", interface], check=True) subprocess.run(["wg-quick", "up", interface], check=True) logging.info(f"WireGuard service for interface {interface} restarted successfully") else: # Interface is down, we'll just bring it up subprocess.run(["wg-quick", "up", interface], check=True) logging.info(f"WireGuard service for interface {interface} started successfully") return True, f"WireGuard service for interface {interface} reloaded successfully" except subprocess.CalledProcessError as e: error_message = f"Failed to reload WireGuard service: {e}" if e.stderr: error_message += f"\nError details: {e.stderr.decode('utf-8')}" logging.error(error_message) return False, error_message # Add these new functions for encryption and decryption def generate_ecc_key_pair(): private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) public_key = private_key.public_key() return private_key, public_key def decrypt_symmetric_key(encrypted_symmetric_key, private_key): return private_key.decrypt( encrypted_symmetric_key, ec.ECIES(hashes.SHA256()) ) def decrypt_data(encrypted_data, symmetric_key): iv = encrypted_data[:16] cipher = Cipher(algorithms.AES(symmetric_key), modes.CFB(iv), backend=default_backend()) decryptor = cipher.decryptor() decrypted_data = decryptor.update(encrypted_data[16:]) + decryptor.finalize() return decrypted_data.decode('utf-8') # Generate server keys on startup SERVER_PRIVATE_KEY, SERVER_PUBLIC_KEY = generate_ecc_key_pair() class SecureWireGuardHandler(http.server.SimpleHTTPRequestHandler): def _send_response(self, status_code: int, data: dict) -> None: self.send_response(status_code) self.send_header('Content-type', 'application/json') self.end_headers() self.wfile.write(json.dumps(data).encode()) def _decrypt_and_verify_request(self): content_length = int(self.headers['Content-Length']) encrypted_data = self.rfile.read(content_length) # Extract the JWE token from the Authorization header auth_header = self.headers.get('Authorization', '') if not auth_header.startswith('Bearer '): raise ValueError("Invalid Authorization header") jwe_token = auth_header.split(' ')[1] payload = json.loads(base64.b64decode(jwe_token.split('.')[1])) encrypted_symmetric_key = base64.b64decode(payload['enc_sym_key']) # Decrypt the symmetric key symmetric_key = decrypt_symmetric_key(encrypted_symmetric_key, SERVER_PRIVATE_KEY) # Decrypt the data decrypted_data = decrypt_data(encrypted_data, symmetric_key) # Verify the client's signature signature = base64.b64decode(payload['signature']) try: CLIENT_PUBLIC_KEY.verify( signature, decrypted_data.encode(), ec.ECDSA(hashes.SHA256()) ) except InvalidSignature: raise ValueError("Invalid client signature") return json.loads(decrypted_data) # Update all request handling methods (do_POST, do_PUT, do_DELETE, do_GET) to use _decrypt_and_verify_request # For example: def do_POST(self): try: decrypted_data = self._decrypt_and_verify_request() action = decrypted_data.get('action') if action == 'add_peer': create_backup() new_peer = decrypted_data.get('peer') config = read_config() config += "\n\n" + peer_to_string(new_peer) write_config(config) success, message = reload_wireguard_service() if success: self._send_response(201, {"message": "Peer added successfully and service reloaded"}) else: self._send_response(500, {"error": "Peer added but failed to reload service"}) elif action == 'restore': if restore_from_backup(): if reload_wireguard_service(): self._send_response(200, {"message": "Configuration restored from backup and service reloaded"}) else: self._send_response(500, {"error": "Configuration restored but failed to reload service"}) else: self._send_response(500, {"error": "Failed to restore from backup"}) else: self._send_response(400, {"error": "Invalid action"}) except ValueError as e: logging.error(f"Error processing request: {str(e)}") self._send_response(400, {"error": "Invalid request"}) except Exception as e: logging.error(f"Unexpected error: {str(e)}") self._send_response(500, {"error": "Internal server error"}) def do_PUT(self): try: decrypted_data = self._decrypt_and_verify_request() action = decrypted_data.get('action') if action == 'update_peer': create_backup() public_key = decrypted_data.get('public_key') updated_peer = decrypted_data.get('peer') config = read_config() peers = parse_peers(config) peer_found = False for i, peer in enumerate(peers): if peer.get('PublicKey') == public_key: peer_found = True peers[i] = updated_peer new_config = re.sub(r'(\[Interface\].*?\n\n)(.*)', r'\1' + '\n\n'.join(peer_to_string(p) for p in peers), config, flags=re.DOTALL) write_config(new_config) success, message = reload_wireguard_service() if success: self._send_response(200, {"message": "Peer updated successfully and service reloaded"}) else: self._send_response(500, {"error": f"Peer updated but failed to reload service: {message}"}) break if not peer_found: self._send_response(404, {"error": "Peer not found"}) else: self._send_response(400, {"error": "Invalid action"}) except ValueError as e: logging.error(f"Error processing request: {str(e)}") self._send_response(400, {"error": "Invalid request"}) except Exception as e: logging.error(f"Unexpected error: {str(e)}") self._send_response(500, {"error": "Internal server error"}) def do_DELETE(self): try: decrypted_data = self._decrypt_and_verify_request() action = decrypted_data.get('action') if action == 'delete_peer': create_backup() public_key = decrypted_data.get('public_key') config = read_config() peers = parse_peers(config) peer_found = False for peer in peers: if peer.get('PublicKey') == public_key: peer_found = True peers.remove(peer) new_config = re.sub(r'(\[Interface\].*?\n\n)(.*)', r'\1' + '\n\n'.join(peer_to_string(p) for p in peers), config, flags=re.DOTALL) write_config(new_config) success, message = reload_wireguard_service() if success: self._send_response(200, {"message": "Peer deleted successfully and service reloaded"}) else: self._send_response(500, {"error": f"Peer deleted but failed to reload service: {message}"}) break if not peer_found: self._send_response(404, {"error": "Peer not found"}) else: self._send_response(400, {"error": "Invalid action"}) except ValueError as e: logging.error(f"Error processing request: {str(e)}") self._send_response(400, {"error": "Invalid request"}) except Exception as e: logging.error(f"Unexpected error: {str(e)}") self._send_response(500, {"error": "Internal server error"}) def do_GET(self): if self.path == '/public_key': try: public_key_pem = SERVER_PUBLIC_KEY.public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo ) self._send_response(200, {"public_key": public_key_pem.decode()}) except Exception as e: logging.error(f"Error sending public key: {str(e)}") self._send_response(500, {"error": "Internal server error"}) else: try: decrypted_data = self._decrypt_and_verify_request() action = decrypted_data.get('action') if action == 'get_peers': config = read_config() peers = parse_peers(config) self._send_response(200, peers) else: self._send_response(400, {"error": "Invalid action"}) except ValueError as e: logging.error(f"Error processing request: {str(e)}") self._send_response(400, {"error": "Invalid request"}) except Exception as e: logging.error(f"Unexpected error: {str(e)}") self._send_response(500, {"error": "Internal server error"}) def run_server() -> None: host = CONFIG["server"]["host"] port = CONFIG["server"]["port"] with socketserver.TCPServer((host, port), SecureWireGuardHandler) as httpd: logging.info(f"Serving on {host}:{port}") httpd.serve_forever() if __name__ == "__main__": run_server()