diff --git a/integration_test.sh b/integration_test.sh index 464087f..22305cf 100755 --- a/integration_test.sh +++ b/integration_test.sh @@ -1,7 +1,7 @@ #!/bin/bash -# KVS Integration Test Suite - Working Version -# Tests all critical features of the distributed key-value store +# KVS Integration Test Suite - Adapted for Merkle Tree Sync +# Tests all critical features of the distributed key-value store with Merkle Tree replication # Colors for output RED='\033[0;31m' @@ -43,7 +43,7 @@ test_start() { # Cleanup function cleanup() { log_info "Cleaning up test environment..." - pkill -f "./kvs" 2>/dev/null || true + pkill -f "$BINARY" 2>/dev/null || true rm -rf "$TEST_DIR" 2>/dev/null || true sleep 2 # Allow processes to fully terminate } @@ -75,6 +75,7 @@ test_build() { log_error "Binary build failed" return 1 fi + # Ensure we are back in TEST_DIR for subsequent tests cd "$TEST_DIR" } @@ -103,12 +104,12 @@ EOF -d '{"message":"hello world"}') local get_result=$(curl -s http://localhost:8090/kv/test/basic) - local message=$(echo "$get_result" | jq -r '.message' 2>/dev/null) + local message=$(echo "$get_result" | jq -r '.data.message' 2>/dev/null) # Adjusted jq path if [ "$message" = "hello world" ]; then log_success "Basic CRUD operations work" else - log_error "Basic CRUD failed: $get_result" + log_error "Basic CRUD failed: Expected 'hello world', got '$message' from $get_result" fi else log_error "Basic test node failed to start" @@ -120,7 +121,7 @@ EOF # Test 3: Cluster formation test_cluster_formation() { - test_start "2-node cluster formation" + test_start "2-node cluster formation and Merkle Tree replication" # Node 1 config cat > cluster1.yaml </dev/null 2>&1 & local pid2=$! @@ -168,28 +169,46 @@ EOF return 1 fi - # Wait for cluster formation - sleep 8 + # Wait for cluster formation and initial Merkle sync + sleep 15 # Check if nodes see each other local node1_members=$(curl -s http://localhost:8101/members/ | jq length 2>/dev/null || echo 0) local node2_members=$(curl -s http://localhost:8102/members/ | jq length 2>/dev/null || echo 0) if [ "$node1_members" -ge 1 ] && [ "$node2_members" -ge 1 ]; then - log_success "2-node cluster formed successfully" + log_success "2-node cluster formed successfully (N1 members: $node1_members, N2 members: $node2_members)" # Test data replication + log_info "Putting data on Node 1, waiting for Merkle sync..." curl -s -X PUT http://localhost:8101/kv/cluster/test \ -H "Content-Type: application/json" \ - -d '{"source":"node1"}' >/dev/null + -d '{"source":"node1", "value": 1}' >/dev/null - sleep 12 # Wait for sync cycle + # Wait for Merkle sync cycle to complete + sleep 12 - local node2_data=$(curl -s http://localhost:8102/kv/cluster/test | jq -r '.source' 2>/dev/null) - if [ "$node2_data" = "node1" ]; then - log_success "Data replication works correctly" + local node2_data_full=$(curl -s http://localhost:8102/kv/cluster/test) + local node2_data_source=$(echo "$node2_data_full" | jq -r '.data.source' 2>/dev/null) + local node2_data_value=$(echo "$node2_data_full" | jq -r '.data.value' 2>/dev/null) + local node1_data_full=$(curl -s http://localhost:8101/kv/cluster/test) + + if [ "$node2_data_source" = "node1" ] && [ "$node2_data_value" = "1" ]; then + log_success "Data replication works correctly (Node 2 has data from Node 1)" + + # Verify UUIDs and Timestamps are identical (crucial for Merkle sync correctness) + local node1_uuid=$(echo "$node1_data_full" | jq -r '.uuid' 2>/dev/null) + local node1_timestamp=$(echo "$node1_data_full" | jq -r '.timestamp' 2>/dev/null) + local node2_uuid=$(echo "$node2_data_full" | jq -r '.uuid' 2>/dev/null) + local node2_timestamp=$(echo "$node2_data_full" | jq -r '.timestamp' 2>/dev/null) + + if [ "$node1_uuid" = "$node2_uuid" ] && [ "$node1_timestamp" = "$node2_timestamp" ]; then + log_success "Replicated data retains original UUID and Timestamp" + else + log_error "Replicated data changed UUID/Timestamp: N1_UUID=$node1_uuid, N1_TS=$node1_timestamp, N2_UUID=$node2_uuid, N2_TS=$node2_timestamp" + fi else - log_error "Data replication failed: $node2_data" + log_error "Data replication failed: Node 2 data: $node2_data_full" fi else log_error "Cluster formation failed (N1 members: $node1_members, N2 members: $node2_members)" @@ -199,9 +218,12 @@ EOF sleep 2 } -# Test 4: Conflict resolution (simplified) +# Test 4: Conflict resolution (Merkle Tree based) +# This test assumes 'test_conflict.go' creates two BadgerDBs with a key +# that has the same path and timestamp but different UUIDs, or different timestamps +# but same path. The Merkle tree sync should then trigger conflict resolution. test_conflict_resolution() { - test_start "Conflict resolution test" + test_start "Conflict resolution test (Merkle Tree based)" # Create conflicting data using our utility rm -rf conflict1_data conflict2_data 2>/dev/null || true @@ -233,7 +255,8 @@ sync_interval: 8 EOF # Start nodes - $BINARY conflict1.yaml >conflict1.log 2>&1 & + # Node 1 started first, making it "older" for tie-breaker if timestamps are equal + "$BINARY" conflict1.yaml >conflict1.log 2>&1 & local pid1=$! if wait_for_service 8111; then @@ -242,26 +265,50 @@ EOF local pid2=$! if wait_for_service 8112; then - # Get initial data - local node1_initial=$(curl -s http://localhost:8111/kv/test/conflict/data | jq -r '.message' 2>/dev/null) - local node2_initial=$(curl -s http://localhost:8112/kv/test/conflict/data | jq -r '.message' 2>/dev/null) + # Get initial data (full StoredValue) + local node1_initial_full=$(curl -s http://localhost:8111/kv/test/conflict/data) + local node2_initial_full=$(curl -s http://localhost:8112/kv/test/conflict/data) - # Wait for conflict resolution - sleep 12 + local node1_initial_msg=$(echo "$node1_initial_full" | jq -r '.data.message' 2>/dev/null) + local node2_initial_msg=$(echo "$node2_initial_full" | jq -r '.data.message' 2>/dev/null) - # Get final data - local node1_final=$(curl -s http://localhost:8111/kv/test/conflict/data | jq -r '.message' 2>/dev/null) - local node2_final=$(curl -s http://localhost:8112/kv/test/conflict/data | jq -r '.message' 2>/dev/null) + log_info "Initial conflict state: Node1='$node1_initial_msg', Node2='$node2_initial_msg'" + + # Wait for conflict resolution (multiple sync cycles might be needed) + sleep 20 + + # Get final data (full StoredValue) + local node1_final_full=$(curl -s http://localhost:8111/kv/test/conflict/data) + local node2_final_full=$(curl -s http://localhost:8112/kv/test/conflict/data) + + local node1_final_msg=$(echo "$node1_final_full" | jq -r '.data.message' 2>/dev/null) + local node2_final_msg=$(echo "$node2_final_full" | jq -r '.data.message' 2>/dev/null) # Check if they converged - if [ "$node1_final" = "$node2_final" ] && [ -n "$node1_final" ]; then - if grep -q "conflict resolution" conflict1.log conflict2.log 2>/dev/null; then - log_success "Conflict resolution detected and resolved ($node1_initial vs $node2_initial → $node1_final)" + if [ "$node1_final_msg" = "$node2_final_msg" ] && [ -n "$node1_final_msg" ]; then + log_success "Conflict resolution converged to: '$node1_final_msg'" + + # Verify UUIDs and Timestamps are identical after resolution + local node1_final_uuid=$(echo "$node1_final_full" | jq -r '.uuid' 2>/dev/null) + local node1_final_timestamp=$(echo "$node1_final_full" | jq -r '.timestamp' 2>/dev/null) + local node2_final_uuid=$(echo "$node2_final_full" | jq -r '.uuid' 2>/dev/null) + local node2_final_timestamp=$(echo "$node2_final_full" | jq -r '.timestamp' 2>/dev/null) + + if [ "$node1_final_uuid" = "$node2_final_uuid" ] && [ "$node1_final_timestamp" = "$node2_final_timestamp" ]; then + log_success "Resolved data retains consistent UUID and Timestamp across nodes" else - log_success "Nodes converged without conflicts ($node1_final)" + log_error "Resolved data has inconsistent UUID/Timestamp: N1_UUID=$node1_final_uuid, N1_TS=$node1_final_timestamp, N2_UUID=$node2_final_uuid, N2_TS=$node2_final_timestamp" fi + + # Optionally, check logs for conflict resolution messages + if grep -q "Conflict resolved" conflict1.log conflict2.log 2>/dev/null; then + log_success "Conflict resolution messages found in logs" + else + log_error "No 'Conflict resolved' messages found in logs, but data converged." + fi + else - log_error "Conflict resolution failed: N1='$node1_final', N2='$node2_final'" + log_error "Conflict resolution failed: N1_final='$node1_final_msg', N2_final='$node2_final_msg'" fi else log_error "Conflict node 2 failed to start" @@ -276,14 +323,14 @@ EOF sleep 2 else cd "$TEST_DIR" - log_error "Failed to create conflict test data" + log_error "Failed to create conflict test data. Ensure test_conflict.go is correct." fi } # Main test execution main() { echo "==================================================" - echo " KVS Integration Test Suite" + echo " KVS Integration Test Suite (Merkle Tree)" echo "==================================================" # Setup @@ -308,7 +355,7 @@ main() { echo "==================================================" if [ $TESTS_FAILED -eq 0 ]; then - echo -e "${GREEN}🎉 All tests passed! KVS is working correctly.${NC}" + echo -e "${GREEN}🎉 All tests passed! KVS with Merkle Tree sync is working correctly.${NC}" cleanup exit 0 else @@ -322,4 +369,4 @@ main() { trap cleanup INT TERM # Run tests -main "$@" \ No newline at end of file +main "$@" diff --git a/main.go b/main.go index a888549..ea5b07e 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,8 @@ package main import ( "bytes" "context" + "crypto/sha256" // Added for Merkle Tree hashing + "encoding/hex" // Added for logging hashes "encoding/json" "fmt" "math/rand" @@ -11,6 +13,7 @@ import ( "os" "os/signal" "path/filepath" + "sort" // Added for sorting keys in Merkle Tree "strconv" "strings" "sync" @@ -66,37 +69,77 @@ type PutResponse struct { Timestamp int64 `json:"timestamp"` } +// Merkle Tree specific data structures +type MerkleNode struct { + Hash []byte `json:"hash"` + StartKey string `json:"start_key"` // The first key in this node's range + EndKey string `json:"end_key"` // The last key in this node's range +} + +// MerkleRootResponse is the response for getting the root hash +type MerkleRootResponse struct { + Root *MerkleNode `json:"root"` +} + +// MerkleTreeDiffRequest is used to request children hashes for a given key range +type MerkleTreeDiffRequest struct { + ParentNode MerkleNode `json:"parent_node"` // The node whose children we want to compare (from the remote peer's perspective) + LocalHash []byte `json:"local_hash"` // The local hash of this node/range (from the requesting peer's perspective) +} + +// MerkleTreeDiffResponse returns the remote children nodes or the actual keys if it's a leaf level +type MerkleTreeDiffResponse struct { + Children []MerkleNode `json:"children,omitempty"` // Children of the remote node + Keys []string `json:"keys,omitempty"` // Actual keys if this is a leaf-level diff +} + +// For fetching a range of KV pairs +type KVRangeRequest struct { + StartKey string `json:"start_key"` + EndKey string `json:"end_key"` + Limit int `json:"limit"` // Max number of items to return +} + +type KVRangeResponse struct { + Pairs []struct { + Path string `json:"path"` + StoredValue StoredValue `json:"stored_value"` + } `json:"pairs"` +} + // Configuration type Config struct { - NodeID string `yaml:"node_id"` - BindAddress string `yaml:"bind_address"` - Port int `yaml:"port"` - DataDir string `yaml:"data_dir"` - SeedNodes []string `yaml:"seed_nodes"` - ReadOnly bool `yaml:"read_only"` - LogLevel string `yaml:"log_level"` - GossipIntervalMin int `yaml:"gossip_interval_min"` - GossipIntervalMax int `yaml:"gossip_interval_max"` - SyncInterval int `yaml:"sync_interval"` - CatchupInterval int `yaml:"catchup_interval"` - BootstrapMaxAgeHours int `yaml:"bootstrap_max_age_hours"` - ThrottleDelayMs int `yaml:"throttle_delay_ms"` - FetchDelayMs int `yaml:"fetch_delay_ms"` + NodeID string `yaml:"node_id"` + BindAddress string `yaml:"bind_address"` + Port int `yaml:"port"` + DataDir string `yaml:"data_dir"` + SeedNodes []string `yaml:"seed_nodes"` + ReadOnly bool `yaml:"read_only"` + LogLevel string `yaml:"log_level"` + GossipIntervalMin int `yaml:"gossip_interval_min"` + GossipIntervalMax int `yaml:"gossip_interval_max"` + SyncInterval int `yaml:"sync_interval"` + CatchupInterval int `yaml:"catchup_interval"` + BootstrapMaxAgeHours int `yaml:"bootstrap_max_age_hours"` + ThrottleDelayMs int `yaml:"throttle_delay_ms"` + FetchDelayMs int `yaml:"fetch_delay_ms"` } // Server represents the KVS node type Server struct { - config *Config - db *badger.DB - members map[string]*Member - membersMu sync.RWMutex - mode string // "normal", "read-only", "syncing" - modeMu sync.RWMutex - logger *logrus.Logger - httpServer *http.Server - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup + config *Config + db *badger.DB + members map[string]*Member + membersMu sync.RWMutex + mode string // "normal", "read-only", "syncing" + modeMu sync.RWMutex + logger *logrus.Logger + httpServer *http.Server + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + merkleRoot *MerkleNode // Added for Merkle Tree + merkleRootMu sync.RWMutex // Protects merkleRoot } // Default configuration @@ -123,35 +166,35 @@ func defaultConfig() *Config { // Load configuration from file or create default func loadConfig(configPath string) (*Config, error) { config := defaultConfig() - + if _, err := os.Stat(configPath); os.IsNotExist(err) { // Create default config file if err := os.MkdirAll(filepath.Dir(configPath), 0755); err != nil { return nil, fmt.Errorf("failed to create config directory: %v", err) } - + data, err := yaml.Marshal(config) if err != nil { return nil, fmt.Errorf("failed to marshal default config: %v", err) } - + if err := os.WriteFile(configPath, data, 0644); err != nil { return nil, fmt.Errorf("failed to write default config: %v", err) } - + fmt.Printf("Created default configuration at %s\n", configPath) return config, nil } - + data, err := os.ReadFile(configPath) if err != nil { return nil, fmt.Errorf("failed to read config file: %v", err) } - + if err := yaml.Unmarshal(data, config); err != nil { return nil, fmt.Errorf("failed to parse config file: %v", err) } - + return config, nil } @@ -159,18 +202,18 @@ func loadConfig(configPath string) (*Config, error) { func NewServer(config *Config) (*Server, error) { logger := logrus.New() logger.SetFormatter(&logrus.JSONFormatter{}) - + level, err := logrus.ParseLevel(config.LogLevel) if err != nil { level = logrus.InfoLevel } logger.SetLevel(level) - + // Create data directory if err := os.MkdirAll(config.DataDir, 0755); err != nil { return nil, fmt.Errorf("failed to create data directory: %v", err) } - + // Open BadgerDB opts := badger.DefaultOptions(filepath.Join(config.DataDir, "badger")) opts.Logger = nil // Disable badger's internal logging @@ -178,9 +221,9 @@ func NewServer(config *Config) (*Server, error) { if err != nil { return nil, fmt.Errorf("failed to open BadgerDB: %v", err) } - + ctx, cancel := context.WithCancel(context.Background()) - + server := &Server{ config: config, db: db, @@ -190,11 +233,23 @@ func NewServer(config *Config) (*Server, error) { ctx: ctx, cancel: cancel, } - + if config.ReadOnly { server.setMode("read-only") } - + + // Build initial Merkle tree + pairs, err := server.getAllKVPairsForMerkleTree() + if err != nil { + return nil, fmt.Errorf("failed to get all KV pairs for initial Merkle tree: %v", err) + } + root, err := server.buildMerkleTreeFromPairs(pairs) + if err != nil { + return nil, fmt.Errorf("failed to build initial Merkle tree: %v", err) + } + server.setMerkleRoot(root) + server.logger.Info("Initial Merkle tree built.") + return server, nil } @@ -253,14 +308,14 @@ func (s *Server) getMembers() []*Member { func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) { mode := s.getMode() memberCount := len(s.getMembers()) - + health := map[string]interface{}{ "status": "ok", "mode": mode, "member_count": memberCount, "node_id": s.config.NodeID, } - + w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(health) } @@ -268,19 +323,19 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) { func (s *Server) getKVHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) path := vars["path"] - + var storedValue StoredValue err := s.db.View(func(txn *badger.Txn) error { item, err := txn.Get([]byte(path)) if err != nil { return err } - + return item.Value(func(val []byte) error { return json.Unmarshal(val, &storedValue) }) }) - + if err == badger.ErrKeyNotFound { http.Error(w, "Not Found", http.StatusNotFound) return @@ -290,84 +345,85 @@ func (s *Server) getKVHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } - + w.Header().Set("Content-Type", "application/json") - w.Write(storedValue.Data) + // CHANGE: Return the entire StoredValue, not just Data + json.NewEncoder(w).Encode(storedValue) } func (s *Server) putKVHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) path := vars["path"] - + mode := s.getMode() if mode == "syncing" { http.Error(w, "Service Unavailable", http.StatusServiceUnavailable) return } - + if mode == "read-only" && !s.isClusterMember(r.RemoteAddr) { http.Error(w, "Forbidden", http.StatusForbidden) return } - + var data json.RawMessage if err := json.NewDecoder(r.Body).Decode(&data); err != nil { http.Error(w, "Bad Request", http.StatusBadRequest) return } - + now := time.Now().UnixMilli() newUUID := uuid.New().String() - + storedValue := StoredValue{ UUID: newUUID, Timestamp: now, Data: data, } - + valueBytes, err := json.Marshal(storedValue) if err != nil { s.logger.WithError(err).Error("Failed to marshal stored value") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } - + var isUpdate bool err = s.db.Update(func(txn *badger.Txn) error { // Check if key exists _, err := txn.Get([]byte(path)) isUpdate = (err == nil) - + // Store main data if err := txn.Set([]byte(path), valueBytes); err != nil { return err } - + // Store timestamp index indexKey := fmt.Sprintf("_ts:%020d:%s", now, path) return txn.Set([]byte(indexKey), []byte(newUUID)) }) - + if err != nil { s.logger.WithError(err).WithField("path", path).Error("Failed to store value") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } - + response := PutResponse{ UUID: newUUID, Timestamp: now, } - + status := http.StatusCreated if isUpdate { status = http.StatusOK } - + w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(response) - + s.logger.WithFields(logrus.Fields{ "path": path, "uuid": newUUID, @@ -379,18 +435,18 @@ func (s *Server) putKVHandler(w http.ResponseWriter, r *http.Request) { func (s *Server) deleteKVHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) path := vars["path"] - + mode := s.getMode() if mode == "syncing" { http.Error(w, "Service Unavailable", http.StatusServiceUnavailable) return } - + if mode == "read-only" && !s.isClusterMember(r.RemoteAddr) { http.Error(w, "Forbidden", http.StatusForbidden) return } - + var found bool err := s.db.Update(func(txn *badger.Txn) error { // Check if key exists and get timestamp for index cleanup @@ -402,7 +458,7 @@ func (s *Server) deleteKVHandler(w http.ResponseWriter, r *http.Request) { return err } found = true - + var storedValue StoredValue err = item.Value(func(val []byte) error { return json.Unmarshal(val, &storedValue) @@ -410,36 +466,36 @@ func (s *Server) deleteKVHandler(w http.ResponseWriter, r *http.Request) { if err != nil { return err } - + // Delete main data if err := txn.Delete([]byte(path)); err != nil { return err } - + // Delete timestamp index indexKey := fmt.Sprintf("_ts:%020d:%s", storedValue.Timestamp, path) return txn.Delete([]byte(indexKey)) }) - + if err != nil { s.logger.WithError(err).WithField("path", path).Error("Failed to delete value") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } - + if !found { http.Error(w, "Not Found", http.StatusNotFound) return } - + w.WriteHeader(http.StatusNoContent) - + s.logger.WithField("path", path).Info("Value deleted") } func (s *Server) getMembersHandler(w http.ResponseWriter, r *http.Request) { members := s.getMembers() - + w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(members) } @@ -450,7 +506,7 @@ func (s *Server) joinMemberHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Bad Request", http.StatusBadRequest) return } - + now := time.Now().UnixMilli() member := &Member{ ID: req.ID, @@ -458,9 +514,9 @@ func (s *Server) joinMemberHandler(w http.ResponseWriter, r *http.Request) { LastSeen: now, JoinedTimestamp: req.JoinedTimestamp, } - + s.addMember(member) - + // Return current member list members := s.getMembers() w.Header().Set("Content-Type", "application/json") @@ -473,7 +529,7 @@ func (s *Server) leaveMemberHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Bad Request", http.StatusBadRequest) return } - + s.removeMember(req.ID) w.WriteHeader(http.StatusNoContent) } @@ -484,40 +540,40 @@ func (s *Server) pairsByTimeHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Bad Request", http.StatusBadRequest) return } - + // Default limit to 15 as per spec if req.Limit <= 0 { req.Limit = 15 } - + var pairs []PairsByTimeResponse - + err := s.db.View(func(txn *badger.Txn) error { opts := badger.DefaultIteratorOptions opts.PrefetchSize = req.Limit it := txn.NewIterator(opts) defer it.Close() - + prefix := []byte("_ts:") - if req.Prefix != "" { - // We need to scan through timestamp entries and filter by path prefix - } + // The original logic for prefix filtering here was incomplete. + // For Merkle tree sync, this handler is no longer used for core sync. + // It remains as a client-facing API. for it.Seek(prefix); it.ValidForPrefix(prefix) && len(pairs) < req.Limit; it.Next() { item := it.Item() key := string(item.Key()) - + // Parse timestamp index key: "_ts:{timestamp}:{path}" parts := strings.SplitN(key, ":", 3) if len(parts) != 3 { continue } - + timestamp, err := strconv.ParseInt(parts[1], 10, 64) if err != nil { continue } - + // Filter by timestamp range if req.StartTimestamp > 0 && timestamp < req.StartTimestamp { continue @@ -525,14 +581,14 @@ func (s *Server) pairsByTimeHandler(w http.ResponseWriter, r *http.Request) { if req.EndTimestamp > 0 && timestamp >= req.EndTimestamp { continue } - + path := parts[2] - + // Filter by prefix if specified if req.Prefix != "" && !strings.HasPrefix(path, req.Prefix) { continue } - + var uuid string err = item.Value(func(val []byte) error { uuid = string(val) @@ -541,28 +597,28 @@ func (s *Server) pairsByTimeHandler(w http.ResponseWriter, r *http.Request) { if err != nil { continue } - + pairs = append(pairs, PairsByTimeResponse{ Path: path, UUID: uuid, Timestamp: timestamp, }) } - + return nil }) - + if err != nil { s.logger.WithError(err).Error("Failed to query pairs by time") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } - + if len(pairs) == 0 { w.WriteHeader(http.StatusNoContent) return } - + w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(pairs) } @@ -573,17 +629,17 @@ func (s *Server) gossipHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Bad Request", http.StatusBadRequest) return } - + // Merge the received member list s.mergeMemberList(remoteMemberList) - + // Respond with our current member list localMembers := s.getMembers() gossipResponse := make([]Member, len(localMembers)) for i, member := range localMembers { gossipResponse[i] = *member } - + // Add ourselves to the response selfMember := Member{ ID: s.config.NodeID, @@ -592,10 +648,10 @@ func (s *Server) gossipHandler(w http.ResponseWriter, r *http.Request) { JoinedTimestamp: s.getJoinedTimestamp(), } gossipResponse = append(gossipResponse, selfMember) - + w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(gossipResponse) - + s.logger.WithField("remote_members", len(remoteMemberList)).Debug("Processed gossip request") } @@ -605,110 +661,119 @@ func (s *Server) isClusterMember(remoteAddr string) bool { if err != nil { return false } - + s.membersMu.RLock() defer s.membersMu.RUnlock() - + for _, member := range s.members { memberHost, _, err := net.SplitHostPort(member.Address) if err == nil && memberHost == host { return true } } - + return false } // Setup HTTP routes func (s *Server) setupRoutes() *mux.Router { router := mux.NewRouter() - + // Health endpoint router.HandleFunc("/health", s.healthHandler).Methods("GET") - + // KV endpoints router.HandleFunc("/kv/{path:.+}", s.getKVHandler).Methods("GET") router.HandleFunc("/kv/{path:.+}", s.putKVHandler).Methods("PUT") router.HandleFunc("/kv/{path:.+}", s.deleteKVHandler).Methods("DELETE") - + // Member endpoints router.HandleFunc("/members/", s.getMembersHandler).Methods("GET") router.HandleFunc("/members/join", s.joinMemberHandler).Methods("POST") router.HandleFunc("/members/leave", s.leaveMemberHandler).Methods("DELETE") router.HandleFunc("/members/gossip", s.gossipHandler).Methods("POST") - router.HandleFunc("/members/pairs_by_time", s.pairsByTimeHandler).Methods("POST") - + router.HandleFunc("/members/pairs_by_time", s.pairsByTimeHandler).Methods("POST") // Still available for clients + + // Merkle Tree endpoints + router.HandleFunc("/merkle_tree/root", s.getMerkleRootHandler).Methods("GET") + router.HandleFunc("/merkle_tree/diff", s.getMerkleDiffHandler).Methods("POST") + router.HandleFunc("/kv_range", s.getKVRangeHandler).Methods("POST") // New endpoint for fetching ranges + return router } // Start the server func (s *Server) Start() error { router := s.setupRoutes() - + addr := fmt.Sprintf("%s:%d", s.config.BindAddress, s.config.Port) s.httpServer = &http.Server{ Addr: addr, Handler: router, } - + s.logger.WithFields(logrus.Fields{ "node_id": s.config.NodeID, "address": addr, }).Info("Starting KVS server") - + // Start gossip and sync routines s.startBackgroundTasks() - + // Try to join cluster if seed nodes are configured if len(s.config.SeedNodes) > 0 { go s.bootstrap() } - + return s.httpServer.ListenAndServe() } // Stop the server gracefully func (s *Server) Stop() error { s.logger.Info("Shutting down KVS server") - + s.cancel() s.wg.Wait() - + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - + if err := s.httpServer.Shutdown(ctx); err != nil { s.logger.WithError(err).Error("HTTP server shutdown error") } - + if err := s.db.Close(); err != nil { s.logger.WithError(err).Error("BadgerDB close error") } - + return nil } -// Background tasks (gossip, sync, etc.) +// Background tasks (gossip, sync, Merkle tree rebuild, etc.) func (s *Server) startBackgroundTasks() { // Start gossip routine s.wg.Add(1) go s.gossipRoutine() - - // Start sync routine + + // Start sync routine (now Merkle-based) s.wg.Add(1) go s.syncRoutine() + + // Start Merkle tree rebuild routine + s.wg.Add(1) + go s.merkleTreeRebuildRoutine() } // Gossip routine - runs periodically to exchange member lists func (s *Server) gossipRoutine() { defer s.wg.Done() - + for { // Random interval between 1-2 minutes minInterval := time.Duration(s.config.GossipIntervalMin) * time.Second maxInterval := time.Duration(s.config.GossipIntervalMax) * time.Second interval := minInterval + time.Duration(rand.Int63n(int64(maxInterval-minInterval))) - + select { case <-s.ctx.Done(): return @@ -725,20 +790,20 @@ func (s *Server) performGossipRound() { s.logger.Debug("No healthy members for gossip round") return } - + // Select 1-3 random peers for gossip maxPeers := 3 if len(members) < maxPeers { maxPeers = len(members) } - + // Shuffle and select rand.Shuffle(len(members), func(i, j int) { members[i], members[j] = members[j], members[i] }) - + selectedPeers := members[:rand.Intn(maxPeers)+1] - + for _, peer := range selectedPeers { go s.gossipWithPeer(peer) } @@ -747,16 +812,16 @@ func (s *Server) performGossipRound() { // Gossip with a specific peer func (s *Server) gossipWithPeer(peer *Member) { s.logger.WithField("peer", peer.Address).Debug("Starting gossip with peer") - + // Get our current member list localMembers := s.getMembers() - + // Send our member list to the peer gossipData := make([]Member, len(localMembers)) for i, member := range localMembers { gossipData[i] = *member } - + // Add ourselves to the list selfMember := Member{ ID: s.config.NodeID, @@ -765,17 +830,17 @@ func (s *Server) gossipWithPeer(peer *Member) { JoinedTimestamp: s.getJoinedTimestamp(), } gossipData = append(gossipData, selfMember) - + jsonData, err := json.Marshal(gossipData) if err != nil { s.logger.WithError(err).Error("Failed to marshal gossip data") return } - + // Send HTTP request to peer client := &http.Client{Timeout: 5 * time.Second} url := fmt.Sprintf("http://%s/members/gossip", peer.Address) - + resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData)) if err != nil { s.logger.WithFields(logrus.Fields{ @@ -786,7 +851,7 @@ func (s *Server) gossipWithPeer(peer *Member) { return } defer resp.Body.Close() - + if resp.StatusCode != http.StatusOK { s.logger.WithFields(logrus.Fields{ "peer": peer.Address, @@ -795,20 +860,20 @@ func (s *Server) gossipWithPeer(peer *Member) { s.markPeerUnhealthy(peer.ID) return } - + // Process response - peer's member list var remoteMemberList []Member if err := json.NewDecoder(resp.Body).Decode(&remoteMemberList); err != nil { s.logger.WithError(err).Error("Failed to decode gossip response") return } - + // Merge remote member list with our local list s.mergeMemberList(remoteMemberList) - + // Update peer's last seen timestamp s.updateMemberLastSeen(peer.ID, time.Now().UnixMilli()) - + s.logger.WithField("peer", peer.Address).Debug("Completed gossip with peer") } @@ -816,17 +881,17 @@ func (s *Server) gossipWithPeer(peer *Member) { func (s *Server) getHealthyMembers() []*Member { s.membersMu.RLock() defer s.membersMu.RUnlock() - + now := time.Now().UnixMilli() healthyMembers := make([]*Member, 0) - + for _, member := range s.members { // Consider member healthy if last seen within last 5 minutes if now-member.LastSeen < 5*60*1000 { healthyMembers = append(healthyMembers, member) } } - + return healthyMembers } @@ -834,7 +899,7 @@ func (s *Server) getHealthyMembers() []*Member { func (s *Server) markPeerUnhealthy(nodeID string) { s.membersMu.Lock() defer s.membersMu.Unlock() - + if member, exists := s.members[nodeID]; exists { // Mark as last seen a long time ago to indicate unhealthy member.LastSeen = time.Now().UnixMilli() - 10*60*1000 // 10 minutes ago @@ -846,7 +911,7 @@ func (s *Server) markPeerUnhealthy(nodeID string) { func (s *Server) updateMemberLastSeen(nodeID string, timestamp int64) { s.membersMu.Lock() defer s.membersMu.Unlock() - + if member, exists := s.members[nodeID]; exists { member.LastSeen = timestamp } @@ -856,15 +921,15 @@ func (s *Server) updateMemberLastSeen(nodeID string, timestamp int64) { func (s *Server) mergeMemberList(remoteMembers []Member) { s.membersMu.Lock() defer s.membersMu.Unlock() - + now := time.Now().UnixMilli() - + for _, remoteMember := range remoteMembers { // Skip ourselves if remoteMember.ID == s.config.NodeID { continue } - + if localMember, exists := s.members[remoteMember.ID]; exists { // Update existing member if remoteMember.LastSeen > localMember.LastSeen { @@ -889,7 +954,7 @@ func (s *Server) mergeMemberList(remoteMembers []Member) { }).Info("Discovered new member through gossip") } } - + // Clean up old members (not seen for more than 10 minutes) toRemove := make([]string, 0) for nodeID, member := range s.members { @@ -897,7 +962,7 @@ func (s *Server) mergeMemberList(remoteMembers []Member) { toRemove = append(toRemove, nodeID) } } - + for _, nodeID := range toRemove { delete(s.members, nodeID) s.logger.WithField("node_id", nodeID).Info("Removed stale member") @@ -913,222 +978,699 @@ func (s *Server) getJoinedTimestamp() int64 { // Sync routine - handles regular and catch-up syncing func (s *Server) syncRoutine() { defer s.wg.Done() - + syncTicker := time.NewTicker(time.Duration(s.config.SyncInterval) * time.Second) defer syncTicker.Stop() - + for { select { case <-s.ctx.Done(): return case <-syncTicker.C: - s.performRegularSync() + s.performMerkleSync() // Use Merkle sync instead of regular sync } } } -// Perform regular 5-minute sync -func (s *Server) performRegularSync() { - members := s.getHealthyMembers() - if len(members) == 0 { - s.logger.Debug("No healthy members for sync") +// Merkle Tree Implementation + +// calculateHash generates a SHA256 hash for a given byte slice +func calculateHash(data []byte) []byte { + h := sha256.New() + h.Write(data) + return h.Sum(nil) +} + +// calculateLeafHash generates a hash for a leaf node based on its path, UUID, timestamp, and data +func (s *Server) calculateLeafHash(path string, storedValue *StoredValue) []byte { + // Concatenate path, UUID, timestamp, and the raw data bytes for hashing + // Ensure a consistent order of fields for hashing + dataToHash := bytes.Buffer{} + dataToHash.WriteString(path) + dataToHash.WriteByte(':') + dataToHash.WriteString(storedValue.UUID) + dataToHash.WriteByte(':') + dataToHash.WriteString(strconv.FormatInt(storedValue.Timestamp, 10)) + dataToHash.WriteByte(':') + dataToHash.Write(storedValue.Data) // Use raw bytes of json.RawMessage + + return calculateHash(dataToHash.Bytes()) +} + +// getAllKVPairsForMerkleTree retrieves all key-value pairs needed for Merkle tree construction. +func (s *Server) getAllKVPairsForMerkleTree() (map[string]*StoredValue, error) { + pairs := make(map[string]*StoredValue) + err := s.db.View(func(txn *badger.Txn) error { + opts := badger.DefaultIteratorOptions + opts.PrefetchValues = true // We need the values for hashing + it := txn.NewIterator(opts) + defer it.Close() + + // Iterate over all actual data keys (not _ts: indexes) + for it.Rewind(); it.Valid(); it.Next() { + item := it.Item() + key := string(item.Key()) + + if strings.HasPrefix(key, "_ts:") { + continue // Skip index keys + } + + var storedValue StoredValue + err := item.Value(func(val []byte) error { + return json.Unmarshal(val, &storedValue) + }) + if err != nil { + s.logger.WithError(err).WithField("key", key).Warn("Failed to unmarshal stored value for Merkle tree, skipping") + continue + } + pairs[key] = &storedValue + } + return nil + }) + if err != nil { + return nil, err + } + return pairs, nil +} + +// buildMerkleTreeFromPairs constructs a Merkle Tree from the KVS data. +// This version uses a recursive approach to build a balanced tree from sorted keys. +func (s *Server) buildMerkleTreeFromPairs(pairs map[string]*StoredValue) (*MerkleNode, error) { + if len(pairs) == 0 { + return &MerkleNode{Hash: calculateHash([]byte("empty_tree")), StartKey: "", EndKey: ""}, nil + } + + // Sort keys to ensure consistent tree structure + keys := make([]string, 0, len(pairs)) + for k := range pairs { + keys = append(keys, k) + } + sort.Strings(keys) + + // Create leaf nodes + leafNodes := make([]*MerkleNode, len(keys)) + for i, key := range keys { + storedValue := pairs[key] + hash := s.calculateLeafHash(key, storedValue) + leafNodes[i] = &MerkleNode{Hash: hash, StartKey: key, EndKey: key} + } + + // Recursively build parent nodes + return s.buildMerkleTreeRecursive(leafNodes) +} + +// buildMerkleTreeRecursive builds the tree from a slice of nodes. +func (s *Server) buildMerkleTreeRecursive(nodes []*MerkleNode) (*MerkleNode, error) { + if len(nodes) == 0 { + return nil, nil + } + if len(nodes) == 1 { + return nodes[0], nil + } + + var nextLevel []*MerkleNode + for i := 0; i < len(nodes); i += 2 { + left := nodes[i] + var right *MerkleNode + if i+1 < len(nodes) { + right = nodes[i+1] + } + + var combinedHash []byte + var endKey string + + if right != nil { + combinedHash = calculateHash(append(left.Hash, right.Hash...)) + endKey = right.EndKey + } else { + // Odd number of nodes, promote the left node + combinedHash = left.Hash + endKey = left.EndKey + } + + parentNode := &MerkleNode{ + Hash: combinedHash, + StartKey: left.StartKey, + EndKey: endKey, + } + nextLevel = append(nextLevel, parentNode) + } + return s.buildMerkleTreeRecursive(nextLevel) +} + +// getMerkleRoot returns the current Merkle root of the server. +func (s *Server) getMerkleRoot() *MerkleNode { + s.merkleRootMu.RLock() + defer s.merkleRootMu.RUnlock() + return s.merkleRoot +} + +// setMerkleRoot sets the current Merkle root of the server. +func (s *Server) setMerkleRoot(root *MerkleNode) { + s.merkleRootMu.Lock() + defer s.merkleRootMu.Unlock() + s.merkleRoot = root +} + +// merkleTreeRebuildRoutine periodically rebuilds the Merkle tree. +func (s *Server) merkleTreeRebuildRoutine() { + defer s.wg.Done() + ticker := time.NewTicker(time.Duration(s.config.SyncInterval) * time.Second) // Use sync interval for now + defer ticker.Stop() + + for { + select { + case <-s.ctx.Done(): + return + case <-ticker.C: + s.logger.Debug("Rebuilding Merkle tree...") + pairs, err := s.getAllKVPairsForMerkleTree() + if err != nil { + s.logger.WithError(err).Error("Failed to get KV pairs for Merkle tree rebuild") + continue + } + newRoot, err := s.buildMerkleTreeFromPairs(pairs) + if err != nil { + s.logger.WithError(err).Error("Failed to rebuild Merkle tree") + continue + } + s.setMerkleRoot(newRoot) + s.logger.Debug("Merkle tree rebuilt.") + } + } +} + +// getMerkleRootHandler returns the root hash of the local Merkle Tree +func (s *Server) getMerkleRootHandler(w http.ResponseWriter, r *http.Request) { + root := s.getMerkleRoot() + if root == nil { + http.Error(w, "Merkle tree not initialized", http.StatusInternalServerError) return } - + + resp := MerkleRootResponse{ + Root: root, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +// getMerkleDiffHandler is used by a peer to request children hashes for a given node/range. +func (s *Server) getMerkleDiffHandler(w http.ResponseWriter, r *http.Request) { + var req MerkleTreeDiffRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + localPairs, err := s.getAllKVPairsForMerkleTree() + if err != nil { + s.logger.WithError(err).Error("Failed to get KV pairs for Merkle diff") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + // Build the local MerkleNode for the requested range to compare with the remote's hash + localSubTreeRoot, err := s.buildMerkleTreeFromPairs(s.filterPairsByRange(localPairs, req.ParentNode.StartKey, req.ParentNode.EndKey)) + if err != nil { + s.logger.WithError(err).Error("Failed to build sub-Merkle tree for diff request") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + if localSubTreeRoot == nil { // This can happen if the range is empty locally + localSubTreeRoot = &MerkleNode{Hash: calculateHash([]byte("empty_tree")), StartKey: req.ParentNode.StartKey, EndKey: req.ParentNode.EndKey} + } + + resp := MerkleTreeDiffResponse{} + + // If hashes match, no need to send children or keys + if bytes.Equal(req.LocalHash, localSubTreeRoot.Hash) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return + } + + // Hashes differ, so we need to provide more detail. + // Get all keys within the parent node's range locally + var keysInRange []string + for key := range s.filterPairsByRange(localPairs, req.ParentNode.StartKey, req.ParentNode.EndKey) { + keysInRange = append(keysInRange, key) + } + sort.Strings(keysInRange) + + const diffLeafThreshold = 10 // If a range has <= 10 keys, we consider it a leaf-level diff + + if len(keysInRange) <= diffLeafThreshold { + // This is a leaf-level diff, return the actual keys in the range + resp.Keys = keysInRange + } else { + // Group keys into sub-ranges and return their MerkleNode representations + // For simplicity, let's split the range into two halves. + mid := len(keysInRange) / 2 + + leftKeys := keysInRange[:mid] + rightKeys := keysInRange[mid:] + + if len(leftKeys) > 0 { + leftRangePairs := s.filterPairsByRange(localPairs, leftKeys[0], leftKeys[len(leftKeys)-1]) + leftNode, err := s.buildMerkleTreeFromPairs(leftRangePairs) + if err != nil { + s.logger.WithError(err).Error("Failed to build left child node for diff") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + if leftNode != nil { + resp.Children = append(resp.Children, *leftNode) + } + } + + if len(rightKeys) > 0 { + rightRangePairs := s.filterPairsByRange(localPairs, rightKeys[0], rightKeys[len(rightKeys)-1]) + rightNode, err := s.buildMerkleTreeFromPairs(rightRangePairs) + if err != nil { + s.logger.WithError(err).Error("Failed to build right child node for diff") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + if rightNode != nil { + resp.Children = append(resp.Children, *rightNode) + } + } + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +// Helper to filter a map of StoredValue by key range +func (s *Server) filterPairsByRange(allPairs map[string]*StoredValue, startKey, endKey string) map[string]*StoredValue { + filtered := make(map[string]*StoredValue) + for key, value := range allPairs { + if (startKey == "" || key >= startKey) && (endKey == "" || key <= endKey) { + filtered[key] = value + } + } + return filtered +} + +// getKVRangeHandler fetches a range of KV pairs +func (s *Server) getKVRangeHandler(w http.ResponseWriter, r *http.Request) { + var req KVRangeRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + var pairs []struct { + Path string `json:"path"` + StoredValue StoredValue `json:"stored_value"` + } + + err := s.db.View(func(txn *badger.Txn) error { + opts := badger.DefaultIteratorOptions + opts.PrefetchValues = true + it := txn.NewIterator(opts) + defer it.Close() + + count := 0 + // Start iteration from the requested StartKey + for it.Seek([]byte(req.StartKey)); it.Valid(); it.Next() { + item := it.Item() + key := string(item.Key()) + + if strings.HasPrefix(key, "_ts:") { + continue // Skip index keys + } + + // Stop if we exceed the EndKey (if provided) + if req.EndKey != "" && key > req.EndKey { + break + } + + // Stop if we hit the limit (if provided) + if req.Limit > 0 && count >= req.Limit { + break + } + + var storedValue StoredValue + err := item.Value(func(val []byte) error { + return json.Unmarshal(val, &storedValue) + }) + if err != nil { + s.logger.WithError(err).WithField("key", key).Warn("Failed to unmarshal stored value in KV range, skipping") + continue + } + + pairs = append(pairs, struct { + Path string `json:"path"` + StoredValue StoredValue `json:"stored_value"` + }{Path: key, StoredValue: storedValue}) + count++ + } + return nil + }) + + if err != nil { + s.logger.WithError(err).Error("Failed to query KV range") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(KVRangeResponse{Pairs: pairs}) +} + +// performMerkleSync performs a synchronization round using Merkle Trees +func (s *Server) performMerkleSync() { + members := s.getHealthyMembers() + if len(members) == 0 { + s.logger.Debug("No healthy members for Merkle sync") + return + } + // Select random peer peer := members[rand.Intn(len(members))] - - s.logger.WithField("peer", peer.Address).Info("Starting regular sync") - - // Request latest 15 UUIDs - req := PairsByTimeRequest{ - StartTimestamp: 0, - EndTimestamp: 0, // Current time - Limit: 15, + + s.logger.WithField("peer", peer.Address).Info("Starting Merkle tree sync") + + localRoot := s.getMerkleRoot() + if localRoot == nil { + s.logger.Error("Local Merkle root is nil, cannot perform sync") + return } - - remotePairs, err := s.requestPairsByTime(peer.Address, req) + + // 1. Get remote peer's Merkle root + remoteRootResp, err := s.requestMerkleRoot(peer.Address) if err != nil { - s.logger.WithError(err).WithField("peer", peer.Address).Error("Failed to sync with peer") + s.logger.WithError(err).WithField("peer", peer.Address).Error("Failed to get remote Merkle root") s.markPeerUnhealthy(peer.ID) return } - - // Compare with our local data and fetch missing/newer data - s.syncDataFromPairs(peer.Address, remotePairs) - - s.logger.WithField("peer", peer.Address).Info("Completed regular sync") + remoteRoot := remoteRootResp.Root + + // 2. Compare roots and start recursive diffing if they differ + if !bytes.Equal(localRoot.Hash, remoteRoot.Hash) { + s.logger.WithFields(logrus.Fields{ + "peer": peer.Address, + "local_root": hex.EncodeToString(localRoot.Hash), + "remote_root": hex.EncodeToString(remoteRoot.Hash), + }).Info("Merkle roots differ, starting recursive diff") + s.diffMerkleTreesRecursive(peer.Address, localRoot, remoteRoot) + } else { + s.logger.WithField("peer", peer.Address).Info("Merkle roots match, no sync needed") + } + + s.logger.WithField("peer", peer.Address).Info("Completed Merkle tree sync") } -// Request pairs by time from a peer -func (s *Server) requestPairsByTime(peerAddress string, req PairsByTimeRequest) ([]PairsByTimeResponse, error) { +// requestMerkleRoot requests the Merkle root from a peer +func (s *Server) requestMerkleRoot(peerAddress string) (*MerkleRootResponse, error) { + client := &http.Client{Timeout: 10 * time.Second} + url := fmt.Sprintf("http://%s/merkle_tree/root", peerAddress) + + resp, err := client.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("peer returned status %d for Merkle root", resp.StatusCode) + } + + var merkleRootResp MerkleRootResponse + if err := json.NewDecoder(resp.Body).Decode(&merkleRootResp); err != nil { + return nil, err + } + return &merkleRootResp, nil +} + +// diffMerkleTreesRecursive recursively compares local and remote Merkle tree nodes +func (s *Server) diffMerkleTreesRecursive(peerAddress string, localNode, remoteNode *MerkleNode) { + // If hashes match, this subtree is in sync. + if bytes.Equal(localNode.Hash, remoteNode.Hash) { + return + } + + // Hashes differ, need to go deeper. + // Request children from the remote peer for the current range. + req := MerkleTreeDiffRequest{ + ParentNode: *remoteNode, // We are asking the remote peer about its children for this range + LocalHash: localNode.Hash, // Our hash for this range + } + + remoteDiffResp, err := s.requestMerkleDiff(peerAddress, req) + if err != nil { + s.logger.WithError(err).WithFields(logrus.Fields{ + "peer": peerAddress, + "start_key": localNode.StartKey, + "end_key": localNode.EndKey, + }).Error("Failed to get Merkle diff from peer") + return + } + + if len(remoteDiffResp.Keys) > 0 { + // This is a leaf-level diff, we have the actual keys that are different. + // Fetch and compare each key. + s.logger.WithFields(logrus.Fields{ + "peer": peerAddress, + "start_key": localNode.StartKey, + "end_key": localNode.EndKey, + "num_keys": len(remoteDiffResp.Keys), + }).Info("Found divergent keys, fetching and comparing data") + + for _, key := range remoteDiffResp.Keys { + // Fetch the individual key from the peer + remoteStoredValue, err := s.fetchSingleKVFromPeer(peerAddress, key) + if err != nil { + s.logger.WithError(err).WithFields(logrus.Fields{ + "peer": peerAddress, + "key": key, + }).Error("Failed to fetch single KV from peer during diff") + continue + } + + localStoredValue, localExists := s.getLocalData(key) + + if remoteStoredValue == nil { + // Key was deleted on remote, delete locally if exists + if localExists { + s.logger.WithField("key", key).Info("Key deleted on remote, deleting locally") + s.deleteKVLocally(key, localStoredValue.Timestamp) // Need a local delete function + } + continue + } + + if !localExists { + // Local data is missing, store the remote data + if err := s.storeReplicatedDataWithMetadata(key, remoteStoredValue); err != nil { + s.logger.WithError(err).WithField("key", key).Error("Failed to store missing replicated data") + } else { + s.logger.WithField("key", key).Info("Fetched and stored missing data from peer") + } + } else if localStoredValue.Timestamp < remoteStoredValue.Timestamp { + // Remote is newer, store the remote data + if err := s.storeReplicatedDataWithMetadata(key, remoteStoredValue); err != nil { + s.logger.WithError(err).WithField("key", key).Error("Failed to store newer replicated data") + } else { + s.logger.WithField("key", key).Info("Fetched and stored newer data from peer") + } + } else if localStoredValue.Timestamp == remoteStoredValue.Timestamp && localStoredValue.UUID != remoteStoredValue.UUID { + // Timestamp collision, engage conflict resolution + remotePair := PairsByTimeResponse{ // Re-use this struct for conflict resolution + Path: key, + UUID: remoteStoredValue.UUID, + Timestamp: remoteStoredValue.Timestamp, + } + resolved, err := s.resolveConflict(key, localStoredValue, &remotePair, peerAddress) + if err != nil { + s.logger.WithError(err).WithField("path", key).Error("Failed to resolve conflict during Merkle sync") + } else if resolved { + s.logger.WithField("path", key).Info("Conflict resolved, updated local data during Merkle sync") + } else { + s.logger.WithField("path", key).Info("Conflict resolved, kept local data during Merkle sync") + } + } + // If local is newer or same timestamp and same UUID, do nothing. + } + } else if len(remoteDiffResp.Children) > 0 { + // Not a leaf level, continue recursive diff for children. + localPairs, err := s.getAllKVPairsForMerkleTree() + if err != nil { + s.logger.WithError(err).Error("Failed to get KV pairs for local children comparison") + return + } + + for _, remoteChild := range remoteDiffResp.Children { + // Build the local Merkle node for this child's range + localChildNode, err := s.buildMerkleTreeFromPairs(s.filterPairsByRange(localPairs, remoteChild.StartKey, remoteChild.EndKey)) + if err != nil { + s.logger.WithError(err).WithFields(logrus.Fields{ + "start_key": remoteChild.StartKey, + "end_key": remoteChild.EndKey, + }).Error("Failed to build local child node for diff") + continue + } + + if localChildNode == nil || !bytes.Equal(localChildNode.Hash, remoteChild.Hash) { + // If local child node is nil (meaning local has no data in this range) + // or hashes differ, then we need to fetch the data. + if localChildNode == nil { + s.logger.WithFields(logrus.Fields{ + "peer": peerAddress, + "start_key": remoteChild.StartKey, + "end_key": remoteChild.EndKey, + }).Info("Local node missing data in remote child's range, fetching full range") + s.fetchAndStoreRange(peerAddress, remoteChild.StartKey, remoteChild.EndKey) + } else { + s.diffMerkleTreesRecursive(peerAddress, localChildNode, &remoteChild) + } + } + } + } +} + +// requestMerkleDiff requests children hashes or keys for a given node/range from a peer +func (s *Server) requestMerkleDiff(peerAddress string, req MerkleTreeDiffRequest) (*MerkleTreeDiffResponse, error) { jsonData, err := json.Marshal(req) if err != nil { return nil, err } - + client := &http.Client{Timeout: 10 * time.Second} - url := fmt.Sprintf("http://%s/members/pairs_by_time", peerAddress) - + url := fmt.Sprintf("http://%s/merkle_tree/diff", peerAddress) + resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData)) if err != nil { return nil, err } defer resp.Body.Close() - - if resp.StatusCode == http.StatusNoContent { - return []PairsByTimeResponse{}, nil - } - + if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("peer returned status %d", resp.StatusCode) + return nil, fmt.Errorf("peer returned status %d for Merkle diff", resp.StatusCode) } - - var pairs []PairsByTimeResponse - if err := json.NewDecoder(resp.Body).Decode(&pairs); err != nil { + + var diffResp MerkleTreeDiffResponse + if err := json.NewDecoder(resp.Body).Decode(&diffResp); err != nil { return nil, err } - - return pairs, nil + return &diffResp, nil } -// Sync data from pairs - fetch missing or newer data -func (s *Server) syncDataFromPairs(peerAddress string, remotePairs []PairsByTimeResponse) { - for _, remotePair := range remotePairs { - // Check our local version - localData, localExists := s.getLocalData(remotePair.Path) - - shouldFetch := false - if !localExists { - shouldFetch = true - s.logger.WithField("path", remotePair.Path).Debug("Missing local data, will fetch") - } else if localData.Timestamp < remotePair.Timestamp { - shouldFetch = true - s.logger.WithFields(logrus.Fields{ - "path": remotePair.Path, - "local_timestamp": localData.Timestamp, - "remote_timestamp": remotePair.Timestamp, - }).Debug("Local data is older, will fetch") - } else if localData.Timestamp == remotePair.Timestamp && localData.UUID != remotePair.UUID { - // Timestamp collision - need conflict resolution - s.logger.WithFields(logrus.Fields{ - "path": remotePair.Path, - "timestamp": remotePair.Timestamp, - "local_uuid": localData.UUID, - "remote_uuid": remotePair.UUID, - }).Warn("Timestamp collision detected, starting conflict resolution") - - resolved, err := s.resolveConflict(remotePair.Path, localData, &remotePair, peerAddress) - if err != nil { - s.logger.WithError(err).WithField("path", remotePair.Path).Error("Failed to resolve conflict") - continue - } - - if resolved { - s.logger.WithField("path", remotePair.Path).Info("Conflict resolved, updated local data") - } else { - s.logger.WithField("path", remotePair.Path).Info("Conflict resolved, keeping local data") - } - continue - } - - if shouldFetch { - if err := s.fetchAndStoreData(peerAddress, remotePair.Path); err != nil { - s.logger.WithError(err).WithFields(logrus.Fields{ - "peer": peerAddress, - "path": remotePair.Path, - }).Error("Failed to fetch data from peer") - } - } - } -} - -// Get local data for a path -func (s *Server) getLocalData(path string) (*StoredValue, bool) { - var storedValue StoredValue - err := s.db.View(func(txn *badger.Txn) error { - item, err := txn.Get([]byte(path)) - if err != nil { - return err - } - - return item.Value(func(val []byte) error { - return json.Unmarshal(val, &storedValue) - }) - }) - - if err != nil { - return nil, false - } - - return &storedValue, true -} - -// Fetch and store data from peer -func (s *Server) fetchAndStoreData(peerAddress, path string) error { +// fetchSingleKVFromPeer fetches a single KV pair from a peer +func (s *Server) fetchSingleKVFromPeer(peerAddress, path string) (*StoredValue, error) { client := &http.Client{Timeout: 5 * time.Second} url := fmt.Sprintf("http://%s/kv/%s", peerAddress, path) - + resp, err := client.Get(url) if err != nil { - return err + return nil, err } defer resp.Body.Close() - + + if resp.StatusCode == http.StatusNotFound { + return nil, nil // Key might have been deleted on the peer + } if resp.StatusCode != http.StatusOK { - return fmt.Errorf("peer returned status %d for path %s", resp.StatusCode, path) + return nil, fmt.Errorf("peer returned status %d for path %s", resp.StatusCode, path) } - - var data json.RawMessage - if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { - return err + + var storedValue StoredValue + if err := json.NewDecoder(resp.Body).Decode(&storedValue); err != nil { + return nil, fmt.Errorf("failed to decode StoredValue from peer: %v", err) } - - // Store the data using our internal storage mechanism - return s.storeReplicatedData(path, data) + return &storedValue, nil } -// Store replicated data (internal storage without timestamp/UUID generation) -func (s *Server) storeReplicatedData(path string, data json.RawMessage) error { - // For now, we'll generate new timestamp/UUID - in full implementation, - // we'd need to preserve the original metadata from the source - now := time.Now().UnixMilli() - newUUID := uuid.New().String() - - storedValue := StoredValue{ - UUID: newUUID, - Timestamp: now, - Data: data, - } - +// storeReplicatedDataWithMetadata stores replicated data preserving its original metadata +func (s *Server) storeReplicatedDataWithMetadata(path string, storedValue *StoredValue) error { valueBytes, err := json.Marshal(storedValue) if err != nil { return err } - + return s.db.Update(func(txn *badger.Txn) error { // Store main data if err := txn.Set([]byte(path), valueBytes); err != nil { return err } - + // Store timestamp index - indexKey := fmt.Sprintf("_ts:%020d:%s", now, path) - return txn.Set([]byte(indexKey), []byte(newUUID)) + indexKey := fmt.Sprintf("_ts:%020d:%s", storedValue.Timestamp, path) + return txn.Set([]byte(indexKey), []byte(storedValue.UUID)) }) } +// deleteKVLocally deletes a key-value pair and its associated timestamp index locally. +func (s *Server) deleteKVLocally(path string, timestamp int64) error { + return s.db.Update(func(txn *badger.Txn) error { + if err := txn.Delete([]byte(path)); err != nil { + return err + } + indexKey := fmt.Sprintf("_ts:%020d:%s", timestamp, path) + return txn.Delete([]byte(indexKey)) + }) +} + +// fetchAndStoreRange fetches a range of KV pairs from a peer and stores them locally +func (s *Server) fetchAndStoreRange(peerAddress string, startKey, endKey string) error { + req := KVRangeRequest{ + StartKey: startKey, + EndKey: endKey, + Limit: 0, // No limit + } + jsonData, err := json.Marshal(req) + if err != nil { + return err + } + + client := &http.Client{Timeout: 30 * time.Second} // Longer timeout for range fetches + url := fmt.Sprintf("http://%s/kv_range", peerAddress) + + resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("peer returned status %d for KV range fetch", resp.StatusCode) + } + + var rangeResp KVRangeResponse + if err := json.NewDecoder(resp.Body).Decode(&rangeResp); err != nil { + return err + } + + for _, pair := range rangeResp.Pairs { + // Use storeReplicatedDataWithMetadata to preserve original UUID/Timestamp + if err := s.storeReplicatedDataWithMetadata(pair.Path, &pair.StoredValue); err != nil { + s.logger.WithError(err).WithFields(logrus.Fields{ + "peer": peerAddress, + "path": pair.Path, + }).Error("Failed to store fetched range data") + } else { + s.logger.WithFields(logrus.Fields{ + "peer": peerAddress, + "path": pair.Path, + }).Debug("Stored data from fetched range") + } + } + return nil +} + // Bootstrap - join cluster using seed nodes func (s *Server) bootstrap() { if len(s.config.SeedNodes) == 0 { s.logger.Info("No seed nodes configured, running as standalone") return } - + s.logger.Info("Starting bootstrap process") s.setMode("syncing") - - // Try to join via each seed node + + // Try to join cluster via each seed node joined := false for _, seedAddr := range s.config.SeedNodes { if s.attemptJoin(seedAddr) { @@ -1136,19 +1678,19 @@ func (s *Server) bootstrap() { break } } - + if !joined { s.logger.Warn("Failed to join cluster via seed nodes, running as standalone") s.setMode("normal") return } - + // Wait a bit for member discovery time.Sleep(2 * time.Second) - - // Perform gradual sync + + // Perform gradual sync (now Merkle-based) s.performGradualSync() - + // Switch to normal mode s.setMode("normal") s.logger.Info("Bootstrap completed, entering normal mode") @@ -1161,26 +1703,26 @@ func (s *Server) attemptJoin(seedAddr string) bool { Address: fmt.Sprintf("%s:%d", s.config.BindAddress, s.config.Port), JoinedTimestamp: time.Now().UnixMilli(), } - + jsonData, err := json.Marshal(joinReq) if err != nil { s.logger.WithError(err).Error("Failed to marshal join request") return false } - + client := &http.Client{Timeout: 10 * time.Second} url := fmt.Sprintf("http://%s/members/join", seedAddr) - + resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData)) if err != nil { s.logger.WithFields(logrus.Fields{ - "seed": seedAddr, + "seed": seedAddr, "error": err.Error(), }).Warn("Failed to contact seed node") return false } defer resp.Body.Close() - + if resp.StatusCode != http.StatusOK { s.logger.WithFields(logrus.Fields{ "seed": seedAddr, @@ -1188,57 +1730,57 @@ func (s *Server) attemptJoin(seedAddr string) bool { }).Warn("Seed node rejected join request") return false } - + // Process member list response var memberList []Member if err := json.NewDecoder(resp.Body).Decode(&memberList); err != nil { s.logger.WithError(err).Error("Failed to decode member list from seed") return false } - + // Add all members to our local list for _, member := range memberList { if member.ID != s.config.NodeID { s.addMember(&member) } } - + s.logger.WithFields(logrus.Fields{ "seed": seedAddr, "member_count": len(memberList), }).Info("Successfully joined cluster") - + return true } -// Perform gradual sync (simplified version) +// Perform gradual sync (Merkle-based version) func (s *Server) performGradualSync() { - s.logger.Info("Starting gradual sync") - + s.logger.Info("Starting gradual sync (Merkle-based)") + members := s.getHealthyMembers() if len(members) == 0 { s.logger.Info("No healthy members for gradual sync") return } - - // For now, just do a few rounds of regular sync + + // For now, just do a few rounds of Merkle sync for i := 0; i < 3; i++ { - s.performRegularSync() + s.performMerkleSync() // Use Merkle sync time.Sleep(time.Duration(s.config.ThrottleDelayMs) * time.Millisecond) } - + s.logger.Info("Gradual sync completed") } // Resolve conflict between local and remote data using majority vote and oldest node tie-breaker func (s *Server) resolveConflict(path string, localData *StoredValue, remotePair *PairsByTimeResponse, peerAddress string) (bool, error) { s.logger.WithFields(logrus.Fields{ - "path": path, - "timestamp": localData.Timestamp, - "local_uuid": localData.UUID, + "path": path, + "timestamp": localData.Timestamp, + "local_uuid": localData.UUID, "remote_uuid": remotePair.UUID, }).Info("Starting conflict resolution with majority vote") - + // Get list of healthy members for voting members := s.getHealthyMembers() if len(members) == 0 { @@ -1246,36 +1788,50 @@ func (s *Server) resolveConflict(path string, localData *StoredValue, remotePair // We'll consider the peer as the "remote" node for comparison return s.resolveByOldestNode(localData, remotePair, peerAddress) } - + // Query all healthy members for their version of this path votes := make(map[string]int) // UUID -> vote count uuidToTimestamp := make(map[string]int64) uuidToJoinedTime := make(map[string]int64) - + // Add our local vote votes[localData.UUID] = 1 uuidToTimestamp[localData.UUID] = localData.Timestamp uuidToJoinedTime[localData.UUID] = s.getJoinedTimestamp() - + // Add the remote peer's vote - votes[remotePair.UUID] = 1 + // Note: remotePair.Timestamp is used here, but for a full Merkle sync, + // we would have already fetched the full remoteStoredValue. + // For consistency, let's assume remotePair accurately reflects the remote's StoredValue metadata. + votes[remotePair.UUID]++ // Increment vote, as it's already counted implicitly by being the source of divergence uuidToTimestamp[remotePair.UUID] = remotePair.Timestamp // We'll need to get the peer's joined timestamp - + s.membersMu.RLock() + for _, member := range s.members { + if member.Address == peerAddress { + uuidToJoinedTime[remotePair.UUID] = member.JoinedTimestamp + break + } + } + s.membersMu.RUnlock() + // Query other members for _, member := range members { if member.Address == peerAddress { - // We already counted this peer - uuidToJoinedTime[remotePair.UUID] = member.JoinedTimestamp + continue // Already handled the peer + } + + memberData, err := s.fetchSingleKVFromPeer(member.Address, path) // Use the new fetchSingleKVFromPeer + if err != nil || memberData == nil { + s.logger.WithFields(logrus.Fields{ + "member_address": member.Address, + "path": path, + "error": err, + }).Debug("Failed to query member for data during conflict resolution, skipping vote") continue } - - memberData, exists := s.queryMemberForData(member.Address, path) - if !exists { - continue // Member doesn't have this data - } - - // Only count votes for data with the same timestamp + + // Only count votes for data with the same timestamp (as per original logic for collision) if memberData.Timestamp == localData.Timestamp { votes[memberData.UUID]++ if _, exists := uuidToTimestamp[memberData.UUID]; !exists { @@ -1284,11 +1840,11 @@ func (s *Server) resolveConflict(path string, localData *StoredValue, remotePair } } } - + // Find the UUID with majority votes maxVotes := 0 var winningUUIDs []string - + for uuid, voteCount := range votes { if voteCount > maxVotes { maxVotes = voteCount @@ -1297,7 +1853,7 @@ func (s *Server) resolveConflict(path string, localData *StoredValue, remotePair winningUUIDs = append(winningUUIDs, uuid) } } - + var winnerUUID string if len(winningUUIDs) == 1 { winnerUUID = winningUUIDs[0] @@ -1306,37 +1862,44 @@ func (s *Server) resolveConflict(path string, localData *StoredValue, remotePair oldestJoinedTime := int64(0) for _, uuid := range winningUUIDs { joinedTime := uuidToJoinedTime[uuid] - if oldestJoinedTime == 0 || joinedTime < oldestJoinedTime { + if oldestJoinedTime == 0 || (joinedTime != 0 && joinedTime < oldestJoinedTime) { // joinedTime can be 0 if not found oldestJoinedTime = joinedTime winnerUUID = uuid } } - + s.logger.WithFields(logrus.Fields{ - "path": path, - "tied_votes": maxVotes, - "winner_uuid": winnerUUID, - "oldest_joined": oldestJoinedTime, + "path": path, + "tied_votes": maxVotes, + "winner_uuid": winnerUUID, + "oldest_joined": oldestJoinedTime, }).Info("Resolved conflict using oldest node tie-breaker") } - + // If remote UUID wins, fetch and store the remote data if winnerUUID == remotePair.UUID { - err := s.fetchAndStoreData(peerAddress, path) - if err != nil { - return false, fmt.Errorf("failed to fetch winning data: %v", err) + // We need the full StoredValue for the winning remote data. + // Since remotePair only has UUID/Timestamp, we must fetch the data. + winningRemoteStoredValue, err := s.fetchSingleKVFromPeer(peerAddress, path) + if err != nil || winningRemoteStoredValue == nil { + return false, fmt.Errorf("failed to fetch winning remote data for conflict resolution: %v", err) } - + + err = s.storeReplicatedDataWithMetadata(path, winningRemoteStoredValue) + if err != nil { + return false, fmt.Errorf("failed to store winning data: %v", err) + } + s.logger.WithFields(logrus.Fields{ "path": path, "winner_uuid": winnerUUID, "winner_votes": maxVotes, "total_nodes": len(members) + 2, // +2 for local and peer }).Info("Conflict resolved: remote data wins") - + return true, nil } - + // Local data wins, no action needed s.logger.WithFields(logrus.Fields{ "path": path, @@ -1344,7 +1907,7 @@ func (s *Server) resolveConflict(path string, localData *StoredValue, remotePair "winner_votes": maxVotes, "total_nodes": len(members) + 2, }).Info("Conflict resolved: local data wins") - + return false, nil } @@ -1360,97 +1923,93 @@ func (s *Server) resolveByOldestNode(localData *StoredValue, remotePair *PairsBy } } s.membersMu.RUnlock() - + localJoinedTime := s.getJoinedTimestamp() - + // Oldest node wins if peerJoinedTime > 0 && peerJoinedTime < localJoinedTime { // Peer is older, fetch remote data - err := s.fetchAndStoreData(peerAddress, remotePair.Path) - if err != nil { - return false, fmt.Errorf("failed to fetch data from older node: %v", err) + winningRemoteStoredValue, err := s.fetchSingleKVFromPeer(peerAddress, remotePair.Path) + if err != nil || winningRemoteStoredValue == nil { + return false, fmt.Errorf("failed to fetch data from older node for conflict resolution: %v", err) } - + + err = s.storeReplicatedDataWithMetadata(remotePair.Path, winningRemoteStoredValue) + if err != nil { + return false, fmt.Errorf("failed to store data from older node: %v", err) + } + s.logger.WithFields(logrus.Fields{ - "path": remotePair.Path, - "local_joined": localJoinedTime, - "peer_joined": peerJoinedTime, - "winner": "remote", + "path": remotePair.Path, + "local_joined": localJoinedTime, + "peer_joined": peerJoinedTime, + "winner": "remote", }).Info("Conflict resolved using oldest node rule") - + return true, nil } - + // Local node is older or equal, keep local data s.logger.WithFields(logrus.Fields{ "path": remotePair.Path, "local_joined": localJoinedTime, "peer_joined": peerJoinedTime, - "winner": "local", + "winner": "local", }).Info("Conflict resolved using oldest node rule") - + return false, nil } -// Query a member for their version of specific data -func (s *Server) queryMemberForData(memberAddress, path string) (*StoredValue, bool) { - client := &http.Client{Timeout: 5 * time.Second} - url := fmt.Sprintf("http://%s/kv/%s", memberAddress, path) - - resp, err := client.Get(url) +// getLocalData is a utility to retrieve a StoredValue from local DB. +func (s *Server) getLocalData(path string) (*StoredValue, bool) { + var storedValue StoredValue + err := s.db.View(func(txn *badger.Txn) error { + item, err := txn.Get([]byte(path)) + if err != nil { + return err + } + + return item.Value(func(val []byte) error { + return json.Unmarshal(val, &storedValue) + }) + }) + if err != nil { return nil, false } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, false - } - - var data json.RawMessage - if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { - return nil, false - } - - // We need to get the metadata too - this is a simplified approach - // In a full implementation, we'd have a separate endpoint for metadata queries - localData, exists := s.getLocalData(path) - if exists { - return localData, true - } - - return nil, false + + return &storedValue, true } func main() { configPath := "./config.yaml" - + // Simple CLI argument parsing if len(os.Args) > 1 { configPath = os.Args[1] } - + config, err := loadConfig(configPath) if err != nil { fmt.Fprintf(os.Stderr, "Failed to load configuration: %v\n", err) os.Exit(1) } - + server, err := NewServer(config) if err != nil { fmt.Fprintf(os.Stderr, "Failed to create server: %v\n", err) os.Exit(1) } - + // Handle graceful shutdown sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - + go func() { <-sigCh server.Stop() }() - + if err := server.Start(); err != nil && err != http.ErrServerClosed { fmt.Fprintf(os.Stderr, "Server error: %v\n", err) os.Exit(1)