Added auth and encryption to the api. Still some bugs, endpoints seem to give http/500 error codes.

This commit is contained in:
kalzu rekku 2024-10-21 23:28:53 +03:00
parent 9b15d5bdeb
commit 6db5290cca
4 changed files with 389 additions and 141 deletions

13
Pipfile Normal file
View File

@ -0,0 +1,13 @@
[[source]]
url = "https://pypi.org/simple"
verify_ssl = true
name = "pypi"
[packages]
requests = "*"
cryptography = "*"
[dev-packages]
[requires]
python_version = "3.12"

View File

@ -9,3 +9,10 @@ config_file = "../wireguard_example_server_config.conf"
log_file = "../wpm.log" # Optional: Log file location log_file = "../wpm.log" # Optional: Log file location
debug = true # Enable debug logging debug = true # Enable debug logging
[client_keys]
public_key = """
-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE0t+2/KJ7+YOHu4DmR20YtisenovD
tvJKywNdAWX5uqnA7UsYWPVKN827afMkgZuGKgZ5wtVM4DvQCq8MyRDHgw==
-----END PUBLIC KEY-----
""" # example public key

194
wpm.py
View File

@ -10,6 +10,15 @@ import shutil
from datetime import datetime from datetime import datetime
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from pathlib import Path from pathlib import Path
import os
import base64
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.exceptions import InvalidSignature
# Default configuration # Default configuration
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
@ -39,7 +48,15 @@ def load_config(config_file: str) -> dict:
logging.error(f"Error parsing config file: {e}") logging.error(f"Error parsing config file: {e}")
exit(1) exit(1)
def load_client_public_key(config):
client_public_key_pem = config.get('client_keys', {}).get('public_key')
if not client_public_key_pem:
logging.error("Client public key not found in config.toml")
return None
return serialization.load_pem_public_key(client_public_key_pem.encode(), backend=default_backend())
CONFIG = load_config("config.toml") CONFIG = load_config("config.toml")
CLIENT_PUBLIC_KEY = load_client_public_key(CONFIG)
# Set up logging # Set up logging
logging.basicConfig( logging.basicConfig(
@ -131,37 +148,90 @@ def reload_wireguard_service():
logging.error(error_message) logging.error(error_message)
return False, error_message return False, error_message
class WireGuardHandler(http.server.SimpleHTTPRequestHandler):
# Add these new functions for encryption and decryption
def generate_ecc_key_pair():
private_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
public_key = private_key.public_key()
return private_key, public_key
def decrypt_symmetric_key(encrypted_symmetric_key, private_key):
return private_key.decrypt(
encrypted_symmetric_key,
ec.ECIES(hashes.SHA256())
)
def decrypt_data(encrypted_data, symmetric_key):
iv = encrypted_data[:16]
cipher = Cipher(algorithms.AES(symmetric_key), modes.CFB(iv), backend=default_backend())
decryptor = cipher.decryptor()
decrypted_data = decryptor.update(encrypted_data[16:]) + decryptor.finalize()
return decrypted_data.decode('utf-8')
# Generate server keys on startup
SERVER_PRIVATE_KEY, SERVER_PUBLIC_KEY = generate_ecc_key_pair()
class SecureWireGuardHandler(http.server.SimpleHTTPRequestHandler):
def _send_response(self, status_code: int, data: dict) -> None: def _send_response(self, status_code: int, data: dict) -> None:
self.send_response(status_code) self.send_response(status_code)
self.send_header('Content-type', 'application/json') self.send_header('Content-type', 'application/json')
self.end_headers() self.end_headers()
self.wfile.write(json.dumps(data).encode()) self.wfile.write(json.dumps(data).encode())
def do_GET(self) -> None: def _decrypt_and_verify_request(self):
if self.path == '/peers':
config = read_config()
peers = parse_peers(config)
self._send_response(200, peers)
else:
self._send_response(404, {"error": "Not found"})
def do_POST(self) -> None:
if self.path == '/peers':
create_backup() # Create a backup before making changes
content_length = int(self.headers['Content-Length']) content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length) encrypted_data = self.rfile.read(content_length)
new_peer = json.loads(post_data.decode('utf-8'))
# Extract the JWE token from the Authorization header
auth_header = self.headers.get('Authorization', '')
if not auth_header.startswith('Bearer '):
raise ValueError("Invalid Authorization header")
jwe_token = auth_header.split(' ')[1]
payload = json.loads(base64.b64decode(jwe_token.split('.')[1]))
encrypted_symmetric_key = base64.b64decode(payload['enc_sym_key'])
# Decrypt the symmetric key
symmetric_key = decrypt_symmetric_key(encrypted_symmetric_key, SERVER_PRIVATE_KEY)
# Decrypt the data
decrypted_data = decrypt_data(encrypted_data, symmetric_key)
# Verify the client's signature
signature = base64.b64decode(payload['signature'])
try:
CLIENT_PUBLIC_KEY.verify(
signature,
decrypted_data.encode(),
ec.ECDSA(hashes.SHA256())
)
except InvalidSignature:
raise ValueError("Invalid client signature")
return json.loads(decrypted_data)
# Update all request handling methods (do_POST, do_PUT, do_DELETE, do_GET) to use _decrypt_and_verify_request
# For example:
def do_POST(self):
try:
decrypted_data = self._decrypt_and_verify_request()
action = decrypted_data.get('action')
if action == 'add_peer':
create_backup()
new_peer = decrypted_data.get('peer')
config = read_config() config = read_config()
config += "\n\n" + peer_to_string(new_peer) config += "\n\n" + peer_to_string(new_peer)
write_config(config) write_config(config)
if reload_wireguard_service(): success, message = reload_wireguard_service()
if success:
self._send_response(201, {"message": "Peer added successfully and service reloaded"}) self._send_response(201, {"message": "Peer added successfully and service reloaded"})
else: else:
self._send_response(500, {"error": "Peer added but failed to reload service"}) self._send_response(500, {"error": "Peer added but failed to reload service"})
elif self.path == '/restore': elif action == 'restore':
if restore_from_backup(): if restore_from_backup():
if reload_wireguard_service(): if reload_wireguard_service():
self._send_response(200, {"message": "Configuration restored from backup and service reloaded"}) self._send_response(200, {"message": "Configuration restored from backup and service reloaded"})
@ -170,16 +240,23 @@ class WireGuardHandler(http.server.SimpleHTTPRequestHandler):
else: else:
self._send_response(500, {"error": "Failed to restore from backup"}) self._send_response(500, {"error": "Failed to restore from backup"})
else: else:
self._send_response(404, {"error": "Not found"}) self._send_response(400, {"error": "Invalid action"})
except ValueError as e:
logging.error(f"Error processing request: {str(e)}")
self._send_response(400, {"error": "Invalid request"})
except Exception as e:
logging.error(f"Unexpected error: {str(e)}")
self._send_response(500, {"error": "Internal server error"})
def do_PUT(self) -> None: def do_PUT(self):
path_parts = self.path.split('/') try:
if len(path_parts) == 3 and path_parts[1] == 'peers': decrypted_data = self._decrypt_and_verify_request()
create_backup() # Create a backup before making changes action = decrypted_data.get('action')
public_key = urllib.parse.unquote(path_parts[2])
content_length = int(self.headers['Content-Length']) if action == 'update_peer':
put_data = self.rfile.read(content_length) create_backup()
updated_peer = json.loads(put_data.decode('utf-8')) public_key = decrypted_data.get('public_key')
updated_peer = decrypted_data.get('peer')
config = read_config() config = read_config()
peers = parse_peers(config) peers = parse_peers(config)
@ -188,7 +265,6 @@ class WireGuardHandler(http.server.SimpleHTTPRequestHandler):
for i, peer in enumerate(peers): for i, peer in enumerate(peers):
if peer.get('PublicKey') == public_key: if peer.get('PublicKey') == public_key:
peer_found = True peer_found = True
# Update the peer
peers[i] = updated_peer peers[i] = updated_peer
new_config = re.sub(r'(\[Interface\].*?\n\n)(.*)', new_config = re.sub(r'(\[Interface\].*?\n\n)(.*)',
r'\1' + '\n\n'.join(peer_to_string(p) for p in peers), r'\1' + '\n\n'.join(peer_to_string(p) for p in peers),
@ -196,7 +272,6 @@ class WireGuardHandler(http.server.SimpleHTTPRequestHandler):
flags=re.DOTALL) flags=re.DOTALL)
write_config(new_config) write_config(new_config)
# Reload WireGuard service and send the appropriate response
success, message = reload_wireguard_service() success, message = reload_wireguard_service()
if success: if success:
self._send_response(200, {"message": "Peer updated successfully and service reloaded"}) self._send_response(200, {"message": "Peer updated successfully and service reloaded"})
@ -205,16 +280,25 @@ class WireGuardHandler(http.server.SimpleHTTPRequestHandler):
break break
if not peer_found: if not peer_found:
# If no peer with the given public key was found, return a 404 response
self._send_response(404, {"error": "Peer not found"}) self._send_response(404, {"error": "Peer not found"})
else: else:
self._send_response(404, {"error": "Not found"}) self._send_response(400, {"error": "Invalid action"})
except ValueError as e:
logging.error(f"Error processing request: {str(e)}")
self._send_response(400, {"error": "Invalid request"})
except Exception as e:
logging.error(f"Unexpected error: {str(e)}")
self._send_response(500, {"error": "Internal server error"})
def do_DELETE(self) -> None:
path_parts = self.path.split('/') def do_DELETE(self):
if len(path_parts) == 3 and path_parts[1] == 'peers': try:
create_backup() # Create a backup before making changes decrypted_data = self._decrypt_and_verify_request()
public_key = urllib.parse.unquote(path_parts[2]) action = decrypted_data.get('action')
if action == 'delete_peer':
create_backup()
public_key = decrypted_data.get('public_key')
config = read_config() config = read_config()
peers = parse_peers(config) peers = parse_peers(config)
@ -223,7 +307,6 @@ class WireGuardHandler(http.server.SimpleHTTPRequestHandler):
for peer in peers: for peer in peers:
if peer.get('PublicKey') == public_key: if peer.get('PublicKey') == public_key:
peer_found = True peer_found = True
# Remove the peer
peers.remove(peer) peers.remove(peer)
new_config = re.sub(r'(\[Interface\].*?\n\n)(.*)', new_config = re.sub(r'(\[Interface\].*?\n\n)(.*)',
r'\1' + '\n\n'.join(peer_to_string(p) for p in peers), r'\1' + '\n\n'.join(peer_to_string(p) for p in peers),
@ -231,7 +314,6 @@ class WireGuardHandler(http.server.SimpleHTTPRequestHandler):
flags=re.DOTALL) flags=re.DOTALL)
write_config(new_config) write_config(new_config)
# Reload WireGuard service and send the appropriate response
success, message = reload_wireguard_service() success, message = reload_wireguard_service()
if success: if success:
self._send_response(200, {"message": "Peer deleted successfully and service reloaded"}) self._send_response(200, {"message": "Peer deleted successfully and service reloaded"})
@ -240,15 +322,49 @@ class WireGuardHandler(http.server.SimpleHTTPRequestHandler):
break break
if not peer_found: if not peer_found:
# If no peer with the given public key was found, return a 404 response
self._send_response(404, {"error": "Peer not found"}) self._send_response(404, {"error": "Peer not found"})
else: else:
self._send_response(404, {"error": "Not found"}) self._send_response(400, {"error": "Invalid action"})
except ValueError as e:
logging.error(f"Error processing request: {str(e)}")
self._send_response(400, {"error": "Invalid request"})
except Exception as e:
logging.error(f"Unexpected error: {str(e)}")
self._send_response(500, {"error": "Internal server error"})
def do_GET(self):
if self.path == '/public_key':
try:
public_key_pem = SERVER_PUBLIC_KEY.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
self._send_response(200, {"public_key": public_key_pem.decode()})
except Exception as e:
logging.error(f"Error sending public key: {str(e)}")
self._send_response(500, {"error": "Internal server error"})
else:
try:
decrypted_data = self._decrypt_and_verify_request()
action = decrypted_data.get('action')
if action == 'get_peers':
config = read_config()
peers = parse_peers(config)
self._send_response(200, peers)
else:
self._send_response(400, {"error": "Invalid action"})
except ValueError as e:
logging.error(f"Error processing request: {str(e)}")
self._send_response(400, {"error": "Invalid request"})
except Exception as e:
logging.error(f"Unexpected error: {str(e)}")
self._send_response(500, {"error": "Internal server error"})
def run_server() -> None: def run_server() -> None:
host = CONFIG["server"]["host"] host = CONFIG["server"]["host"]
port = CONFIG["server"]["port"] port = CONFIG["server"]["port"]
with socketserver.TCPServer((host, port), WireGuardHandler) as httpd: with socketserver.TCPServer((host, port), SecureWireGuardHandler) as httpd:
logging.info(f"Serving on {host}:{port}") logging.info(f"Serving on {host}:{port}")
httpd.serve_forever() httpd.serve_forever()

View File

@ -2,44 +2,117 @@ import argparse
import requests import requests
import json import json
import sys import sys
import os
import base64
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
BASE_URL = "http://localhost:8080" # Update this if your server is running on a different host or port BASE_URL = "http://localhost:8080" # Update this if your server is running on a different host or port
def create_or_update_peer(public_key, allowed_ips): class SecureApiClient:
url = f"{BASE_URL}/peers" def __init__(self, server_public_key, client_private_key):
self.server_public_key = server_public_key
self.client_private_key = client_private_key
def generate_symmetric_key(self):
return os.urandom(32) # AES 256-bit key
def encrypt_symmetric_key(self, symmetric_key):
encrypted_symmetric_key = self.server_public_key.encrypt(
symmetric_key,
ec.ECIES(hashes.SHA256())
)
return encrypted_symmetric_key
def encrypt_data(self, data, symmetric_key):
iv = os.urandom(16)
cipher = Cipher(algorithms.AES(symmetric_key), modes.CFB(iv), backend=default_backend())
encryptor = cipher.encryptor()
encrypted_data = encryptor.update(json.dumps(data).encode('utf-8')) + encryptor.finalize()
return iv + encrypted_data
def sign_data(self, data):
signature = self.client_private_key.sign(
data,
ec.ECDSA(hashes.SHA256())
)
return base64.b64encode(signature).decode('utf-8')
def create_jwe(self, encrypted_symmetric_key, encrypted_data, signature):
token = base64.b64encode(json.dumps({
'enc_sym_key': base64.b64encode(encrypted_symmetric_key).decode('utf-8'),
'data': base64.b64encode(encrypted_data).decode('utf-8'),
'signature': signature
}).encode()).decode()
return f"eyJhbGciOiJub25lIn0.{token}." # Add header and empty signature
def make_request(self, method, endpoint, data):
symmetric_key = self.generate_symmetric_key()
encrypted_symmetric_key = self.encrypt_symmetric_key(symmetric_key)
encrypted_data = self.encrypt_data(data, symmetric_key)
signature = self.sign_data(encrypted_data)
jwe_token = self.create_jwe(encrypted_symmetric_key, encrypted_data, signature)
headers = {
'Authorization': f'Bearer {jwe_token}',
'Content-Type': 'application/octet-stream'
}
response = requests.request(method, f"{BASE_URL}{endpoint}", headers=headers, data=encrypted_data)
return response
def get_server_public_key():
response = requests.get(f"{BASE_URL}/public_key")
if response.status_code == 200:
public_key_pem = response.json()['public_key']
return serialization.load_pem_public_key(public_key_pem.encode(), backend=default_backend())
else:
print(f"Error getting server public key: {response.status_code} - {response.text}")
sys.exit(1)
def load_client_private_key(key_path):
with open(key_path, "rb") as key_file:
return serialization.load_pem_private_key(
key_file.read(),
password=None,
backend=default_backend()
)
def create_or_update_peer(client, public_key, allowed_ips):
data = { data = {
"action": "add_peer" if not public_key else "update_peer",
"peer": {
"PublicKey": public_key, "PublicKey": public_key,
"AllowedIPs": allowed_ips "AllowedIPs": allowed_ips
} }
}
if public_key:
data["public_key"] = public_key
# First, try to update an existing peer response = client.make_request("POST" if not public_key else "PUT", "/peers", data)
response = requests.put(f"{url}/{public_key}", json=data)
if response.status_code == 404:
# If the peer doesn't exist, create a new one
response = requests.post(url, json=data)
if response.status_code in (200, 201): if response.status_code in (200, 201):
result = response.json() print(response.json()['message'])
if "warning" in result:
print(f"Warning: {result['warning']}")
else:
print(result['message'])
else: else:
print(f"Error: {response.status_code} - {response.text}") print(f"Error: {response.status_code} - {response.text}")
def delete_peer(public_key): def delete_peer(client, public_key):
url = f"{BASE_URL}/peers/{public_key}" data = {
response = requests.delete(url) "action": "delete_peer",
"public_key": public_key
}
response = client.make_request("DELETE", "/peers", data)
if response.status_code == 200: if response.status_code == 200:
print("Peer deleted successfully.") print(response.json()['message'])
else: else:
print(f"Error: {response.status_code} - {response.text}") print(f"Error: {response.status_code} - {response.text}")
def list_peers(): def list_peers(client):
url = f"{BASE_URL}/peers" response = requests.get(f"{BASE_URL}/peers")
response = requests.get(url)
if response.status_code == 200: if response.status_code == 200:
peers = response.json() peers = response.json()
@ -47,38 +120,77 @@ def list_peers():
else: else:
print(f"Error: {response.status_code} - {response.text}") print(f"Error: {response.status_code} - {response.text}")
def restore_config(): def restore_config(client):
url = f"{BASE_URL}/restore" data = {
response = requests.post(url) "action": "restore"
}
response = client.make_request("POST", "/restore", data)
if response.status_code == 200: if response.status_code == 200:
print("Configuration restored successfully.") print(response.json()['message'])
else: else:
print(f"Error: {response.status_code} - {response.text}") print(f"Error: {response.status_code} - {response.text}")
def generate_client_keys():
private_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
public_key = private_key.public_key()
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
public_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
with open("client_private_key.pem", "wb") as f:
f.write(private_pem)
with open("client_public_key.pem", "wb") as f:
f.write(public_pem)
print("Client keys generated successfully.")
print("Private key saved to: client_private_key.pem")
print("Public key saved to: client_public_key.pem")
print("\nIMPORTANT: Add the following to the server's config.toml file:")
print("\n[client_keys]")
print(f'public_key = """\n{public_pem.decode()}"""')
print("\nAfter adding the key, restart the server for the changes to take effect.")
def main(): def main():
parser = argparse.ArgumentParser(description="WireGuard Config Manager Client") parser = argparse.ArgumentParser(description="WireGuard Config Manager Client")
parser.add_argument("action", choices=["create", "update", "delete", "list", "restore"], help="Action to perform") parser.add_argument("action", choices=["create", "update", "delete", "list", "restore", "generate_keys"], help="Action to perform")
parser.add_argument("--public-key", help="Public key of the peer") parser.add_argument("--public-key", help="Public key of the peer")
parser.add_argument("--allowed-ips", help="Allowed IPs for the peer") parser.add_argument("--allowed-ips", help="Allowed IPs for the peer")
parser.add_argument("--private-key", default="client_private_key.pem", help="Path to client's private key file")
args = parser.parse_args() args = parser.parse_args()
if args.action == "generate_keys":
generate_client_keys()
return
server_public_key = get_server_public_key()
client_private_key = load_client_private_key(args.private_key)
client = SecureApiClient(server_public_key, client_private_key)
if args.action in ["create", "update"]: if args.action in ["create", "update"]:
if not args.public_key or not args.allowed_ips: if not args.public_key or not args.allowed_ips:
print("Error: Both --public-key and --allowed-ips are required for create/update actions.") print("Error: Both --public-key and --allowed-ips are required for create/update actions.")
sys.exit(1) sys.exit(1)
create_or_update_peer(args.public_key, args.allowed_ips) create_or_update_peer(client, args.public_key, args.allowed_ips)
elif args.action == "delete": elif args.action == "delete":
if not args.public_key: if not args.public_key:
print("Error: --public-key is required for delete action.") print("Error: --public-key is required for delete action.")
sys.exit(1) sys.exit(1)
delete_peer(args.public_key) delete_peer(client, args.public_key)
elif args.action == "list": elif args.action == "list":
list_peers() list_peers(client)
elif args.action == "restore": elif args.action == "restore":
restore_config() restore_config(client)
if __name__ == "__main__": if __name__ == "__main__":
main() main()