package cluster import ( "bytes" "crypto/sha256" "encoding/json" "fmt" "sort" "strconv" "strings" badger "github.com/dgraph-io/badger/v4" "github.com/sirupsen/logrus" "kvs/types" ) // MerkleService handles Merkle tree operations type MerkleService struct { db *badger.DB logger *logrus.Logger } // NewMerkleService creates a new Merkle tree service func NewMerkleService(db *badger.DB, logger *logrus.Logger) *MerkleService { return &MerkleService{ db: db, logger: logger, } } // CalculateHash generates a SHA256 hash for a given byte slice func CalculateHash(data []byte) []byte { h := sha256.New() h.Write(data) return h.Sum(nil) } // CalculateLeafHash generates a hash for a leaf node based on its path, UUID, timestamp, and data func (s *MerkleService) CalculateLeafHash(path string, storedValue *types.StoredValue) []byte { // Concatenate path, UUID, timestamp, and the raw data bytes for hashing // Ensure a consistent order of fields for hashing dataToHash := bytes.Buffer{} dataToHash.WriteString(path) dataToHash.WriteByte(':') dataToHash.WriteString(storedValue.UUID) dataToHash.WriteByte(':') dataToHash.WriteString(strconv.FormatInt(storedValue.Timestamp, 10)) dataToHash.WriteByte(':') dataToHash.Write(storedValue.Data) // Use raw bytes of json.RawMessage return CalculateHash(dataToHash.Bytes()) } // GetAllKVPairsForMerkleTree retrieves all key-value pairs needed for Merkle tree construction func (s *MerkleService) GetAllKVPairsForMerkleTree() (map[string]*types.StoredValue, error) { pairs := make(map[string]*types.StoredValue) err := s.db.View(func(txn *badger.Txn) error { opts := badger.DefaultIteratorOptions opts.PrefetchValues = true // We need the values for hashing it := txn.NewIterator(opts) defer it.Close() // Iterate over all actual data keys (not _ts: indexes) for it.Rewind(); it.Valid(); it.Next() { item := it.Item() key := string(item.Key()) if strings.HasPrefix(key, "_ts:") { continue // Skip index keys } var storedValue types.StoredValue err := item.Value(func(val []byte) error { return json.Unmarshal(val, &storedValue) }) if err != nil { s.logger.WithError(err).WithField("key", key).Warn("Failed to unmarshal stored value for Merkle tree, skipping") continue } pairs[key] = &storedValue } return nil }) if err != nil { return nil, err } return pairs, nil } // BuildMerkleTreeFromPairs constructs a Merkle Tree from the KVS data // This version uses a recursive approach to build a balanced tree from sorted keys func (s *MerkleService) BuildMerkleTreeFromPairs(pairs map[string]*types.StoredValue) (*types.MerkleNode, error) { if len(pairs) == 0 { return &types.MerkleNode{Hash: CalculateHash([]byte("empty_tree")), StartKey: "", EndKey: ""}, nil } // Sort keys to ensure consistent tree structure keys := make([]string, 0, len(pairs)) for k := range pairs { keys = append(keys, k) } sort.Strings(keys) // Create leaf nodes leafNodes := make([]*types.MerkleNode, len(keys)) for i, key := range keys { storedValue := pairs[key] hash := s.CalculateLeafHash(key, storedValue) leafNodes[i] = &types.MerkleNode{Hash: hash, StartKey: key, EndKey: key} } // Recursively build parent nodes return s.buildMerkleTreeRecursive(leafNodes) } // buildMerkleTreeRecursive builds the tree from a slice of nodes func (s *MerkleService) buildMerkleTreeRecursive(nodes []*types.MerkleNode) (*types.MerkleNode, error) { if len(nodes) == 0 { return nil, nil } if len(nodes) == 1 { return nodes[0], nil } var nextLevel []*types.MerkleNode for i := 0; i < len(nodes); i += 2 { left := nodes[i] var right *types.MerkleNode if i+1 < len(nodes) { right = nodes[i+1] } var combinedHash []byte var endKey string if right != nil { combinedHash = CalculateHash(append(left.Hash, right.Hash...)) endKey = right.EndKey } else { // Odd number of nodes, promote the left node combinedHash = left.Hash endKey = left.EndKey } parentNode := &types.MerkleNode{ Hash: combinedHash, StartKey: left.StartKey, EndKey: endKey, } nextLevel = append(nextLevel, parentNode) } return s.buildMerkleTreeRecursive(nextLevel) } // FilterPairsByRange filters a map of StoredValue by key range func FilterPairsByRange(allPairs map[string]*types.StoredValue, startKey, endKey string) map[string]*types.StoredValue { filtered := make(map[string]*types.StoredValue) for key, value := range allPairs { if (startKey == "" || key >= startKey) && (endKey == "" || key <= endKey) { filtered[key] = value } } return filtered } // BuildSubtreeForRange builds a Merkle subtree for a specific key range func (s *MerkleService) BuildSubtreeForRange(startKey, endKey string) (*types.MerkleNode, error) { pairs, err := s.GetAllKVPairsForMerkleTree() if err != nil { return nil, fmt.Errorf("failed to get KV pairs for subtree: %v", err) } filteredPairs := FilterPairsByRange(pairs, startKey, endKey) return s.BuildMerkleTreeFromPairs(filteredPairs) }