package cluster import ( "bytes" "context" "encoding/hex" "encoding/json" "fmt" "math/rand" "net/http" "sync" "time" badger "github.com/dgraph-io/badger/v4" "github.com/sirupsen/logrus" "kvs/types" ) // SyncService handles data synchronization between cluster nodes type SyncService struct { db *badger.DB config *types.Config gossipService *GossipService merkleService *MerkleService logger *logrus.Logger merkleRoot *types.MerkleNode merkleRootMu sync.RWMutex ctx context.Context cancel context.CancelFunc wg sync.WaitGroup } // NewSyncService creates a new sync service func NewSyncService(db *badger.DB, config *types.Config, gossipService *GossipService, merkleService *MerkleService, logger *logrus.Logger) *SyncService { ctx, cancel := context.WithCancel(context.Background()) return &SyncService{ db: db, config: config, gossipService: gossipService, merkleService: merkleService, logger: logger, ctx: ctx, cancel: cancel, } } // Start begins the sync routines func (s *SyncService) Start() { if !s.config.ClusteringEnabled { s.logger.Info("Clustering disabled, skipping sync routines") return } // Start sync routine s.wg.Add(1) go s.syncRoutine() // Start Merkle tree rebuild routine s.wg.Add(1) go s.merkleTreeRebuildRoutine() } // Stop terminates the sync service func (s *SyncService) Stop() { s.cancel() s.wg.Wait() } // GetMerkleRoot returns the current Merkle root func (s *SyncService) GetMerkleRoot() *types.MerkleNode { s.merkleRootMu.RLock() defer s.merkleRootMu.RUnlock() return s.merkleRoot } // SetMerkleRoot sets the current Merkle root func (s *SyncService) SetMerkleRoot(root *types.MerkleNode) { s.merkleRootMu.Lock() defer s.merkleRootMu.Unlock() s.merkleRoot = root } // syncRoutine handles regular and catch-up syncing func (s *SyncService) syncRoutine() { defer s.wg.Done() syncTicker := time.NewTicker(time.Duration(s.config.SyncInterval) * time.Second) defer syncTicker.Stop() for { select { case <-s.ctx.Done(): return case <-syncTicker.C: s.performMerkleSync() } } } // merkleTreeRebuildRoutine periodically rebuilds the Merkle tree func (s *SyncService) merkleTreeRebuildRoutine() { defer s.wg.Done() ticker := time.NewTicker(time.Duration(s.config.SyncInterval) * time.Second) defer ticker.Stop() for { select { case <-s.ctx.Done(): return case <-ticker.C: s.logger.Debug("Rebuilding Merkle tree...") pairs, err := s.merkleService.GetAllKVPairsForMerkleTree() if err != nil { s.logger.WithError(err).Error("Failed to get KV pairs for Merkle tree rebuild") continue } newRoot, err := s.merkleService.BuildMerkleTreeFromPairs(pairs) if err != nil { s.logger.WithError(err).Error("Failed to rebuild Merkle tree") continue } s.SetMerkleRoot(newRoot) s.logger.Debug("Merkle tree rebuilt.") } } } // InitializeMerkleTree builds the initial Merkle tree func (s *SyncService) InitializeMerkleTree() error { pairs, err := s.merkleService.GetAllKVPairsForMerkleTree() if err != nil { return fmt.Errorf("failed to get all KV pairs for initial Merkle tree: %v", err) } root, err := s.merkleService.BuildMerkleTreeFromPairs(pairs) if err != nil { return fmt.Errorf("failed to build initial Merkle tree: %v", err) } s.SetMerkleRoot(root) s.logger.Info("Initial Merkle tree built.") return nil } // performMerkleSync performs a synchronization round using Merkle Trees func (s *SyncService) performMerkleSync() { members := s.gossipService.GetHealthyMembers() if len(members) == 0 { s.logger.Debug("No healthy members for Merkle sync") return } // Select random peer peer := members[rand.Intn(len(members))] s.logger.WithField("peer", peer.Address).Info("Starting Merkle tree sync") localRoot := s.GetMerkleRoot() if localRoot == nil { s.logger.Error("Local Merkle root is nil, cannot perform sync") return } // 1. Get remote peer's Merkle root remoteRootResp, err := s.requestMerkleRoot(peer.Address) if err != nil { s.logger.WithError(err).WithField("peer", peer.Address).Error("Failed to get remote Merkle root") s.gossipService.markPeerUnhealthy(peer.ID) return } remoteRoot := remoteRootResp.Root // 2. Compare roots and start recursive diffing if they differ if !bytes.Equal(localRoot.Hash, remoteRoot.Hash) { s.logger.WithFields(logrus.Fields{ "peer": peer.Address, "local_root": hex.EncodeToString(localRoot.Hash), "remote_root": hex.EncodeToString(remoteRoot.Hash), }).Info("Merkle roots differ, starting recursive diff") s.diffMerkleTreesRecursive(peer.Address, localRoot, remoteRoot) } else { s.logger.WithField("peer", peer.Address).Info("Merkle roots match, no sync needed") } s.logger.WithField("peer", peer.Address).Info("Completed Merkle tree sync") } // requestMerkleRoot requests the Merkle root from a peer func (s *SyncService) requestMerkleRoot(peerAddress string) (*types.MerkleRootResponse, error) { client := &http.Client{Timeout: 10 * time.Second} url := fmt.Sprintf("http://%s/merkle_tree/root", peerAddress) resp, err := client.Get(url) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("peer returned status %d for Merkle root", resp.StatusCode) } var merkleRootResp types.MerkleRootResponse if err := json.NewDecoder(resp.Body).Decode(&merkleRootResp); err != nil { return nil, err } return &merkleRootResp, nil } // diffMerkleTreesRecursive recursively compares local and remote Merkle tree nodes func (s *SyncService) diffMerkleTreesRecursive(peerAddress string, localNode, remoteNode *types.MerkleNode) { // If hashes match, this subtree is in sync. if bytes.Equal(localNode.Hash, remoteNode.Hash) { return } // Hashes differ, need to go deeper. // Request children from the remote peer for the current range. req := types.MerkleTreeDiffRequest{ ParentNode: *remoteNode, // We are asking the remote peer about its children for this range LocalHash: localNode.Hash, // Our hash for this range } remoteDiffResp, err := s.requestMerkleDiff(peerAddress, req) if err != nil { s.logger.WithError(err).WithFields(logrus.Fields{ "peer": peerAddress, "start_key": localNode.StartKey, "end_key": localNode.EndKey, }).Error("Failed to get Merkle diff from peer") return } if len(remoteDiffResp.Keys) > 0 { // This is a leaf-level diff, we have the actual keys that are different. s.handleLeafLevelDiff(peerAddress, remoteDiffResp.Keys, localNode) } else if len(remoteDiffResp.Children) > 0 { // Not a leaf level, continue recursive diff for children. s.handleChildrenDiff(peerAddress, remoteDiffResp.Children) } } // handleLeafLevelDiff processes leaf-level differences func (s *SyncService) handleLeafLevelDiff(peerAddress string, keys []string, localNode *types.MerkleNode) { s.logger.WithFields(logrus.Fields{ "peer": peerAddress, "start_key": localNode.StartKey, "end_key": localNode.EndKey, "num_keys": len(keys), }).Info("Found divergent keys, fetching and comparing data") for _, key := range keys { // Fetch the individual key from the peer remoteStoredValue, err := s.fetchSingleKVFromPeer(peerAddress, key) if err != nil { s.logger.WithError(err).WithFields(logrus.Fields{ "peer": peerAddress, "key": key, }).Error("Failed to fetch single KV from peer during diff") continue } localStoredValue, localExists := s.getLocalData(key) if remoteStoredValue == nil { // Key was deleted on remote, delete locally if exists if localExists { s.logger.WithField("key", key).Info("Key deleted on remote, deleting locally") s.deleteKVLocally(key, localStoredValue.Timestamp) } continue } if !localExists { // Local data is missing, store the remote data if err := s.storeReplicatedDataWithMetadata(key, remoteStoredValue); err != nil { s.logger.WithError(err).WithField("key", key).Error("Failed to store missing replicated data") } else { s.logger.WithField("key", key).Info("Fetched and stored missing data from peer") } } else if localStoredValue.Timestamp < remoteStoredValue.Timestamp { // Remote is newer, store the remote data if err := s.storeReplicatedDataWithMetadata(key, remoteStoredValue); err != nil { s.logger.WithError(err).WithField("key", key).Error("Failed to store newer replicated data") } else { s.logger.WithField("key", key).Info("Fetched and stored newer data from peer") } } else if localStoredValue.Timestamp == remoteStoredValue.Timestamp && localStoredValue.UUID != remoteStoredValue.UUID { // Timestamp collision, engage conflict resolution s.resolveConflict(key, localStoredValue, remoteStoredValue, peerAddress) } // If local is newer or same timestamp and same UUID, do nothing. } } // fetchSingleKVFromPeer fetches a single KV pair from a peer func (s *SyncService) fetchSingleKVFromPeer(peerAddress, path string) (*types.StoredValue, error) { client := &http.Client{Timeout: 5 * time.Second} url := fmt.Sprintf("http://%s/kv/%s", peerAddress, path) resp, err := client.Get(url) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode == http.StatusNotFound { return nil, nil // Key might have been deleted on the peer } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("peer returned status %d for path %s", resp.StatusCode, path) } var storedValue types.StoredValue if err := json.NewDecoder(resp.Body).Decode(&storedValue); err != nil { return nil, fmt.Errorf("failed to decode types.StoredValue from peer: %v", err) } return &storedValue, nil } // getLocalData is a utility to retrieve a types.StoredValue from local DB func (s *SyncService) getLocalData(path string) (*types.StoredValue, bool) { var storedValue types.StoredValue err := s.db.View(func(txn *badger.Txn) error { item, err := txn.Get([]byte(path)) if err != nil { return err } return item.Value(func(val []byte) error { return json.Unmarshal(val, &storedValue) }) }) if err != nil { return nil, false } return &storedValue, true } // deleteKVLocally deletes a key-value pair and its associated timestamp index locally func (s *SyncService) deleteKVLocally(key string, timestamp int64) error { return s.db.Update(func(txn *badger.Txn) error { // Delete the main key if err := txn.Delete([]byte(key)); err != nil { return err } // Delete the timestamp index indexKey := fmt.Sprintf("_ts:%d:%s", timestamp, key) return txn.Delete([]byte(indexKey)) }) } // storeReplicatedDataWithMetadata stores replicated data preserving its original metadata func (s *SyncService) storeReplicatedDataWithMetadata(path string, storedValue *types.StoredValue) error { valueBytes, err := json.Marshal(storedValue) if err != nil { return err } return s.db.Update(func(txn *badger.Txn) error { // Store main data if err := txn.Set([]byte(path), valueBytes); err != nil { return err } // Store timestamp index indexKey := fmt.Sprintf("_ts:%020d:%s", storedValue.Timestamp, path) return txn.Set([]byte(indexKey), []byte(storedValue.UUID)) }) } // resolveConflict performs sophisticated conflict resolution with majority vote and oldest-node tie-breaking func (s *SyncService) resolveConflict(key string, local, remote *types.StoredValue, peerAddress string) error { s.logger.WithFields(logrus.Fields{ "key": key, "local_ts": local.Timestamp, "remote_ts": remote.Timestamp, "local_uuid": local.UUID, "remote_uuid": remote.UUID, "peer": peerAddress, }).Info("Resolving timestamp collision conflict") if remote.Timestamp > local.Timestamp { // Remote is newer, store it err := s.storeReplicatedDataWithMetadata(key, remote) if err == nil { s.logger.WithField("key", key).Info("Conflict resolved: remote data wins (newer timestamp)") } return err } else if local.Timestamp > remote.Timestamp { // Local is newer, keep local data s.logger.WithField("key", key).Info("Conflict resolved: local data wins (newer timestamp)") return nil } // Timestamps are equal - need sophisticated conflict resolution s.logger.WithField("key", key).Info("Timestamp collision detected, applying oldest-node rule") // Get cluster members to determine which node is older members := s.gossipService.GetMembers() // Find the local node and the remote node in membership var localMember, remoteMember *types.Member localNodeID := s.config.NodeID for _, member := range members { if member.ID == localNodeID { localMember = member } if member.Address == peerAddress { remoteMember = member } } // If we can't find membership info, fall back to UUID comparison for deterministic result if localMember == nil || remoteMember == nil { s.logger.WithField("key", key).Warn("Could not find membership info for conflict resolution, using UUID comparison") if remote.UUID < local.UUID { // Remote UUID lexically smaller (deterministic choice) err := s.storeReplicatedDataWithMetadata(key, remote) if err == nil { s.logger.WithField("key", key).Info("Conflict resolved: remote data wins (UUID tie-breaker)") } return err } s.logger.WithField("key", key).Info("Conflict resolved: local data wins (UUID tie-breaker)") return nil } // Apply oldest-node rule: node with earliest joined_timestamp wins if remoteMember.JoinedTimestamp < localMember.JoinedTimestamp { // Remote node is older, its data wins err := s.storeReplicatedDataWithMetadata(key, remote) if err == nil { s.logger.WithFields(logrus.Fields{ "key": key, "local_joined": localMember.JoinedTimestamp, "remote_joined": remoteMember.JoinedTimestamp, }).Info("Conflict resolved: remote data wins (oldest-node rule)") } return err } // Local node is older or equal, keep local data s.logger.WithFields(logrus.Fields{ "key": key, "local_joined": localMember.JoinedTimestamp, "remote_joined": remoteMember.JoinedTimestamp, }).Info("Conflict resolved: local data wins (oldest-node rule)") return nil } // requestMerkleDiff requests children hashes or keys for a given node/range from a peer func (s *SyncService) requestMerkleDiff(peerAddress string, req types.MerkleTreeDiffRequest) (*types.MerkleTreeDiffResponse, error) { jsonData, err := json.Marshal(req) if err != nil { return nil, err } client := &http.Client{Timeout: 10 * time.Second} url := fmt.Sprintf("http://%s/merkle_tree/diff", peerAddress) resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData)) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("peer returned status %d for Merkle diff", resp.StatusCode) } var diffResp types.MerkleTreeDiffResponse if err := json.NewDecoder(resp.Body).Decode(&diffResp); err != nil { return nil, err } return &diffResp, nil } // handleChildrenDiff processes children-level differences func (s *SyncService) handleChildrenDiff(peerAddress string, children []types.MerkleNode) { localPairs, err := s.merkleService.GetAllKVPairsForMerkleTree() if err != nil { s.logger.WithError(err).Error("Failed to get KV pairs for local children comparison") return } for _, remoteChild := range children { // Build the local Merkle node for this child's range localChildNode, err := s.merkleService.BuildMerkleTreeFromPairs(FilterPairsByRange(localPairs, remoteChild.StartKey, remoteChild.EndKey)) if err != nil { s.logger.WithError(err).WithFields(logrus.Fields{ "start_key": remoteChild.StartKey, "end_key": remoteChild.EndKey, }).Error("Failed to build local child node for diff") continue } if localChildNode == nil || !bytes.Equal(localChildNode.Hash, remoteChild.Hash) { // If local child node is nil (meaning local has no data in this range) // or hashes differ, then we need to fetch the data. if localChildNode == nil { s.logger.WithFields(logrus.Fields{ "peer": peerAddress, "start_key": remoteChild.StartKey, "end_key": remoteChild.EndKey, }).Info("Local node missing data in remote child's range, fetching full range") s.fetchAndStoreRange(peerAddress, remoteChild.StartKey, remoteChild.EndKey) } else { s.diffMerkleTreesRecursive(peerAddress, localChildNode, &remoteChild) } } } } // fetchAndStoreRange fetches a range of KV pairs from a peer and stores them locally func (s *SyncService) fetchAndStoreRange(peerAddress string, startKey, endKey string) error { req := types.KVRangeRequest{ StartKey: startKey, EndKey: endKey, Limit: 0, // No limit } jsonData, err := json.Marshal(req) if err != nil { return err } client := &http.Client{Timeout: 30 * time.Second} // Longer timeout for range fetches url := fmt.Sprintf("http://%s/kv_range", peerAddress) resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData)) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("peer returned status %d for KV range fetch", resp.StatusCode) } var rangeResp types.KVRangeResponse if err := json.NewDecoder(resp.Body).Decode(&rangeResp); err != nil { return err } for _, pair := range rangeResp.Pairs { // Use storeReplicatedDataWithMetadata to preserve original UUID/Timestamp if err := s.storeReplicatedDataWithMetadata(pair.Path, &pair.StoredValue); err != nil { s.logger.WithError(err).WithFields(logrus.Fields{ "peer": peerAddress, "path": pair.Path, }).Error("Failed to store fetched range data") } else { s.logger.WithFields(logrus.Fields{ "peer": peerAddress, "path": pair.Path, }).Debug("Stored data from fetched range") } } return nil }