feat: implement secure cluster authentication (issue #13)

Implemented a comprehensive secure authentication mechanism for inter-node
cluster communication with the following features:

1. Global Cluster Secret (GCS)
   - Auto-generated cryptographically secure random secret (256-bit)
   - Configurable via YAML config file
   - Shared across all cluster nodes for authentication

2. Cluster Authentication Middleware
   - Validates X-Cluster-Secret and X-Node-ID headers
   - Applied to all cluster endpoints (/members/*, /merkle_tree/*, /kv_range)
   - Comprehensive logging of authentication attempts

3. Authenticated HTTP Client
   - Custom HTTP client with cluster auth headers
   - TLS support with configurable certificate verification
   - Protocol-aware (http/https based on TLS settings)

4. Secure Bootstrap Endpoint
   - New /auth/cluster-bootstrap endpoint
   - Protected by JWT authentication with admin scope
   - Allows new nodes to securely obtain cluster secret

5. Updated Cluster Communication
   - All gossip protocol requests include auth headers
   - All Merkle tree sync requests include auth headers
   - All data replication requests include auth headers

6. Configuration
   - cluster_secret: Shared secret (auto-generated if not provided)
   - cluster_tls_enabled: Enable TLS for inter-node communication
   - cluster_tls_cert_file: Path to TLS certificate
   - cluster_tls_key_file: Path to TLS private key
   - cluster_tls_skip_verify: Skip TLS verification (testing only)

This implementation addresses the security vulnerability of unprotected
cluster endpoints and provides a flexible, secure approach to protecting
internal cluster communication while allowing for automated node bootstrapping.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-10-02 22:19:40 +03:00
parent 2431d3cfb0
commit c7dcebb894
28 changed files with 477 additions and 230 deletions

View File

@@ -142,4 +142,4 @@ func (s *BootstrapService) performGradualSync() {
}
s.logger.Info("Gradual sync completed")
}
}

View File

@@ -17,13 +17,13 @@ import (
// GossipService handles gossip protocol operations
type GossipService struct {
config *types.Config
members map[string]*types.Member
membersMu sync.RWMutex
logger *logrus.Logger
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
config *types.Config
members map[string]*types.Member
membersMu sync.RWMutex
logger *logrus.Logger
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// NewGossipService creates a new gossip service
@@ -44,7 +44,7 @@ func (s *GossipService) Start() {
s.logger.Info("Clustering disabled, skipping gossip routine")
return
}
s.wg.Add(1)
go s.gossipRoutine()
}
@@ -181,11 +181,20 @@ func (s *GossipService) gossipWithPeer(peer *types.Member) error {
return err
}
// Send HTTP request to peer
client := &http.Client{Timeout: 5 * time.Second}
url := fmt.Sprintf("http://%s/members/gossip", peer.Address)
// Send HTTP request to peer with cluster authentication
client := NewAuthenticatedHTTPClient(s.config, 5*time.Second)
protocol := GetProtocol(s.config)
url := fmt.Sprintf("%s://%s/members/gossip", protocol, peer.Address)
resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
s.logger.WithError(err).Error("Failed to create gossip request")
return err
}
req.Header.Set("Content-Type", "application/json")
AddClusterAuthHeaders(req, s.config)
resp, err := client.Do(req)
if err != nil {
s.logger.WithFields(logrus.Fields{
"peer": peer.Address,
@@ -300,4 +309,4 @@ func (s *GossipService) MergeMemberList(remoteMembers []types.Member, selfNodeID
func (s *GossipService) GetJoinedTimestamp() int64 {
// This should be implemented by the server that uses this service
return time.Now().UnixMilli()
}
}

43
cluster/http_client.go Normal file
View File

@@ -0,0 +1,43 @@
package cluster
import (
"crypto/tls"
"net/http"
"time"
"kvs/types"
)
// NewAuthenticatedHTTPClient creates an HTTP client configured for cluster authentication
func NewAuthenticatedHTTPClient(config *types.Config, timeout time.Duration) *http.Client {
client := &http.Client{
Timeout: timeout,
}
// Configure TLS if enabled
if config.ClusterTLSEnabled {
tlsConfig := &tls.Config{
InsecureSkipVerify: config.ClusterTLSSkipVerify,
}
client.Transport = &http.Transport{
TLSClientConfig: tlsConfig,
}
}
return client
}
// AddClusterAuthHeaders adds authentication headers to an HTTP request
func AddClusterAuthHeaders(req *http.Request, config *types.Config) {
req.Header.Set("X-Cluster-Secret", config.ClusterSecret)
req.Header.Set("X-Node-ID", config.NodeID)
}
// GetProtocol returns the appropriate protocol (http or https) based on TLS configuration
func GetProtocol(config *types.Config) string {
if config.ClusterTLSEnabled {
return "https"
}
return "http"
}

View File

@@ -170,7 +170,7 @@ func (s *MerkleService) BuildSubtreeForRange(startKey, endKey string) (*types.Me
if err != nil {
return nil, fmt.Errorf("failed to get KV pairs for subtree: %v", err)
}
filteredPairs := FilterPairsByRange(pairs, startKey, endKey)
return s.BuildMerkleTreeFromPairs(filteredPairs)
}
}

View File

@@ -51,11 +51,11 @@ func (s *SyncService) Start() {
s.logger.Info("Clustering disabled, skipping sync routines")
return
}
// Start sync routine
s.wg.Add(1)
go s.syncRoutine()
// Start Merkle tree rebuild routine
s.wg.Add(1)
go s.merkleTreeRebuildRoutine()
@@ -172,9 +172,9 @@ func (s *SyncService) performMerkleSync() {
// 2. Compare roots and start recursive diffing if they differ
if !bytes.Equal(localRoot.Hash, remoteRoot.Hash) {
s.logger.WithFields(logrus.Fields{
"peer": peer.Address,
"local_root": hex.EncodeToString(localRoot.Hash),
"remote_root": hex.EncodeToString(remoteRoot.Hash),
"peer": peer.Address,
"local_root": hex.EncodeToString(localRoot.Hash),
"remote_root": hex.EncodeToString(remoteRoot.Hash),
}).Info("Merkle roots differ, starting recursive diff")
s.diffMerkleTreesRecursive(peer.Address, localRoot, remoteRoot)
} else {
@@ -186,10 +186,17 @@ func (s *SyncService) performMerkleSync() {
// requestMerkleRoot requests the Merkle root from a peer
func (s *SyncService) requestMerkleRoot(peerAddress string) (*types.MerkleRootResponse, error) {
client := &http.Client{Timeout: 10 * time.Second}
url := fmt.Sprintf("http://%s/merkle_tree/root", peerAddress)
client := NewAuthenticatedHTTPClient(s.config, 10*time.Second)
protocol := GetProtocol(s.config)
url := fmt.Sprintf("%s://%s/merkle_tree/root", protocol, peerAddress)
resp, err := client.Get(url)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
AddClusterAuthHeaders(req, s.config)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
@@ -216,7 +223,7 @@ func (s *SyncService) diffMerkleTreesRecursive(peerAddress string, localNode, re
// Hashes differ, need to go deeper.
// Request children from the remote peer for the current range.
req := types.MerkleTreeDiffRequest{
ParentNode: *remoteNode, // We are asking the remote peer about its children for this range
ParentNode: *remoteNode, // We are asking the remote peer about its children for this range
LocalHash: localNode.Hash, // Our hash for this range
}
@@ -294,10 +301,17 @@ func (s *SyncService) handleLeafLevelDiff(peerAddress string, keys []string, loc
// fetchSingleKVFromPeer fetches a single KV pair from a peer
func (s *SyncService) fetchSingleKVFromPeer(peerAddress, path string) (*types.StoredValue, error) {
client := &http.Client{Timeout: 5 * time.Second}
url := fmt.Sprintf("http://%s/kv/%s", peerAddress, path)
client := NewAuthenticatedHTTPClient(s.config, 5*time.Second)
protocol := GetProtocol(s.config)
url := fmt.Sprintf("%s://%s/kv/%s", protocol, peerAddress, path)
resp, err := client.Get(url)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
AddClusterAuthHeaders(req, s.config)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
@@ -398,14 +412,14 @@ func (s *SyncService) resolveConflict(key string, local, remote *types.StoredVal
// Timestamps are equal - need sophisticated conflict resolution
s.logger.WithField("key", key).Info("Timestamp collision detected, applying oldest-node rule")
// Get cluster members to determine which node is older
members := s.gossipService.GetMembers()
// Find the local node and the remote node in membership
var localMember, remoteMember *types.Member
localNodeID := s.config.NodeID
for _, member := range members {
if member.ID == localNodeID {
localMember = member
@@ -414,16 +428,16 @@ func (s *SyncService) resolveConflict(key string, local, remote *types.StoredVal
remoteMember = member
}
}
// If we can't find membership info, fall back to UUID comparison for deterministic result
if localMember == nil || remoteMember == nil {
s.logger.WithFields(logrus.Fields{
"key": key,
"peerAddress": peerAddress,
"localNodeID": localNodeID,
"localMember": localMember != nil,
"remoteMember": remoteMember != nil,
"totalMembers": len(members),
"key": key,
"peerAddress": peerAddress,
"localNodeID": localNodeID,
"localMember": localMember != nil,
"remoteMember": remoteMember != nil,
"totalMembers": len(members),
}).Warn("Could not find membership info for conflict resolution, using UUID comparison")
if remote.UUID < local.UUID {
// Remote UUID lexically smaller (deterministic choice)
@@ -436,41 +450,49 @@ func (s *SyncService) resolveConflict(key string, local, remote *types.StoredVal
s.logger.WithField("key", key).Info("Conflict resolved: local data wins (UUID tie-breaker)")
return nil
}
// Apply oldest-node rule: node with earliest joined_timestamp wins
if remoteMember.JoinedTimestamp < localMember.JoinedTimestamp {
// Remote node is older, its data wins
err := s.storeReplicatedDataWithMetadata(key, remote)
if err == nil {
s.logger.WithFields(logrus.Fields{
"key": key,
"local_joined": localMember.JoinedTimestamp,
"remote_joined": remoteMember.JoinedTimestamp,
"key": key,
"local_joined": localMember.JoinedTimestamp,
"remote_joined": remoteMember.JoinedTimestamp,
}).Info("Conflict resolved: remote data wins (oldest-node rule)")
}
return err
}
// Local node is older or equal, keep local data
s.logger.WithFields(logrus.Fields{
"key": key,
"local_joined": localMember.JoinedTimestamp,
"remote_joined": remoteMember.JoinedTimestamp,
"key": key,
"local_joined": localMember.JoinedTimestamp,
"remote_joined": remoteMember.JoinedTimestamp,
}).Info("Conflict resolved: local data wins (oldest-node rule)")
return nil
}
// requestMerkleDiff requests children hashes or keys for a given node/range from a peer
func (s *SyncService) requestMerkleDiff(peerAddress string, req types.MerkleTreeDiffRequest) (*types.MerkleTreeDiffResponse, error) {
jsonData, err := json.Marshal(req)
func (s *SyncService) requestMerkleDiff(peerAddress string, reqData types.MerkleTreeDiffRequest) (*types.MerkleTreeDiffResponse, error) {
jsonData, err := json.Marshal(reqData)
if err != nil {
return nil, err
}
client := &http.Client{Timeout: 10 * time.Second}
url := fmt.Sprintf("http://%s/merkle_tree/diff", peerAddress)
client := NewAuthenticatedHTTPClient(s.config, 10*time.Second)
protocol := GetProtocol(s.config)
url := fmt.Sprintf("%s://%s/merkle_tree/diff", protocol, peerAddress)
resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
AddClusterAuthHeaders(req, s.config)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
@@ -525,20 +547,28 @@ func (s *SyncService) handleChildrenDiff(peerAddress string, children []types.Me
// fetchAndStoreRange fetches a range of KV pairs from a peer and stores them locally
func (s *SyncService) fetchAndStoreRange(peerAddress string, startKey, endKey string) error {
req := types.KVRangeRequest{
reqData := types.KVRangeRequest{
StartKey: startKey,
EndKey: endKey,
Limit: 0, // No limit
}
jsonData, err := json.Marshal(req)
jsonData, err := json.Marshal(reqData)
if err != nil {
return err
}
client := &http.Client{Timeout: 30 * time.Second} // Longer timeout for range fetches
url := fmt.Sprintf("http://%s/kv_range", peerAddress)
client := NewAuthenticatedHTTPClient(s.config, 30*time.Second) // Longer timeout for range fetches
protocol := GetProtocol(s.config)
url := fmt.Sprintf("%s://%s/kv_range", protocol, peerAddress)
resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
AddClusterAuthHeaders(req, s.config)
resp, err := client.Do(req)
if err != nil {
return err
}
@@ -568,4 +598,4 @@ func (s *SyncService) fetchAndStoreRange(peerAddress string, startKey, endKey st
}
}
return nil
}
}