package main import ( "bytes" "context" "crypto/sha256" // Added for Merkle Tree hashing "encoding/hex" // Added for logging hashes "encoding/json" "fmt" "math/rand" "net" "net/http" "os" "os/signal" "path/filepath" "sort" // Added for sorting keys in Merkle Tree "strconv" "strings" "sync" "syscall" "time" badger "github.com/dgraph-io/badger/v4" "github.com/google/uuid" "github.com/gorilla/mux" "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" ) // Core data structures type StoredValue struct { UUID string `json:"uuid"` Timestamp int64 `json:"timestamp"` Data json.RawMessage `json:"data"` } type Member struct { ID string `json:"id"` Address string `json:"address"` LastSeen int64 `json:"last_seen"` JoinedTimestamp int64 `json:"joined_timestamp"` } type JoinRequest struct { ID string `json:"id"` Address string `json:"address"` JoinedTimestamp int64 `json:"joined_timestamp"` } type LeaveRequest struct { ID string `json:"id"` } type PairsByTimeRequest struct { StartTimestamp int64 `json:"start_timestamp"` EndTimestamp int64 `json:"end_timestamp"` Limit int `json:"limit"` Prefix string `json:"prefix,omitempty"` } type PairsByTimeResponse struct { Path string `json:"path"` UUID string `json:"uuid"` Timestamp int64 `json:"timestamp"` } type PutResponse struct { UUID string `json:"uuid"` Timestamp int64 `json:"timestamp"` } // Merkle Tree specific data structures type MerkleNode struct { Hash []byte `json:"hash"` StartKey string `json:"start_key"` // The first key in this node's range EndKey string `json:"end_key"` // The last key in this node's range } // MerkleRootResponse is the response for getting the root hash type MerkleRootResponse struct { Root *MerkleNode `json:"root"` } // MerkleTreeDiffRequest is used to request children hashes for a given key range type MerkleTreeDiffRequest struct { ParentNode MerkleNode `json:"parent_node"` // The node whose children we want to compare (from the remote peer's perspective) LocalHash []byte `json:"local_hash"` // The local hash of this node/range (from the requesting peer's perspective) } // MerkleTreeDiffResponse returns the remote children nodes or the actual keys if it's a leaf level type MerkleTreeDiffResponse struct { Children []MerkleNode `json:"children,omitempty"` // Children of the remote node Keys []string `json:"keys,omitempty"` // Actual keys if this is a leaf-level diff } // For fetching a range of KV pairs type KVRangeRequest struct { StartKey string `json:"start_key"` EndKey string `json:"end_key"` Limit int `json:"limit"` // Max number of items to return } type KVRangeResponse struct { Pairs []struct { Path string `json:"path"` StoredValue StoredValue `json:"stored_value"` } `json:"pairs"` } // Configuration type Config struct { NodeID string `yaml:"node_id"` BindAddress string `yaml:"bind_address"` Port int `yaml:"port"` DataDir string `yaml:"data_dir"` SeedNodes []string `yaml:"seed_nodes"` ReadOnly bool `yaml:"read_only"` LogLevel string `yaml:"log_level"` GossipIntervalMin int `yaml:"gossip_interval_min"` GossipIntervalMax int `yaml:"gossip_interval_max"` SyncInterval int `yaml:"sync_interval"` CatchupInterval int `yaml:"catchup_interval"` BootstrapMaxAgeHours int `yaml:"bootstrap_max_age_hours"` ThrottleDelayMs int `yaml:"throttle_delay_ms"` FetchDelayMs int `yaml:"fetch_delay_ms"` } // Server represents the KVS node type Server struct { config *Config db *badger.DB members map[string]*Member membersMu sync.RWMutex mode string // "normal", "read-only", "syncing" modeMu sync.RWMutex logger *logrus.Logger httpServer *http.Server ctx context.Context cancel context.CancelFunc wg sync.WaitGroup merkleRoot *MerkleNode // Added for Merkle Tree merkleRootMu sync.RWMutex // Protects merkleRoot } // Default configuration func defaultConfig() *Config { hostname, _ := os.Hostname() return &Config{ NodeID: hostname, BindAddress: "127.0.0.1", Port: 8080, DataDir: "./data", SeedNodes: []string{}, ReadOnly: false, LogLevel: "info", GossipIntervalMin: 60, // 1 minute GossipIntervalMax: 120, // 2 minutes SyncInterval: 300, // 5 minutes CatchupInterval: 120, // 2 minutes BootstrapMaxAgeHours: 720, // 30 days ThrottleDelayMs: 100, FetchDelayMs: 50, } } // Load configuration from file or create default func loadConfig(configPath string) (*Config, error) { config := defaultConfig() if _, err := os.Stat(configPath); os.IsNotExist(err) { // Create default config file if err := os.MkdirAll(filepath.Dir(configPath), 0755); err != nil { return nil, fmt.Errorf("failed to create config directory: %v", err) } data, err := yaml.Marshal(config) if err != nil { return nil, fmt.Errorf("failed to marshal default config: %v", err) } if err := os.WriteFile(configPath, data, 0644); err != nil { return nil, fmt.Errorf("failed to write default config: %v", err) } fmt.Printf("Created default configuration at %s\n", configPath) return config, nil } data, err := os.ReadFile(configPath) if err != nil { return nil, fmt.Errorf("failed to read config file: %v", err) } if err := yaml.Unmarshal(data, config); err != nil { return nil, fmt.Errorf("failed to parse config file: %v", err) } return config, nil } // Initialize server func NewServer(config *Config) (*Server, error) { logger := logrus.New() logger.SetFormatter(&logrus.JSONFormatter{}) level, err := logrus.ParseLevel(config.LogLevel) if err != nil { level = logrus.InfoLevel } logger.SetLevel(level) // Create data directory if err := os.MkdirAll(config.DataDir, 0755); err != nil { return nil, fmt.Errorf("failed to create data directory: %v", err) } // Open BadgerDB opts := badger.DefaultOptions(filepath.Join(config.DataDir, "badger")) opts.Logger = nil // Disable badger's internal logging db, err := badger.Open(opts) if err != nil { return nil, fmt.Errorf("failed to open BadgerDB: %v", err) } ctx, cancel := context.WithCancel(context.Background()) server := &Server{ config: config, db: db, members: make(map[string]*Member), mode: "normal", logger: logger, ctx: ctx, cancel: cancel, } if config.ReadOnly { server.setMode("read-only") } // Build initial Merkle tree pairs, err := server.getAllKVPairsForMerkleTree() if err != nil { return nil, fmt.Errorf("failed to get all KV pairs for initial Merkle tree: %v", err) } root, err := server.buildMerkleTreeFromPairs(pairs) if err != nil { return nil, fmt.Errorf("failed to build initial Merkle tree: %v", err) } server.setMerkleRoot(root) server.logger.Info("Initial Merkle tree built.") return server, nil } // Mode management func (s *Server) getMode() string { s.modeMu.RLock() defer s.modeMu.RUnlock() return s.mode } func (s *Server) setMode(mode string) { s.modeMu.Lock() defer s.modeMu.Unlock() oldMode := s.mode s.mode = mode s.logger.WithFields(logrus.Fields{ "old_mode": oldMode, "new_mode": mode, }).Info("Mode changed") } // Member management func (s *Server) addMember(member *Member) { s.membersMu.Lock() defer s.membersMu.Unlock() s.members[member.ID] = member s.logger.WithFields(logrus.Fields{ "node_id": member.ID, "address": member.Address, }).Info("Member added") } func (s *Server) removeMember(nodeID string) { s.membersMu.Lock() defer s.membersMu.Unlock() if member, exists := s.members[nodeID]; exists { delete(s.members, nodeID) s.logger.WithFields(logrus.Fields{ "node_id": member.ID, "address": member.Address, }).Info("Member removed") } } func (s *Server) getMembers() []*Member { s.membersMu.RLock() defer s.membersMu.RUnlock() members := make([]*Member, 0, len(s.members)) for _, member := range s.members { members = append(members, member) } return members } // HTTP Handlers func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) { mode := s.getMode() memberCount := len(s.getMembers()) health := map[string]interface{}{ "status": "ok", "mode": mode, "member_count": memberCount, "node_id": s.config.NodeID, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(health) } func (s *Server) getKVHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) path := vars["path"] var storedValue StoredValue err := s.db.View(func(txn *badger.Txn) error { item, err := txn.Get([]byte(path)) if err != nil { return err } return item.Value(func(val []byte) error { return json.Unmarshal(val, &storedValue) }) }) if err == badger.ErrKeyNotFound { http.Error(w, "Not Found", http.StatusNotFound) return } if err != nil { s.logger.WithError(err).WithField("path", path).Error("Failed to get value") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") // CHANGE: Return the entire StoredValue, not just Data json.NewEncoder(w).Encode(storedValue) } func (s *Server) putKVHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) path := vars["path"] mode := s.getMode() if mode == "syncing" { http.Error(w, "Service Unavailable", http.StatusServiceUnavailable) return } if mode == "read-only" && !s.isClusterMember(r.RemoteAddr) { http.Error(w, "Forbidden", http.StatusForbidden) return } var data json.RawMessage if err := json.NewDecoder(r.Body).Decode(&data); err != nil { http.Error(w, "Bad Request", http.StatusBadRequest) return } now := time.Now().UnixMilli() newUUID := uuid.New().String() storedValue := StoredValue{ UUID: newUUID, Timestamp: now, Data: data, } valueBytes, err := json.Marshal(storedValue) if err != nil { s.logger.WithError(err).Error("Failed to marshal stored value") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } var isUpdate bool err = s.db.Update(func(txn *badger.Txn) error { // Check if key exists _, err := txn.Get([]byte(path)) isUpdate = (err == nil) // Store main data if err := txn.Set([]byte(path), valueBytes); err != nil { return err } // Store timestamp index indexKey := fmt.Sprintf("_ts:%020d:%s", now, path) return txn.Set([]byte(indexKey), []byte(newUUID)) }) if err != nil { s.logger.WithError(err).WithField("path", path).Error("Failed to store value") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } response := PutResponse{ UUID: newUUID, Timestamp: now, } status := http.StatusCreated if isUpdate { status = http.StatusOK } w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(response) s.logger.WithFields(logrus.Fields{ "path": path, "uuid": newUUID, "timestamp": now, "is_update": isUpdate, }).Info("Value stored") } func (s *Server) deleteKVHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) path := vars["path"] mode := s.getMode() if mode == "syncing" { http.Error(w, "Service Unavailable", http.StatusServiceUnavailable) return } if mode == "read-only" && !s.isClusterMember(r.RemoteAddr) { http.Error(w, "Forbidden", http.StatusForbidden) return } var found bool err := s.db.Update(func(txn *badger.Txn) error { // Check if key exists and get timestamp for index cleanup item, err := txn.Get([]byte(path)) if err == badger.ErrKeyNotFound { return nil } if err != nil { return err } found = true var storedValue StoredValue err = item.Value(func(val []byte) error { return json.Unmarshal(val, &storedValue) }) if err != nil { return err } // Delete main data if err := txn.Delete([]byte(path)); err != nil { return err } // Delete timestamp index indexKey := fmt.Sprintf("_ts:%020d:%s", storedValue.Timestamp, path) return txn.Delete([]byte(indexKey)) }) if err != nil { s.logger.WithError(err).WithField("path", path).Error("Failed to delete value") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } if !found { http.Error(w, "Not Found", http.StatusNotFound) return } w.WriteHeader(http.StatusNoContent) s.logger.WithField("path", path).Info("Value deleted") } func (s *Server) getMembersHandler(w http.ResponseWriter, r *http.Request) { members := s.getMembers() w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(members) } func (s *Server) joinMemberHandler(w http.ResponseWriter, r *http.Request) { var req JoinRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad Request", http.StatusBadRequest) return } now := time.Now().UnixMilli() member := &Member{ ID: req.ID, Address: req.Address, LastSeen: now, JoinedTimestamp: req.JoinedTimestamp, } s.addMember(member) // Return current member list members := s.getMembers() w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(members) } func (s *Server) leaveMemberHandler(w http.ResponseWriter, r *http.Request) { var req LeaveRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad Request", http.StatusBadRequest) return } s.removeMember(req.ID) w.WriteHeader(http.StatusNoContent) } func (s *Server) pairsByTimeHandler(w http.ResponseWriter, r *http.Request) { var req PairsByTimeRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad Request", http.StatusBadRequest) return } // Default limit to 15 as per spec if req.Limit <= 0 { req.Limit = 15 } var pairs []PairsByTimeResponse err := s.db.View(func(txn *badger.Txn) error { opts := badger.DefaultIteratorOptions opts.PrefetchSize = req.Limit it := txn.NewIterator(opts) defer it.Close() prefix := []byte("_ts:") // The original logic for prefix filtering here was incomplete. // For Merkle tree sync, this handler is no longer used for core sync. // It remains as a client-facing API. for it.Seek(prefix); it.ValidForPrefix(prefix) && len(pairs) < req.Limit; it.Next() { item := it.Item() key := string(item.Key()) // Parse timestamp index key: "_ts:{timestamp}:{path}" parts := strings.SplitN(key, ":", 3) if len(parts) != 3 { continue } timestamp, err := strconv.ParseInt(parts[1], 10, 64) if err != nil { continue } // Filter by timestamp range if req.StartTimestamp > 0 && timestamp < req.StartTimestamp { continue } if req.EndTimestamp > 0 && timestamp >= req.EndTimestamp { continue } path := parts[2] // Filter by prefix if specified if req.Prefix != "" && !strings.HasPrefix(path, req.Prefix) { continue } var uuid string err = item.Value(func(val []byte) error { uuid = string(val) return nil }) if err != nil { continue } pairs = append(pairs, PairsByTimeResponse{ Path: path, UUID: uuid, Timestamp: timestamp, }) } return nil }) if err != nil { s.logger.WithError(err).Error("Failed to query pairs by time") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } if len(pairs) == 0 { w.WriteHeader(http.StatusNoContent) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(pairs) } func (s *Server) gossipHandler(w http.ResponseWriter, r *http.Request) { var remoteMemberList []Member if err := json.NewDecoder(r.Body).Decode(&remoteMemberList); err != nil { http.Error(w, "Bad Request", http.StatusBadRequest) return } // Merge the received member list s.mergeMemberList(remoteMemberList) // Respond with our current member list localMembers := s.getMembers() gossipResponse := make([]Member, len(localMembers)) for i, member := range localMembers { gossipResponse[i] = *member } // Add ourselves to the response selfMember := Member{ ID: s.config.NodeID, Address: fmt.Sprintf("%s:%d", s.config.BindAddress, s.config.Port), LastSeen: time.Now().UnixMilli(), JoinedTimestamp: s.getJoinedTimestamp(), } gossipResponse = append(gossipResponse, selfMember) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(gossipResponse) s.logger.WithField("remote_members", len(remoteMemberList)).Debug("Processed gossip request") } // Utility function to check if request is from cluster member func (s *Server) isClusterMember(remoteAddr string) bool { host, _, err := net.SplitHostPort(remoteAddr) if err != nil { return false } s.membersMu.RLock() defer s.membersMu.RUnlock() for _, member := range s.members { memberHost, _, err := net.SplitHostPort(member.Address) if err == nil && memberHost == host { return true } } return false } // Setup HTTP routes func (s *Server) setupRoutes() *mux.Router { router := mux.NewRouter() // Health endpoint router.HandleFunc("/health", s.healthHandler).Methods("GET") // KV endpoints router.HandleFunc("/kv/{path:.+}", s.getKVHandler).Methods("GET") router.HandleFunc("/kv/{path:.+}", s.putKVHandler).Methods("PUT") router.HandleFunc("/kv/{path:.+}", s.deleteKVHandler).Methods("DELETE") // Member endpoints router.HandleFunc("/members/", s.getMembersHandler).Methods("GET") router.HandleFunc("/members/join", s.joinMemberHandler).Methods("POST") router.HandleFunc("/members/leave", s.leaveMemberHandler).Methods("DELETE") router.HandleFunc("/members/gossip", s.gossipHandler).Methods("POST") router.HandleFunc("/members/pairs_by_time", s.pairsByTimeHandler).Methods("POST") // Still available for clients // Merkle Tree endpoints router.HandleFunc("/merkle_tree/root", s.getMerkleRootHandler).Methods("GET") router.HandleFunc("/merkle_tree/diff", s.getMerkleDiffHandler).Methods("POST") router.HandleFunc("/kv_range", s.getKVRangeHandler).Methods("POST") // New endpoint for fetching ranges return router } // Start the server func (s *Server) Start() error { router := s.setupRoutes() addr := fmt.Sprintf("%s:%d", s.config.BindAddress, s.config.Port) s.httpServer = &http.Server{ Addr: addr, Handler: router, } s.logger.WithFields(logrus.Fields{ "node_id": s.config.NodeID, "address": addr, }).Info("Starting KVS server") // Start gossip and sync routines s.startBackgroundTasks() // Try to join cluster if seed nodes are configured if len(s.config.SeedNodes) > 0 { go s.bootstrap() } return s.httpServer.ListenAndServe() } // Stop the server gracefully func (s *Server) Stop() error { s.logger.Info("Shutting down KVS server") s.cancel() s.wg.Wait() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := s.httpServer.Shutdown(ctx); err != nil { s.logger.WithError(err).Error("HTTP server shutdown error") } if err := s.db.Close(); err != nil { s.logger.WithError(err).Error("BadgerDB close error") } return nil } // Background tasks (gossip, sync, Merkle tree rebuild, etc.) func (s *Server) startBackgroundTasks() { // Start gossip routine s.wg.Add(1) go s.gossipRoutine() // Start sync routine (now Merkle-based) s.wg.Add(1) go s.syncRoutine() // Start Merkle tree rebuild routine s.wg.Add(1) go s.merkleTreeRebuildRoutine() } // Gossip routine - runs periodically to exchange member lists func (s *Server) gossipRoutine() { defer s.wg.Done() for { // Random interval between 1-2 minutes minInterval := time.Duration(s.config.GossipIntervalMin) * time.Second maxInterval := time.Duration(s.config.GossipIntervalMax) * time.Second interval := minInterval + time.Duration(rand.Int63n(int64(maxInterval-minInterval))) select { case <-s.ctx.Done(): return case <-time.After(interval): s.performGossipRound() } } } // Perform a gossip round with random healthy peers func (s *Server) performGossipRound() { members := s.getHealthyMembers() if len(members) == 0 { s.logger.Debug("No healthy members for gossip round") return } // Select 1-3 random peers for gossip maxPeers := 3 if len(members) < maxPeers { maxPeers = len(members) } // Shuffle and select rand.Shuffle(len(members), func(i, j int) { members[i], members[j] = members[j], members[i] }) selectedPeers := members[:rand.Intn(maxPeers)+1] for _, peer := range selectedPeers { go s.gossipWithPeer(peer) } } // Gossip with a specific peer func (s *Server) gossipWithPeer(peer *Member) { s.logger.WithField("peer", peer.Address).Debug("Starting gossip with peer") // Get our current member list localMembers := s.getMembers() // Send our member list to the peer gossipData := make([]Member, len(localMembers)) for i, member := range localMembers { gossipData[i] = *member } // Add ourselves to the list selfMember := Member{ ID: s.config.NodeID, Address: fmt.Sprintf("%s:%d", s.config.BindAddress, s.config.Port), LastSeen: time.Now().UnixMilli(), JoinedTimestamp: s.getJoinedTimestamp(), } gossipData = append(gossipData, selfMember) jsonData, err := json.Marshal(gossipData) if err != nil { s.logger.WithError(err).Error("Failed to marshal gossip data") return } // Send HTTP request to peer client := &http.Client{Timeout: 5 * time.Second} url := fmt.Sprintf("http://%s/members/gossip", peer.Address) resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData)) if err != nil { s.logger.WithFields(logrus.Fields{ "peer": peer.Address, "error": err.Error(), }).Warn("Failed to gossip with peer") s.markPeerUnhealthy(peer.ID) return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { s.logger.WithFields(logrus.Fields{ "peer": peer.Address, "status": resp.StatusCode, }).Warn("Gossip request failed") s.markPeerUnhealthy(peer.ID) return } // Process response - peer's member list var remoteMemberList []Member if err := json.NewDecoder(resp.Body).Decode(&remoteMemberList); err != nil { s.logger.WithError(err).Error("Failed to decode gossip response") return } // Merge remote member list with our local list s.mergeMemberList(remoteMemberList) // Update peer's last seen timestamp s.updateMemberLastSeen(peer.ID, time.Now().UnixMilli()) s.logger.WithField("peer", peer.Address).Debug("Completed gossip with peer") } // Get healthy members (exclude those marked as down) func (s *Server) getHealthyMembers() []*Member { s.membersMu.RLock() defer s.membersMu.RUnlock() now := time.Now().UnixMilli() healthyMembers := make([]*Member, 0) for _, member := range s.members { // Consider member healthy if last seen within last 5 minutes if now-member.LastSeen < 5*60*1000 { healthyMembers = append(healthyMembers, member) } } return healthyMembers } // Mark a peer as unhealthy func (s *Server) markPeerUnhealthy(nodeID string) { s.membersMu.Lock() defer s.membersMu.Unlock() if member, exists := s.members[nodeID]; exists { // Mark as last seen a long time ago to indicate unhealthy member.LastSeen = time.Now().UnixMilli() - 10*60*1000 // 10 minutes ago s.logger.WithField("node_id", nodeID).Warn("Marked peer as unhealthy") } } // Update member's last seen timestamp func (s *Server) updateMemberLastSeen(nodeID string, timestamp int64) { s.membersMu.Lock() defer s.membersMu.Unlock() if member, exists := s.members[nodeID]; exists { member.LastSeen = timestamp } } // Merge remote member list with local member list func (s *Server) mergeMemberList(remoteMembers []Member) { s.membersMu.Lock() defer s.membersMu.Unlock() now := time.Now().UnixMilli() for _, remoteMember := range remoteMembers { // Skip ourselves if remoteMember.ID == s.config.NodeID { continue } if localMember, exists := s.members[remoteMember.ID]; exists { // Update existing member if remoteMember.LastSeen > localMember.LastSeen { localMember.LastSeen = remoteMember.LastSeen } // Keep the earlier joined timestamp if remoteMember.JoinedTimestamp < localMember.JoinedTimestamp { localMember.JoinedTimestamp = remoteMember.JoinedTimestamp } } else { // Add new member newMember := &Member{ ID: remoteMember.ID, Address: remoteMember.Address, LastSeen: remoteMember.LastSeen, JoinedTimestamp: remoteMember.JoinedTimestamp, } s.members[remoteMember.ID] = newMember s.logger.WithFields(logrus.Fields{ "node_id": remoteMember.ID, "address": remoteMember.Address, }).Info("Discovered new member through gossip") } } // Clean up old members (not seen for more than 10 minutes) toRemove := make([]string, 0) for nodeID, member := range s.members { if now-member.LastSeen > 10*60*1000 { // 10 minutes toRemove = append(toRemove, nodeID) } } for _, nodeID := range toRemove { delete(s.members, nodeID) s.logger.WithField("node_id", nodeID).Info("Removed stale member") } } // Get this node's joined timestamp (startup time) func (s *Server) getJoinedTimestamp() int64 { // For now, use a simple approach - this should be stored persistently return time.Now().UnixMilli() } // Sync routine - handles regular and catch-up syncing func (s *Server) syncRoutine() { defer s.wg.Done() syncTicker := time.NewTicker(time.Duration(s.config.SyncInterval) * time.Second) defer syncTicker.Stop() for { select { case <-s.ctx.Done(): return case <-syncTicker.C: s.performMerkleSync() // Use Merkle sync instead of regular sync } } } // Merkle Tree Implementation // calculateHash generates a SHA256 hash for a given byte slice func calculateHash(data []byte) []byte { h := sha256.New() h.Write(data) return h.Sum(nil) } // calculateLeafHash generates a hash for a leaf node based on its path, UUID, timestamp, and data func (s *Server) calculateLeafHash(path string, storedValue *StoredValue) []byte { // Concatenate path, UUID, timestamp, and the raw data bytes for hashing // Ensure a consistent order of fields for hashing dataToHash := bytes.Buffer{} dataToHash.WriteString(path) dataToHash.WriteByte(':') dataToHash.WriteString(storedValue.UUID) dataToHash.WriteByte(':') dataToHash.WriteString(strconv.FormatInt(storedValue.Timestamp, 10)) dataToHash.WriteByte(':') dataToHash.Write(storedValue.Data) // Use raw bytes of json.RawMessage return calculateHash(dataToHash.Bytes()) } // getAllKVPairsForMerkleTree retrieves all key-value pairs needed for Merkle tree construction. func (s *Server) getAllKVPairsForMerkleTree() (map[string]*StoredValue, error) { pairs := make(map[string]*StoredValue) err := s.db.View(func(txn *badger.Txn) error { opts := badger.DefaultIteratorOptions opts.PrefetchValues = true // We need the values for hashing it := txn.NewIterator(opts) defer it.Close() // Iterate over all actual data keys (not _ts: indexes) for it.Rewind(); it.Valid(); it.Next() { item := it.Item() key := string(item.Key()) if strings.HasPrefix(key, "_ts:") { continue // Skip index keys } var storedValue StoredValue err := item.Value(func(val []byte) error { return json.Unmarshal(val, &storedValue) }) if err != nil { s.logger.WithError(err).WithField("key", key).Warn("Failed to unmarshal stored value for Merkle tree, skipping") continue } pairs[key] = &storedValue } return nil }) if err != nil { return nil, err } return pairs, nil } // buildMerkleTreeFromPairs constructs a Merkle Tree from the KVS data. // This version uses a recursive approach to build a balanced tree from sorted keys. func (s *Server) buildMerkleTreeFromPairs(pairs map[string]*StoredValue) (*MerkleNode, error) { if len(pairs) == 0 { return &MerkleNode{Hash: calculateHash([]byte("empty_tree")), StartKey: "", EndKey: ""}, nil } // Sort keys to ensure consistent tree structure keys := make([]string, 0, len(pairs)) for k := range pairs { keys = append(keys, k) } sort.Strings(keys) // Create leaf nodes leafNodes := make([]*MerkleNode, len(keys)) for i, key := range keys { storedValue := pairs[key] hash := s.calculateLeafHash(key, storedValue) leafNodes[i] = &MerkleNode{Hash: hash, StartKey: key, EndKey: key} } // Recursively build parent nodes return s.buildMerkleTreeRecursive(leafNodes) } // buildMerkleTreeRecursive builds the tree from a slice of nodes. func (s *Server) buildMerkleTreeRecursive(nodes []*MerkleNode) (*MerkleNode, error) { if len(nodes) == 0 { return nil, nil } if len(nodes) == 1 { return nodes[0], nil } var nextLevel []*MerkleNode for i := 0; i < len(nodes); i += 2 { left := nodes[i] var right *MerkleNode if i+1 < len(nodes) { right = nodes[i+1] } var combinedHash []byte var endKey string if right != nil { combinedHash = calculateHash(append(left.Hash, right.Hash...)) endKey = right.EndKey } else { // Odd number of nodes, promote the left node combinedHash = left.Hash endKey = left.EndKey } parentNode := &MerkleNode{ Hash: combinedHash, StartKey: left.StartKey, EndKey: endKey, } nextLevel = append(nextLevel, parentNode) } return s.buildMerkleTreeRecursive(nextLevel) } // getMerkleRoot returns the current Merkle root of the server. func (s *Server) getMerkleRoot() *MerkleNode { s.merkleRootMu.RLock() defer s.merkleRootMu.RUnlock() return s.merkleRoot } // setMerkleRoot sets the current Merkle root of the server. func (s *Server) setMerkleRoot(root *MerkleNode) { s.merkleRootMu.Lock() defer s.merkleRootMu.Unlock() s.merkleRoot = root } // merkleTreeRebuildRoutine periodically rebuilds the Merkle tree. func (s *Server) merkleTreeRebuildRoutine() { defer s.wg.Done() ticker := time.NewTicker(time.Duration(s.config.SyncInterval) * time.Second) // Use sync interval for now defer ticker.Stop() for { select { case <-s.ctx.Done(): return case <-ticker.C: s.logger.Debug("Rebuilding Merkle tree...") pairs, err := s.getAllKVPairsForMerkleTree() if err != nil { s.logger.WithError(err).Error("Failed to get KV pairs for Merkle tree rebuild") continue } newRoot, err := s.buildMerkleTreeFromPairs(pairs) if err != nil { s.logger.WithError(err).Error("Failed to rebuild Merkle tree") continue } s.setMerkleRoot(newRoot) s.logger.Debug("Merkle tree rebuilt.") } } } // getMerkleRootHandler returns the root hash of the local Merkle Tree func (s *Server) getMerkleRootHandler(w http.ResponseWriter, r *http.Request) { root := s.getMerkleRoot() if root == nil { http.Error(w, "Merkle tree not initialized", http.StatusInternalServerError) return } resp := MerkleRootResponse{ Root: root, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) } // getMerkleDiffHandler is used by a peer to request children hashes for a given node/range. func (s *Server) getMerkleDiffHandler(w http.ResponseWriter, r *http.Request) { var req MerkleTreeDiffRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad Request", http.StatusBadRequest) return } localPairs, err := s.getAllKVPairsForMerkleTree() if err != nil { s.logger.WithError(err).Error("Failed to get KV pairs for Merkle diff") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } // Build the local MerkleNode for the requested range to compare with the remote's hash localSubTreeRoot, err := s.buildMerkleTreeFromPairs(s.filterPairsByRange(localPairs, req.ParentNode.StartKey, req.ParentNode.EndKey)) if err != nil { s.logger.WithError(err).Error("Failed to build sub-Merkle tree for diff request") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } if localSubTreeRoot == nil { // This can happen if the range is empty locally localSubTreeRoot = &MerkleNode{Hash: calculateHash([]byte("empty_tree")), StartKey: req.ParentNode.StartKey, EndKey: req.ParentNode.EndKey} } resp := MerkleTreeDiffResponse{} // If hashes match, no need to send children or keys if bytes.Equal(req.LocalHash, localSubTreeRoot.Hash) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) return } // Hashes differ, so we need to provide more detail. // Get all keys within the parent node's range locally var keysInRange []string for key := range s.filterPairsByRange(localPairs, req.ParentNode.StartKey, req.ParentNode.EndKey) { keysInRange = append(keysInRange, key) } sort.Strings(keysInRange) const diffLeafThreshold = 10 // If a range has <= 10 keys, we consider it a leaf-level diff if len(keysInRange) <= diffLeafThreshold { // This is a leaf-level diff, return the actual keys in the range resp.Keys = keysInRange } else { // Group keys into sub-ranges and return their MerkleNode representations // For simplicity, let's split the range into two halves. mid := len(keysInRange) / 2 leftKeys := keysInRange[:mid] rightKeys := keysInRange[mid:] if len(leftKeys) > 0 { leftRangePairs := s.filterPairsByRange(localPairs, leftKeys[0], leftKeys[len(leftKeys)-1]) leftNode, err := s.buildMerkleTreeFromPairs(leftRangePairs) if err != nil { s.logger.WithError(err).Error("Failed to build left child node for diff") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } if leftNode != nil { resp.Children = append(resp.Children, *leftNode) } } if len(rightKeys) > 0 { rightRangePairs := s.filterPairsByRange(localPairs, rightKeys[0], rightKeys[len(rightKeys)-1]) rightNode, err := s.buildMerkleTreeFromPairs(rightRangePairs) if err != nil { s.logger.WithError(err).Error("Failed to build right child node for diff") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } if rightNode != nil { resp.Children = append(resp.Children, *rightNode) } } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) } // Helper to filter a map of StoredValue by key range func (s *Server) filterPairsByRange(allPairs map[string]*StoredValue, startKey, endKey string) map[string]*StoredValue { filtered := make(map[string]*StoredValue) for key, value := range allPairs { if (startKey == "" || key >= startKey) && (endKey == "" || key <= endKey) { filtered[key] = value } } return filtered } // getKVRangeHandler fetches a range of KV pairs func (s *Server) getKVRangeHandler(w http.ResponseWriter, r *http.Request) { var req KVRangeRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad Request", http.StatusBadRequest) return } var pairs []struct { Path string `json:"path"` StoredValue StoredValue `json:"stored_value"` } err := s.db.View(func(txn *badger.Txn) error { opts := badger.DefaultIteratorOptions opts.PrefetchValues = true it := txn.NewIterator(opts) defer it.Close() count := 0 // Start iteration from the requested StartKey for it.Seek([]byte(req.StartKey)); it.Valid(); it.Next() { item := it.Item() key := string(item.Key()) if strings.HasPrefix(key, "_ts:") { continue // Skip index keys } // Stop if we exceed the EndKey (if provided) if req.EndKey != "" && key > req.EndKey { break } // Stop if we hit the limit (if provided) if req.Limit > 0 && count >= req.Limit { break } var storedValue StoredValue err := item.Value(func(val []byte) error { return json.Unmarshal(val, &storedValue) }) if err != nil { s.logger.WithError(err).WithField("key", key).Warn("Failed to unmarshal stored value in KV range, skipping") continue } pairs = append(pairs, struct { Path string `json:"path"` StoredValue StoredValue `json:"stored_value"` }{Path: key, StoredValue: storedValue}) count++ } return nil }) if err != nil { s.logger.WithError(err).Error("Failed to query KV range") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(KVRangeResponse{Pairs: pairs}) } // performMerkleSync performs a synchronization round using Merkle Trees func (s *Server) performMerkleSync() { members := s.getHealthyMembers() if len(members) == 0 { s.logger.Debug("No healthy members for Merkle sync") return } // Select random peer peer := members[rand.Intn(len(members))] s.logger.WithField("peer", peer.Address).Info("Starting Merkle tree sync") localRoot := s.getMerkleRoot() if localRoot == nil { s.logger.Error("Local Merkle root is nil, cannot perform sync") return } // 1. Get remote peer's Merkle root remoteRootResp, err := s.requestMerkleRoot(peer.Address) if err != nil { s.logger.WithError(err).WithField("peer", peer.Address).Error("Failed to get remote Merkle root") s.markPeerUnhealthy(peer.ID) return } remoteRoot := remoteRootResp.Root // 2. Compare roots and start recursive diffing if they differ if !bytes.Equal(localRoot.Hash, remoteRoot.Hash) { s.logger.WithFields(logrus.Fields{ "peer": peer.Address, "local_root": hex.EncodeToString(localRoot.Hash), "remote_root": hex.EncodeToString(remoteRoot.Hash), }).Info("Merkle roots differ, starting recursive diff") s.diffMerkleTreesRecursive(peer.Address, localRoot, remoteRoot) } else { s.logger.WithField("peer", peer.Address).Info("Merkle roots match, no sync needed") } s.logger.WithField("peer", peer.Address).Info("Completed Merkle tree sync") } // requestMerkleRoot requests the Merkle root from a peer func (s *Server) requestMerkleRoot(peerAddress string) (*MerkleRootResponse, error) { client := &http.Client{Timeout: 10 * time.Second} url := fmt.Sprintf("http://%s/merkle_tree/root", peerAddress) resp, err := client.Get(url) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("peer returned status %d for Merkle root", resp.StatusCode) } var merkleRootResp MerkleRootResponse if err := json.NewDecoder(resp.Body).Decode(&merkleRootResp); err != nil { return nil, err } return &merkleRootResp, nil } // diffMerkleTreesRecursive recursively compares local and remote Merkle tree nodes func (s *Server) diffMerkleTreesRecursive(peerAddress string, localNode, remoteNode *MerkleNode) { // If hashes match, this subtree is in sync. if bytes.Equal(localNode.Hash, remoteNode.Hash) { return } // Hashes differ, need to go deeper. // Request children from the remote peer for the current range. req := MerkleTreeDiffRequest{ ParentNode: *remoteNode, // We are asking the remote peer about its children for this range LocalHash: localNode.Hash, // Our hash for this range } remoteDiffResp, err := s.requestMerkleDiff(peerAddress, req) if err != nil { s.logger.WithError(err).WithFields(logrus.Fields{ "peer": peerAddress, "start_key": localNode.StartKey, "end_key": localNode.EndKey, }).Error("Failed to get Merkle diff from peer") return } if len(remoteDiffResp.Keys) > 0 { // This is a leaf-level diff, we have the actual keys that are different. // Fetch and compare each key. s.logger.WithFields(logrus.Fields{ "peer": peerAddress, "start_key": localNode.StartKey, "end_key": localNode.EndKey, "num_keys": len(remoteDiffResp.Keys), }).Info("Found divergent keys, fetching and comparing data") for _, key := range remoteDiffResp.Keys { // Fetch the individual key from the peer remoteStoredValue, err := s.fetchSingleKVFromPeer(peerAddress, key) if err != nil { s.logger.WithError(err).WithFields(logrus.Fields{ "peer": peerAddress, "key": key, }).Error("Failed to fetch single KV from peer during diff") continue } localStoredValue, localExists := s.getLocalData(key) if remoteStoredValue == nil { // Key was deleted on remote, delete locally if exists if localExists { s.logger.WithField("key", key).Info("Key deleted on remote, deleting locally") s.deleteKVLocally(key, localStoredValue.Timestamp) // Need a local delete function } continue } if !localExists { // Local data is missing, store the remote data if err := s.storeReplicatedDataWithMetadata(key, remoteStoredValue); err != nil { s.logger.WithError(err).WithField("key", key).Error("Failed to store missing replicated data") } else { s.logger.WithField("key", key).Info("Fetched and stored missing data from peer") } } else if localStoredValue.Timestamp < remoteStoredValue.Timestamp { // Remote is newer, store the remote data if err := s.storeReplicatedDataWithMetadata(key, remoteStoredValue); err != nil { s.logger.WithError(err).WithField("key", key).Error("Failed to store newer replicated data") } else { s.logger.WithField("key", key).Info("Fetched and stored newer data from peer") } } else if localStoredValue.Timestamp == remoteStoredValue.Timestamp && localStoredValue.UUID != remoteStoredValue.UUID { // Timestamp collision, engage conflict resolution remotePair := PairsByTimeResponse{ // Re-use this struct for conflict resolution Path: key, UUID: remoteStoredValue.UUID, Timestamp: remoteStoredValue.Timestamp, } resolved, err := s.resolveConflict(key, localStoredValue, &remotePair, peerAddress) if err != nil { s.logger.WithError(err).WithField("path", key).Error("Failed to resolve conflict during Merkle sync") } else if resolved { s.logger.WithField("path", key).Info("Conflict resolved, updated local data during Merkle sync") } else { s.logger.WithField("path", key).Info("Conflict resolved, kept local data during Merkle sync") } } // If local is newer or same timestamp and same UUID, do nothing. } } else if len(remoteDiffResp.Children) > 0 { // Not a leaf level, continue recursive diff for children. localPairs, err := s.getAllKVPairsForMerkleTree() if err != nil { s.logger.WithError(err).Error("Failed to get KV pairs for local children comparison") return } for _, remoteChild := range remoteDiffResp.Children { // Build the local Merkle node for this child's range localChildNode, err := s.buildMerkleTreeFromPairs(s.filterPairsByRange(localPairs, remoteChild.StartKey, remoteChild.EndKey)) if err != nil { s.logger.WithError(err).WithFields(logrus.Fields{ "start_key": remoteChild.StartKey, "end_key": remoteChild.EndKey, }).Error("Failed to build local child node for diff") continue } if localChildNode == nil || !bytes.Equal(localChildNode.Hash, remoteChild.Hash) { // If local child node is nil (meaning local has no data in this range) // or hashes differ, then we need to fetch the data. if localChildNode == nil { s.logger.WithFields(logrus.Fields{ "peer": peerAddress, "start_key": remoteChild.StartKey, "end_key": remoteChild.EndKey, }).Info("Local node missing data in remote child's range, fetching full range") s.fetchAndStoreRange(peerAddress, remoteChild.StartKey, remoteChild.EndKey) } else { s.diffMerkleTreesRecursive(peerAddress, localChildNode, &remoteChild) } } } } } // requestMerkleDiff requests children hashes or keys for a given node/range from a peer func (s *Server) requestMerkleDiff(peerAddress string, req MerkleTreeDiffRequest) (*MerkleTreeDiffResponse, error) { jsonData, err := json.Marshal(req) if err != nil { return nil, err } client := &http.Client{Timeout: 10 * time.Second} url := fmt.Sprintf("http://%s/merkle_tree/diff", peerAddress) resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData)) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("peer returned status %d for Merkle diff", resp.StatusCode) } var diffResp MerkleTreeDiffResponse if err := json.NewDecoder(resp.Body).Decode(&diffResp); err != nil { return nil, err } return &diffResp, nil } // fetchSingleKVFromPeer fetches a single KV pair from a peer func (s *Server) fetchSingleKVFromPeer(peerAddress, path string) (*StoredValue, error) { client := &http.Client{Timeout: 5 * time.Second} url := fmt.Sprintf("http://%s/kv/%s", peerAddress, path) resp, err := client.Get(url) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode == http.StatusNotFound { return nil, nil // Key might have been deleted on the peer } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("peer returned status %d for path %s", resp.StatusCode, path) } var storedValue StoredValue if err := json.NewDecoder(resp.Body).Decode(&storedValue); err != nil { return nil, fmt.Errorf("failed to decode StoredValue from peer: %v", err) } return &storedValue, nil } // storeReplicatedDataWithMetadata stores replicated data preserving its original metadata func (s *Server) storeReplicatedDataWithMetadata(path string, storedValue *StoredValue) error { valueBytes, err := json.Marshal(storedValue) if err != nil { return err } return s.db.Update(func(txn *badger.Txn) error { // Store main data if err := txn.Set([]byte(path), valueBytes); err != nil { return err } // Store timestamp index indexKey := fmt.Sprintf("_ts:%020d:%s", storedValue.Timestamp, path) return txn.Set([]byte(indexKey), []byte(storedValue.UUID)) }) } // deleteKVLocally deletes a key-value pair and its associated timestamp index locally. func (s *Server) deleteKVLocally(path string, timestamp int64) error { return s.db.Update(func(txn *badger.Txn) error { if err := txn.Delete([]byte(path)); err != nil { return err } indexKey := fmt.Sprintf("_ts:%020d:%s", timestamp, path) return txn.Delete([]byte(indexKey)) }) } // fetchAndStoreRange fetches a range of KV pairs from a peer and stores them locally func (s *Server) fetchAndStoreRange(peerAddress string, startKey, endKey string) error { req := KVRangeRequest{ StartKey: startKey, EndKey: endKey, Limit: 0, // No limit } jsonData, err := json.Marshal(req) if err != nil { return err } client := &http.Client{Timeout: 30 * time.Second} // Longer timeout for range fetches url := fmt.Sprintf("http://%s/kv_range", peerAddress) resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData)) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("peer returned status %d for KV range fetch", resp.StatusCode) } var rangeResp KVRangeResponse if err := json.NewDecoder(resp.Body).Decode(&rangeResp); err != nil { return err } for _, pair := range rangeResp.Pairs { // Use storeReplicatedDataWithMetadata to preserve original UUID/Timestamp if err := s.storeReplicatedDataWithMetadata(pair.Path, &pair.StoredValue); err != nil { s.logger.WithError(err).WithFields(logrus.Fields{ "peer": peerAddress, "path": pair.Path, }).Error("Failed to store fetched range data") } else { s.logger.WithFields(logrus.Fields{ "peer": peerAddress, "path": pair.Path, }).Debug("Stored data from fetched range") } } return nil } // Bootstrap - join cluster using seed nodes func (s *Server) bootstrap() { if len(s.config.SeedNodes) == 0 { s.logger.Info("No seed nodes configured, running as standalone") return } s.logger.Info("Starting bootstrap process") s.setMode("syncing") // Try to join cluster via each seed node joined := false for _, seedAddr := range s.config.SeedNodes { if s.attemptJoin(seedAddr) { joined = true break } } if !joined { s.logger.Warn("Failed to join cluster via seed nodes, running as standalone") s.setMode("normal") return } // Wait a bit for member discovery time.Sleep(2 * time.Second) // Perform gradual sync (now Merkle-based) s.performGradualSync() // Switch to normal mode s.setMode("normal") s.logger.Info("Bootstrap completed, entering normal mode") } // Attempt to join cluster via a seed node func (s *Server) attemptJoin(seedAddr string) bool { joinReq := JoinRequest{ ID: s.config.NodeID, Address: fmt.Sprintf("%s:%d", s.config.BindAddress, s.config.Port), JoinedTimestamp: time.Now().UnixMilli(), } jsonData, err := json.Marshal(joinReq) if err != nil { s.logger.WithError(err).Error("Failed to marshal join request") return false } client := &http.Client{Timeout: 10 * time.Second} url := fmt.Sprintf("http://%s/members/join", seedAddr) resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData)) if err != nil { s.logger.WithFields(logrus.Fields{ "seed": seedAddr, "error": err.Error(), }).Warn("Failed to contact seed node") return false } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { s.logger.WithFields(logrus.Fields{ "seed": seedAddr, "status": resp.StatusCode, }).Warn("Seed node rejected join request") return false } // Process member list response var memberList []Member if err := json.NewDecoder(resp.Body).Decode(&memberList); err != nil { s.logger.WithError(err).Error("Failed to decode member list from seed") return false } // Add all members to our local list for _, member := range memberList { if member.ID != s.config.NodeID { s.addMember(&member) } } s.logger.WithFields(logrus.Fields{ "seed": seedAddr, "member_count": len(memberList), }).Info("Successfully joined cluster") return true } // Perform gradual sync (Merkle-based version) func (s *Server) performGradualSync() { s.logger.Info("Starting gradual sync (Merkle-based)") members := s.getHealthyMembers() if len(members) == 0 { s.logger.Info("No healthy members for gradual sync") return } // For now, just do a few rounds of Merkle sync for i := 0; i < 3; i++ { s.performMerkleSync() // Use Merkle sync time.Sleep(time.Duration(s.config.ThrottleDelayMs) * time.Millisecond) } s.logger.Info("Gradual sync completed") } // Resolve conflict between local and remote data using majority vote and oldest node tie-breaker func (s *Server) resolveConflict(path string, localData *StoredValue, remotePair *PairsByTimeResponse, peerAddress string) (bool, error) { s.logger.WithFields(logrus.Fields{ "path": path, "timestamp": localData.Timestamp, "local_uuid": localData.UUID, "remote_uuid": remotePair.UUID, }).Info("Starting conflict resolution with majority vote") // Get list of healthy members for voting members := s.getHealthyMembers() if len(members) == 0 { // No other members to consult, use oldest node rule (local vs remote) // We'll consider the peer as the "remote" node for comparison return s.resolveByOldestNode(localData, remotePair, peerAddress) } // Query all healthy members for their version of this path votes := make(map[string]int) // UUID -> vote count uuidToTimestamp := make(map[string]int64) uuidToJoinedTime := make(map[string]int64) // Add our local vote votes[localData.UUID] = 1 uuidToTimestamp[localData.UUID] = localData.Timestamp uuidToJoinedTime[localData.UUID] = s.getJoinedTimestamp() // Add the remote peer's vote // Note: remotePair.Timestamp is used here, but for a full Merkle sync, // we would have already fetched the full remoteStoredValue. // For consistency, let's assume remotePair accurately reflects the remote's StoredValue metadata. votes[remotePair.UUID]++ // Increment vote, as it's already counted implicitly by being the source of divergence uuidToTimestamp[remotePair.UUID] = remotePair.Timestamp // We'll need to get the peer's joined timestamp s.membersMu.RLock() for _, member := range s.members { if member.Address == peerAddress { uuidToJoinedTime[remotePair.UUID] = member.JoinedTimestamp break } } s.membersMu.RUnlock() // Query other members for _, member := range members { if member.Address == peerAddress { continue // Already handled the peer } memberData, err := s.fetchSingleKVFromPeer(member.Address, path) // Use the new fetchSingleKVFromPeer if err != nil || memberData == nil { s.logger.WithFields(logrus.Fields{ "member_address": member.Address, "path": path, "error": err, }).Debug("Failed to query member for data during conflict resolution, skipping vote") continue } // Only count votes for data with the same timestamp (as per original logic for collision) if memberData.Timestamp == localData.Timestamp { votes[memberData.UUID]++ if _, exists := uuidToTimestamp[memberData.UUID]; !exists { uuidToTimestamp[memberData.UUID] = memberData.Timestamp uuidToJoinedTime[memberData.UUID] = member.JoinedTimestamp } } } // Find the UUID with majority votes maxVotes := 0 var winningUUIDs []string for uuid, voteCount := range votes { if voteCount > maxVotes { maxVotes = voteCount winningUUIDs = []string{uuid} } else if voteCount == maxVotes { winningUUIDs = append(winningUUIDs, uuid) } } var winnerUUID string if len(winningUUIDs) == 1 { winnerUUID = winningUUIDs[0] } else { // Tie-breaker: oldest node (earliest joined timestamp) oldestJoinedTime := int64(0) for _, uuid := range winningUUIDs { joinedTime := uuidToJoinedTime[uuid] if oldestJoinedTime == 0 || (joinedTime != 0 && joinedTime < oldestJoinedTime) { // joinedTime can be 0 if not found oldestJoinedTime = joinedTime winnerUUID = uuid } } s.logger.WithFields(logrus.Fields{ "path": path, "tied_votes": maxVotes, "winner_uuid": winnerUUID, "oldest_joined": oldestJoinedTime, }).Info("Resolved conflict using oldest node tie-breaker") } // If remote UUID wins, fetch and store the remote data if winnerUUID == remotePair.UUID { // We need the full StoredValue for the winning remote data. // Since remotePair only has UUID/Timestamp, we must fetch the data. winningRemoteStoredValue, err := s.fetchSingleKVFromPeer(peerAddress, path) if err != nil || winningRemoteStoredValue == nil { return false, fmt.Errorf("failed to fetch winning remote data for conflict resolution: %v", err) } err = s.storeReplicatedDataWithMetadata(path, winningRemoteStoredValue) if err != nil { return false, fmt.Errorf("failed to store winning data: %v", err) } s.logger.WithFields(logrus.Fields{ "path": path, "winner_uuid": winnerUUID, "winner_votes": maxVotes, "total_nodes": len(members) + 2, // +2 for local and peer }).Info("Conflict resolved: remote data wins") return true, nil } // Local data wins, no action needed s.logger.WithFields(logrus.Fields{ "path": path, "winner_uuid": winnerUUID, "winner_votes": maxVotes, "total_nodes": len(members) + 2, }).Info("Conflict resolved: local data wins") return false, nil } // Resolve conflict using oldest node rule when no other members available func (s *Server) resolveByOldestNode(localData *StoredValue, remotePair *PairsByTimeResponse, peerAddress string) (bool, error) { // Find the peer's joined timestamp peerJoinedTime := int64(0) s.membersMu.RLock() for _, member := range s.members { if member.Address == peerAddress { peerJoinedTime = member.JoinedTimestamp break } } s.membersMu.RUnlock() localJoinedTime := s.getJoinedTimestamp() // Oldest node wins if peerJoinedTime > 0 && peerJoinedTime < localJoinedTime { // Peer is older, fetch remote data winningRemoteStoredValue, err := s.fetchSingleKVFromPeer(peerAddress, remotePair.Path) if err != nil || winningRemoteStoredValue == nil { return false, fmt.Errorf("failed to fetch data from older node for conflict resolution: %v", err) } err = s.storeReplicatedDataWithMetadata(remotePair.Path, winningRemoteStoredValue) if err != nil { return false, fmt.Errorf("failed to store data from older node: %v", err) } s.logger.WithFields(logrus.Fields{ "path": remotePair.Path, "local_joined": localJoinedTime, "peer_joined": peerJoinedTime, "winner": "remote", }).Info("Conflict resolved using oldest node rule") return true, nil } // Local node is older or equal, keep local data s.logger.WithFields(logrus.Fields{ "path": remotePair.Path, "local_joined": localJoinedTime, "peer_joined": peerJoinedTime, "winner": "local", }).Info("Conflict resolved using oldest node rule") return false, nil } // getLocalData is a utility to retrieve a StoredValue from local DB. func (s *Server) getLocalData(path string) (*StoredValue, bool) { var storedValue StoredValue err := s.db.View(func(txn *badger.Txn) error { item, err := txn.Get([]byte(path)) if err != nil { return err } return item.Value(func(val []byte) error { return json.Unmarshal(val, &storedValue) }) }) if err != nil { return nil, false } return &storedValue, true } func main() { configPath := "./config.yaml" // Simple CLI argument parsing if len(os.Args) > 1 { configPath = os.Args[1] } config, err := loadConfig(configPath) if err != nil { fmt.Fprintf(os.Stderr, "Failed to load configuration: %v\n", err) os.Exit(1) } server, err := NewServer(config) if err != nil { fmt.Fprintf(os.Stderr, "Failed to create server: %v\n", err) os.Exit(1) } // Handle graceful shutdown sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) go func() { <-sigCh server.Stop() }() if err := server.Start(); err != nil && err != http.ErrServerClosed { fmt.Fprintf(os.Stderr, "Server error: %v\n", err) os.Exit(1) } }