374 lines
15 KiB
Python
374 lines
15 KiB
Python
import http.server
|
|
import socketserver
|
|
import json
|
|
import re
|
|
import urllib.parse
|
|
import tomllib
|
|
import logging
|
|
import subprocess
|
|
import shutil
|
|
from datetime import datetime
|
|
from typing import Dict, List, Tuple
|
|
from pathlib import Path
|
|
import os
|
|
import base64
|
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives import serialization
|
|
from cryptography.exceptions import InvalidSignature
|
|
|
|
# Default configuration
|
|
DEFAULT_CONFIG = {
|
|
"server": {
|
|
"host": "0.0.0.0",
|
|
"port": 8000,
|
|
},
|
|
"wireguard": {
|
|
"config_file": "/etc/wireguard/wg0.conf",
|
|
},
|
|
"logging": {
|
|
"log_file": None,
|
|
"debug": False,
|
|
}
|
|
}
|
|
|
|
def load_config(config_file: str) -> dict:
|
|
try:
|
|
with open(config_file, "rb") as f:
|
|
config = tomllib.load(f)
|
|
# Merge with default config
|
|
return {**DEFAULT_CONFIG, **config}
|
|
except FileNotFoundError:
|
|
logging.warning(f"Config file {config_file} not found. Using default configuration.")
|
|
return DEFAULT_CONFIG
|
|
except tomllib.TOMLDecodeError as e:
|
|
logging.error(f"Error parsing config file: {e}")
|
|
exit(1)
|
|
|
|
def load_client_public_key(config):
|
|
client_public_key_pem = config.get('client_keys', {}).get('public_key')
|
|
if not client_public_key_pem:
|
|
logging.error("Client public key not found in config.toml")
|
|
return None
|
|
return serialization.load_pem_public_key(client_public_key_pem.encode(), backend=default_backend())
|
|
|
|
CONFIG = load_config("config.toml")
|
|
CLIENT_PUBLIC_KEY = load_client_public_key(CONFIG)
|
|
|
|
# Set up logging
|
|
logging.basicConfig(
|
|
level=logging.DEBUG if CONFIG["logging"]["debug"] else logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
filename=CONFIG["logging"]["log_file"],
|
|
)
|
|
|
|
def read_config() -> str:
|
|
with open(CONFIG["wireguard"]["config_file"], 'r') as file:
|
|
return file.read()
|
|
|
|
def write_config(content: str) -> None:
|
|
with open(CONFIG["wireguard"]["config_file"], 'w') as file:
|
|
file.write(content)
|
|
|
|
def create_backup():
|
|
config_file = Path(CONFIG["wireguard"]["config_file"])
|
|
backup_dir = config_file.parent / "backups"
|
|
backup_dir.mkdir(exist_ok=True)
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
backup_file = backup_dir / f"{config_file.stem}_{timestamp}.conf"
|
|
|
|
shutil.copy2(config_file, backup_file)
|
|
logging.info(f"Backup created: {backup_file}")
|
|
return backup_file
|
|
|
|
def restore_from_backup():
|
|
config_file = Path(CONFIG["wireguard"]["config_file"])
|
|
backup_dir = config_file.parent / "backups"
|
|
|
|
if not backup_dir.exists() or not list(backup_dir.glob("*.conf")):
|
|
logging.error("No backups found")
|
|
return False
|
|
|
|
latest_backup = max(backup_dir.glob("*.conf"), key=lambda f: f.stat().st_mtime)
|
|
shutil.copy2(latest_backup, config_file)
|
|
logging.info(f"Restored from backup: {latest_backup}")
|
|
return True
|
|
|
|
def parse_peers(config: str) -> List[Dict[str, str]]:
|
|
peers = []
|
|
current_peer = None
|
|
for line in config.split('\n'):
|
|
line = line.strip()
|
|
if line == "[Peer]":
|
|
if current_peer:
|
|
peers.append(current_peer)
|
|
current_peer = {}
|
|
elif current_peer is not None and '=' in line:
|
|
key, value = line.split('=', 1)
|
|
current_peer[key.strip()] = value.strip()
|
|
if current_peer:
|
|
peers.append(current_peer)
|
|
return peers
|
|
|
|
def peer_to_string(peer: Dict[str, str]) -> str:
|
|
return f"[Peer]\n" + "\n".join(f"{k} = {v}" for k, v in peer.items())
|
|
|
|
def reload_wireguard_service():
|
|
interface = Path(CONFIG["wireguard"]["config_file"]).stem
|
|
|
|
try:
|
|
# Check if wg-quick is available
|
|
subprocess.run(["which", "wg-quick"], check=True, capture_output=True)
|
|
except subprocess.CalledProcessError:
|
|
logging.warning("wg-quick not found. WireGuard might not be installed.")
|
|
return False, "WireGuard (wg-quick) not found. Please ensure WireGuard is installed."
|
|
|
|
try:
|
|
# Check if the interface is up
|
|
result = subprocess.run(["wg", "show", interface], capture_output=True, text=True)
|
|
if result.returncode == 0:
|
|
# Interface is up, we'll restart it
|
|
subprocess.run(["wg-quick", "down", interface], check=True)
|
|
subprocess.run(["wg-quick", "up", interface], check=True)
|
|
logging.info(f"WireGuard service for interface {interface} restarted successfully")
|
|
else:
|
|
# Interface is down, we'll just bring it up
|
|
subprocess.run(["wg-quick", "up", interface], check=True)
|
|
logging.info(f"WireGuard service for interface {interface} started successfully")
|
|
|
|
return True, f"WireGuard service for interface {interface} reloaded successfully"
|
|
except subprocess.CalledProcessError as e:
|
|
error_message = f"Failed to reload WireGuard service: {e}"
|
|
if e.stderr:
|
|
error_message += f"\nError details: {e.stderr.decode('utf-8')}"
|
|
logging.error(error_message)
|
|
return False, error_message
|
|
|
|
|
|
# Add these new functions for encryption and decryption
|
|
def generate_ecc_key_pair():
|
|
private_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
|
|
public_key = private_key.public_key()
|
|
return private_key, public_key
|
|
|
|
def decrypt_symmetric_key(encrypted_symmetric_key, private_key):
|
|
return private_key.decrypt(
|
|
encrypted_symmetric_key,
|
|
ec.ECIES(hashes.SHA256())
|
|
)
|
|
|
|
def decrypt_data(encrypted_data, symmetric_key):
|
|
iv = encrypted_data[:16]
|
|
cipher = Cipher(algorithms.AES(symmetric_key), modes.CFB(iv), backend=default_backend())
|
|
decryptor = cipher.decryptor()
|
|
decrypted_data = decryptor.update(encrypted_data[16:]) + decryptor.finalize()
|
|
return decrypted_data.decode('utf-8')
|
|
|
|
# Generate server keys on startup
|
|
SERVER_PRIVATE_KEY, SERVER_PUBLIC_KEY = generate_ecc_key_pair()
|
|
|
|
class SecureWireGuardHandler(http.server.SimpleHTTPRequestHandler):
|
|
def _send_response(self, status_code: int, data: dict) -> None:
|
|
self.send_response(status_code)
|
|
self.send_header('Content-type', 'application/json')
|
|
self.end_headers()
|
|
self.wfile.write(json.dumps(data).encode())
|
|
|
|
def _decrypt_and_verify_request(self):
|
|
content_length = int(self.headers['Content-Length'])
|
|
encrypted_data = self.rfile.read(content_length)
|
|
|
|
# Extract the JWE token from the Authorization header
|
|
auth_header = self.headers.get('Authorization', '')
|
|
if not auth_header.startswith('Bearer '):
|
|
raise ValueError("Invalid Authorization header")
|
|
|
|
jwe_token = auth_header.split(' ')[1]
|
|
payload = json.loads(base64.b64decode(jwe_token.split('.')[1]))
|
|
|
|
encrypted_symmetric_key = base64.b64decode(payload['enc_sym_key'])
|
|
|
|
# Decrypt the symmetric key
|
|
symmetric_key = decrypt_symmetric_key(encrypted_symmetric_key, SERVER_PRIVATE_KEY)
|
|
|
|
# Decrypt the data
|
|
decrypted_data = decrypt_data(encrypted_data, symmetric_key)
|
|
|
|
# Verify the client's signature
|
|
signature = base64.b64decode(payload['signature'])
|
|
try:
|
|
CLIENT_PUBLIC_KEY.verify(
|
|
signature,
|
|
decrypted_data.encode(),
|
|
ec.ECDSA(hashes.SHA256())
|
|
)
|
|
except InvalidSignature:
|
|
raise ValueError("Invalid client signature")
|
|
|
|
return json.loads(decrypted_data)
|
|
|
|
# Update all request handling methods (do_POST, do_PUT, do_DELETE, do_GET) to use _decrypt_and_verify_request
|
|
# For example:
|
|
|
|
def do_POST(self):
|
|
try:
|
|
decrypted_data = self._decrypt_and_verify_request()
|
|
action = decrypted_data.get('action')
|
|
|
|
if action == 'add_peer':
|
|
create_backup()
|
|
new_peer = decrypted_data.get('peer')
|
|
config = read_config()
|
|
config += "\n\n" + peer_to_string(new_peer)
|
|
write_config(config)
|
|
|
|
success, message = reload_wireguard_service()
|
|
if success:
|
|
self._send_response(201, {"message": "Peer added successfully and service reloaded"})
|
|
else:
|
|
self._send_response(500, {"error": "Peer added but failed to reload service"})
|
|
elif action == 'restore':
|
|
if restore_from_backup():
|
|
if reload_wireguard_service():
|
|
self._send_response(200, {"message": "Configuration restored from backup and service reloaded"})
|
|
else:
|
|
self._send_response(500, {"error": "Configuration restored but failed to reload service"})
|
|
else:
|
|
self._send_response(500, {"error": "Failed to restore from backup"})
|
|
else:
|
|
self._send_response(400, {"error": "Invalid action"})
|
|
except ValueError as e:
|
|
logging.error(f"Error processing request: {str(e)}")
|
|
self._send_response(400, {"error": "Invalid request"})
|
|
except Exception as e:
|
|
logging.error(f"Unexpected error: {str(e)}")
|
|
self._send_response(500, {"error": "Internal server error"})
|
|
|
|
def do_PUT(self):
|
|
try:
|
|
decrypted_data = self._decrypt_and_verify_request()
|
|
action = decrypted_data.get('action')
|
|
|
|
if action == 'update_peer':
|
|
create_backup()
|
|
public_key = decrypted_data.get('public_key')
|
|
updated_peer = decrypted_data.get('peer')
|
|
|
|
config = read_config()
|
|
peers = parse_peers(config)
|
|
|
|
peer_found = False
|
|
for i, peer in enumerate(peers):
|
|
if peer.get('PublicKey') == public_key:
|
|
peer_found = True
|
|
peers[i] = updated_peer
|
|
new_config = re.sub(r'(\[Interface\].*?\n\n)(.*)',
|
|
r'\1' + '\n\n'.join(peer_to_string(p) for p in peers),
|
|
config,
|
|
flags=re.DOTALL)
|
|
write_config(new_config)
|
|
|
|
success, message = reload_wireguard_service()
|
|
if success:
|
|
self._send_response(200, {"message": "Peer updated successfully and service reloaded"})
|
|
else:
|
|
self._send_response(500, {"error": f"Peer updated but failed to reload service: {message}"})
|
|
break
|
|
|
|
if not peer_found:
|
|
self._send_response(404, {"error": "Peer not found"})
|
|
else:
|
|
self._send_response(400, {"error": "Invalid action"})
|
|
except ValueError as e:
|
|
logging.error(f"Error processing request: {str(e)}")
|
|
self._send_response(400, {"error": "Invalid request"})
|
|
except Exception as e:
|
|
logging.error(f"Unexpected error: {str(e)}")
|
|
self._send_response(500, {"error": "Internal server error"})
|
|
|
|
|
|
def do_DELETE(self):
|
|
try:
|
|
decrypted_data = self._decrypt_and_verify_request()
|
|
action = decrypted_data.get('action')
|
|
|
|
if action == 'delete_peer':
|
|
create_backup()
|
|
public_key = decrypted_data.get('public_key')
|
|
|
|
config = read_config()
|
|
peers = parse_peers(config)
|
|
|
|
peer_found = False
|
|
for peer in peers:
|
|
if peer.get('PublicKey') == public_key:
|
|
peer_found = True
|
|
peers.remove(peer)
|
|
new_config = re.sub(r'(\[Interface\].*?\n\n)(.*)',
|
|
r'\1' + '\n\n'.join(peer_to_string(p) for p in peers),
|
|
config,
|
|
flags=re.DOTALL)
|
|
write_config(new_config)
|
|
|
|
success, message = reload_wireguard_service()
|
|
if success:
|
|
self._send_response(200, {"message": "Peer deleted successfully and service reloaded"})
|
|
else:
|
|
self._send_response(500, {"error": f"Peer deleted but failed to reload service: {message}"})
|
|
break
|
|
|
|
if not peer_found:
|
|
self._send_response(404, {"error": "Peer not found"})
|
|
else:
|
|
self._send_response(400, {"error": "Invalid action"})
|
|
except ValueError as e:
|
|
logging.error(f"Error processing request: {str(e)}")
|
|
self._send_response(400, {"error": "Invalid request"})
|
|
except Exception as e:
|
|
logging.error(f"Unexpected error: {str(e)}")
|
|
self._send_response(500, {"error": "Internal server error"})
|
|
|
|
def do_GET(self):
|
|
if self.path == '/public_key':
|
|
try:
|
|
public_key_pem = SERVER_PUBLIC_KEY.public_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
|
)
|
|
self._send_response(200, {"public_key": public_key_pem.decode()})
|
|
except Exception as e:
|
|
logging.error(f"Error sending public key: {str(e)}")
|
|
self._send_response(500, {"error": "Internal server error"})
|
|
else:
|
|
try:
|
|
decrypted_data = self._decrypt_and_verify_request()
|
|
action = decrypted_data.get('action')
|
|
|
|
if action == 'get_peers':
|
|
config = read_config()
|
|
peers = parse_peers(config)
|
|
self._send_response(200, peers)
|
|
else:
|
|
self._send_response(400, {"error": "Invalid action"})
|
|
except ValueError as e:
|
|
logging.error(f"Error processing request: {str(e)}")
|
|
self._send_response(400, {"error": "Invalid request"})
|
|
except Exception as e:
|
|
logging.error(f"Unexpected error: {str(e)}")
|
|
self._send_response(500, {"error": "Internal server error"})
|
|
|
|
def run_server() -> None:
|
|
host = CONFIG["server"]["host"]
|
|
port = CONFIG["server"]["port"]
|
|
with socketserver.TCPServer((host, port), SecureWireGuardHandler) as httpd:
|
|
logging.info(f"Serving on {host}:{port}")
|
|
httpd.serve_forever()
|
|
|
|
if __name__ == "__main__":
|
|
run_server()
|
|
|