Files
kalzu-value-store/main.go

2017 lines
57 KiB
Go

package main
import (
"bytes"
"context"
"crypto/sha256" // Added for Merkle Tree hashing
"encoding/hex" // Added for logging hashes
"encoding/json"
"fmt"
"math/rand"
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"sort" // Added for sorting keys in Merkle Tree
"strconv"
"strings"
"sync"
"syscall"
"time"
badger "github.com/dgraph-io/badger/v4"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
)
// Core data structures
type StoredValue struct {
UUID string `json:"uuid"`
Timestamp int64 `json:"timestamp"`
Data json.RawMessage `json:"data"`
}
type Member struct {
ID string `json:"id"`
Address string `json:"address"`
LastSeen int64 `json:"last_seen"`
JoinedTimestamp int64 `json:"joined_timestamp"`
}
type JoinRequest struct {
ID string `json:"id"`
Address string `json:"address"`
JoinedTimestamp int64 `json:"joined_timestamp"`
}
type LeaveRequest struct {
ID string `json:"id"`
}
type PairsByTimeRequest struct {
StartTimestamp int64 `json:"start_timestamp"`
EndTimestamp int64 `json:"end_timestamp"`
Limit int `json:"limit"`
Prefix string `json:"prefix,omitempty"`
}
type PairsByTimeResponse struct {
Path string `json:"path"`
UUID string `json:"uuid"`
Timestamp int64 `json:"timestamp"`
}
type PutResponse struct {
UUID string `json:"uuid"`
Timestamp int64 `json:"timestamp"`
}
// Merkle Tree specific data structures
type MerkleNode struct {
Hash []byte `json:"hash"`
StartKey string `json:"start_key"` // The first key in this node's range
EndKey string `json:"end_key"` // The last key in this node's range
}
// MerkleRootResponse is the response for getting the root hash
type MerkleRootResponse struct {
Root *MerkleNode `json:"root"`
}
// MerkleTreeDiffRequest is used to request children hashes for a given key range
type MerkleTreeDiffRequest struct {
ParentNode MerkleNode `json:"parent_node"` // The node whose children we want to compare (from the remote peer's perspective)
LocalHash []byte `json:"local_hash"` // The local hash of this node/range (from the requesting peer's perspective)
}
// MerkleTreeDiffResponse returns the remote children nodes or the actual keys if it's a leaf level
type MerkleTreeDiffResponse struct {
Children []MerkleNode `json:"children,omitempty"` // Children of the remote node
Keys []string `json:"keys,omitempty"` // Actual keys if this is a leaf-level diff
}
// For fetching a range of KV pairs
type KVRangeRequest struct {
StartKey string `json:"start_key"`
EndKey string `json:"end_key"`
Limit int `json:"limit"` // Max number of items to return
}
type KVRangeResponse struct {
Pairs []struct {
Path string `json:"path"`
StoredValue StoredValue `json:"stored_value"`
} `json:"pairs"`
}
// Configuration
type Config struct {
NodeID string `yaml:"node_id"`
BindAddress string `yaml:"bind_address"`
Port int `yaml:"port"`
DataDir string `yaml:"data_dir"`
SeedNodes []string `yaml:"seed_nodes"`
ReadOnly bool `yaml:"read_only"`
LogLevel string `yaml:"log_level"`
GossipIntervalMin int `yaml:"gossip_interval_min"`
GossipIntervalMax int `yaml:"gossip_interval_max"`
SyncInterval int `yaml:"sync_interval"`
CatchupInterval int `yaml:"catchup_interval"`
BootstrapMaxAgeHours int `yaml:"bootstrap_max_age_hours"`
ThrottleDelayMs int `yaml:"throttle_delay_ms"`
FetchDelayMs int `yaml:"fetch_delay_ms"`
}
// Server represents the KVS node
type Server struct {
config *Config
db *badger.DB
members map[string]*Member
membersMu sync.RWMutex
mode string // "normal", "read-only", "syncing"
modeMu sync.RWMutex
logger *logrus.Logger
httpServer *http.Server
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
merkleRoot *MerkleNode // Added for Merkle Tree
merkleRootMu sync.RWMutex // Protects merkleRoot
}
// Default configuration
func defaultConfig() *Config {
hostname, _ := os.Hostname()
return &Config{
NodeID: hostname,
BindAddress: "127.0.0.1",
Port: 8080,
DataDir: "./data",
SeedNodes: []string{},
ReadOnly: false,
LogLevel: "info",
GossipIntervalMin: 60, // 1 minute
GossipIntervalMax: 120, // 2 minutes
SyncInterval: 300, // 5 minutes
CatchupInterval: 120, // 2 minutes
BootstrapMaxAgeHours: 720, // 30 days
ThrottleDelayMs: 100,
FetchDelayMs: 50,
}
}
// Load configuration from file or create default
func loadConfig(configPath string) (*Config, error) {
config := defaultConfig()
if _, err := os.Stat(configPath); os.IsNotExist(err) {
// Create default config file
if err := os.MkdirAll(filepath.Dir(configPath), 0755); err != nil {
return nil, fmt.Errorf("failed to create config directory: %v", err)
}
data, err := yaml.Marshal(config)
if err != nil {
return nil, fmt.Errorf("failed to marshal default config: %v", err)
}
if err := os.WriteFile(configPath, data, 0644); err != nil {
return nil, fmt.Errorf("failed to write default config: %v", err)
}
fmt.Printf("Created default configuration at %s\n", configPath)
return config, nil
}
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %v", err)
}
if err := yaml.Unmarshal(data, config); err != nil {
return nil, fmt.Errorf("failed to parse config file: %v", err)
}
return config, nil
}
// Initialize server
func NewServer(config *Config) (*Server, error) {
logger := logrus.New()
logger.SetFormatter(&logrus.JSONFormatter{})
level, err := logrus.ParseLevel(config.LogLevel)
if err != nil {
level = logrus.InfoLevel
}
logger.SetLevel(level)
// Create data directory
if err := os.MkdirAll(config.DataDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create data directory: %v", err)
}
// Open BadgerDB
opts := badger.DefaultOptions(filepath.Join(config.DataDir, "badger"))
opts.Logger = nil // Disable badger's internal logging
db, err := badger.Open(opts)
if err != nil {
return nil, fmt.Errorf("failed to open BadgerDB: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
server := &Server{
config: config,
db: db,
members: make(map[string]*Member),
mode: "normal",
logger: logger,
ctx: ctx,
cancel: cancel,
}
if config.ReadOnly {
server.setMode("read-only")
}
// Build initial Merkle tree
pairs, err := server.getAllKVPairsForMerkleTree()
if err != nil {
return nil, fmt.Errorf("failed to get all KV pairs for initial Merkle tree: %v", err)
}
root, err := server.buildMerkleTreeFromPairs(pairs)
if err != nil {
return nil, fmt.Errorf("failed to build initial Merkle tree: %v", err)
}
server.setMerkleRoot(root)
server.logger.Info("Initial Merkle tree built.")
return server, nil
}
// Mode management
func (s *Server) getMode() string {
s.modeMu.RLock()
defer s.modeMu.RUnlock()
return s.mode
}
func (s *Server) setMode(mode string) {
s.modeMu.Lock()
defer s.modeMu.Unlock()
oldMode := s.mode
s.mode = mode
s.logger.WithFields(logrus.Fields{
"old_mode": oldMode,
"new_mode": mode,
}).Info("Mode changed")
}
// Member management
func (s *Server) addMember(member *Member) {
s.membersMu.Lock()
defer s.membersMu.Unlock()
s.members[member.ID] = member
s.logger.WithFields(logrus.Fields{
"node_id": member.ID,
"address": member.Address,
}).Info("Member added")
}
func (s *Server) removeMember(nodeID string) {
s.membersMu.Lock()
defer s.membersMu.Unlock()
if member, exists := s.members[nodeID]; exists {
delete(s.members, nodeID)
s.logger.WithFields(logrus.Fields{
"node_id": member.ID,
"address": member.Address,
}).Info("Member removed")
}
}
func (s *Server) getMembers() []*Member {
s.membersMu.RLock()
defer s.membersMu.RUnlock()
members := make([]*Member, 0, len(s.members))
for _, member := range s.members {
members = append(members, member)
}
return members
}
// HTTP Handlers
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
mode := s.getMode()
memberCount := len(s.getMembers())
health := map[string]interface{}{
"status": "ok",
"mode": mode,
"member_count": memberCount,
"node_id": s.config.NodeID,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(health)
}
func (s *Server) getKVHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
path := vars["path"]
var storedValue StoredValue
err := s.db.View(func(txn *badger.Txn) error {
item, err := txn.Get([]byte(path))
if err != nil {
return err
}
return item.Value(func(val []byte) error {
return json.Unmarshal(val, &storedValue)
})
})
if err == badger.ErrKeyNotFound {
http.Error(w, "Not Found", http.StatusNotFound)
return
}
if err != nil {
s.logger.WithError(err).WithField("path", path).Error("Failed to get value")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
// CHANGE: Return the entire StoredValue, not just Data
json.NewEncoder(w).Encode(storedValue)
}
func (s *Server) putKVHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
path := vars["path"]
mode := s.getMode()
if mode == "syncing" {
http.Error(w, "Service Unavailable", http.StatusServiceUnavailable)
return
}
if mode == "read-only" && !s.isClusterMember(r.RemoteAddr) {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
var data json.RawMessage
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
now := time.Now().UnixMilli()
newUUID := uuid.New().String()
storedValue := StoredValue{
UUID: newUUID,
Timestamp: now,
Data: data,
}
valueBytes, err := json.Marshal(storedValue)
if err != nil {
s.logger.WithError(err).Error("Failed to marshal stored value")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
var isUpdate bool
err = s.db.Update(func(txn *badger.Txn) error {
// Check if key exists
_, err := txn.Get([]byte(path))
isUpdate = (err == nil)
// Store main data
if err := txn.Set([]byte(path), valueBytes); err != nil {
return err
}
// Store timestamp index
indexKey := fmt.Sprintf("_ts:%020d:%s", now, path)
return txn.Set([]byte(indexKey), []byte(newUUID))
})
if err != nil {
s.logger.WithError(err).WithField("path", path).Error("Failed to store value")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
response := PutResponse{
UUID: newUUID,
Timestamp: now,
}
status := http.StatusCreated
if isUpdate {
status = http.StatusOK
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(response)
s.logger.WithFields(logrus.Fields{
"path": path,
"uuid": newUUID,
"timestamp": now,
"is_update": isUpdate,
}).Info("Value stored")
}
func (s *Server) deleteKVHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
path := vars["path"]
mode := s.getMode()
if mode == "syncing" {
http.Error(w, "Service Unavailable", http.StatusServiceUnavailable)
return
}
if mode == "read-only" && !s.isClusterMember(r.RemoteAddr) {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
var found bool
err := s.db.Update(func(txn *badger.Txn) error {
// Check if key exists and get timestamp for index cleanup
item, err := txn.Get([]byte(path))
if err == badger.ErrKeyNotFound {
return nil
}
if err != nil {
return err
}
found = true
var storedValue StoredValue
err = item.Value(func(val []byte) error {
return json.Unmarshal(val, &storedValue)
})
if err != nil {
return err
}
// Delete main data
if err := txn.Delete([]byte(path)); err != nil {
return err
}
// Delete timestamp index
indexKey := fmt.Sprintf("_ts:%020d:%s", storedValue.Timestamp, path)
return txn.Delete([]byte(indexKey))
})
if err != nil {
s.logger.WithError(err).WithField("path", path).Error("Failed to delete value")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
if !found {
http.Error(w, "Not Found", http.StatusNotFound)
return
}
w.WriteHeader(http.StatusNoContent)
s.logger.WithField("path", path).Info("Value deleted")
}
func (s *Server) getMembersHandler(w http.ResponseWriter, r *http.Request) {
members := s.getMembers()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(members)
}
func (s *Server) joinMemberHandler(w http.ResponseWriter, r *http.Request) {
var req JoinRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
now := time.Now().UnixMilli()
member := &Member{
ID: req.ID,
Address: req.Address,
LastSeen: now,
JoinedTimestamp: req.JoinedTimestamp,
}
s.addMember(member)
// Return current member list
members := s.getMembers()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(members)
}
func (s *Server) leaveMemberHandler(w http.ResponseWriter, r *http.Request) {
var req LeaveRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
s.removeMember(req.ID)
w.WriteHeader(http.StatusNoContent)
}
func (s *Server) pairsByTimeHandler(w http.ResponseWriter, r *http.Request) {
var req PairsByTimeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
// Default limit to 15 as per spec
if req.Limit <= 0 {
req.Limit = 15
}
var pairs []PairsByTimeResponse
err := s.db.View(func(txn *badger.Txn) error {
opts := badger.DefaultIteratorOptions
opts.PrefetchSize = req.Limit
it := txn.NewIterator(opts)
defer it.Close()
prefix := []byte("_ts:")
// The original logic for prefix filtering here was incomplete.
// For Merkle tree sync, this handler is no longer used for core sync.
// It remains as a client-facing API.
for it.Seek(prefix); it.ValidForPrefix(prefix) && len(pairs) < req.Limit; it.Next() {
item := it.Item()
key := string(item.Key())
// Parse timestamp index key: "_ts:{timestamp}:{path}"
parts := strings.SplitN(key, ":", 3)
if len(parts) != 3 {
continue
}
timestamp, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
continue
}
// Filter by timestamp range
if req.StartTimestamp > 0 && timestamp < req.StartTimestamp {
continue
}
if req.EndTimestamp > 0 && timestamp >= req.EndTimestamp {
continue
}
path := parts[2]
// Filter by prefix if specified
if req.Prefix != "" && !strings.HasPrefix(path, req.Prefix) {
continue
}
var uuid string
err = item.Value(func(val []byte) error {
uuid = string(val)
return nil
})
if err != nil {
continue
}
pairs = append(pairs, PairsByTimeResponse{
Path: path,
UUID: uuid,
Timestamp: timestamp,
})
}
return nil
})
if err != nil {
s.logger.WithError(err).Error("Failed to query pairs by time")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
if len(pairs) == 0 {
w.WriteHeader(http.StatusNoContent)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(pairs)
}
func (s *Server) gossipHandler(w http.ResponseWriter, r *http.Request) {
var remoteMemberList []Member
if err := json.NewDecoder(r.Body).Decode(&remoteMemberList); err != nil {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
// Merge the received member list
s.mergeMemberList(remoteMemberList)
// Respond with our current member list
localMembers := s.getMembers()
gossipResponse := make([]Member, len(localMembers))
for i, member := range localMembers {
gossipResponse[i] = *member
}
// Add ourselves to the response
selfMember := Member{
ID: s.config.NodeID,
Address: fmt.Sprintf("%s:%d", s.config.BindAddress, s.config.Port),
LastSeen: time.Now().UnixMilli(),
JoinedTimestamp: s.getJoinedTimestamp(),
}
gossipResponse = append(gossipResponse, selfMember)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(gossipResponse)
s.logger.WithField("remote_members", len(remoteMemberList)).Debug("Processed gossip request")
}
// Utility function to check if request is from cluster member
func (s *Server) isClusterMember(remoteAddr string) bool {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return false
}
s.membersMu.RLock()
defer s.membersMu.RUnlock()
for _, member := range s.members {
memberHost, _, err := net.SplitHostPort(member.Address)
if err == nil && memberHost == host {
return true
}
}
return false
}
// Setup HTTP routes
func (s *Server) setupRoutes() *mux.Router {
router := mux.NewRouter()
// Health endpoint
router.HandleFunc("/health", s.healthHandler).Methods("GET")
// KV endpoints
router.HandleFunc("/kv/{path:.+}", s.getKVHandler).Methods("GET")
router.HandleFunc("/kv/{path:.+}", s.putKVHandler).Methods("PUT")
router.HandleFunc("/kv/{path:.+}", s.deleteKVHandler).Methods("DELETE")
// Member endpoints
router.HandleFunc("/members/", s.getMembersHandler).Methods("GET")
router.HandleFunc("/members/join", s.joinMemberHandler).Methods("POST")
router.HandleFunc("/members/leave", s.leaveMemberHandler).Methods("DELETE")
router.HandleFunc("/members/gossip", s.gossipHandler).Methods("POST")
router.HandleFunc("/members/pairs_by_time", s.pairsByTimeHandler).Methods("POST") // Still available for clients
// Merkle Tree endpoints
router.HandleFunc("/merkle_tree/root", s.getMerkleRootHandler).Methods("GET")
router.HandleFunc("/merkle_tree/diff", s.getMerkleDiffHandler).Methods("POST")
router.HandleFunc("/kv_range", s.getKVRangeHandler).Methods("POST") // New endpoint for fetching ranges
return router
}
// Start the server
func (s *Server) Start() error {
router := s.setupRoutes()
addr := fmt.Sprintf("%s:%d", s.config.BindAddress, s.config.Port)
s.httpServer = &http.Server{
Addr: addr,
Handler: router,
}
s.logger.WithFields(logrus.Fields{
"node_id": s.config.NodeID,
"address": addr,
}).Info("Starting KVS server")
// Start gossip and sync routines
s.startBackgroundTasks()
// Try to join cluster if seed nodes are configured
if len(s.config.SeedNodes) > 0 {
go s.bootstrap()
}
return s.httpServer.ListenAndServe()
}
// Stop the server gracefully
func (s *Server) Stop() error {
s.logger.Info("Shutting down KVS server")
s.cancel()
s.wg.Wait()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := s.httpServer.Shutdown(ctx); err != nil {
s.logger.WithError(err).Error("HTTP server shutdown error")
}
if err := s.db.Close(); err != nil {
s.logger.WithError(err).Error("BadgerDB close error")
}
return nil
}
// Background tasks (gossip, sync, Merkle tree rebuild, etc.)
func (s *Server) startBackgroundTasks() {
// Start gossip routine
s.wg.Add(1)
go s.gossipRoutine()
// Start sync routine (now Merkle-based)
s.wg.Add(1)
go s.syncRoutine()
// Start Merkle tree rebuild routine
s.wg.Add(1)
go s.merkleTreeRebuildRoutine()
}
// Gossip routine - runs periodically to exchange member lists
func (s *Server) gossipRoutine() {
defer s.wg.Done()
for {
// Random interval between 1-2 minutes
minInterval := time.Duration(s.config.GossipIntervalMin) * time.Second
maxInterval := time.Duration(s.config.GossipIntervalMax) * time.Second
interval := minInterval + time.Duration(rand.Int63n(int64(maxInterval-minInterval)))
select {
case <-s.ctx.Done():
return
case <-time.After(interval):
s.performGossipRound()
}
}
}
// Perform a gossip round with random healthy peers
func (s *Server) performGossipRound() {
members := s.getHealthyMembers()
if len(members) == 0 {
s.logger.Debug("No healthy members for gossip round")
return
}
// Select 1-3 random peers for gossip
maxPeers := 3
if len(members) < maxPeers {
maxPeers = len(members)
}
// Shuffle and select
rand.Shuffle(len(members), func(i, j int) {
members[i], members[j] = members[j], members[i]
})
selectedPeers := members[:rand.Intn(maxPeers)+1]
for _, peer := range selectedPeers {
go s.gossipWithPeer(peer)
}
}
// Gossip with a specific peer
func (s *Server) gossipWithPeer(peer *Member) {
s.logger.WithField("peer", peer.Address).Debug("Starting gossip with peer")
// Get our current member list
localMembers := s.getMembers()
// Send our member list to the peer
gossipData := make([]Member, len(localMembers))
for i, member := range localMembers {
gossipData[i] = *member
}
// Add ourselves to the list
selfMember := Member{
ID: s.config.NodeID,
Address: fmt.Sprintf("%s:%d", s.config.BindAddress, s.config.Port),
LastSeen: time.Now().UnixMilli(),
JoinedTimestamp: s.getJoinedTimestamp(),
}
gossipData = append(gossipData, selfMember)
jsonData, err := json.Marshal(gossipData)
if err != nil {
s.logger.WithError(err).Error("Failed to marshal gossip data")
return
}
// Send HTTP request to peer
client := &http.Client{Timeout: 5 * time.Second}
url := fmt.Sprintf("http://%s/members/gossip", peer.Address)
resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
s.logger.WithFields(logrus.Fields{
"peer": peer.Address,
"error": err.Error(),
}).Warn("Failed to gossip with peer")
s.markPeerUnhealthy(peer.ID)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
s.logger.WithFields(logrus.Fields{
"peer": peer.Address,
"status": resp.StatusCode,
}).Warn("Gossip request failed")
s.markPeerUnhealthy(peer.ID)
return
}
// Process response - peer's member list
var remoteMemberList []Member
if err := json.NewDecoder(resp.Body).Decode(&remoteMemberList); err != nil {
s.logger.WithError(err).Error("Failed to decode gossip response")
return
}
// Merge remote member list with our local list
s.mergeMemberList(remoteMemberList)
// Update peer's last seen timestamp
s.updateMemberLastSeen(peer.ID, time.Now().UnixMilli())
s.logger.WithField("peer", peer.Address).Debug("Completed gossip with peer")
}
// Get healthy members (exclude those marked as down)
func (s *Server) getHealthyMembers() []*Member {
s.membersMu.RLock()
defer s.membersMu.RUnlock()
now := time.Now().UnixMilli()
healthyMembers := make([]*Member, 0)
for _, member := range s.members {
// Consider member healthy if last seen within last 5 minutes
if now-member.LastSeen < 5*60*1000 {
healthyMembers = append(healthyMembers, member)
}
}
return healthyMembers
}
// Mark a peer as unhealthy
func (s *Server) markPeerUnhealthy(nodeID string) {
s.membersMu.Lock()
defer s.membersMu.Unlock()
if member, exists := s.members[nodeID]; exists {
// Mark as last seen a long time ago to indicate unhealthy
member.LastSeen = time.Now().UnixMilli() - 10*60*1000 // 10 minutes ago
s.logger.WithField("node_id", nodeID).Warn("Marked peer as unhealthy")
}
}
// Update member's last seen timestamp
func (s *Server) updateMemberLastSeen(nodeID string, timestamp int64) {
s.membersMu.Lock()
defer s.membersMu.Unlock()
if member, exists := s.members[nodeID]; exists {
member.LastSeen = timestamp
}
}
// Merge remote member list with local member list
func (s *Server) mergeMemberList(remoteMembers []Member) {
s.membersMu.Lock()
defer s.membersMu.Unlock()
now := time.Now().UnixMilli()
for _, remoteMember := range remoteMembers {
// Skip ourselves
if remoteMember.ID == s.config.NodeID {
continue
}
if localMember, exists := s.members[remoteMember.ID]; exists {
// Update existing member
if remoteMember.LastSeen > localMember.LastSeen {
localMember.LastSeen = remoteMember.LastSeen
}
// Keep the earlier joined timestamp
if remoteMember.JoinedTimestamp < localMember.JoinedTimestamp {
localMember.JoinedTimestamp = remoteMember.JoinedTimestamp
}
} else {
// Add new member
newMember := &Member{
ID: remoteMember.ID,
Address: remoteMember.Address,
LastSeen: remoteMember.LastSeen,
JoinedTimestamp: remoteMember.JoinedTimestamp,
}
s.members[remoteMember.ID] = newMember
s.logger.WithFields(logrus.Fields{
"node_id": remoteMember.ID,
"address": remoteMember.Address,
}).Info("Discovered new member through gossip")
}
}
// Clean up old members (not seen for more than 10 minutes)
toRemove := make([]string, 0)
for nodeID, member := range s.members {
if now-member.LastSeen > 10*60*1000 { // 10 minutes
toRemove = append(toRemove, nodeID)
}
}
for _, nodeID := range toRemove {
delete(s.members, nodeID)
s.logger.WithField("node_id", nodeID).Info("Removed stale member")
}
}
// Get this node's joined timestamp (startup time)
func (s *Server) getJoinedTimestamp() int64 {
// For now, use a simple approach - this should be stored persistently
return time.Now().UnixMilli()
}
// Sync routine - handles regular and catch-up syncing
func (s *Server) syncRoutine() {
defer s.wg.Done()
syncTicker := time.NewTicker(time.Duration(s.config.SyncInterval) * time.Second)
defer syncTicker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-syncTicker.C:
s.performMerkleSync() // Use Merkle sync instead of regular sync
}
}
}
// Merkle Tree Implementation
// calculateHash generates a SHA256 hash for a given byte slice
func calculateHash(data []byte) []byte {
h := sha256.New()
h.Write(data)
return h.Sum(nil)
}
// calculateLeafHash generates a hash for a leaf node based on its path, UUID, timestamp, and data
func (s *Server) calculateLeafHash(path string, storedValue *StoredValue) []byte {
// Concatenate path, UUID, timestamp, and the raw data bytes for hashing
// Ensure a consistent order of fields for hashing
dataToHash := bytes.Buffer{}
dataToHash.WriteString(path)
dataToHash.WriteByte(':')
dataToHash.WriteString(storedValue.UUID)
dataToHash.WriteByte(':')
dataToHash.WriteString(strconv.FormatInt(storedValue.Timestamp, 10))
dataToHash.WriteByte(':')
dataToHash.Write(storedValue.Data) // Use raw bytes of json.RawMessage
return calculateHash(dataToHash.Bytes())
}
// getAllKVPairsForMerkleTree retrieves all key-value pairs needed for Merkle tree construction.
func (s *Server) getAllKVPairsForMerkleTree() (map[string]*StoredValue, error) {
pairs := make(map[string]*StoredValue)
err := s.db.View(func(txn *badger.Txn) error {
opts := badger.DefaultIteratorOptions
opts.PrefetchValues = true // We need the values for hashing
it := txn.NewIterator(opts)
defer it.Close()
// Iterate over all actual data keys (not _ts: indexes)
for it.Rewind(); it.Valid(); it.Next() {
item := it.Item()
key := string(item.Key())
if strings.HasPrefix(key, "_ts:") {
continue // Skip index keys
}
var storedValue StoredValue
err := item.Value(func(val []byte) error {
return json.Unmarshal(val, &storedValue)
})
if err != nil {
s.logger.WithError(err).WithField("key", key).Warn("Failed to unmarshal stored value for Merkle tree, skipping")
continue
}
pairs[key] = &storedValue
}
return nil
})
if err != nil {
return nil, err
}
return pairs, nil
}
// buildMerkleTreeFromPairs constructs a Merkle Tree from the KVS data.
// This version uses a recursive approach to build a balanced tree from sorted keys.
func (s *Server) buildMerkleTreeFromPairs(pairs map[string]*StoredValue) (*MerkleNode, error) {
if len(pairs) == 0 {
return &MerkleNode{Hash: calculateHash([]byte("empty_tree")), StartKey: "", EndKey: ""}, nil
}
// Sort keys to ensure consistent tree structure
keys := make([]string, 0, len(pairs))
for k := range pairs {
keys = append(keys, k)
}
sort.Strings(keys)
// Create leaf nodes
leafNodes := make([]*MerkleNode, len(keys))
for i, key := range keys {
storedValue := pairs[key]
hash := s.calculateLeafHash(key, storedValue)
leafNodes[i] = &MerkleNode{Hash: hash, StartKey: key, EndKey: key}
}
// Recursively build parent nodes
return s.buildMerkleTreeRecursive(leafNodes)
}
// buildMerkleTreeRecursive builds the tree from a slice of nodes.
func (s *Server) buildMerkleTreeRecursive(nodes []*MerkleNode) (*MerkleNode, error) {
if len(nodes) == 0 {
return nil, nil
}
if len(nodes) == 1 {
return nodes[0], nil
}
var nextLevel []*MerkleNode
for i := 0; i < len(nodes); i += 2 {
left := nodes[i]
var right *MerkleNode
if i+1 < len(nodes) {
right = nodes[i+1]
}
var combinedHash []byte
var endKey string
if right != nil {
combinedHash = calculateHash(append(left.Hash, right.Hash...))
endKey = right.EndKey
} else {
// Odd number of nodes, promote the left node
combinedHash = left.Hash
endKey = left.EndKey
}
parentNode := &MerkleNode{
Hash: combinedHash,
StartKey: left.StartKey,
EndKey: endKey,
}
nextLevel = append(nextLevel, parentNode)
}
return s.buildMerkleTreeRecursive(nextLevel)
}
// getMerkleRoot returns the current Merkle root of the server.
func (s *Server) getMerkleRoot() *MerkleNode {
s.merkleRootMu.RLock()
defer s.merkleRootMu.RUnlock()
return s.merkleRoot
}
// setMerkleRoot sets the current Merkle root of the server.
func (s *Server) setMerkleRoot(root *MerkleNode) {
s.merkleRootMu.Lock()
defer s.merkleRootMu.Unlock()
s.merkleRoot = root
}
// merkleTreeRebuildRoutine periodically rebuilds the Merkle tree.
func (s *Server) merkleTreeRebuildRoutine() {
defer s.wg.Done()
ticker := time.NewTicker(time.Duration(s.config.SyncInterval) * time.Second) // Use sync interval for now
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
s.logger.Debug("Rebuilding Merkle tree...")
pairs, err := s.getAllKVPairsForMerkleTree()
if err != nil {
s.logger.WithError(err).Error("Failed to get KV pairs for Merkle tree rebuild")
continue
}
newRoot, err := s.buildMerkleTreeFromPairs(pairs)
if err != nil {
s.logger.WithError(err).Error("Failed to rebuild Merkle tree")
continue
}
s.setMerkleRoot(newRoot)
s.logger.Debug("Merkle tree rebuilt.")
}
}
}
// getMerkleRootHandler returns the root hash of the local Merkle Tree
func (s *Server) getMerkleRootHandler(w http.ResponseWriter, r *http.Request) {
root := s.getMerkleRoot()
if root == nil {
http.Error(w, "Merkle tree not initialized", http.StatusInternalServerError)
return
}
resp := MerkleRootResponse{
Root: root,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
// getMerkleDiffHandler is used by a peer to request children hashes for a given node/range.
func (s *Server) getMerkleDiffHandler(w http.ResponseWriter, r *http.Request) {
var req MerkleTreeDiffRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
localPairs, err := s.getAllKVPairsForMerkleTree()
if err != nil {
s.logger.WithError(err).Error("Failed to get KV pairs for Merkle diff")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Build the local MerkleNode for the requested range to compare with the remote's hash
localSubTreeRoot, err := s.buildMerkleTreeFromPairs(s.filterPairsByRange(localPairs, req.ParentNode.StartKey, req.ParentNode.EndKey))
if err != nil {
s.logger.WithError(err).Error("Failed to build sub-Merkle tree for diff request")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
if localSubTreeRoot == nil { // This can happen if the range is empty locally
localSubTreeRoot = &MerkleNode{Hash: calculateHash([]byte("empty_tree")), StartKey: req.ParentNode.StartKey, EndKey: req.ParentNode.EndKey}
}
resp := MerkleTreeDiffResponse{}
// If hashes match, no need to send children or keys
if bytes.Equal(req.LocalHash, localSubTreeRoot.Hash) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
return
}
// Hashes differ, so we need to provide more detail.
// Get all keys within the parent node's range locally
var keysInRange []string
for key := range s.filterPairsByRange(localPairs, req.ParentNode.StartKey, req.ParentNode.EndKey) {
keysInRange = append(keysInRange, key)
}
sort.Strings(keysInRange)
const diffLeafThreshold = 10 // If a range has <= 10 keys, we consider it a leaf-level diff
if len(keysInRange) <= diffLeafThreshold {
// This is a leaf-level diff, return the actual keys in the range
resp.Keys = keysInRange
} else {
// Group keys into sub-ranges and return their MerkleNode representations
// For simplicity, let's split the range into two halves.
mid := len(keysInRange) / 2
leftKeys := keysInRange[:mid]
rightKeys := keysInRange[mid:]
if len(leftKeys) > 0 {
leftRangePairs := s.filterPairsByRange(localPairs, leftKeys[0], leftKeys[len(leftKeys)-1])
leftNode, err := s.buildMerkleTreeFromPairs(leftRangePairs)
if err != nil {
s.logger.WithError(err).Error("Failed to build left child node for diff")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
if leftNode != nil {
resp.Children = append(resp.Children, *leftNode)
}
}
if len(rightKeys) > 0 {
rightRangePairs := s.filterPairsByRange(localPairs, rightKeys[0], rightKeys[len(rightKeys)-1])
rightNode, err := s.buildMerkleTreeFromPairs(rightRangePairs)
if err != nil {
s.logger.WithError(err).Error("Failed to build right child node for diff")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
if rightNode != nil {
resp.Children = append(resp.Children, *rightNode)
}
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
// Helper to filter a map of StoredValue by key range
func (s *Server) filterPairsByRange(allPairs map[string]*StoredValue, startKey, endKey string) map[string]*StoredValue {
filtered := make(map[string]*StoredValue)
for key, value := range allPairs {
if (startKey == "" || key >= startKey) && (endKey == "" || key <= endKey) {
filtered[key] = value
}
}
return filtered
}
// getKVRangeHandler fetches a range of KV pairs
func (s *Server) getKVRangeHandler(w http.ResponseWriter, r *http.Request) {
var req KVRangeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
var pairs []struct {
Path string `json:"path"`
StoredValue StoredValue `json:"stored_value"`
}
err := s.db.View(func(txn *badger.Txn) error {
opts := badger.DefaultIteratorOptions
opts.PrefetchValues = true
it := txn.NewIterator(opts)
defer it.Close()
count := 0
// Start iteration from the requested StartKey
for it.Seek([]byte(req.StartKey)); it.Valid(); it.Next() {
item := it.Item()
key := string(item.Key())
if strings.HasPrefix(key, "_ts:") {
continue // Skip index keys
}
// Stop if we exceed the EndKey (if provided)
if req.EndKey != "" && key > req.EndKey {
break
}
// Stop if we hit the limit (if provided)
if req.Limit > 0 && count >= req.Limit {
break
}
var storedValue StoredValue
err := item.Value(func(val []byte) error {
return json.Unmarshal(val, &storedValue)
})
if err != nil {
s.logger.WithError(err).WithField("key", key).Warn("Failed to unmarshal stored value in KV range, skipping")
continue
}
pairs = append(pairs, struct {
Path string `json:"path"`
StoredValue StoredValue `json:"stored_value"`
}{Path: key, StoredValue: storedValue})
count++
}
return nil
})
if err != nil {
s.logger.WithError(err).Error("Failed to query KV range")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(KVRangeResponse{Pairs: pairs})
}
// performMerkleSync performs a synchronization round using Merkle Trees
func (s *Server) performMerkleSync() {
members := s.getHealthyMembers()
if len(members) == 0 {
s.logger.Debug("No healthy members for Merkle sync")
return
}
// Select random peer
peer := members[rand.Intn(len(members))]
s.logger.WithField("peer", peer.Address).Info("Starting Merkle tree sync")
localRoot := s.getMerkleRoot()
if localRoot == nil {
s.logger.Error("Local Merkle root is nil, cannot perform sync")
return
}
// 1. Get remote peer's Merkle root
remoteRootResp, err := s.requestMerkleRoot(peer.Address)
if err != nil {
s.logger.WithError(err).WithField("peer", peer.Address).Error("Failed to get remote Merkle root")
s.markPeerUnhealthy(peer.ID)
return
}
remoteRoot := remoteRootResp.Root
// 2. Compare roots and start recursive diffing if they differ
if !bytes.Equal(localRoot.Hash, remoteRoot.Hash) {
s.logger.WithFields(logrus.Fields{
"peer": peer.Address,
"local_root": hex.EncodeToString(localRoot.Hash),
"remote_root": hex.EncodeToString(remoteRoot.Hash),
}).Info("Merkle roots differ, starting recursive diff")
s.diffMerkleTreesRecursive(peer.Address, localRoot, remoteRoot)
} else {
s.logger.WithField("peer", peer.Address).Info("Merkle roots match, no sync needed")
}
s.logger.WithField("peer", peer.Address).Info("Completed Merkle tree sync")
}
// requestMerkleRoot requests the Merkle root from a peer
func (s *Server) requestMerkleRoot(peerAddress string) (*MerkleRootResponse, error) {
client := &http.Client{Timeout: 10 * time.Second}
url := fmt.Sprintf("http://%s/merkle_tree/root", peerAddress)
resp, err := client.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("peer returned status %d for Merkle root", resp.StatusCode)
}
var merkleRootResp MerkleRootResponse
if err := json.NewDecoder(resp.Body).Decode(&merkleRootResp); err != nil {
return nil, err
}
return &merkleRootResp, nil
}
// diffMerkleTreesRecursive recursively compares local and remote Merkle tree nodes
func (s *Server) diffMerkleTreesRecursive(peerAddress string, localNode, remoteNode *MerkleNode) {
// If hashes match, this subtree is in sync.
if bytes.Equal(localNode.Hash, remoteNode.Hash) {
return
}
// Hashes differ, need to go deeper.
// Request children from the remote peer for the current range.
req := MerkleTreeDiffRequest{
ParentNode: *remoteNode, // We are asking the remote peer about its children for this range
LocalHash: localNode.Hash, // Our hash for this range
}
remoteDiffResp, err := s.requestMerkleDiff(peerAddress, req)
if err != nil {
s.logger.WithError(err).WithFields(logrus.Fields{
"peer": peerAddress,
"start_key": localNode.StartKey,
"end_key": localNode.EndKey,
}).Error("Failed to get Merkle diff from peer")
return
}
if len(remoteDiffResp.Keys) > 0 {
// This is a leaf-level diff, we have the actual keys that are different.
// Fetch and compare each key.
s.logger.WithFields(logrus.Fields{
"peer": peerAddress,
"start_key": localNode.StartKey,
"end_key": localNode.EndKey,
"num_keys": len(remoteDiffResp.Keys),
}).Info("Found divergent keys, fetching and comparing data")
for _, key := range remoteDiffResp.Keys {
// Fetch the individual key from the peer
remoteStoredValue, err := s.fetchSingleKVFromPeer(peerAddress, key)
if err != nil {
s.logger.WithError(err).WithFields(logrus.Fields{
"peer": peerAddress,
"key": key,
}).Error("Failed to fetch single KV from peer during diff")
continue
}
localStoredValue, localExists := s.getLocalData(key)
if remoteStoredValue == nil {
// Key was deleted on remote, delete locally if exists
if localExists {
s.logger.WithField("key", key).Info("Key deleted on remote, deleting locally")
s.deleteKVLocally(key, localStoredValue.Timestamp) // Need a local delete function
}
continue
}
if !localExists {
// Local data is missing, store the remote data
if err := s.storeReplicatedDataWithMetadata(key, remoteStoredValue); err != nil {
s.logger.WithError(err).WithField("key", key).Error("Failed to store missing replicated data")
} else {
s.logger.WithField("key", key).Info("Fetched and stored missing data from peer")
}
} else if localStoredValue.Timestamp < remoteStoredValue.Timestamp {
// Remote is newer, store the remote data
if err := s.storeReplicatedDataWithMetadata(key, remoteStoredValue); err != nil {
s.logger.WithError(err).WithField("key", key).Error("Failed to store newer replicated data")
} else {
s.logger.WithField("key", key).Info("Fetched and stored newer data from peer")
}
} else if localStoredValue.Timestamp == remoteStoredValue.Timestamp && localStoredValue.UUID != remoteStoredValue.UUID {
// Timestamp collision, engage conflict resolution
remotePair := PairsByTimeResponse{ // Re-use this struct for conflict resolution
Path: key,
UUID: remoteStoredValue.UUID,
Timestamp: remoteStoredValue.Timestamp,
}
resolved, err := s.resolveConflict(key, localStoredValue, &remotePair, peerAddress)
if err != nil {
s.logger.WithError(err).WithField("path", key).Error("Failed to resolve conflict during Merkle sync")
} else if resolved {
s.logger.WithField("path", key).Info("Conflict resolved, updated local data during Merkle sync")
} else {
s.logger.WithField("path", key).Info("Conflict resolved, kept local data during Merkle sync")
}
}
// If local is newer or same timestamp and same UUID, do nothing.
}
} else if len(remoteDiffResp.Children) > 0 {
// Not a leaf level, continue recursive diff for children.
localPairs, err := s.getAllKVPairsForMerkleTree()
if err != nil {
s.logger.WithError(err).Error("Failed to get KV pairs for local children comparison")
return
}
for _, remoteChild := range remoteDiffResp.Children {
// Build the local Merkle node for this child's range
localChildNode, err := s.buildMerkleTreeFromPairs(s.filterPairsByRange(localPairs, remoteChild.StartKey, remoteChild.EndKey))
if err != nil {
s.logger.WithError(err).WithFields(logrus.Fields{
"start_key": remoteChild.StartKey,
"end_key": remoteChild.EndKey,
}).Error("Failed to build local child node for diff")
continue
}
if localChildNode == nil || !bytes.Equal(localChildNode.Hash, remoteChild.Hash) {
// If local child node is nil (meaning local has no data in this range)
// or hashes differ, then we need to fetch the data.
if localChildNode == nil {
s.logger.WithFields(logrus.Fields{
"peer": peerAddress,
"start_key": remoteChild.StartKey,
"end_key": remoteChild.EndKey,
}).Info("Local node missing data in remote child's range, fetching full range")
s.fetchAndStoreRange(peerAddress, remoteChild.StartKey, remoteChild.EndKey)
} else {
s.diffMerkleTreesRecursive(peerAddress, localChildNode, &remoteChild)
}
}
}
}
}
// requestMerkleDiff requests children hashes or keys for a given node/range from a peer
func (s *Server) requestMerkleDiff(peerAddress string, req MerkleTreeDiffRequest) (*MerkleTreeDiffResponse, error) {
jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
}
client := &http.Client{Timeout: 10 * time.Second}
url := fmt.Sprintf("http://%s/merkle_tree/diff", peerAddress)
resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("peer returned status %d for Merkle diff", resp.StatusCode)
}
var diffResp MerkleTreeDiffResponse
if err := json.NewDecoder(resp.Body).Decode(&diffResp); err != nil {
return nil, err
}
return &diffResp, nil
}
// fetchSingleKVFromPeer fetches a single KV pair from a peer
func (s *Server) fetchSingleKVFromPeer(peerAddress, path string) (*StoredValue, error) {
client := &http.Client{Timeout: 5 * time.Second}
url := fmt.Sprintf("http://%s/kv/%s", peerAddress, path)
resp, err := client.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, nil // Key might have been deleted on the peer
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("peer returned status %d for path %s", resp.StatusCode, path)
}
var storedValue StoredValue
if err := json.NewDecoder(resp.Body).Decode(&storedValue); err != nil {
return nil, fmt.Errorf("failed to decode StoredValue from peer: %v", err)
}
return &storedValue, nil
}
// storeReplicatedDataWithMetadata stores replicated data preserving its original metadata
func (s *Server) storeReplicatedDataWithMetadata(path string, storedValue *StoredValue) error {
valueBytes, err := json.Marshal(storedValue)
if err != nil {
return err
}
return s.db.Update(func(txn *badger.Txn) error {
// Store main data
if err := txn.Set([]byte(path), valueBytes); err != nil {
return err
}
// Store timestamp index
indexKey := fmt.Sprintf("_ts:%020d:%s", storedValue.Timestamp, path)
return txn.Set([]byte(indexKey), []byte(storedValue.UUID))
})
}
// deleteKVLocally deletes a key-value pair and its associated timestamp index locally.
func (s *Server) deleteKVLocally(path string, timestamp int64) error {
return s.db.Update(func(txn *badger.Txn) error {
if err := txn.Delete([]byte(path)); err != nil {
return err
}
indexKey := fmt.Sprintf("_ts:%020d:%s", timestamp, path)
return txn.Delete([]byte(indexKey))
})
}
// fetchAndStoreRange fetches a range of KV pairs from a peer and stores them locally
func (s *Server) fetchAndStoreRange(peerAddress string, startKey, endKey string) error {
req := KVRangeRequest{
StartKey: startKey,
EndKey: endKey,
Limit: 0, // No limit
}
jsonData, err := json.Marshal(req)
if err != nil {
return err
}
client := &http.Client{Timeout: 30 * time.Second} // Longer timeout for range fetches
url := fmt.Sprintf("http://%s/kv_range", peerAddress)
resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("peer returned status %d for KV range fetch", resp.StatusCode)
}
var rangeResp KVRangeResponse
if err := json.NewDecoder(resp.Body).Decode(&rangeResp); err != nil {
return err
}
for _, pair := range rangeResp.Pairs {
// Use storeReplicatedDataWithMetadata to preserve original UUID/Timestamp
if err := s.storeReplicatedDataWithMetadata(pair.Path, &pair.StoredValue); err != nil {
s.logger.WithError(err).WithFields(logrus.Fields{
"peer": peerAddress,
"path": pair.Path,
}).Error("Failed to store fetched range data")
} else {
s.logger.WithFields(logrus.Fields{
"peer": peerAddress,
"path": pair.Path,
}).Debug("Stored data from fetched range")
}
}
return nil
}
// Bootstrap - join cluster using seed nodes
func (s *Server) bootstrap() {
if len(s.config.SeedNodes) == 0 {
s.logger.Info("No seed nodes configured, running as standalone")
return
}
s.logger.Info("Starting bootstrap process")
s.setMode("syncing")
// Try to join cluster via each seed node
joined := false
for _, seedAddr := range s.config.SeedNodes {
if s.attemptJoin(seedAddr) {
joined = true
break
}
}
if !joined {
s.logger.Warn("Failed to join cluster via seed nodes, running as standalone")
s.setMode("normal")
return
}
// Wait a bit for member discovery
time.Sleep(2 * time.Second)
// Perform gradual sync (now Merkle-based)
s.performGradualSync()
// Switch to normal mode
s.setMode("normal")
s.logger.Info("Bootstrap completed, entering normal mode")
}
// Attempt to join cluster via a seed node
func (s *Server) attemptJoin(seedAddr string) bool {
joinReq := JoinRequest{
ID: s.config.NodeID,
Address: fmt.Sprintf("%s:%d", s.config.BindAddress, s.config.Port),
JoinedTimestamp: time.Now().UnixMilli(),
}
jsonData, err := json.Marshal(joinReq)
if err != nil {
s.logger.WithError(err).Error("Failed to marshal join request")
return false
}
client := &http.Client{Timeout: 10 * time.Second}
url := fmt.Sprintf("http://%s/members/join", seedAddr)
resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
s.logger.WithFields(logrus.Fields{
"seed": seedAddr,
"error": err.Error(),
}).Warn("Failed to contact seed node")
return false
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
s.logger.WithFields(logrus.Fields{
"seed": seedAddr,
"status": resp.StatusCode,
}).Warn("Seed node rejected join request")
return false
}
// Process member list response
var memberList []Member
if err := json.NewDecoder(resp.Body).Decode(&memberList); err != nil {
s.logger.WithError(err).Error("Failed to decode member list from seed")
return false
}
// Add all members to our local list
for _, member := range memberList {
if member.ID != s.config.NodeID {
s.addMember(&member)
}
}
s.logger.WithFields(logrus.Fields{
"seed": seedAddr,
"member_count": len(memberList),
}).Info("Successfully joined cluster")
return true
}
// Perform gradual sync (Merkle-based version)
func (s *Server) performGradualSync() {
s.logger.Info("Starting gradual sync (Merkle-based)")
members := s.getHealthyMembers()
if len(members) == 0 {
s.logger.Info("No healthy members for gradual sync")
return
}
// For now, just do a few rounds of Merkle sync
for i := 0; i < 3; i++ {
s.performMerkleSync() // Use Merkle sync
time.Sleep(time.Duration(s.config.ThrottleDelayMs) * time.Millisecond)
}
s.logger.Info("Gradual sync completed")
}
// Resolve conflict between local and remote data using majority vote and oldest node tie-breaker
func (s *Server) resolveConflict(path string, localData *StoredValue, remotePair *PairsByTimeResponse, peerAddress string) (bool, error) {
s.logger.WithFields(logrus.Fields{
"path": path,
"timestamp": localData.Timestamp,
"local_uuid": localData.UUID,
"remote_uuid": remotePair.UUID,
}).Info("Starting conflict resolution with majority vote")
// Get list of healthy members for voting
members := s.getHealthyMembers()
if len(members) == 0 {
// No other members to consult, use oldest node rule (local vs remote)
// We'll consider the peer as the "remote" node for comparison
return s.resolveByOldestNode(localData, remotePair, peerAddress)
}
// Query all healthy members for their version of this path
votes := make(map[string]int) // UUID -> vote count
uuidToTimestamp := make(map[string]int64)
uuidToJoinedTime := make(map[string]int64)
// Add our local vote
votes[localData.UUID] = 1
uuidToTimestamp[localData.UUID] = localData.Timestamp
uuidToJoinedTime[localData.UUID] = s.getJoinedTimestamp()
// Add the remote peer's vote
// Note: remotePair.Timestamp is used here, but for a full Merkle sync,
// we would have already fetched the full remoteStoredValue.
// For consistency, let's assume remotePair accurately reflects the remote's StoredValue metadata.
votes[remotePair.UUID]++ // Increment vote, as it's already counted implicitly by being the source of divergence
uuidToTimestamp[remotePair.UUID] = remotePair.Timestamp
// We'll need to get the peer's joined timestamp
s.membersMu.RLock()
for _, member := range s.members {
if member.Address == peerAddress {
uuidToJoinedTime[remotePair.UUID] = member.JoinedTimestamp
break
}
}
s.membersMu.RUnlock()
// Query other members
for _, member := range members {
if member.Address == peerAddress {
continue // Already handled the peer
}
memberData, err := s.fetchSingleKVFromPeer(member.Address, path) // Use the new fetchSingleKVFromPeer
if err != nil || memberData == nil {
s.logger.WithFields(logrus.Fields{
"member_address": member.Address,
"path": path,
"error": err,
}).Debug("Failed to query member for data during conflict resolution, skipping vote")
continue
}
// Only count votes for data with the same timestamp (as per original logic for collision)
if memberData.Timestamp == localData.Timestamp {
votes[memberData.UUID]++
if _, exists := uuidToTimestamp[memberData.UUID]; !exists {
uuidToTimestamp[memberData.UUID] = memberData.Timestamp
uuidToJoinedTime[memberData.UUID] = member.JoinedTimestamp
}
}
}
// Find the UUID with majority votes
maxVotes := 0
var winningUUIDs []string
for uuid, voteCount := range votes {
if voteCount > maxVotes {
maxVotes = voteCount
winningUUIDs = []string{uuid}
} else if voteCount == maxVotes {
winningUUIDs = append(winningUUIDs, uuid)
}
}
var winnerUUID string
if len(winningUUIDs) == 1 {
winnerUUID = winningUUIDs[0]
} else {
// Tie-breaker: oldest node (earliest joined timestamp)
oldestJoinedTime := int64(0)
for _, uuid := range winningUUIDs {
joinedTime := uuidToJoinedTime[uuid]
if oldestJoinedTime == 0 || (joinedTime != 0 && joinedTime < oldestJoinedTime) { // joinedTime can be 0 if not found
oldestJoinedTime = joinedTime
winnerUUID = uuid
}
}
s.logger.WithFields(logrus.Fields{
"path": path,
"tied_votes": maxVotes,
"winner_uuid": winnerUUID,
"oldest_joined": oldestJoinedTime,
}).Info("Resolved conflict using oldest node tie-breaker")
}
// If remote UUID wins, fetch and store the remote data
if winnerUUID == remotePair.UUID {
// We need the full StoredValue for the winning remote data.
// Since remotePair only has UUID/Timestamp, we must fetch the data.
winningRemoteStoredValue, err := s.fetchSingleKVFromPeer(peerAddress, path)
if err != nil || winningRemoteStoredValue == nil {
return false, fmt.Errorf("failed to fetch winning remote data for conflict resolution: %v", err)
}
err = s.storeReplicatedDataWithMetadata(path, winningRemoteStoredValue)
if err != nil {
return false, fmt.Errorf("failed to store winning data: %v", err)
}
s.logger.WithFields(logrus.Fields{
"path": path,
"winner_uuid": winnerUUID,
"winner_votes": maxVotes,
"total_nodes": len(members) + 2, // +2 for local and peer
}).Info("Conflict resolved: remote data wins")
return true, nil
}
// Local data wins, no action needed
s.logger.WithFields(logrus.Fields{
"path": path,
"winner_uuid": winnerUUID,
"winner_votes": maxVotes,
"total_nodes": len(members) + 2,
}).Info("Conflict resolved: local data wins")
return false, nil
}
// Resolve conflict using oldest node rule when no other members available
func (s *Server) resolveByOldestNode(localData *StoredValue, remotePair *PairsByTimeResponse, peerAddress string) (bool, error) {
// Find the peer's joined timestamp
peerJoinedTime := int64(0)
s.membersMu.RLock()
for _, member := range s.members {
if member.Address == peerAddress {
peerJoinedTime = member.JoinedTimestamp
break
}
}
s.membersMu.RUnlock()
localJoinedTime := s.getJoinedTimestamp()
// Oldest node wins
if peerJoinedTime > 0 && peerJoinedTime < localJoinedTime {
// Peer is older, fetch remote data
winningRemoteStoredValue, err := s.fetchSingleKVFromPeer(peerAddress, remotePair.Path)
if err != nil || winningRemoteStoredValue == nil {
return false, fmt.Errorf("failed to fetch data from older node for conflict resolution: %v", err)
}
err = s.storeReplicatedDataWithMetadata(remotePair.Path, winningRemoteStoredValue)
if err != nil {
return false, fmt.Errorf("failed to store data from older node: %v", err)
}
s.logger.WithFields(logrus.Fields{
"path": remotePair.Path,
"local_joined": localJoinedTime,
"peer_joined": peerJoinedTime,
"winner": "remote",
}).Info("Conflict resolved using oldest node rule")
return true, nil
}
// Local node is older or equal, keep local data
s.logger.WithFields(logrus.Fields{
"path": remotePair.Path,
"local_joined": localJoinedTime,
"peer_joined": peerJoinedTime,
"winner": "local",
}).Info("Conflict resolved using oldest node rule")
return false, nil
}
// getLocalData is a utility to retrieve a StoredValue from local DB.
func (s *Server) getLocalData(path string) (*StoredValue, bool) {
var storedValue StoredValue
err := s.db.View(func(txn *badger.Txn) error {
item, err := txn.Get([]byte(path))
if err != nil {
return err
}
return item.Value(func(val []byte) error {
return json.Unmarshal(val, &storedValue)
})
})
if err != nil {
return nil, false
}
return &storedValue, true
}
func main() {
configPath := "./config.yaml"
// Simple CLI argument parsing
if len(os.Args) > 1 {
configPath = os.Args[1]
}
config, err := loadConfig(configPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load configuration: %v\n", err)
os.Exit(1)
}
server, err := NewServer(config)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to create server: %v\n", err)
os.Exit(1)
}
// Handle graceful shutdown
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
server.Stop()
}()
if err := server.Start(); err != nil && err != http.ErrServerClosed {
fmt.Fprintf(os.Stderr, "Server error: %v\n", err)
os.Exit(1)
}
}