From 6db5290ccaa2492f01804a4f40312bc28cd366bf Mon Sep 17 00:00:00 2001 From: kalzu rekku Date: Mon, 21 Oct 2024 23:28:53 +0300 Subject: [PATCH] Added auth and encryption to the api. Still some bugs, endpoints seem to give http/500 error codes. --- Pipfile | 13 ++ config.toml | 7 ++ wpm.py | 322 ++++++++++++++++++++++++++++++++++---------------- wpm_client.py | 188 +++++++++++++++++++++++------ 4 files changed, 389 insertions(+), 141 deletions(-) create mode 100644 Pipfile diff --git a/Pipfile b/Pipfile new file mode 100644 index 0000000..3ca1cf1 --- /dev/null +++ b/Pipfile @@ -0,0 +1,13 @@ +[[source]] +url = "https://pypi.org/simple" +verify_ssl = true +name = "pypi" + +[packages] +requests = "*" +cryptography = "*" + +[dev-packages] + +[requires] +python_version = "3.12" diff --git a/config.toml b/config.toml index 5270910..acfa898 100644 --- a/config.toml +++ b/config.toml @@ -9,3 +9,10 @@ config_file = "../wireguard_example_server_config.conf" log_file = "../wpm.log" # Optional: Log file location debug = true # Enable debug logging +[client_keys] +public_key = """ +-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE0t+2/KJ7+YOHu4DmR20YtisenovD +tvJKywNdAWX5uqnA7UsYWPVKN827afMkgZuGKgZ5wtVM4DvQCq8MyRDHgw== +-----END PUBLIC KEY----- +""" # example public key diff --git a/wpm.py b/wpm.py index e4c7383..8ce324e 100644 --- a/wpm.py +++ b/wpm.py @@ -10,6 +10,15 @@ import shutil from datetime import datetime from typing import Dict, List, Tuple from pathlib import Path +import os +import base64 +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.exceptions import InvalidSignature # Default configuration DEFAULT_CONFIG = { @@ -39,7 +48,15 @@ def load_config(config_file: str) -> dict: logging.error(f"Error parsing config file: {e}") exit(1) +def load_client_public_key(config): + client_public_key_pem = config.get('client_keys', {}).get('public_key') + if not client_public_key_pem: + logging.error("Client public key not found in config.toml") + return None + return serialization.load_pem_public_key(client_public_key_pem.encode(), backend=default_backend()) + CONFIG = load_config("config.toml") +CLIENT_PUBLIC_KEY = load_client_public_key(CONFIG) # Set up logging logging.basicConfig( @@ -131,124 +148,223 @@ def reload_wireguard_service(): logging.error(error_message) return False, error_message -class WireGuardHandler(http.server.SimpleHTTPRequestHandler): + +# Add these new functions for encryption and decryption +def generate_ecc_key_pair(): + private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) + public_key = private_key.public_key() + return private_key, public_key + +def decrypt_symmetric_key(encrypted_symmetric_key, private_key): + return private_key.decrypt( + encrypted_symmetric_key, + ec.ECIES(hashes.SHA256()) + ) + +def decrypt_data(encrypted_data, symmetric_key): + iv = encrypted_data[:16] + cipher = Cipher(algorithms.AES(symmetric_key), modes.CFB(iv), backend=default_backend()) + decryptor = cipher.decryptor() + decrypted_data = decryptor.update(encrypted_data[16:]) + decryptor.finalize() + return decrypted_data.decode('utf-8') + +# Generate server keys on startup +SERVER_PRIVATE_KEY, SERVER_PUBLIC_KEY = generate_ecc_key_pair() + +class SecureWireGuardHandler(http.server.SimpleHTTPRequestHandler): def _send_response(self, status_code: int, data: dict) -> None: self.send_response(status_code) self.send_header('Content-type', 'application/json') self.end_headers() self.wfile.write(json.dumps(data).encode()) - def do_GET(self) -> None: - if self.path == '/peers': - config = read_config() - peers = parse_peers(config) - self._send_response(200, peers) - else: - self._send_response(404, {"error": "Not found"}) - - def do_POST(self) -> None: - if self.path == '/peers': - create_backup() # Create a backup before making changes - content_length = int(self.headers['Content-Length']) - post_data = self.rfile.read(content_length) - new_peer = json.loads(post_data.decode('utf-8')) - - config = read_config() - config += "\n\n" + peer_to_string(new_peer) - write_config(config) - - if reload_wireguard_service(): - self._send_response(201, {"message": "Peer added successfully and service reloaded"}) - else: - self._send_response(500, {"error": "Peer added but failed to reload service"}) - elif self.path == '/restore': - if restore_from_backup(): - if reload_wireguard_service(): - self._send_response(200, {"message": "Configuration restored from backup and service reloaded"}) - else: - self._send_response(500, {"error": "Configuration restored but failed to reload service"}) - else: - self._send_response(500, {"error": "Failed to restore from backup"}) - else: - self._send_response(404, {"error": "Not found"}) - - def do_PUT(self) -> None: - path_parts = self.path.split('/') - if len(path_parts) == 3 and path_parts[1] == 'peers': - create_backup() # Create a backup before making changes - public_key = urllib.parse.unquote(path_parts[2]) - content_length = int(self.headers['Content-Length']) - put_data = self.rfile.read(content_length) - updated_peer = json.loads(put_data.decode('utf-8')) - - config = read_config() - peers = parse_peers(config) - - peer_found = False - for i, peer in enumerate(peers): - if peer.get('PublicKey') == public_key: - peer_found = True - # Update the peer - peers[i] = updated_peer - new_config = re.sub(r'(\[Interface\].*?\n\n)(.*)', - r'\1' + '\n\n'.join(peer_to_string(p) for p in peers), - config, - flags=re.DOTALL) - write_config(new_config) - - # Reload WireGuard service and send the appropriate response - success, message = reload_wireguard_service() - if success: - self._send_response(200, {"message": "Peer updated successfully and service reloaded"}) - else: - self._send_response(500, {"error": f"Peer updated but failed to reload service: {message}"}) - break + def _decrypt_and_verify_request(self): + content_length = int(self.headers['Content-Length']) + encrypted_data = self.rfile.read(content_length) - if not peer_found: - # If no peer with the given public key was found, return a 404 response - self._send_response(404, {"error": "Peer not found"}) - else: - self._send_response(404, {"error": "Not found"}) + # Extract the JWE token from the Authorization header + auth_header = self.headers.get('Authorization', '') + if not auth_header.startswith('Bearer '): + raise ValueError("Invalid Authorization header") + + jwe_token = auth_header.split(' ')[1] + payload = json.loads(base64.b64decode(jwe_token.split('.')[1])) + + encrypted_symmetric_key = base64.b64decode(payload['enc_sym_key']) + + # Decrypt the symmetric key + symmetric_key = decrypt_symmetric_key(encrypted_symmetric_key, SERVER_PRIVATE_KEY) + + # Decrypt the data + decrypted_data = decrypt_data(encrypted_data, symmetric_key) + + # Verify the client's signature + signature = base64.b64decode(payload['signature']) + try: + CLIENT_PUBLIC_KEY.verify( + signature, + decrypted_data.encode(), + ec.ECDSA(hashes.SHA256()) + ) + except InvalidSignature: + raise ValueError("Invalid client signature") + + return json.loads(decrypted_data) - def do_DELETE(self) -> None: - path_parts = self.path.split('/') - if len(path_parts) == 3 and path_parts[1] == 'peers': - create_backup() # Create a backup before making changes - public_key = urllib.parse.unquote(path_parts[2]) + # Update all request handling methods (do_POST, do_PUT, do_DELETE, do_GET) to use _decrypt_and_verify_request + # For example: - config = read_config() - peers = parse_peers(config) - - peer_found = False - for peer in peers: - if peer.get('PublicKey') == public_key: - peer_found = True - # Remove the peer - peers.remove(peer) - new_config = re.sub(r'(\[Interface\].*?\n\n)(.*)', - r'\1' + '\n\n'.join(peer_to_string(p) for p in peers), - config, - flags=re.DOTALL) - write_config(new_config) - - # Reload WireGuard service and send the appropriate response - success, message = reload_wireguard_service() - if success: - self._send_response(200, {"message": "Peer deleted successfully and service reloaded"}) - else: - self._send_response(500, {"error": f"Peer deleted but failed to reload service: {message}"}) - break + def do_POST(self): + try: + decrypted_data = self._decrypt_and_verify_request() + action = decrypted_data.get('action') - if not peer_found: - # If no peer with the given public key was found, return a 404 response - self._send_response(404, {"error": "Peer not found"}) + if action == 'add_peer': + create_backup() + new_peer = decrypted_data.get('peer') + config = read_config() + config += "\n\n" + peer_to_string(new_peer) + write_config(config) + + success, message = reload_wireguard_service() + if success: + self._send_response(201, {"message": "Peer added successfully and service reloaded"}) + else: + self._send_response(500, {"error": "Peer added but failed to reload service"}) + elif action == 'restore': + if restore_from_backup(): + if reload_wireguard_service(): + self._send_response(200, {"message": "Configuration restored from backup and service reloaded"}) + else: + self._send_response(500, {"error": "Configuration restored but failed to reload service"}) + else: + self._send_response(500, {"error": "Failed to restore from backup"}) + else: + self._send_response(400, {"error": "Invalid action"}) + except ValueError as e: + logging.error(f"Error processing request: {str(e)}") + self._send_response(400, {"error": "Invalid request"}) + except Exception as e: + logging.error(f"Unexpected error: {str(e)}") + self._send_response(500, {"error": "Internal server error"}) + + def do_PUT(self): + try: + decrypted_data = self._decrypt_and_verify_request() + action = decrypted_data.get('action') + + if action == 'update_peer': + create_backup() + public_key = decrypted_data.get('public_key') + updated_peer = decrypted_data.get('peer') + + config = read_config() + peers = parse_peers(config) + + peer_found = False + for i, peer in enumerate(peers): + if peer.get('PublicKey') == public_key: + peer_found = True + peers[i] = updated_peer + new_config = re.sub(r'(\[Interface\].*?\n\n)(.*)', + r'\1' + '\n\n'.join(peer_to_string(p) for p in peers), + config, + flags=re.DOTALL) + write_config(new_config) + + success, message = reload_wireguard_service() + if success: + self._send_response(200, {"message": "Peer updated successfully and service reloaded"}) + else: + self._send_response(500, {"error": f"Peer updated but failed to reload service: {message}"}) + break + + if not peer_found: + self._send_response(404, {"error": "Peer not found"}) + else: + self._send_response(400, {"error": "Invalid action"}) + except ValueError as e: + logging.error(f"Error processing request: {str(e)}") + self._send_response(400, {"error": "Invalid request"}) + except Exception as e: + logging.error(f"Unexpected error: {str(e)}") + self._send_response(500, {"error": "Internal server error"}) + + + def do_DELETE(self): + try: + decrypted_data = self._decrypt_and_verify_request() + action = decrypted_data.get('action') + + if action == 'delete_peer': + create_backup() + public_key = decrypted_data.get('public_key') + + config = read_config() + peers = parse_peers(config) + + peer_found = False + for peer in peers: + if peer.get('PublicKey') == public_key: + peer_found = True + peers.remove(peer) + new_config = re.sub(r'(\[Interface\].*?\n\n)(.*)', + r'\1' + '\n\n'.join(peer_to_string(p) for p in peers), + config, + flags=re.DOTALL) + write_config(new_config) + + success, message = reload_wireguard_service() + if success: + self._send_response(200, {"message": "Peer deleted successfully and service reloaded"}) + else: + self._send_response(500, {"error": f"Peer deleted but failed to reload service: {message}"}) + break + + if not peer_found: + self._send_response(404, {"error": "Peer not found"}) + else: + self._send_response(400, {"error": "Invalid action"}) + except ValueError as e: + logging.error(f"Error processing request: {str(e)}") + self._send_response(400, {"error": "Invalid request"}) + except Exception as e: + logging.error(f"Unexpected error: {str(e)}") + self._send_response(500, {"error": "Internal server error"}) + + def do_GET(self): + if self.path == '/public_key': + try: + public_key_pem = SERVER_PUBLIC_KEY.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + self._send_response(200, {"public_key": public_key_pem.decode()}) + except Exception as e: + logging.error(f"Error sending public key: {str(e)}") + self._send_response(500, {"error": "Internal server error"}) else: - self._send_response(404, {"error": "Not found"}) + try: + decrypted_data = self._decrypt_and_verify_request() + action = decrypted_data.get('action') + + if action == 'get_peers': + config = read_config() + peers = parse_peers(config) + self._send_response(200, peers) + else: + self._send_response(400, {"error": "Invalid action"}) + except ValueError as e: + logging.error(f"Error processing request: {str(e)}") + self._send_response(400, {"error": "Invalid request"}) + except Exception as e: + logging.error(f"Unexpected error: {str(e)}") + self._send_response(500, {"error": "Internal server error"}) def run_server() -> None: host = CONFIG["server"]["host"] port = CONFIG["server"]["port"] - with socketserver.TCPServer((host, port), WireGuardHandler) as httpd: + with socketserver.TCPServer((host, port), SecureWireGuardHandler) as httpd: logging.info(f"Serving on {host}:{port}") httpd.serve_forever() diff --git a/wpm_client.py b/wpm_client.py index d26367b..27fb3f1 100644 --- a/wpm_client.py +++ b/wpm_client.py @@ -2,44 +2,117 @@ import argparse import requests import json import sys +import os +import base64 +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization BASE_URL = "http://localhost:8080" # Update this if your server is running on a different host or port -def create_or_update_peer(public_key, allowed_ips): - url = f"{BASE_URL}/peers" +class SecureApiClient: + def __init__(self, server_public_key, client_private_key): + self.server_public_key = server_public_key + self.client_private_key = client_private_key + + def generate_symmetric_key(self): + return os.urandom(32) # AES 256-bit key + + def encrypt_symmetric_key(self, symmetric_key): + encrypted_symmetric_key = self.server_public_key.encrypt( + symmetric_key, + ec.ECIES(hashes.SHA256()) + ) + return encrypted_symmetric_key + + def encrypt_data(self, data, symmetric_key): + iv = os.urandom(16) + cipher = Cipher(algorithms.AES(symmetric_key), modes.CFB(iv), backend=default_backend()) + encryptor = cipher.encryptor() + encrypted_data = encryptor.update(json.dumps(data).encode('utf-8')) + encryptor.finalize() + return iv + encrypted_data + + def sign_data(self, data): + signature = self.client_private_key.sign( + data, + ec.ECDSA(hashes.SHA256()) + ) + return base64.b64encode(signature).decode('utf-8') + + def create_jwe(self, encrypted_symmetric_key, encrypted_data, signature): + token = base64.b64encode(json.dumps({ + 'enc_sym_key': base64.b64encode(encrypted_symmetric_key).decode('utf-8'), + 'data': base64.b64encode(encrypted_data).decode('utf-8'), + 'signature': signature + }).encode()).decode() + return f"eyJhbGciOiJub25lIn0.{token}." # Add header and empty signature + + def make_request(self, method, endpoint, data): + symmetric_key = self.generate_symmetric_key() + encrypted_symmetric_key = self.encrypt_symmetric_key(symmetric_key) + encrypted_data = self.encrypt_data(data, symmetric_key) + signature = self.sign_data(encrypted_data) + jwe_token = self.create_jwe(encrypted_symmetric_key, encrypted_data, signature) + + headers = { + 'Authorization': f'Bearer {jwe_token}', + 'Content-Type': 'application/octet-stream' + } + + response = requests.request(method, f"{BASE_URL}{endpoint}", headers=headers, data=encrypted_data) + return response + +def get_server_public_key(): + response = requests.get(f"{BASE_URL}/public_key") + if response.status_code == 200: + public_key_pem = response.json()['public_key'] + return serialization.load_pem_public_key(public_key_pem.encode(), backend=default_backend()) + else: + print(f"Error getting server public key: {response.status_code} - {response.text}") + sys.exit(1) + +def load_client_private_key(key_path): + with open(key_path, "rb") as key_file: + return serialization.load_pem_private_key( + key_file.read(), + password=None, + backend=default_backend() + ) + +def create_or_update_peer(client, public_key, allowed_ips): data = { - "PublicKey": public_key, - "AllowedIPs": allowed_ips + "action": "add_peer" if not public_key else "update_peer", + "peer": { + "PublicKey": public_key, + "AllowedIPs": allowed_ips + } } - - # First, try to update an existing peer - response = requests.put(f"{url}/{public_key}", json=data) - - if response.status_code == 404: - # If the peer doesn't exist, create a new one - response = requests.post(url, json=data) + if public_key: + data["public_key"] = public_key + + response = client.make_request("POST" if not public_key else "PUT", "/peers", data) if response.status_code in (200, 201): - result = response.json() - if "warning" in result: - print(f"Warning: {result['warning']}") - else: - print(result['message']) - else: - print(f"Error: {response.status_code} - {response.text}") - -def delete_peer(public_key): - url = f"{BASE_URL}/peers/{public_key}" - response = requests.delete(url) - - if response.status_code == 200: - print("Peer deleted successfully.") + print(response.json()['message']) else: print(f"Error: {response.status_code} - {response.text}") -def list_peers(): - url = f"{BASE_URL}/peers" - response = requests.get(url) +def delete_peer(client, public_key): + data = { + "action": "delete_peer", + "public_key": public_key + } + response = client.make_request("DELETE", "/peers", data) + + if response.status_code == 200: + print(response.json()['message']) + else: + print(f"Error: {response.status_code} - {response.text}") + +def list_peers(client): + response = requests.get(f"{BASE_URL}/peers") if response.status_code == 200: peers = response.json() @@ -47,38 +120,77 @@ def list_peers(): else: print(f"Error: {response.status_code} - {response.text}") -def restore_config(): - url = f"{BASE_URL}/restore" - response = requests.post(url) +def restore_config(client): + data = { + "action": "restore" + } + response = client.make_request("POST", "/restore", data) if response.status_code == 200: - print("Configuration restored successfully.") + print(response.json()['message']) else: print(f"Error: {response.status_code} - {response.text}") +def generate_client_keys(): + private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) + public_key = private_key.public_key() + + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + ) + + public_pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + + with open("client_private_key.pem", "wb") as f: + f.write(private_pem) + + with open("client_public_key.pem", "wb") as f: + f.write(public_pem) + + print("Client keys generated successfully.") + print("Private key saved to: client_private_key.pem") + print("Public key saved to: client_public_key.pem") + print("\nIMPORTANT: Add the following to the server's config.toml file:") + print("\n[client_keys]") + print(f'public_key = """\n{public_pem.decode()}"""') + print("\nAfter adding the key, restart the server for the changes to take effect.") + def main(): parser = argparse.ArgumentParser(description="WireGuard Config Manager Client") - parser.add_argument("action", choices=["create", "update", "delete", "list", "restore"], help="Action to perform") + parser.add_argument("action", choices=["create", "update", "delete", "list", "restore", "generate_keys"], help="Action to perform") parser.add_argument("--public-key", help="Public key of the peer") parser.add_argument("--allowed-ips", help="Allowed IPs for the peer") + parser.add_argument("--private-key", default="client_private_key.pem", help="Path to client's private key file") args = parser.parse_args() + + if args.action == "generate_keys": + generate_client_keys() + return + + server_public_key = get_server_public_key() + client_private_key = load_client_private_key(args.private_key) + client = SecureApiClient(server_public_key, client_private_key) if args.action in ["create", "update"]: if not args.public_key or not args.allowed_ips: print("Error: Both --public-key and --allowed-ips are required for create/update actions.") sys.exit(1) - create_or_update_peer(args.public_key, args.allowed_ips) + create_or_update_peer(client, args.public_key, args.allowed_ips) elif args.action == "delete": if not args.public_key: print("Error: --public-key is required for delete action.") sys.exit(1) - delete_peer(args.public_key) + delete_peer(client, args.public_key) elif args.action == "list": - list_peers() + list_peers(client) elif args.action == "restore": - restore_config() + restore_config(client) if __name__ == "__main__": main() -