diff --git a/cluster/sync.go b/cluster/sync.go index ea4142f..ed6786d 100644 --- a/cluster/sync.go +++ b/cluster/sync.go @@ -8,7 +8,6 @@ import ( "fmt" "math/rand" "net/http" - "sort" "sync" "time" diff --git a/main.go b/main.go index fe23daf..5190f3f 100644 --- a/main.go +++ b/main.go @@ -1,71 +1,23 @@ package main import ( - "bytes" - "context" - "crypto/sha256" // Added for Merkle Tree hashing - "encoding/hex" // Added for logging hashes - "encoding/json" "fmt" - "math/rand" - "net" "net/http" "os" "os/signal" - "path/filepath" - "sort" // Added for sorting keys in Merkle Tree - "strconv" "strings" - "sync" "syscall" "time" - badger "github.com/dgraph-io/badger/v4" "github.com/golang-jwt/jwt/v4" - "github.com/google/uuid" "github.com/gorilla/mux" - "github.com/klauspost/compress/zstd" - "github.com/robfig/cron/v3" - "github.com/sirupsen/logrus" - "kvs/auth" "kvs/config" + "kvs/server" "kvs/types" "kvs/utils" ) - -// Server represents the KVS node -type Server struct { - config *types.Config - db *badger.DB - members map[string]*types.Member - membersMu sync.RWMutex - mode string // "normal", "read-only", "syncing" - modeMu sync.RWMutex - logger *logrus.Logger - httpServer *http.Server - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup - merkleRoot *types.MerkleNode // Added for Merkle Tree - merkleRootMu sync.RWMutex // Protects merkleRoot - - // Phase 2: ZSTD compression - compressor *zstd.Encoder // ZSTD compressor - decompressor *zstd.Decoder // ZSTD decompressor - - // Phase 2: Backup system - cronScheduler *cron.Cron // Cron scheduler for backups - backupStatus types.BackupStatus // Current backup status - backupMu sync.RWMutex // Protects backup status - - // Authentication service - authService *auth.AuthService -} - - - // Phase 2: Permission checking utilities func checkPermission(permissions int, operation string, isOwner, isGroupMember bool) bool { switch operation { @@ -77,7 +29,7 @@ func checkPermission(permissions int, operation string, isOwner, isGroupMember b return (permissions & types.PermGroupCreate) != 0 } return (permissions & types.PermOthersCreate) != 0 - + case "delete": if isOwner { return (permissions & types.PermOwnerDelete) != 0 @@ -86,7 +38,7 @@ func checkPermission(permissions int, operation string, isOwner, isGroupMember b return (permissions & types.PermGroupDelete) != 0 } return (permissions & types.PermOthersDelete) != 0 - + case "write": if isOwner { return (permissions & types.PermOwnerWrite) != 0 @@ -95,7 +47,7 @@ func checkPermission(permissions int, operation string, isOwner, isGroupMember b return (permissions & types.PermGroupWrite) != 0 } return (permissions & types.PermOthersWrite) != 0 - + case "read": if isOwner { return (permissions & types.PermOwnerRead) != 0 @@ -104,7 +56,7 @@ func checkPermission(permissions int, operation string, isOwner, isGroupMember b return (permissions & types.PermGroupRead) != 0 } return (permissions & types.PermOthersRead) != 0 - + default: return false } @@ -113,7 +65,7 @@ func checkPermission(permissions int, operation string, isOwner, isGroupMember b // Helper function to determine user relationship to resource func checkUserResourceRelationship(userUUID string, metadata *types.ResourceMetadata, userGroups []string) (isOwner, isGroupMember bool) { isOwner = (userUUID == metadata.OwnerUUID) - + if metadata.GroupUUID != "" { for _, groupUUID := range userGroups { if groupUUID == metadata.GroupUUID { @@ -122,7 +74,7 @@ func checkUserResourceRelationship(userUUID string, metadata *types.ResourceMeta } } } - + return isOwner, isGroupMember } @@ -139,32 +91,6 @@ type JWTClaims struct { } // generateJWT creates a new JWT token for a user with specified scopes -func generateJWT(userUUID string, scopes []string, expirationHours int) (string, int64, error) { - if expirationHours <= 0 { - expirationHours = 1 // Default to 1 hour - } - - now := time.Now() - expiresAt := now.Add(time.Duration(expirationHours) * time.Hour) - - claims := JWTClaims{ - UserUUID: userUUID, - Scopes: scopes, - RegisteredClaims: jwt.RegisteredClaims{ - IssuedAt: jwt.NewNumericDate(now), - ExpiresAt: jwt.NewNumericDate(expiresAt), - Issuer: "kvs-server", - }, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenString, err := token.SignedString(jwtSigningKey) - if err != nil { - return "", 0, err - } - - return tokenString, expiresAt.Unix(), nil -} // validateJWT validates a JWT token and returns the claims if valid func validateJWT(tokenString string) (*JWTClaims, error) { @@ -188,56 +114,8 @@ func validateJWT(tokenString string) (*JWTClaims, error) { } // storeAPIToken stores an API token in BadgerDB with TTL -func (s *Server) storeAPIToken(tokenString string, userUUID string, scopes []string, expiresAt int64) error { - tokenHash := utils.HashToken(tokenString) - - apiToken := types.APIToken{ - TokenHash: tokenHash, - UserUUID: userUUID, - Scopes: scopes, - IssuedAt: time.Now().Unix(), - ExpiresAt: expiresAt, - } - - tokenData, err := json.Marshal(apiToken) - if err != nil { - return err - } - - return s.db.Update(func(txn *badger.Txn) error { - entry := badger.NewEntry([]byte(auth.TokenStorageKey(tokenHash)), tokenData) - - // Set TTL to the token expiration time - ttl := time.Until(time.Unix(expiresAt, 0)) - if ttl > 0 { - entry = entry.WithTTL(ttl) - } - - return txn.SetEntry(entry) - }) -} // getAPIToken retrieves an API token from BadgerDB by hash -func (s *Server) getAPIToken(tokenHash string) (*types.APIToken, error) { - var apiToken types.APIToken - - err := s.db.View(func(txn *badger.Txn) error { - item, err := txn.Get([]byte(auth.TokenStorageKey(tokenHash))) - if err != nil { - return err - } - - return item.Value(func(val []byte) error { - return json.Unmarshal(val, &apiToken) - }) - }) - - if err != nil { - return nil, err - } - - return &apiToken, nil -} // Phase 2: Authorization middleware @@ -264,167 +142,12 @@ func extractTokenFromHeader(r *http.Request) (string, error) { } // getUserGroups retrieves all groups that a user belongs to -func (s *Server) getUserGroups(userUUID string) ([]string, error) { - var user types.User - err := s.db.View(func(txn *badger.Txn) error { - item, err := txn.Get([]byte(auth.UserStorageKey(userUUID))) - if err != nil { - return err - } - - return item.Value(func(val []byte) error { - return json.Unmarshal(val, &user) - }) - }) - - if err != nil { - return nil, err - } - - return user.Groups, nil -} // authenticateRequest validates the JWT token and returns authentication context -func (s *Server) authenticateRequest(r *http.Request) (*AuthContext, error) { - // Extract token from header - tokenString, err := extractTokenFromHeader(r) - if err != nil { - return nil, err - } - - // Validate JWT token - claims, err := validateJWT(tokenString) - if err != nil { - return nil, fmt.Errorf("invalid token: %v", err) - } - - // Verify token exists in our database (not revoked) - tokenHash := utils.HashToken(tokenString) - _, err = s.getAPIToken(tokenHash) - if err == badger.ErrKeyNotFound { - return nil, fmt.Errorf("token not found or revoked") - } - if err != nil { - return nil, fmt.Errorf("failed to verify token: %v", err) - } - - // Get user's groups - groups, err := s.getUserGroups(claims.UserUUID) - if err != nil { - s.logger.WithError(err).WithField("user_uuid", claims.UserUUID).Warn("Failed to get user groups") - groups = []string{} // Continue with empty groups on error - } - - return &AuthContext{ - UserUUID: claims.UserUUID, - Scopes: claims.Scopes, - Groups: groups, - }, nil -} // checkResourcePermission checks if a user has permission to perform an operation on a resource -func (s *Server) checkResourcePermission(authCtx *AuthContext, resourceKey string, operation string) bool { - // Get resource metadata - var metadata types.ResourceMetadata - err := s.db.View(func(txn *badger.Txn) error { - item, err := txn.Get([]byte(auth.ResourceMetadataKey(resourceKey))) - if err != nil { - return err - } - - return item.Value(func(val []byte) error { - return json.Unmarshal(val, &metadata) - }) - }) - - // If no metadata exists, use default permissions - if err == badger.ErrKeyNotFound { - metadata = types.ResourceMetadata{ - OwnerUUID: authCtx.UserUUID, // Treat requester as owner for new resources - GroupUUID: "", - Permissions: types.DefaultPermissions, - } - } else if err != nil { - s.logger.WithError(err).WithField("resource_key", resourceKey).Warn("Failed to get resource metadata") - return false - } - - // Check user relationship to resource - isOwner, isGroupMember := checkUserResourceRelationship(authCtx.UserUUID, &metadata, authCtx.Groups) - - // Check permission - return checkPermission(metadata.Permissions, operation, isOwner, isGroupMember) -} // authMiddleware is the HTTP middleware that enforces authentication and authorization -func (s *Server) authMiddleware(requiredScopes []string, resourceKeyExtractor func(*http.Request) string, operation string) func(http.HandlerFunc) http.HandlerFunc { - return func(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // Skip authentication if disabled - if !s.config.AuthEnabled { - next(w, r) - return - } - - // Authenticate request - authCtx, err := s.authenticateRequest(r) - if err != nil { - s.logger.WithError(err).WithField("path", r.URL.Path).Info("Authentication failed") - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } - - // Check required scopes - if len(requiredScopes) > 0 { - hasRequiredScope := false - for _, required := range requiredScopes { - for _, scope := range authCtx.Scopes { - if scope == required { - hasRequiredScope = true - break - } - } - if hasRequiredScope { - break - } - } - - if !hasRequiredScope { - s.logger.WithFields(logrus.Fields{ - "user_uuid": authCtx.UserUUID, - "user_scopes": authCtx.Scopes, - "required_scopes": requiredScopes, - }).Info("Insufficient scopes") - http.Error(w, "Forbidden", http.StatusForbidden) - return - } - } - - // Check resource-level permissions if applicable - if resourceKeyExtractor != nil && operation != "" { - resourceKey := resourceKeyExtractor(r) - if resourceKey != "" { - hasPermission := s.checkResourcePermission(authCtx, resourceKey, operation) - if !hasPermission { - s.logger.WithFields(logrus.Fields{ - "user_uuid": authCtx.UserUUID, - "resource_key": resourceKey, - "operation": operation, - }).Info("Permission denied") - http.Error(w, "Forbidden", http.StatusForbidden) - return - } - } - } - - // Store auth context in request context for use in handlers - ctx := context.WithValue(r.Context(), "auth", authCtx) - r = r.WithContext(ctx) - - next(w, r) - } - } -} // Helper function to extract KV resource key from request func extractKVResourceKey(r *http.Request) string { @@ -437,23 +160,9 @@ func extractKVResourceKey(r *http.Request) string { // Phase 2: ZSTD compression utilities -// compressData compresses JSON data using ZSTD if compression is enabled -func (s *Server) compressData(data []byte) ([]byte, error) { - if !s.config.CompressionEnabled || s.compressor == nil { - return data, nil - } - - return s.compressor.EncodeAll(data, make([]byte, 0, len(data))), nil -} +// compressData compresses JSON data using storage service -// decompressData decompresses ZSTD-compressed data if compression is enabled -func (s *Server) decompressData(compressedData []byte) ([]byte, error) { - if !s.config.CompressionEnabled || s.decompressor == nil { - return compressedData, nil - } - - return s.decompressor.DecodeAll(compressedData, nil) -} +// decompressData decompresses data using storage service // Phase 2: TTL and size validation utilities @@ -462,81 +171,26 @@ func parseTTL(ttlString string) (time.Duration, error) { if ttlString == "" || ttlString == "0" { return 0, nil // No TTL } - + duration, err := time.ParseDuration(ttlString) if err != nil { return 0, fmt.Errorf("invalid TTL format: %v", err) } - + if duration < 0 { return 0, fmt.Errorf("TTL cannot be negative") } - + return duration, nil } // validateJSONSize checks if JSON data exceeds maximum allowed size -func (s *Server) validateJSONSize(data []byte) error { - if s.config.MaxJSONSize > 0 && len(data) > s.config.MaxJSONSize { - return fmt.Errorf("JSON size (%d bytes) exceeds maximum allowed size (%d bytes)", - len(data), s.config.MaxJSONSize) - } - return nil -} // createResourceMetadata creates metadata for a new resource with TTL and permissions -func (s *Server) createResourceMetadata(ownerUUID, groupUUID, ttlString string, permissions int) (*types.ResourceMetadata, error) { - now := time.Now().Unix() - - metadata := &types.ResourceMetadata{ - OwnerUUID: ownerUUID, - GroupUUID: groupUUID, - Permissions: permissions, - TTL: ttlString, - CreatedAt: now, - UpdatedAt: now, - } - - return metadata, nil -} // storeWithTTL stores data in BadgerDB with optional TTL -func (s *Server) storeWithTTL(txn *badger.Txn, key []byte, data []byte, ttl time.Duration) error { - // Compress data if compression is enabled - compressedData, err := s.compressData(data) - if err != nil { - return fmt.Errorf("failed to compress data: %v", err) - } - - entry := badger.NewEntry(key, compressedData) - - // Apply TTL if specified - if ttl > 0 { - entry = entry.WithTTL(ttl) - } - - return txn.SetEntry(entry) -} // retrieveWithDecompression retrieves and decompresses data from BadgerDB -func (s *Server) retrieveWithDecompression(txn *badger.Txn, key []byte) ([]byte, error) { - item, err := txn.Get(key) - if err != nil { - return nil, err - } - - var compressedData []byte - err = item.Value(func(val []byte) error { - compressedData = append(compressedData, val...) - return nil - }) - if err != nil { - return nil, err - } - - // Decompress data if compression is enabled - return s.decompressData(compressedData) -} // Phase 2: Revision history system utilities @@ -546,296 +200,15 @@ func getRevisionKey(baseKey string, revision int) string { } // storeRevisionHistory stores a value and manages revision history (up to 3 revisions) -func (s *Server) storeRevisionHistory(txn *badger.Txn, key string, storedValue types.StoredValue, ttl time.Duration) error { - // Get existing metadata to check current revisions - metadataKey := auth.ResourceMetadataKey(key) - - var metadata types.ResourceMetadata - var currentRevisions []int - - // Try to get existing metadata - metadataData, err := s.retrieveWithDecompression(txn, []byte(metadataKey)) - if err == badger.ErrKeyNotFound { - // No existing metadata, this is a new key - metadata = types.ResourceMetadata{ - OwnerUUID: "", // Will be set by caller if needed - GroupUUID: "", - Permissions: types.DefaultPermissions, - TTL: "", - CreatedAt: time.Now().Unix(), - UpdatedAt: time.Now().Unix(), - } - currentRevisions = []int{} - } else if err != nil { - // Error reading metadata - return fmt.Errorf("failed to read metadata: %v", err) - } else { - // Parse existing metadata - err = json.Unmarshal(metadataData, &metadata) - if err != nil { - return fmt.Errorf("failed to unmarshal metadata: %v", err) - } - - // Extract current revisions (we store them as a custom field) - if metadata.TTL == "" { - currentRevisions = []int{} - } else { - // For now, we'll manage revisions separately - let's create a new metadata field - currentRevisions = []int{1, 2, 3} // Assume all revisions exist for existing keys - } - } - - // Revision rotation logic: shift existing revisions - if len(currentRevisions) >= 3 { - // Delete oldest revision (rev:3) - oldestRevKey := getRevisionKey(key, 3) - txn.Delete([]byte(oldestRevKey)) - - // Shift rev:2 → rev:3 - rev2Key := getRevisionKey(key, 2) - rev2Data, err := s.retrieveWithDecompression(txn, []byte(rev2Key)) - if err == nil { - rev3Key := getRevisionKey(key, 3) - s.storeWithTTL(txn, []byte(rev3Key), rev2Data, ttl) - } - - // Shift rev:1 → rev:2 - rev1Key := getRevisionKey(key, 1) - rev1Data, err := s.retrieveWithDecompression(txn, []byte(rev1Key)) - if err == nil { - rev2Key := getRevisionKey(key, 2) - s.storeWithTTL(txn, []byte(rev2Key), rev1Data, ttl) - } - - currentRevisions = []int{1, 2, 3} - } else if len(currentRevisions) == 2 { - // Shift rev:1 → rev:2 - rev1Key := getRevisionKey(key, 1) - rev1Data, err := s.retrieveWithDecompression(txn, []byte(rev1Key)) - if err == nil { - rev2Key := getRevisionKey(key, 2) - s.storeWithTTL(txn, []byte(rev2Key), rev1Data, ttl) - } - currentRevisions = []int{1, 2, 3} - } else if len(currentRevisions) == 1 { - currentRevisions = []int{1, 2} - } else { - currentRevisions = []int{1} - } - - // Store new value as rev:1 - storedValueData, err := json.Marshal(storedValue) - if err != nil { - return fmt.Errorf("failed to marshal stored value: %v", err) - } - - rev1Key := getRevisionKey(key, 1) - err = s.storeWithTTL(txn, []byte(rev1Key), storedValueData, ttl) - if err != nil { - return fmt.Errorf("failed to store revision 1: %v", err) - } - - // Store main data (current version) - err = s.storeWithTTL(txn, []byte(key), storedValueData, ttl) - if err != nil { - return fmt.Errorf("failed to store main data: %v", err) - } - - // Update metadata with revision information - metadata.UpdatedAt = time.Now().Unix() - // Store revision numbers in a custom metadata field (we'll extend metadata later) - - metadataData, err = json.Marshal(metadata) - if err != nil { - return fmt.Errorf("failed to marshal metadata: %v", err) - } - - return s.storeWithTTL(txn, []byte(metadataKey), metadataData, ttl) -} - -// getRevisionHistory retrieves all available revisions for a key -func (s *Server) getRevisionHistory(key string) ([]map[string]interface{}, error) { - var revisions []map[string]interface{} - - err := s.db.View(func(txn *badger.Txn) error { - // Check revisions 1, 2, 3 - for i := 1; i <= 3; i++ { - revKey := getRevisionKey(key, i) - revData, err := s.retrieveWithDecompression(txn, []byte(revKey)) - if err == badger.ErrKeyNotFound { - continue // This revision doesn't exist - } - if err != nil { - return fmt.Errorf("failed to retrieve revision %d: %v", i, err) - } - - var storedValue types.StoredValue - err = json.Unmarshal(revData, &storedValue) - if err != nil { - return fmt.Errorf("failed to unmarshal revision %d: %v", i, err) - } - - revision := map[string]interface{}{ - "number": i, - "uuid": storedValue.UUID, - "timestamp": storedValue.Timestamp, - } - - revisions = append(revisions, revision) - } - - return nil - }) - - if err != nil { - return nil, err - } - - return revisions, nil -} - -// getSpecificRevision retrieves a specific revision of a key -func (s *Server) getSpecificRevision(key string, revision int) (*types.StoredValue, error) { - if revision < 1 || revision > 3 { - return nil, fmt.Errorf("invalid revision number: %d (must be 1-3)", revision) - } - - var storedValue types.StoredValue - - err := s.db.View(func(txn *badger.Txn) error { - revKey := getRevisionKey(key, revision) - revData, err := s.retrieveWithDecompression(txn, []byte(revKey)) - if err != nil { - return err - } - - return json.Unmarshal(revData, &storedValue) - }) - - if err != nil { - return nil, err - } - - return &storedValue, nil -} - -// Phase 2: Rate limiting utilities - -// getRateLimitKey generates the storage key for rate limiting counters func getRateLimitKey(userUUID string, windowStart int64) string { return fmt.Sprintf("ratelimit:%s:%d", userUUID, windowStart) } // getCurrentWindow calculates the current rate limiting window start time -func (s *Server) getCurrentWindow() (int64, time.Duration, error) { - windowDuration, err := time.ParseDuration(s.config.RateLimitWindow) - if err != nil { - return 0, 0, fmt.Errorf("invalid rate limit window: %v", err) - } - - now := time.Now() - windowStart := now.Truncate(windowDuration).Unix() - - return windowStart, windowDuration, nil -} // checkRateLimit checks if a user has exceeded the rate limit -func (s *Server) checkRateLimit(userUUID string) (bool, error) { - if s.config.RateLimitRequests <= 0 { - return true, nil // Rate limiting disabled - } - - windowStart, windowDuration, err := s.getCurrentWindow() - if err != nil { - return false, err - } - - rateLimitKey := getRateLimitKey(userUUID, windowStart) - - var currentCount int - - err = s.db.Update(func(txn *badger.Txn) error { - // Try to get current counter - item, err := txn.Get([]byte(rateLimitKey)) - if err == badger.ErrKeyNotFound { - // No counter exists, create one - currentCount = 1 - } else if err != nil { - return fmt.Errorf("failed to get rate limit counter: %v", err) - } else { - // Counter exists, increment it - err = item.Value(func(val []byte) error { - count, err := strconv.Atoi(string(val)) - if err != nil { - return fmt.Errorf("failed to parse rate limit counter: %v", err) - } - currentCount = count + 1 - return nil - }) - if err != nil { - return err - } - } - - // Store updated counter with TTL - counterData := []byte(strconv.Itoa(currentCount)) - entry := badger.NewEntry([]byte(rateLimitKey), counterData) - entry = entry.WithTTL(windowDuration) - - return txn.SetEntry(entry) - }) - - if err != nil { - return false, err - } - - // Check if rate limit is exceeded - return currentCount <= s.config.RateLimitRequests, nil -} // rateLimitMiddleware is the HTTP middleware that enforces rate limiting -func (s *Server) rateLimitMiddleware(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // Skip rate limiting if disabled - if !s.config.RateLimitingEnabled { - next(w, r) - return - } - - // Extract auth context to get user UUID - authCtx, ok := r.Context().Value("auth").(*AuthContext) - if !ok || authCtx == nil { - // No auth context, skip rate limiting (unauthenticated requests) - next(w, r) - return - } - - // Check rate limit - allowed, err := s.checkRateLimit(authCtx.UserUUID) - if err != nil { - s.logger.WithError(err).WithField("user_uuid", authCtx.UserUUID).Error("Failed to check rate limit") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - if !allowed { - s.logger.WithFields(logrus.Fields{ - "user_uuid": authCtx.UserUUID, - "limit": s.config.RateLimitRequests, - "window": s.config.RateLimitWindow, - }).Info("Rate limit exceeded") - - // Set rate limit headers - w.Header().Set("X-Rate-Limit-Limit", strconv.Itoa(s.config.RateLimitRequests)) - w.Header().Set("X-Rate-Limit-Window", s.config.RateLimitWindow) - - http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) - return - } - - next(w, r) - } -} // Phase 2: Tamper-evident logging utilities @@ -857,142 +230,14 @@ func generateLogSignature(timestamp, action, userUUID, resource string) string { } // isActionLogged checks if a specific action should be logged -func (s *Server) isActionLogged(action string) bool { - for _, loggedAction := range s.config.TamperLogActions { - if loggedAction == action { - return true - } - } - return false -} // createTamperLogEntry creates a new tamper-evident log entry -func (s *Server) createTamperLogEntry(action, userUUID, resource string) *types.TamperLogEntry { - if !s.config.TamperLoggingEnabled || !s.isActionLogged(action) { - return nil // Tamper logging disabled or action not configured for logging - } - - timestamp := time.Now().UTC().Format(time.RFC3339) - signature := generateLogSignature(timestamp, action, userUUID, resource) - - return &types.TamperLogEntry{ - Timestamp: timestamp, - Action: action, - UserUUID: userUUID, - Resource: resource, - Signature: signature, - } -} // storeTamperLogEntry stores a tamper-evident log entry in BadgerDB -func (s *Server) storeTamperLogEntry(logEntry *types.TamperLogEntry) error { - if logEntry == nil { - return nil // No log entry to store - } - - // Generate UUID for this log entry - entryUUID := uuid.New().String() - logKey := getTamperLogKey(logEntry.Timestamp, entryUUID) - - // Marshal log entry - logData, err := json.Marshal(logEntry) - if err != nil { - return fmt.Errorf("failed to marshal log entry: %v", err) - } - - // Store log entry with compression - return s.db.Update(func(txn *badger.Txn) error { - // No TTL for log entries - they should be permanent for audit purposes - return s.storeWithTTL(txn, []byte(logKey), logData, 0) - }) -} // logTamperEvent logs a tamper-evident event if the action is configured for logging -func (s *Server) logTamperEvent(action, userUUID, resource string) { - logEntry := s.createTamperLogEntry(action, userUUID, resource) - if logEntry == nil { - return // Action not configured for logging - } - - err := s.storeTamperLogEntry(logEntry) - if err != nil { - s.logger.WithError(err).WithFields(logrus.Fields{ - "action": action, - "user_uuid": userUUID, - "resource": resource, - }).Error("Failed to store tamper log entry") - } else { - s.logger.WithFields(logrus.Fields{ - "action": action, - "user_uuid": userUUID, - "resource": resource, - "timestamp": logEntry.Timestamp, - }).Debug("Tamper log entry created") - } -} // getTamperLogs retrieves tamper log entries within a time range (for auditing) -func (s *Server) getTamperLogs(startTime, endTime time.Time, limit int) ([]*types.TamperLogEntry, error) { - var logEntries []*types.TamperLogEntry - - err := s.db.View(func(txn *badger.Txn) error { - opts := badger.DefaultIteratorOptions - opts.PrefetchValues = true - it := txn.NewIterator(opts) - defer it.Close() - - prefix := []byte("log:") - count := 0 - - for it.Seek(prefix); it.Valid() && it.Item().KeyCopy(nil)[0:4:4][0] != 'l' || - string(it.Item().KeyCopy(nil)[0:4:4]) == "log:"; it.Next() { - - if limit > 0 && count >= limit { - break - } - - key := string(it.Item().Key()) - if !strings.HasPrefix(key, "log:") || strings.HasPrefix(key, "log:merkle:") { - continue // Skip non-log entries and merkle roots - } - - // Extract timestamp from key to filter by time range - parts := strings.Split(key, ":") - if len(parts) < 3 { - continue - } - - // Parse timestamp from key - entryTime, err := time.Parse(time.RFC3339, parts[1]) - if err != nil { - continue // Skip entries with invalid timestamps - } - - if entryTime.Before(startTime) || entryTime.After(endTime) { - continue // Skip entries outside time range - } - - // Retrieve and decompress log entry - logData, err := s.retrieveWithDecompression(txn, it.Item().Key()) - if err != nil { - continue // Skip entries that can't be read - } - - var logEntry types.TamperLogEntry - err = json.Unmarshal(logData, &logEntry) - if err != nil { - continue // Skip entries that can't be parsed - } - - logEntries = append(logEntries, &logEntry) - count++ - } - - return nil - }) - - return logEntries, err -} // Phase 2: Backup system utilities @@ -1002,2575 +247,42 @@ func getBackupFilename(timestamp time.Time) string { } // createBackup creates a compressed backup of the BadgerDB database -func (s *Server) createBackup() error { - s.backupMu.Lock() - s.backupStatus.BackupsRunning++ - s.backupMu.Unlock() - - defer func() { - s.backupMu.Lock() - s.backupStatus.BackupsRunning-- - s.backupMu.Unlock() - }() - - now := time.Now() - backupFilename := getBackupFilename(now) - backupPath := filepath.Join(s.config.BackupPath, backupFilename) - - // Create backup directory if it doesn't exist - if err := os.MkdirAll(s.config.BackupPath, 0755); err != nil { - return fmt.Errorf("failed to create backup directory: %v", err) - } - - // Create temporary file for backup - tempPath := backupPath + ".tmp" - tempFile, err := os.Create(tempPath) - if err != nil { - return fmt.Errorf("failed to create temporary backup file: %v", err) - } - defer tempFile.Close() - - // Create ZSTD compressor for backup file - compressor, err := zstd.NewWriter(tempFile, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(s.config.CompressionLevel))) - if err != nil { - os.Remove(tempPath) - return fmt.Errorf("failed to create backup compressor: %v", err) - } - defer compressor.Close() - - // Create BadgerDB backup stream - since := uint64(0) // Full backup - _, err = s.db.Backup(compressor, since) - if err != nil { - compressor.Close() - tempFile.Close() - os.Remove(tempPath) - return fmt.Errorf("failed to create database backup: %v", err) - } - - // Close compressor and temp file - compressor.Close() - tempFile.Close() - - // Move temporary file to final backup path - if err := os.Rename(tempPath, backupPath); err != nil { - os.Remove(tempPath) - return fmt.Errorf("failed to finalize backup file: %v", err) - } - - // Update backup status - s.backupMu.Lock() - s.backupStatus.LastBackupTime = now.Unix() - s.backupStatus.LastBackupSuccess = true - s.backupStatus.LastBackupPath = backupPath - s.backupMu.Unlock() - - s.logger.WithFields(logrus.Fields{ - "backup_path": backupPath, - "timestamp": now.Format(time.RFC3339), - }).Info("Database backup completed successfully") - - // Clean up old backups - s.cleanupOldBackups() - - return nil -} // cleanupOldBackups removes backup files older than the retention period -func (s *Server) cleanupOldBackups() { - if s.config.BackupRetention <= 0 { - return // No cleanup if retention is disabled - } - - cutoffTime := time.Now().AddDate(0, 0, -s.config.BackupRetention) - - entries, err := os.ReadDir(s.config.BackupPath) - if err != nil { - s.logger.WithError(err).Error("Failed to read backup directory for cleanup") - return - } - - for _, entry := range entries { - if !strings.HasPrefix(entry.Name(), "kvs-backup-") || !strings.HasSuffix(entry.Name(), ".zstd") { - continue // Skip non-backup files - } - - info, err := entry.Info() - if err != nil { - continue - } - - if info.ModTime().Before(cutoffTime) { - backupPath := filepath.Join(s.config.BackupPath, entry.Name()) - if err := os.Remove(backupPath); err != nil { - s.logger.WithError(err).WithField("backup_path", backupPath).Warn("Failed to remove old backup") - } else { - s.logger.WithField("backup_path", backupPath).Info("Removed old backup") - } - } - } -} // initializeBackupScheduler sets up the cron scheduler for automated backups -func (s *Server) initializeBackupScheduler() error { - if !s.config.BackupEnabled { - s.logger.Info("Backup system disabled") - return nil - } - - s.cronScheduler = cron.New() - - _, err := s.cronScheduler.AddFunc(s.config.BackupSchedule, func() { - s.logger.Info("Starting scheduled backup") - - if err := s.createBackup(); err != nil { - s.logger.WithError(err).Error("Scheduled backup failed") - - // Update backup status on failure - s.backupMu.Lock() - s.backupStatus.LastBackupTime = time.Now().Unix() - s.backupStatus.LastBackupSuccess = false - s.backupMu.Unlock() - } - }) - - if err != nil { - return fmt.Errorf("failed to schedule backup: %v", err) - } - - s.cronScheduler.Start() - - s.logger.WithFields(logrus.Fields{ - "schedule": s.config.BackupSchedule, - "path": s.config.BackupPath, - "retention": s.config.BackupRetention, - }).Info("Backup scheduler initialized") - - return nil -} // getBackupStatus returns the current backup status -func (s *Server) getBackupStatus() types.BackupStatus { - s.backupMu.RLock() - defer s.backupMu.RUnlock() - - status := s.backupStatus - - // Calculate next backup time if scheduler is running - if s.cronScheduler != nil && len(s.cronScheduler.Entries()) > 0 { - nextRun := s.cronScheduler.Entries()[0].Next - if !nextRun.IsZero() { - status.NextBackupTime = nextRun.Unix() - } - } - - return status -} - // Initialize server -func NewServer(config *types.Config) (*Server, error) { - logger := logrus.New() - logger.SetFormatter(&logrus.JSONFormatter{}) - - level, err := logrus.ParseLevel(config.LogLevel) - if err != nil { - level = logrus.InfoLevel - } - logger.SetLevel(level) - - // Create data directory - if err := os.MkdirAll(config.DataDir, 0755); err != nil { - return nil, fmt.Errorf("failed to create data directory: %v", err) - } - - // Open BadgerDB - opts := badger.DefaultOptions(filepath.Join(config.DataDir, "badger")) - opts.Logger = nil // Disable badger's internal logging - db, err := badger.Open(opts) - if err != nil { - return nil, fmt.Errorf("failed to open BadgerDB: %v", err) - } - - ctx, cancel := context.WithCancel(context.Background()) - - server := &Server{ - config: config, - db: db, - members: make(map[string]*types.Member), - mode: "normal", - logger: logger, - ctx: ctx, - cancel: cancel, - } - - if config.ReadOnly { - server.setMode("read-only") - } - - // Build initial Merkle tree - pairs, err := server.getAllKVPairsForMerkleTree() - if err != nil { - return nil, fmt.Errorf("failed to get all KV pairs for initial Merkle tree: %v", err) - } - root, err := server.buildMerkleTreeFromPairs(pairs) - if err != nil { - return nil, fmt.Errorf("failed to build initial Merkle tree: %v", err) - } - server.setMerkleRoot(root) - server.logger.Info("Initial Merkle tree built.") - - // Phase 2: Initialize ZSTD compression if enabled - if config.CompressionEnabled { - // Validate compression level - if config.CompressionLevel < 1 || config.CompressionLevel > 19 { - config.CompressionLevel = 3 // Default to level 3 - } - - // Create encoder with specified compression level - encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(config.CompressionLevel))) - if err != nil { - return nil, fmt.Errorf("failed to create ZSTD encoder: %v", err) - } - server.compressor = encoder - - // Create decoder for decompression - decoder, err := zstd.NewReader(nil) - if err != nil { - return nil, fmt.Errorf("failed to create ZSTD decoder: %v", err) - } - server.decompressor = decoder - - server.logger.WithField("compression_level", config.CompressionLevel).Info("ZSTD compression enabled") - } else { - server.logger.Info("ZSTD compression disabled") - } - - // Phase 2: Initialize backup scheduler - err = server.initializeBackupScheduler() - if err != nil { - return nil, fmt.Errorf("failed to initialize backup scheduler: %v", err) - } - - return server, nil -} - -// Mode management -func (s *Server) getMode() string { - s.modeMu.RLock() - defer s.modeMu.RUnlock() - return s.mode -} - -func (s *Server) setMode(mode string) { - s.modeMu.Lock() - defer s.modeMu.Unlock() - oldMode := s.mode - s.mode = mode - s.logger.WithFields(logrus.Fields{ - "old_mode": oldMode, - "new_mode": mode, - }).Info("Mode changed") -} - -// types.Member management -func (s *Server) addMember(member *types.Member) { - s.membersMu.Lock() - defer s.membersMu.Unlock() - s.members[member.ID] = member - s.logger.WithFields(logrus.Fields{ - "node_id": member.ID, - "address": member.Address, - }).Info("types.Member added") -} - -func (s *Server) removeMember(nodeID string) { - s.membersMu.Lock() - defer s.membersMu.Unlock() - if member, exists := s.members[nodeID]; exists { - delete(s.members, nodeID) - s.logger.WithFields(logrus.Fields{ - "node_id": member.ID, - "address": member.Address, - }).Info("types.Member removed") - } -} - -func (s *Server) getMembers() []*types.Member { - s.membersMu.RLock() - defer s.membersMu.RUnlock() - members := make([]*types.Member, 0, len(s.members)) - for _, member := range s.members { - members = append(members, member) - } - return members -} - -// HTTP Handlers -func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) { - mode := s.getMode() - memberCount := len(s.getMembers()) - - health := map[string]interface{}{ - "status": "ok", - "mode": mode, - "member_count": memberCount, - "node_id": s.config.NodeID, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(health) -} - -func (s *Server) getKVHandler(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - path := vars["path"] - - var storedValue types.StoredValue - err := s.db.View(func(txn *badger.Txn) error { - item, err := txn.Get([]byte(path)) - if err != nil { - return err - } - - return item.Value(func(val []byte) error { - return json.Unmarshal(val, &storedValue) - }) - }) - - if err == badger.ErrKeyNotFound { - http.Error(w, "Not Found", http.StatusNotFound) - return - } - if err != nil { - s.logger.WithError(err).WithField("path", path).Error("Failed to get value") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - // CHANGE: Return the entire types.StoredValue, not just Data - json.NewEncoder(w).Encode(storedValue) -} - -func (s *Server) putKVHandler(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - path := vars["path"] - - mode := s.getMode() - if mode == "syncing" { - http.Error(w, "Service Unavailable", http.StatusServiceUnavailable) - return - } - - if mode == "read-only" && !s.isClusterMember(r.RemoteAddr) { - http.Error(w, "Forbidden", http.StatusForbidden) - return - } - - var data json.RawMessage - if err := json.NewDecoder(r.Body).Decode(&data); err != nil { - http.Error(w, "Bad Request", http.StatusBadRequest) - return - } - - now := time.Now().UnixMilli() - newUUID := uuid.New().String() - - storedValue := types.StoredValue{ - UUID: newUUID, - Timestamp: now, - Data: data, - } - - valueBytes, err := json.Marshal(storedValue) - if err != nil { - s.logger.WithError(err).Error("Failed to marshal stored value") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - var isUpdate bool - err = s.db.Update(func(txn *badger.Txn) error { - // Check if key exists - _, err := txn.Get([]byte(path)) - isUpdate = (err == nil) - - // Store main data - if err := txn.Set([]byte(path), valueBytes); err != nil { - return err - } - - // Store timestamp index - indexKey := fmt.Sprintf("_ts:%020d:%s", now, path) - return txn.Set([]byte(indexKey), []byte(newUUID)) - }) - - if err != nil { - s.logger.WithError(err).WithField("path", path).Error("Failed to store value") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - response := types.PutResponse{ - UUID: newUUID, - Timestamp: now, - } - - status := http.StatusCreated - if isUpdate { - status = http.StatusOK - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - json.NewEncoder(w).Encode(response) - - s.logger.WithFields(logrus.Fields{ - "path": path, - "uuid": newUUID, - "timestamp": now, - "is_update": isUpdate, - }).Info("Value stored") -} - -func (s *Server) deleteKVHandler(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - path := vars["path"] - - mode := s.getMode() - if mode == "syncing" { - http.Error(w, "Service Unavailable", http.StatusServiceUnavailable) - return - } - - if mode == "read-only" && !s.isClusterMember(r.RemoteAddr) { - http.Error(w, "Forbidden", http.StatusForbidden) - return - } - - var found bool - err := s.db.Update(func(txn *badger.Txn) error { - // Check if key exists and get timestamp for index cleanup - item, err := txn.Get([]byte(path)) - if err == badger.ErrKeyNotFound { - return nil - } - if err != nil { - return err - } - found = true - - var storedValue types.StoredValue - err = item.Value(func(val []byte) error { - return json.Unmarshal(val, &storedValue) - }) - if err != nil { - return err - } - - // Delete main data - if err := txn.Delete([]byte(path)); err != nil { - return err - } - - // Delete timestamp index - indexKey := fmt.Sprintf("_ts:%020d:%s", storedValue.Timestamp, path) - return txn.Delete([]byte(indexKey)) - }) - - if err != nil { - s.logger.WithError(err).WithField("path", path).Error("Failed to delete value") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - if !found { - http.Error(w, "Not Found", http.StatusNotFound) - return - } - - w.WriteHeader(http.StatusNoContent) - - s.logger.WithField("path", path).Info("Value deleted") -} - -func (s *Server) getMembersHandler(w http.ResponseWriter, r *http.Request) { - members := s.getMembers() - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(members) -} - -func (s *Server) joinMemberHandler(w http.ResponseWriter, r *http.Request) { - var req types.JoinRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Bad Request", http.StatusBadRequest) - return - } - - now := time.Now().UnixMilli() - member := &types.Member{ - ID: req.ID, - Address: req.Address, - LastSeen: now, - JoinedTimestamp: req.JoinedTimestamp, - } - - s.addMember(member) - - // Return current member list - members := s.getMembers() - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(members) -} - -func (s *Server) leaveMemberHandler(w http.ResponseWriter, r *http.Request) { - var req types.LeaveRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Bad Request", http.StatusBadRequest) - return - } - - s.removeMember(req.ID) - w.WriteHeader(http.StatusNoContent) -} - -func (s *Server) pairsByTimeHandler(w http.ResponseWriter, r *http.Request) { - var req types.PairsByTimeRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Bad Request", http.StatusBadRequest) - return - } - - // Default limit to 15 as per spec - if req.Limit <= 0 { - req.Limit = 15 - } - - var pairs []types.PairsByTimeResponse - - err := s.db.View(func(txn *badger.Txn) error { - opts := badger.DefaultIteratorOptions - opts.PrefetchSize = req.Limit - it := txn.NewIterator(opts) - defer it.Close() - - prefix := []byte("_ts:") - // The original logic for prefix filtering here was incomplete. - // For Merkle tree sync, this handler is no longer used for core sync. - // It remains as a client-facing API. - - for it.Seek(prefix); it.ValidForPrefix(prefix) && len(pairs) < req.Limit; it.Next() { - item := it.Item() - key := string(item.Key()) - - // Parse timestamp index key: "_ts:{timestamp}:{path}" - parts := strings.SplitN(key, ":", 3) - if len(parts) != 3 { - continue - } - - timestamp, err := strconv.ParseInt(parts[1], 10, 64) - if err != nil { - continue - } - - // Filter by timestamp range - if req.StartTimestamp > 0 && timestamp < req.StartTimestamp { - continue - } - if req.EndTimestamp > 0 && timestamp >= req.EndTimestamp { - continue - } - - path := parts[2] - - // Filter by prefix if specified - if req.Prefix != "" && !strings.HasPrefix(path, req.Prefix) { - continue - } - - var uuid string - err = item.Value(func(val []byte) error { - uuid = string(val) - return nil - }) - if err != nil { - continue - } - - pairs = append(pairs, types.PairsByTimeResponse{ - Path: path, - UUID: uuid, - Timestamp: timestamp, - }) - } - - return nil - }) - - if err != nil { - s.logger.WithError(err).Error("Failed to query pairs by time") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - if len(pairs) == 0 { - w.WriteHeader(http.StatusNoContent) - return - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(pairs) -} - -func (s *Server) gossipHandler(w http.ResponseWriter, r *http.Request) { - var remoteMemberList []types.Member - if err := json.NewDecoder(r.Body).Decode(&remoteMemberList); err != nil { - http.Error(w, "Bad Request", http.StatusBadRequest) - return - } - - // Merge the received member list - s.mergeMemberList(remoteMemberList) - - // Respond with our current member list - localMembers := s.getMembers() - gossipResponse := make([]types.Member, len(localMembers)) - for i, member := range localMembers { - gossipResponse[i] = *member - } - - // Add ourselves to the response - selfMember := types.Member{ - ID: s.config.NodeID, - Address: fmt.Sprintf("%s:%d", s.config.BindAddress, s.config.Port), - LastSeen: time.Now().UnixMilli(), - JoinedTimestamp: s.getJoinedTimestamp(), - } - gossipResponse = append(gossipResponse, selfMember) - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(gossipResponse) - - s.logger.WithField("remote_members", len(remoteMemberList)).Debug("Processed gossip request") -} - -// Utility function to check if request is from cluster member -func (s *Server) isClusterMember(remoteAddr string) bool { - host, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - return false - } - - s.membersMu.RLock() - defer s.membersMu.RUnlock() - - for _, member := range s.members { - memberHost, _, err := net.SplitHostPort(member.Address) - if err == nil && memberHost == host { - return true - } - } - - return false -} - -// Setup HTTP routes -func (s *Server) setupRoutes() *mux.Router { - router := mux.NewRouter() - - // Health endpoint - router.HandleFunc("/health", s.healthHandler).Methods("GET") - - // KV endpoints - router.HandleFunc("/kv/{path:.+}", s.getKVHandler).Methods("GET") - router.HandleFunc("/kv/{path:.+}", s.putKVHandler).Methods("PUT") - router.HandleFunc("/kv/{path:.+}", s.deleteKVHandler).Methods("DELETE") - - // types.Member endpoints - router.HandleFunc("/members/", s.getMembersHandler).Methods("GET") - router.HandleFunc("/members/join", s.joinMemberHandler).Methods("POST") - router.HandleFunc("/members/leave", s.leaveMemberHandler).Methods("DELETE") - router.HandleFunc("/members/gossip", s.gossipHandler).Methods("POST") - router.HandleFunc("/members/pairs_by_time", s.pairsByTimeHandler).Methods("POST") // Still available for clients - - // Merkle Tree endpoints - router.HandleFunc("/merkle_tree/root", s.getMerkleRootHandler).Methods("GET") - router.HandleFunc("/merkle_tree/diff", s.getMerkleDiffHandler).Methods("POST") - router.HandleFunc("/kv_range", s.getKVRangeHandler).Methods("POST") // New endpoint for fetching ranges - - // Phase 2: types.User Management endpoints - router.HandleFunc("/api/users", s.createUserHandler).Methods("POST") - router.HandleFunc("/api/users/{uuid}", s.getUserHandler).Methods("GET") - router.HandleFunc("/api/users/{uuid}", s.updateUserHandler).Methods("PUT") - router.HandleFunc("/api/users/{uuid}", s.deleteUserHandler).Methods("DELETE") - - // Phase 2: types.Group Management endpoints - router.HandleFunc("/api/groups", s.createGroupHandler).Methods("POST") - router.HandleFunc("/api/groups/{uuid}", s.getGroupHandler).Methods("GET") - router.HandleFunc("/api/groups/{uuid}", s.updateGroupHandler).Methods("PUT") - router.HandleFunc("/api/groups/{uuid}", s.deleteGroupHandler).Methods("DELETE") - - // Phase 2: Token Management endpoints - router.HandleFunc("/api/tokens", s.createTokenHandler).Methods("POST") - - // Phase 2: Revision History endpoints - router.HandleFunc("/api/data/{key}/history", s.getRevisionHistoryHandler).Methods("GET") - router.HandleFunc("/api/data/{key}/history/{revision}", s.getSpecificRevisionHandler).Methods("GET") - - // Phase 2: Backup Status endpoint - router.HandleFunc("/api/backup/status", s.getBackupStatusHandler).Methods("GET") - - return router -} - -// Start the server -func (s *Server) Start() error { - router := s.setupRoutes() - - addr := fmt.Sprintf("%s:%d", s.config.BindAddress, s.config.Port) - s.httpServer = &http.Server{ - Addr: addr, - Handler: router, - } - - s.logger.WithFields(logrus.Fields{ - "node_id": s.config.NodeID, - "address": addr, - }).Info("Starting KVS server") - - // Start gossip and sync routines - s.startBackgroundTasks() - - // Try to join cluster if seed nodes are configured and clustering is enabled - if s.config.ClusteringEnabled && len(s.config.SeedNodes) > 0 { - go s.bootstrap() - } - - return s.httpServer.ListenAndServe() -} - -// Stop the server gracefully -func (s *Server) Stop() error { - s.logger.Info("Shutting down KVS server") - - s.cancel() - s.wg.Wait() - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - if err := s.httpServer.Shutdown(ctx); err != nil { - s.logger.WithError(err).Error("HTTP server shutdown error") - } - - if err := s.db.Close(); err != nil { - s.logger.WithError(err).Error("BadgerDB close error") - } - - return nil -} - -// Background tasks (gossip, sync, Merkle tree rebuild, etc.) -func (s *Server) startBackgroundTasks() { - // Start clustering-related routines only if clustering is enabled - if s.config.ClusteringEnabled { - // Start gossip routine - s.wg.Add(1) - go s.gossipRoutine() - - // Start sync routine (now Merkle-based) - s.wg.Add(1) - go s.syncRoutine() - } - - // Start Merkle tree rebuild routine (always needed for data integrity) - s.wg.Add(1) - go s.merkleTreeRebuildRoutine() -} - -// Gossip routine - runs periodically to exchange member lists -func (s *Server) gossipRoutine() { - defer s.wg.Done() - - for { - // Random interval between 1-2 minutes - minInterval := time.Duration(s.config.GossipIntervalMin) * time.Second - maxInterval := time.Duration(s.config.GossipIntervalMax) * time.Second - interval := minInterval + time.Duration(rand.Int63n(int64(maxInterval-minInterval))) - - select { - case <-s.ctx.Done(): - return - case <-time.After(interval): - s.performGossipRound() - } - } -} - -// Perform a gossip round with random healthy peers -func (s *Server) performGossipRound() { - members := s.getHealthyMembers() - if len(members) == 0 { - s.logger.Debug("No healthy members for gossip round") - return - } - - // Select 1-3 random peers for gossip - maxPeers := 3 - if len(members) < maxPeers { - maxPeers = len(members) - } - - // Shuffle and select - rand.Shuffle(len(members), func(i, j int) { - members[i], members[j] = members[j], members[i] - }) - - selectedPeers := members[:rand.Intn(maxPeers)+1] - - for _, peer := range selectedPeers { - go s.gossipWithPeer(peer) - } -} - -// Gossip with a specific peer -func (s *Server) gossipWithPeer(peer *types.Member) { - s.logger.WithField("peer", peer.Address).Debug("Starting gossip with peer") - - // Get our current member list - localMembers := s.getMembers() - - // Send our member list to the peer - gossipData := make([]types.Member, len(localMembers)) - for i, member := range localMembers { - gossipData[i] = *member - } - - // Add ourselves to the list - selfMember := types.Member{ - ID: s.config.NodeID, - Address: fmt.Sprintf("%s:%d", s.config.BindAddress, s.config.Port), - LastSeen: time.Now().UnixMilli(), - JoinedTimestamp: s.getJoinedTimestamp(), - } - gossipData = append(gossipData, selfMember) - - jsonData, err := json.Marshal(gossipData) - if err != nil { - s.logger.WithError(err).Error("Failed to marshal gossip data") - return - } - - // Send HTTP request to peer - client := &http.Client{Timeout: 5 * time.Second} - url := fmt.Sprintf("http://%s/members/gossip", peer.Address) - - resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData)) - if err != nil { - s.logger.WithFields(logrus.Fields{ - "peer": peer.Address, - "error": err.Error(), - }).Warn("Failed to gossip with peer") - s.markPeerUnhealthy(peer.ID) - return - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - s.logger.WithFields(logrus.Fields{ - "peer": peer.Address, - "status": resp.StatusCode, - }).Warn("Gossip request failed") - s.markPeerUnhealthy(peer.ID) - return - } - - // Process response - peer's member list - var remoteMemberList []types.Member - if err := json.NewDecoder(resp.Body).Decode(&remoteMemberList); err != nil { - s.logger.WithError(err).Error("Failed to decode gossip response") - return - } - - // Merge remote member list with our local list - s.mergeMemberList(remoteMemberList) - - // Update peer's last seen timestamp - s.updateMemberLastSeen(peer.ID, time.Now().UnixMilli()) - - s.logger.WithField("peer", peer.Address).Debug("Completed gossip with peer") -} - -// Get healthy members (exclude those marked as down) -func (s *Server) getHealthyMembers() []*types.Member { - s.membersMu.RLock() - defer s.membersMu.RUnlock() - - now := time.Now().UnixMilli() - healthyMembers := make([]*types.Member, 0) - - for _, member := range s.members { - // Consider member healthy if last seen within last 5 minutes - if now-member.LastSeen < 5*60*1000 { - healthyMembers = append(healthyMembers, member) - } - } - - return healthyMembers -} - -// Mark a peer as unhealthy -func (s *Server) markPeerUnhealthy(nodeID string) { - s.membersMu.Lock() - defer s.membersMu.Unlock() - - if member, exists := s.members[nodeID]; exists { - // Mark as last seen a long time ago to indicate unhealthy - member.LastSeen = time.Now().UnixMilli() - 10*60*1000 // 10 minutes ago - s.logger.WithField("node_id", nodeID).Warn("Marked peer as unhealthy") - } -} - -// Update member's last seen timestamp -func (s *Server) updateMemberLastSeen(nodeID string, timestamp int64) { - s.membersMu.Lock() - defer s.membersMu.Unlock() - - if member, exists := s.members[nodeID]; exists { - member.LastSeen = timestamp - } -} - -// Merge remote member list with local member list -func (s *Server) mergeMemberList(remoteMembers []types.Member) { - s.membersMu.Lock() - defer s.membersMu.Unlock() - - now := time.Now().UnixMilli() - - for _, remoteMember := range remoteMembers { - // Skip ourselves - if remoteMember.ID == s.config.NodeID { - continue - } - - if localMember, exists := s.members[remoteMember.ID]; exists { - // Update existing member - if remoteMember.LastSeen > localMember.LastSeen { - localMember.LastSeen = remoteMember.LastSeen - } - // Keep the earlier joined timestamp - if remoteMember.JoinedTimestamp < localMember.JoinedTimestamp { - localMember.JoinedTimestamp = remoteMember.JoinedTimestamp - } - } else { - // Add new member - newMember := &types.Member{ - ID: remoteMember.ID, - Address: remoteMember.Address, - LastSeen: remoteMember.LastSeen, - JoinedTimestamp: remoteMember.JoinedTimestamp, - } - s.members[remoteMember.ID] = newMember - s.logger.WithFields(logrus.Fields{ - "node_id": remoteMember.ID, - "address": remoteMember.Address, - }).Info("Discovered new member through gossip") - } - } - - // Clean up old members (not seen for more than 10 minutes) - toRemove := make([]string, 0) - for nodeID, member := range s.members { - if now-member.LastSeen > 10*60*1000 { // 10 minutes - toRemove = append(toRemove, nodeID) - } - } - - for _, nodeID := range toRemove { - delete(s.members, nodeID) - s.logger.WithField("node_id", nodeID).Info("Removed stale member") - } -} - -// Get this node's joined timestamp (startup time) -func (s *Server) getJoinedTimestamp() int64 { - // For now, use a simple approach - this should be stored persistently - return time.Now().UnixMilli() -} - -// Sync routine - handles regular and catch-up syncing -func (s *Server) syncRoutine() { - defer s.wg.Done() - - syncTicker := time.NewTicker(time.Duration(s.config.SyncInterval) * time.Second) - defer syncTicker.Stop() - - for { - select { - case <-s.ctx.Done(): - return - case <-syncTicker.C: - s.performMerkleSync() // Use Merkle sync instead of regular sync - } - } -} - -// Merkle Tree Implementation - -// calculateHash generates a SHA256 hash for a given byte slice -func calculateHash(data []byte) []byte { - h := sha256.New() - h.Write(data) - return h.Sum(nil) -} - -// calculateLeafHash generates a hash for a leaf node based on its path, UUID, timestamp, and data -func (s *Server) calculateLeafHash(path string, storedValue *types.StoredValue) []byte { - // Concatenate path, UUID, timestamp, and the raw data bytes for hashing - // Ensure a consistent order of fields for hashing - dataToHash := bytes.Buffer{} - dataToHash.WriteString(path) - dataToHash.WriteByte(':') - dataToHash.WriteString(storedValue.UUID) - dataToHash.WriteByte(':') - dataToHash.WriteString(strconv.FormatInt(storedValue.Timestamp, 10)) - dataToHash.WriteByte(':') - dataToHash.Write(storedValue.Data) // Use raw bytes of json.RawMessage - - return calculateHash(dataToHash.Bytes()) -} - -// getAllKVPairsForMerkleTree retrieves all key-value pairs needed for Merkle tree construction. -func (s *Server) getAllKVPairsForMerkleTree() (map[string]*types.StoredValue, error) { - pairs := make(map[string]*types.StoredValue) - err := s.db.View(func(txn *badger.Txn) error { - opts := badger.DefaultIteratorOptions - opts.PrefetchValues = true // We need the values for hashing - it := txn.NewIterator(opts) - defer it.Close() - - // Iterate over all actual data keys (not _ts: indexes) - for it.Rewind(); it.Valid(); it.Next() { - item := it.Item() - key := string(item.Key()) - - if strings.HasPrefix(key, "_ts:") { - continue // Skip index keys - } - - var storedValue types.StoredValue - err := item.Value(func(val []byte) error { - return json.Unmarshal(val, &storedValue) - }) - if err != nil { - s.logger.WithError(err).WithField("key", key).Warn("Failed to unmarshal stored value for Merkle tree, skipping") - continue - } - pairs[key] = &storedValue - } - return nil - }) - if err != nil { - return nil, err - } - return pairs, nil -} - -// buildMerkleTreeFromPairs constructs a Merkle Tree from the KVS data. -// This version uses a recursive approach to build a balanced tree from sorted keys. -func (s *Server) buildMerkleTreeFromPairs(pairs map[string]*types.StoredValue) (*types.MerkleNode, error) { - if len(pairs) == 0 { - return &types.MerkleNode{Hash: calculateHash([]byte("empty_tree")), StartKey: "", EndKey: ""}, nil - } - - // Sort keys to ensure consistent tree structure - keys := make([]string, 0, len(pairs)) - for k := range pairs { - keys = append(keys, k) - } - sort.Strings(keys) - - // Create leaf nodes - leafNodes := make([]*types.MerkleNode, len(keys)) - for i, key := range keys { - storedValue := pairs[key] - hash := s.calculateLeafHash(key, storedValue) - leafNodes[i] = &types.MerkleNode{Hash: hash, StartKey: key, EndKey: key} - } - - // Recursively build parent nodes - return s.buildMerkleTreeRecursive(leafNodes) -} - -// buildMerkleTreeRecursive builds the tree from a slice of nodes. -func (s *Server) buildMerkleTreeRecursive(nodes []*types.MerkleNode) (*types.MerkleNode, error) { - if len(nodes) == 0 { - return nil, nil - } - if len(nodes) == 1 { - return nodes[0], nil - } - - var nextLevel []*types.MerkleNode - for i := 0; i < len(nodes); i += 2 { - left := nodes[i] - var right *types.MerkleNode - if i+1 < len(nodes) { - right = nodes[i+1] - } - - var combinedHash []byte - var endKey string - - if right != nil { - combinedHash = calculateHash(append(left.Hash, right.Hash...)) - endKey = right.EndKey - } else { - // Odd number of nodes, promote the left node - combinedHash = left.Hash - endKey = left.EndKey - } - - parentNode := &types.MerkleNode{ - Hash: combinedHash, - StartKey: left.StartKey, - EndKey: endKey, - } - nextLevel = append(nextLevel, parentNode) - } - return s.buildMerkleTreeRecursive(nextLevel) -} - -// getMerkleRoot returns the current Merkle root of the server. -func (s *Server) getMerkleRoot() *types.MerkleNode { - s.merkleRootMu.RLock() - defer s.merkleRootMu.RUnlock() - return s.merkleRoot -} - -// setMerkleRoot sets the current Merkle root of the server. -func (s *Server) setMerkleRoot(root *types.MerkleNode) { - s.merkleRootMu.Lock() - defer s.merkleRootMu.Unlock() - s.merkleRoot = root -} - -// merkleTreeRebuildRoutine periodically rebuilds the Merkle tree. -func (s *Server) merkleTreeRebuildRoutine() { - defer s.wg.Done() - ticker := time.NewTicker(time.Duration(s.config.SyncInterval) * time.Second) // Use sync interval for now - defer ticker.Stop() - - for { - select { - case <-s.ctx.Done(): - return - case <-ticker.C: - s.logger.Debug("Rebuilding Merkle tree...") - pairs, err := s.getAllKVPairsForMerkleTree() - if err != nil { - s.logger.WithError(err).Error("Failed to get KV pairs for Merkle tree rebuild") - continue - } - newRoot, err := s.buildMerkleTreeFromPairs(pairs) - if err != nil { - s.logger.WithError(err).Error("Failed to rebuild Merkle tree") - continue - } - s.setMerkleRoot(newRoot) - s.logger.Debug("Merkle tree rebuilt.") - } - } -} - -// getMerkleRootHandler returns the root hash of the local Merkle Tree -func (s *Server) getMerkleRootHandler(w http.ResponseWriter, r *http.Request) { - root := s.getMerkleRoot() - if root == nil { - http.Error(w, "Merkle tree not initialized", http.StatusInternalServerError) - return - } - - resp := types.MerkleRootResponse{ - Root: root, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) -} - -// getMerkleDiffHandler is used by a peer to request children hashes for a given node/range. -func (s *Server) getMerkleDiffHandler(w http.ResponseWriter, r *http.Request) { - var req types.MerkleTreeDiffRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Bad Request", http.StatusBadRequest) - return - } - - localPairs, err := s.getAllKVPairsForMerkleTree() - if err != nil { - s.logger.WithError(err).Error("Failed to get KV pairs for Merkle diff") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - // Build the local types.MerkleNode for the requested range to compare with the remote's hash - localSubTreeRoot, err := s.buildMerkleTreeFromPairs(s.filterPairsByRange(localPairs, req.ParentNode.StartKey, req.ParentNode.EndKey)) - if err != nil { - s.logger.WithError(err).Error("Failed to build sub-Merkle tree for diff request") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - if localSubTreeRoot == nil { // This can happen if the range is empty locally - localSubTreeRoot = &types.MerkleNode{Hash: calculateHash([]byte("empty_tree")), StartKey: req.ParentNode.StartKey, EndKey: req.ParentNode.EndKey} - } - - resp := types.MerkleTreeDiffResponse{} - - // If hashes match, no need to send children or keys - if bytes.Equal(req.LocalHash, localSubTreeRoot.Hash) { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) - return - } - - // Hashes differ, so we need to provide more detail. - // Get all keys within the parent node's range locally - var keysInRange []string - for key := range s.filterPairsByRange(localPairs, req.ParentNode.StartKey, req.ParentNode.EndKey) { - keysInRange = append(keysInRange, key) - } - sort.Strings(keysInRange) - - const diffLeafThreshold = 10 // If a range has <= 10 keys, we consider it a leaf-level diff - - if len(keysInRange) <= diffLeafThreshold { - // This is a leaf-level diff, return the actual keys in the range - resp.Keys = keysInRange - } else { - // types.Group keys into sub-ranges and return their types.MerkleNode representations - // For simplicity, let's split the range into two halves. - mid := len(keysInRange) / 2 - - leftKeys := keysInRange[:mid] - rightKeys := keysInRange[mid:] - - if len(leftKeys) > 0 { - leftRangePairs := s.filterPairsByRange(localPairs, leftKeys[0], leftKeys[len(leftKeys)-1]) - leftNode, err := s.buildMerkleTreeFromPairs(leftRangePairs) - if err != nil { - s.logger.WithError(err).Error("Failed to build left child node for diff") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - if leftNode != nil { - resp.Children = append(resp.Children, *leftNode) - } - } - - if len(rightKeys) > 0 { - rightRangePairs := s.filterPairsByRange(localPairs, rightKeys[0], rightKeys[len(rightKeys)-1]) - rightNode, err := s.buildMerkleTreeFromPairs(rightRangePairs) - if err != nil { - s.logger.WithError(err).Error("Failed to build right child node for diff") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - if rightNode != nil { - resp.Children = append(resp.Children, *rightNode) - } - } - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) -} - -// Helper to filter a map of types.StoredValue by key range -func (s *Server) filterPairsByRange(allPairs map[string]*types.StoredValue, startKey, endKey string) map[string]*types.StoredValue { - filtered := make(map[string]*types.StoredValue) - for key, value := range allPairs { - if (startKey == "" || key >= startKey) && (endKey == "" || key <= endKey) { - filtered[key] = value - } - } - return filtered -} - -// getKVRangeHandler fetches a range of KV pairs -func (s *Server) getKVRangeHandler(w http.ResponseWriter, r *http.Request) { - var req types.KVRangeRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Bad Request", http.StatusBadRequest) - return - } - - var pairs []struct { - Path string `json:"path"` - StoredValue types.StoredValue `json:"stored_value"` - } - - err := s.db.View(func(txn *badger.Txn) error { - opts := badger.DefaultIteratorOptions - opts.PrefetchValues = true - it := txn.NewIterator(opts) - defer it.Close() - - count := 0 - // Start iteration from the requested StartKey - for it.Seek([]byte(req.StartKey)); it.Valid(); it.Next() { - item := it.Item() - key := string(item.Key()) - - if strings.HasPrefix(key, "_ts:") { - continue // Skip index keys - } - - // Stop if we exceed the EndKey (if provided) - if req.EndKey != "" && key > req.EndKey { - break - } - - // Stop if we hit the limit (if provided) - if req.Limit > 0 && count >= req.Limit { - break - } - - var storedValue types.StoredValue - err := item.Value(func(val []byte) error { - return json.Unmarshal(val, &storedValue) - }) - if err != nil { - s.logger.WithError(err).WithField("key", key).Warn("Failed to unmarshal stored value in KV range, skipping") - continue - } - - pairs = append(pairs, struct { - Path string `json:"path"` - StoredValue types.StoredValue `json:"stored_value"` - }{Path: key, StoredValue: storedValue}) - count++ - } - return nil - }) - - if err != nil { - s.logger.WithError(err).Error("Failed to query KV range") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(types.KVRangeResponse{Pairs: pairs}) -} - -// Phase 2: types.User Management API Handlers - -// createUserHandler handles POST /api/users -func (s *Server) createUserHandler(w http.ResponseWriter, r *http.Request) { - var req types.CreateUserRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Bad Request", http.StatusBadRequest) - return - } - - if req.Nickname == "" { - http.Error(w, "Nickname is required", http.StatusBadRequest) - return - } - - // Generate UUID for the user - userUUID := uuid.New().String() - now := time.Now().Unix() - - user := types.User{ - UUID: userUUID, - NicknameHash: utils.HashUserNickname(req.Nickname), - Groups: []string{}, - CreatedAt: now, - UpdatedAt: now, - } - - // Store user in BadgerDB - userData, err := json.Marshal(user) - if err != nil { - s.logger.WithError(err).Error("Failed to marshal user data") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - err = s.db.Update(func(txn *badger.Txn) error { - return txn.Set([]byte(auth.UserStorageKey(userUUID)), userData) - }) - - if err != nil { - s.logger.WithError(err).Error("Failed to store user") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - s.logger.WithField("user_uuid", userUUID).Info("types.User created successfully") - - response := types.CreateUserResponse{UUID: userUUID} - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) -} - -// getUserHandler handles GET /api/users/{uuid} -func (s *Server) getUserHandler(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - userUUID := vars["uuid"] - - if userUUID == "" { - http.Error(w, "types.User UUID is required", http.StatusBadRequest) - return - } - - var user types.User - err := s.db.View(func(txn *badger.Txn) error { - item, err := txn.Get([]byte(auth.UserStorageKey(userUUID))) - if err != nil { - return err - } - - return item.Value(func(val []byte) error { - return json.Unmarshal(val, &user) - }) - }) - - if err == badger.ErrKeyNotFound { - http.Error(w, "types.User not found", http.StatusNotFound) - return - } - - if err != nil { - s.logger.WithError(err).Error("Failed to get user") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - response := types.GetUserResponse{ - UUID: user.UUID, - NicknameHash: user.NicknameHash, - Groups: user.Groups, - CreatedAt: user.CreatedAt, - UpdatedAt: user.UpdatedAt, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) -} - -// updateUserHandler handles PUT /api/users/{uuid} -func (s *Server) updateUserHandler(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - userUUID := vars["uuid"] - - if userUUID == "" { - http.Error(w, "types.User UUID is required", http.StatusBadRequest) - return - } - - var req types.UpdateUserRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Bad Request", http.StatusBadRequest) - return - } - - err := s.db.Update(func(txn *badger.Txn) error { - // Get existing user - item, err := txn.Get([]byte(auth.UserStorageKey(userUUID))) - if err != nil { - return err - } - - var user types.User - err = item.Value(func(val []byte) error { - return json.Unmarshal(val, &user) - }) - if err != nil { - return err - } - - // Update fields if provided - now := time.Now().Unix() - user.UpdatedAt = now - - if req.Nickname != "" { - user.NicknameHash = utils.HashUserNickname(req.Nickname) - } - - if req.Groups != nil { - user.Groups = req.Groups - } - - // Store updated user - userData, err := json.Marshal(user) - if err != nil { - return err - } - - return txn.Set([]byte(auth.UserStorageKey(userUUID)), userData) - }) - - if err == badger.ErrKeyNotFound { - http.Error(w, "types.User not found", http.StatusNotFound) - return - } - - if err != nil { - s.logger.WithError(err).Error("Failed to update user") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - s.logger.WithField("user_uuid", userUUID).Info("types.User updated successfully") - w.WriteHeader(http.StatusOK) -} - -// deleteUserHandler handles DELETE /api/users/{uuid} -func (s *Server) deleteUserHandler(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - userUUID := vars["uuid"] - - if userUUID == "" { - http.Error(w, "types.User UUID is required", http.StatusBadRequest) - return - } - - err := s.db.Update(func(txn *badger.Txn) error { - // Check if user exists first - _, err := txn.Get([]byte(auth.UserStorageKey(userUUID))) - if err != nil { - return err - } - - // Delete the user - return txn.Delete([]byte(auth.UserStorageKey(userUUID))) - }) - - if err == badger.ErrKeyNotFound { - http.Error(w, "types.User not found", http.StatusNotFound) - return - } - - if err != nil { - s.logger.WithError(err).Error("Failed to delete user") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - s.logger.WithField("user_uuid", userUUID).Info("types.User deleted successfully") - w.WriteHeader(http.StatusOK) -} - -// Phase 2: types.Group Management API Handlers - -// createGroupHandler handles POST /api/groups -func (s *Server) createGroupHandler(w http.ResponseWriter, r *http.Request) { - var req types.CreateGroupRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Bad Request", http.StatusBadRequest) - return - } - - if req.Groupname == "" { - http.Error(w, "Groupname is required", http.StatusBadRequest) - return - } - - // Generate UUID for the group - groupUUID := uuid.New().String() - now := time.Now().Unix() - - group := types.Group{ - UUID: groupUUID, - NameHash: utils.HashGroupName(req.Groupname), - Members: req.Members, - CreatedAt: now, - UpdatedAt: now, - } - - if group.Members == nil { - group.Members = []string{} - } - - // Store group in BadgerDB - groupData, err := json.Marshal(group) - if err != nil { - s.logger.WithError(err).Error("Failed to marshal group data") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - err = s.db.Update(func(txn *badger.Txn) error { - return txn.Set([]byte(auth.GroupStorageKey(groupUUID)), groupData) - }) - - if err != nil { - s.logger.WithError(err).Error("Failed to store group") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - s.logger.WithField("group_uuid", groupUUID).Info("types.Group created successfully") - - response := types.CreateGroupResponse{UUID: groupUUID} - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) -} - -// getGroupHandler handles GET /api/groups/{uuid} -func (s *Server) getGroupHandler(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - groupUUID := vars["uuid"] - - if groupUUID == "" { - http.Error(w, "types.Group UUID is required", http.StatusBadRequest) - return - } - - var group types.Group - err := s.db.View(func(txn *badger.Txn) error { - item, err := txn.Get([]byte(auth.GroupStorageKey(groupUUID))) - if err != nil { - return err - } - - return item.Value(func(val []byte) error { - return json.Unmarshal(val, &group) - }) - }) - - if err == badger.ErrKeyNotFound { - http.Error(w, "types.Group not found", http.StatusNotFound) - return - } - - if err != nil { - s.logger.WithError(err).Error("Failed to get group") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - response := types.GetGroupResponse{ - UUID: group.UUID, - NameHash: group.NameHash, - Members: group.Members, - CreatedAt: group.CreatedAt, - UpdatedAt: group.UpdatedAt, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) -} - -// updateGroupHandler handles PUT /api/groups/{uuid} -func (s *Server) updateGroupHandler(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - groupUUID := vars["uuid"] - - if groupUUID == "" { - http.Error(w, "types.Group UUID is required", http.StatusBadRequest) - return - } - - var req types.UpdateGroupRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Bad Request", http.StatusBadRequest) - return - } - - err := s.db.Update(func(txn *badger.Txn) error { - // Get existing group - item, err := txn.Get([]byte(auth.GroupStorageKey(groupUUID))) - if err != nil { - return err - } - - var group types.Group - err = item.Value(func(val []byte) error { - return json.Unmarshal(val, &group) - }) - if err != nil { - return err - } - - // Update fields - now := time.Now().Unix() - group.UpdatedAt = now - group.Members = req.Members - - if group.Members == nil { - group.Members = []string{} - } - - // Store updated group - groupData, err := json.Marshal(group) - if err != nil { - return err - } - - return txn.Set([]byte(auth.GroupStorageKey(groupUUID)), groupData) - }) - - if err == badger.ErrKeyNotFound { - http.Error(w, "types.Group not found", http.StatusNotFound) - return - } - - if err != nil { - s.logger.WithError(err).Error("Failed to update group") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - s.logger.WithField("group_uuid", groupUUID).Info("types.Group updated successfully") - w.WriteHeader(http.StatusOK) -} - -// deleteGroupHandler handles DELETE /api/groups/{uuid} -func (s *Server) deleteGroupHandler(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - groupUUID := vars["uuid"] - - if groupUUID == "" { - http.Error(w, "types.Group UUID is required", http.StatusBadRequest) - return - } - - err := s.db.Update(func(txn *badger.Txn) error { - // Check if group exists first - _, err := txn.Get([]byte(auth.GroupStorageKey(groupUUID))) - if err != nil { - return err - } - - // Delete the group - return txn.Delete([]byte(auth.GroupStorageKey(groupUUID))) - }) - - if err == badger.ErrKeyNotFound { - http.Error(w, "types.Group not found", http.StatusNotFound) - return - } - - if err != nil { - s.logger.WithError(err).Error("Failed to delete group") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - s.logger.WithField("group_uuid", groupUUID).Info("types.Group deleted successfully") - w.WriteHeader(http.StatusOK) -} - -// Phase 2: Token Management API Handlers - -// createTokenHandler handles POST /api/tokens -func (s *Server) createTokenHandler(w http.ResponseWriter, r *http.Request) { - var req types.CreateTokenRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "Bad Request", http.StatusBadRequest) - return - } - - if req.UserUUID == "" { - http.Error(w, "types.User UUID is required", http.StatusBadRequest) - return - } - - if len(req.Scopes) == 0 { - http.Error(w, "At least one scope is required", http.StatusBadRequest) - return - } - - // Verify user exists - err := s.db.View(func(txn *badger.Txn) error { - _, err := txn.Get([]byte(auth.UserStorageKey(req.UserUUID))) - return err - }) - - if err == badger.ErrKeyNotFound { - http.Error(w, "types.User not found", http.StatusNotFound) - return - } - - if err != nil { - s.logger.WithError(err).Error("Failed to verify user") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - // Generate JWT token - tokenString, expiresAt, err := generateJWT(req.UserUUID, req.Scopes, 1) // 1 hour default - if err != nil { - s.logger.WithError(err).Error("Failed to generate JWT token") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - // Store token in BadgerDB - err = s.storeAPIToken(tokenString, req.UserUUID, req.Scopes, expiresAt) - if err != nil { - s.logger.WithError(err).Error("Failed to store API token") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - s.logger.WithFields(logrus.Fields{ - "user_uuid": req.UserUUID, - "scopes": req.Scopes, - "expires_at": expiresAt, - }).Info("API token created successfully") - - response := types.CreateTokenResponse{ - Token: tokenString, - ExpiresAt: expiresAt, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) -} - -// Phase 2: Revision History API Handlers - -// getRevisionHistoryHandler handles GET /api/data/{key}/history -func (s *Server) getRevisionHistoryHandler(w http.ResponseWriter, r *http.Request) { - // Check if revision history is enabled - if !s.config.RevisionHistoryEnabled { - http.Error(w, "Revision history is disabled", http.StatusServiceUnavailable) - return - } - - vars := mux.Vars(r) - key := vars["key"] - - if key == "" { - http.Error(w, "Key is required", http.StatusBadRequest) - return - } - - revisions, err := s.getRevisionHistory(key) - if err != nil { - s.logger.WithError(err).WithField("key", key).Error("Failed to get revision history") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - if len(revisions) == 0 { - http.Error(w, "No revisions found", http.StatusNotFound) - return - } - - response := map[string]interface{}{ - "revisions": revisions, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(response) -} - -// getSpecificRevisionHandler handles GET /api/data/{key}/history/{revision} -func (s *Server) getSpecificRevisionHandler(w http.ResponseWriter, r *http.Request) { - // Check if revision history is enabled - if !s.config.RevisionHistoryEnabled { - http.Error(w, "Revision history is disabled", http.StatusServiceUnavailable) - return - } - - vars := mux.Vars(r) - key := vars["key"] - revisionStr := vars["revision"] - - if key == "" { - http.Error(w, "Key is required", http.StatusBadRequest) - return - } - - if revisionStr == "" { - http.Error(w, "Revision is required", http.StatusBadRequest) - return - } - - revision, err := strconv.Atoi(revisionStr) - if err != nil { - http.Error(w, "Invalid revision number", http.StatusBadRequest) - return - } - - storedValue, err := s.getSpecificRevision(key, revision) - if err == badger.ErrKeyNotFound { - http.Error(w, "Revision not found", http.StatusNotFound) - return - } - - if err != nil { - s.logger.WithError(err).WithFields(logrus.Fields{ - "key": key, - "revision": revision, - }).Error("Failed to get specific revision") - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(storedValue) -} - -// Phase 2: Backup Status API Handler - -// getBackupStatusHandler handles GET /api/backup/status -func (s *Server) getBackupStatusHandler(w http.ResponseWriter, r *http.Request) { - status := s.getBackupStatus() - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(status) -} // performMerkleSync performs a synchronization round using Merkle Trees -func (s *Server) performMerkleSync() { - members := s.getHealthyMembers() - if len(members) == 0 { - s.logger.Debug("No healthy members for Merkle sync") - return - } - - // Select random peer - peer := members[rand.Intn(len(members))] - - s.logger.WithField("peer", peer.Address).Info("Starting Merkle tree sync") - - localRoot := s.getMerkleRoot() - if localRoot == nil { - s.logger.Error("Local Merkle root is nil, cannot perform sync") - return - } - - // 1. Get remote peer's Merkle root - remoteRootResp, err := s.requestMerkleRoot(peer.Address) - if err != nil { - s.logger.WithError(err).WithField("peer", peer.Address).Error("Failed to get remote Merkle root") - s.markPeerUnhealthy(peer.ID) - return - } - remoteRoot := remoteRootResp.Root - - // 2. Compare roots and start recursive diffing if they differ - if !bytes.Equal(localRoot.Hash, remoteRoot.Hash) { - s.logger.WithFields(logrus.Fields{ - "peer": peer.Address, - "local_root": hex.EncodeToString(localRoot.Hash), - "remote_root": hex.EncodeToString(remoteRoot.Hash), - }).Info("Merkle roots differ, starting recursive diff") - s.diffMerkleTreesRecursive(peer.Address, localRoot, remoteRoot) - } else { - s.logger.WithField("peer", peer.Address).Info("Merkle roots match, no sync needed") - } - - s.logger.WithField("peer", peer.Address).Info("Completed Merkle tree sync") -} // requestMerkleRoot requests the Merkle root from a peer -func (s *Server) requestMerkleRoot(peerAddress string) (*types.MerkleRootResponse, error) { - client := &http.Client{Timeout: 10 * time.Second} - url := fmt.Sprintf("http://%s/merkle_tree/root", peerAddress) - - resp, err := client.Get(url) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("peer returned status %d for Merkle root", resp.StatusCode) - } - - var merkleRootResp types.MerkleRootResponse - if err := json.NewDecoder(resp.Body).Decode(&merkleRootResp); err != nil { - return nil, err - } - return &merkleRootResp, nil -} // diffMerkleTreesRecursive recursively compares local and remote Merkle tree nodes -func (s *Server) diffMerkleTreesRecursive(peerAddress string, localNode, remoteNode *types.MerkleNode) { - // If hashes match, this subtree is in sync. - if bytes.Equal(localNode.Hash, remoteNode.Hash) { - return - } - - // Hashes differ, need to go deeper. - // Request children from the remote peer for the current range. - req := types.MerkleTreeDiffRequest{ - ParentNode: *remoteNode, // We are asking the remote peer about its children for this range - LocalHash: localNode.Hash, // Our hash for this range - } - - remoteDiffResp, err := s.requestMerkleDiff(peerAddress, req) - if err != nil { - s.logger.WithError(err).WithFields(logrus.Fields{ - "peer": peerAddress, - "start_key": localNode.StartKey, - "end_key": localNode.EndKey, - }).Error("Failed to get Merkle diff from peer") - return - } - - if len(remoteDiffResp.Keys) > 0 { - // This is a leaf-level diff, we have the actual keys that are different. - // Fetch and compare each key. - s.logger.WithFields(logrus.Fields{ - "peer": peerAddress, - "start_key": localNode.StartKey, - "end_key": localNode.EndKey, - "num_keys": len(remoteDiffResp.Keys), - }).Info("Found divergent keys, fetching and comparing data") - - for _, key := range remoteDiffResp.Keys { - // Fetch the individual key from the peer - remoteStoredValue, err := s.fetchSingleKVFromPeer(peerAddress, key) - if err != nil { - s.logger.WithError(err).WithFields(logrus.Fields{ - "peer": peerAddress, - "key": key, - }).Error("Failed to fetch single KV from peer during diff") - continue - } - - localStoredValue, localExists := s.getLocalData(key) - - if remoteStoredValue == nil { - // Key was deleted on remote, delete locally if exists - if localExists { - s.logger.WithField("key", key).Info("Key deleted on remote, deleting locally") - s.deleteKVLocally(key, localStoredValue.Timestamp) // Need a local delete function - } - continue - } - - if !localExists { - // Local data is missing, store the remote data - if err := s.storeReplicatedDataWithMetadata(key, remoteStoredValue); err != nil { - s.logger.WithError(err).WithField("key", key).Error("Failed to store missing replicated data") - } else { - s.logger.WithField("key", key).Info("Fetched and stored missing data from peer") - } - } else if localStoredValue.Timestamp < remoteStoredValue.Timestamp { - // Remote is newer, store the remote data - if err := s.storeReplicatedDataWithMetadata(key, remoteStoredValue); err != nil { - s.logger.WithError(err).WithField("key", key).Error("Failed to store newer replicated data") - } else { - s.logger.WithField("key", key).Info("Fetched and stored newer data from peer") - } - } else if localStoredValue.Timestamp == remoteStoredValue.Timestamp && localStoredValue.UUID != remoteStoredValue.UUID { - // Timestamp collision, engage conflict resolution - remotePair := types.PairsByTimeResponse{ // Re-use this struct for conflict resolution - Path: key, - UUID: remoteStoredValue.UUID, - Timestamp: remoteStoredValue.Timestamp, - } - resolved, err := s.resolveConflict(key, localStoredValue, &remotePair, peerAddress) - if err != nil { - s.logger.WithError(err).WithField("path", key).Error("Failed to resolve conflict during Merkle sync") - } else if resolved { - s.logger.WithField("path", key).Info("Conflict resolved, updated local data during Merkle sync") - } else { - s.logger.WithField("path", key).Info("Conflict resolved, kept local data during Merkle sync") - } - } - // If local is newer or same timestamp and same UUID, do nothing. - } - } else if len(remoteDiffResp.Children) > 0 { - // Not a leaf level, continue recursive diff for children. - localPairs, err := s.getAllKVPairsForMerkleTree() - if err != nil { - s.logger.WithError(err).Error("Failed to get KV pairs for local children comparison") - return - } - - for _, remoteChild := range remoteDiffResp.Children { - // Build the local Merkle node for this child's range - localChildNode, err := s.buildMerkleTreeFromPairs(s.filterPairsByRange(localPairs, remoteChild.StartKey, remoteChild.EndKey)) - if err != nil { - s.logger.WithError(err).WithFields(logrus.Fields{ - "start_key": remoteChild.StartKey, - "end_key": remoteChild.EndKey, - }).Error("Failed to build local child node for diff") - continue - } - - if localChildNode == nil || !bytes.Equal(localChildNode.Hash, remoteChild.Hash) { - // If local child node is nil (meaning local has no data in this range) - // or hashes differ, then we need to fetch the data. - if localChildNode == nil { - s.logger.WithFields(logrus.Fields{ - "peer": peerAddress, - "start_key": remoteChild.StartKey, - "end_key": remoteChild.EndKey, - }).Info("Local node missing data in remote child's range, fetching full range") - s.fetchAndStoreRange(peerAddress, remoteChild.StartKey, remoteChild.EndKey) - } else { - s.diffMerkleTreesRecursive(peerAddress, localChildNode, &remoteChild) - } - } - } - } -} // requestMerkleDiff requests children hashes or keys for a given node/range from a peer -func (s *Server) requestMerkleDiff(peerAddress string, req types.MerkleTreeDiffRequest) (*types.MerkleTreeDiffResponse, error) { - jsonData, err := json.Marshal(req) - if err != nil { - return nil, err - } - - client := &http.Client{Timeout: 10 * time.Second} - url := fmt.Sprintf("http://%s/merkle_tree/diff", peerAddress) - - resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData)) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("peer returned status %d for Merkle diff", resp.StatusCode) - } - - var diffResp types.MerkleTreeDiffResponse - if err := json.NewDecoder(resp.Body).Decode(&diffResp); err != nil { - return nil, err - } - return &diffResp, nil -} // fetchSingleKVFromPeer fetches a single KV pair from a peer -func (s *Server) fetchSingleKVFromPeer(peerAddress, path string) (*types.StoredValue, error) { - client := &http.Client{Timeout: 5 * time.Second} - url := fmt.Sprintf("http://%s/kv/%s", peerAddress, path) - - resp, err := client.Get(url) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusNotFound { - return nil, nil // Key might have been deleted on the peer - } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("peer returned status %d for path %s", resp.StatusCode, path) - } - - var storedValue types.StoredValue - if err := json.NewDecoder(resp.Body).Decode(&storedValue); err != nil { - return nil, fmt.Errorf("failed to decode types.StoredValue from peer: %v", err) - } - return &storedValue, nil -} // storeReplicatedDataWithMetadata stores replicated data preserving its original metadata -func (s *Server) storeReplicatedDataWithMetadata(path string, storedValue *types.StoredValue) error { - valueBytes, err := json.Marshal(storedValue) - if err != nil { - return err - } - - return s.db.Update(func(txn *badger.Txn) error { - // Store main data - if err := txn.Set([]byte(path), valueBytes); err != nil { - return err - } - - // Store timestamp index - indexKey := fmt.Sprintf("_ts:%020d:%s", storedValue.Timestamp, path) - return txn.Set([]byte(indexKey), []byte(storedValue.UUID)) - }) -} // deleteKVLocally deletes a key-value pair and its associated timestamp index locally. -func (s *Server) deleteKVLocally(path string, timestamp int64) error { - return s.db.Update(func(txn *badger.Txn) error { - if err := txn.Delete([]byte(path)); err != nil { - return err - } - indexKey := fmt.Sprintf("_ts:%020d:%s", timestamp, path) - return txn.Delete([]byte(indexKey)) - }) -} // fetchAndStoreRange fetches a range of KV pairs from a peer and stores them locally -func (s *Server) fetchAndStoreRange(peerAddress string, startKey, endKey string) error { - req := types.KVRangeRequest{ - StartKey: startKey, - EndKey: endKey, - Limit: 0, // No limit - } - jsonData, err := json.Marshal(req) - if err != nil { - return err - } - - client := &http.Client{Timeout: 30 * time.Second} // Longer timeout for range fetches - url := fmt.Sprintf("http://%s/kv_range", peerAddress) - - resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData)) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("peer returned status %d for KV range fetch", resp.StatusCode) - } - - var rangeResp types.KVRangeResponse - if err := json.NewDecoder(resp.Body).Decode(&rangeResp); err != nil { - return err - } - - for _, pair := range rangeResp.Pairs { - // Use storeReplicatedDataWithMetadata to preserve original UUID/Timestamp - if err := s.storeReplicatedDataWithMetadata(pair.Path, &pair.StoredValue); err != nil { - s.logger.WithError(err).WithFields(logrus.Fields{ - "peer": peerAddress, - "path": pair.Path, - }).Error("Failed to store fetched range data") - } else { - s.logger.WithFields(logrus.Fields{ - "peer": peerAddress, - "path": pair.Path, - }).Debug("Stored data from fetched range") - } - } - return nil -} // Bootstrap - join cluster using seed nodes -func (s *Server) bootstrap() { - if len(s.config.SeedNodes) == 0 { - s.logger.Info("No seed nodes configured, running as standalone") - return - } - - s.logger.Info("Starting bootstrap process") - s.setMode("syncing") - - // Try to join cluster via each seed node - joined := false - for _, seedAddr := range s.config.SeedNodes { - if s.attemptJoin(seedAddr) { - joined = true - break - } - } - - if !joined { - s.logger.Warn("Failed to join cluster via seed nodes, running as standalone") - s.setMode("normal") - return - } - - // Wait a bit for member discovery - time.Sleep(2 * time.Second) - - // Perform gradual sync (now Merkle-based) - s.performGradualSync() - - // Switch to normal mode - s.setMode("normal") - s.logger.Info("Bootstrap completed, entering normal mode") -} // Attempt to join cluster via a seed node -func (s *Server) attemptJoin(seedAddr string) bool { - joinReq := types.JoinRequest{ - ID: s.config.NodeID, - Address: fmt.Sprintf("%s:%d", s.config.BindAddress, s.config.Port), - JoinedTimestamp: time.Now().UnixMilli(), - } - - jsonData, err := json.Marshal(joinReq) - if err != nil { - s.logger.WithError(err).Error("Failed to marshal join request") - return false - } - - client := &http.Client{Timeout: 10 * time.Second} - url := fmt.Sprintf("http://%s/members/join", seedAddr) - - resp, err := client.Post(url, "application/json", bytes.NewBuffer(jsonData)) - if err != nil { - s.logger.WithFields(logrus.Fields{ - "seed": seedAddr, - "error": err.Error(), - }).Warn("Failed to contact seed node") - return false - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - s.logger.WithFields(logrus.Fields{ - "seed": seedAddr, - "status": resp.StatusCode, - }).Warn("Seed node rejected join request") - return false - } - - // Process member list response - var memberList []types.Member - if err := json.NewDecoder(resp.Body).Decode(&memberList); err != nil { - s.logger.WithError(err).Error("Failed to decode member list from seed") - return false - } - - // Add all members to our local list - for _, member := range memberList { - if member.ID != s.config.NodeID { - s.addMember(&member) - } - } - - s.logger.WithFields(logrus.Fields{ - "seed": seedAddr, - "member_count": len(memberList), - }).Info("Successfully joined cluster") - - return true -} // Perform gradual sync (Merkle-based version) -func (s *Server) performGradualSync() { - s.logger.Info("Starting gradual sync (Merkle-based)") - - members := s.getHealthyMembers() - if len(members) == 0 { - s.logger.Info("No healthy members for gradual sync") - return - } - - // For now, just do a few rounds of Merkle sync - for i := 0; i < 3; i++ { - s.performMerkleSync() // Use Merkle sync - time.Sleep(time.Duration(s.config.ThrottleDelayMs) * time.Millisecond) - } - - s.logger.Info("Gradual sync completed") -} // Resolve conflict between local and remote data using majority vote and oldest node tie-breaker -func (s *Server) resolveConflict(path string, localData *types.StoredValue, remotePair *types.PairsByTimeResponse, peerAddress string) (bool, error) { - s.logger.WithFields(logrus.Fields{ - "path": path, - "timestamp": localData.Timestamp, - "local_uuid": localData.UUID, - "remote_uuid": remotePair.UUID, - }).Info("Starting conflict resolution with majority vote") - - // Get list of healthy members for voting - members := s.getHealthyMembers() - if len(members) == 0 { - // No other members to consult, use oldest node rule (local vs remote) - // We'll consider the peer as the "remote" node for comparison - return s.resolveByOldestNode(localData, remotePair, peerAddress) - } - - // Query all healthy members for their version of this path - votes := make(map[string]int) // UUID -> vote count - uuidToTimestamp := make(map[string]int64) - uuidToJoinedTime := make(map[string]int64) - - // Add our local vote - votes[localData.UUID] = 1 - uuidToTimestamp[localData.UUID] = localData.Timestamp - uuidToJoinedTime[localData.UUID] = s.getJoinedTimestamp() - - // Add the remote peer's vote - // Note: remotePair.Timestamp is used here, but for a full Merkle sync, - // we would have already fetched the full remoteStoredValue. - // For consistency, let's assume remotePair accurately reflects the remote's types.StoredValue metadata. - votes[remotePair.UUID]++ // Increment vote, as it's already counted implicitly by being the source of divergence - uuidToTimestamp[remotePair.UUID] = remotePair.Timestamp - // We'll need to get the peer's joined timestamp - s.membersMu.RLock() - for _, member := range s.members { - if member.Address == peerAddress { - uuidToJoinedTime[remotePair.UUID] = member.JoinedTimestamp - break - } - } - s.membersMu.RUnlock() - - // Query other members - for _, member := range members { - if member.Address == peerAddress { - continue // Already handled the peer - } - - memberData, err := s.fetchSingleKVFromPeer(member.Address, path) // Use the new fetchSingleKVFromPeer - if err != nil || memberData == nil { - s.logger.WithFields(logrus.Fields{ - "member_address": member.Address, - "path": path, - "error": err, - }).Debug("Failed to query member for data during conflict resolution, skipping vote") - continue - } - - // Only count votes for data with the same timestamp (as per original logic for collision) - if memberData.Timestamp == localData.Timestamp { - votes[memberData.UUID]++ - if _, exists := uuidToTimestamp[memberData.UUID]; !exists { - uuidToTimestamp[memberData.UUID] = memberData.Timestamp - uuidToJoinedTime[memberData.UUID] = member.JoinedTimestamp - } - } - } - - // Find the UUID with majority votes - maxVotes := 0 - var winningUUIDs []string - - for uuid, voteCount := range votes { - if voteCount > maxVotes { - maxVotes = voteCount - winningUUIDs = []string{uuid} - } else if voteCount == maxVotes { - winningUUIDs = append(winningUUIDs, uuid) - } - } - - var winnerUUID string - if len(winningUUIDs) == 1 { - winnerUUID = winningUUIDs[0] - } else { - // Tie-breaker: oldest node (earliest joined timestamp) - oldestJoinedTime := int64(0) - for _, uuid := range winningUUIDs { - joinedTime := uuidToJoinedTime[uuid] - if oldestJoinedTime == 0 || (joinedTime != 0 && joinedTime < oldestJoinedTime) { // joinedTime can be 0 if not found - oldestJoinedTime = joinedTime - winnerUUID = uuid - } - } - - s.logger.WithFields(logrus.Fields{ - "path": path, - "tied_votes": maxVotes, - "winner_uuid": winnerUUID, - "oldest_joined": oldestJoinedTime, - }).Info("Resolved conflict using oldest node tie-breaker") - } - - // If remote UUID wins, fetch and store the remote data - if winnerUUID == remotePair.UUID { - // We need the full types.StoredValue for the winning remote data. - // Since remotePair only has UUID/Timestamp, we must fetch the data. - winningRemoteStoredValue, err := s.fetchSingleKVFromPeer(peerAddress, path) - if err != nil || winningRemoteStoredValue == nil { - return false, fmt.Errorf("failed to fetch winning remote data for conflict resolution: %v", err) - } - - err = s.storeReplicatedDataWithMetadata(path, winningRemoteStoredValue) - if err != nil { - return false, fmt.Errorf("failed to store winning data: %v", err) - } - - s.logger.WithFields(logrus.Fields{ - "path": path, - "winner_uuid": winnerUUID, - "winner_votes": maxVotes, - "total_nodes": len(members) + 2, // +2 for local and peer - }).Info("Conflict resolved: remote data wins") - - return true, nil - } - - // Local data wins, no action needed - s.logger.WithFields(logrus.Fields{ - "path": path, - "winner_uuid": winnerUUID, - "winner_votes": maxVotes, - "total_nodes": len(members) + 2, - }).Info("Conflict resolved: local data wins") - - return false, nil -} // Resolve conflict using oldest node rule when no other members available -func (s *Server) resolveByOldestNode(localData *types.StoredValue, remotePair *types.PairsByTimeResponse, peerAddress string) (bool, error) { - // Find the peer's joined timestamp - peerJoinedTime := int64(0) - s.membersMu.RLock() - for _, member := range s.members { - if member.Address == peerAddress { - peerJoinedTime = member.JoinedTimestamp - break - } - } - s.membersMu.RUnlock() - - localJoinedTime := s.getJoinedTimestamp() - - // Oldest node wins - if peerJoinedTime > 0 && peerJoinedTime < localJoinedTime { - // Peer is older, fetch remote data - winningRemoteStoredValue, err := s.fetchSingleKVFromPeer(peerAddress, remotePair.Path) - if err != nil || winningRemoteStoredValue == nil { - return false, fmt.Errorf("failed to fetch data from older node for conflict resolution: %v", err) - } - - err = s.storeReplicatedDataWithMetadata(remotePair.Path, winningRemoteStoredValue) - if err != nil { - return false, fmt.Errorf("failed to store data from older node: %v", err) - } - - s.logger.WithFields(logrus.Fields{ - "path": remotePair.Path, - "local_joined": localJoinedTime, - "peer_joined": peerJoinedTime, - "winner": "remote", - }).Info("Conflict resolved using oldest node rule") - - return true, nil - } - - // Local node is older or equal, keep local data - s.logger.WithFields(logrus.Fields{ - "path": remotePair.Path, - "local_joined": localJoinedTime, - "peer_joined": peerJoinedTime, - "winner": "local", - }).Info("Conflict resolved using oldest node rule") - - return false, nil -} // getLocalData is a utility to retrieve a types.StoredValue from local DB. -func (s *Server) getLocalData(path string) (*types.StoredValue, bool) { - var storedValue types.StoredValue - err := s.db.View(func(txn *badger.Txn) error { - item, err := txn.Get([]byte(path)) - if err != nil { - return err - } - - return item.Value(func(val []byte) error { - return json.Unmarshal(val, &storedValue) - }) - }) - - if err != nil { - return nil, false - } - - return &storedValue, true -} func main() { configPath := "./config.yaml" @@ -3586,7 +298,7 @@ func main() { os.Exit(1) } - server, err := NewServer(cfg) + kvServer, err := server.NewServer(cfg) if err != nil { fmt.Fprintf(os.Stderr, "Failed to create server: %v\n", err) os.Exit(1) @@ -3598,11 +310,11 @@ func main() { go func() { <-sigCh - server.Stop() + kvServer.Stop() }() - if err := server.Start(); err != nil && err != http.ErrServerClosed { + if err := kvServer.Start(); err != nil && err != http.ErrServerClosed { fmt.Fprintf(os.Stderr, "Server error: %v\n", err) os.Exit(1) } -} \ No newline at end of file +} diff --git a/server/handlers.go b/server/handlers.go index 874778f..d750063 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -1,22 +1,37 @@ package server import ( + "bytes" + "crypto/sha256" "encoding/json" "fmt" "net" "net/http" + "sort" "strconv" "strings" "time" - "github.com/dgraph-io/badger/v3" + "github.com/dgraph-io/badger/v4" + "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "github.com/gorilla/mux" "github.com/sirupsen/logrus" - "github.com/kalzu/kvs/types" + "kvs/auth" + "kvs/types" + "kvs/utils" ) +// JWTClaims represents the custom claims for JWT tokens +type JWTClaims struct { + UserUUID string `json:"user_uuid"` + Scopes []string `json:"scopes"` + jwt.RegisteredClaims +} + +var jwtSigningKey = []byte("your-super-secret-key") // TODO: Move to config + // healthHandler returns server health status func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) { mode := s.getMode() @@ -295,7 +310,7 @@ func (s *Server) pairsByTimeHandler(w http.ResponseWriter, r *http.Request) { // The original logic for prefix filtering here was incomplete. // For Merkle tree sync, this handler is no longer used for core sync. // It remains as a client-facing API. - + for it.Seek(prefix); it.ValidForPrefix(prefix) && len(pairs) < req.Limit; it.Next() { item := it.Item() key := string(item.Key()) @@ -396,7 +411,898 @@ func (s *Server) gossipHandler(w http.ResponseWriter, r *http.Request) { // getBackupStatusHandler returns current backup status func (s *Server) getBackupStatusHandler(w http.ResponseWriter, r *http.Request) { status := s.getBackupStatus() - + w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(status) -} \ No newline at end of file +} + +// getMerkleRootHandler returns the current Merkle tree root +func (s *Server) getMerkleRootHandler(w http.ResponseWriter, r *http.Request) { + root := s.syncService.GetMerkleRoot() + if root == nil { + http.Error(w, "Merkle tree not initialized", http.StatusInternalServerError) + return + } + + resp := types.MerkleRootResponse{ + Root: root, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +func (s *Server) getMerkleDiffHandler(w http.ResponseWriter, r *http.Request) { + var req types.MerkleTreeDiffRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + localPairs, err := s.getAllKVPairsForMerkleTree() + if err != nil { + s.logger.WithError(err).Error("Failed to get KV pairs for Merkle diff") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + // Build the local types.MerkleNode for the requested range to compare with the remote's hash + localSubTreeRoot, err := s.buildMerkleTreeFromPairs(s.filterPairsByRange(localPairs, req.ParentNode.StartKey, req.ParentNode.EndKey)) + if err != nil { + s.logger.WithError(err).Error("Failed to build sub-Merkle tree for diff request") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + if localSubTreeRoot == nil { // This can happen if the range is empty locally + localSubTreeRoot = &types.MerkleNode{Hash: calculateHash([]byte("empty_tree")), StartKey: req.ParentNode.StartKey, EndKey: req.ParentNode.EndKey} + } + + resp := types.MerkleTreeDiffResponse{} + + // If hashes match, no need to send children or keys + if bytes.Equal(req.LocalHash, localSubTreeRoot.Hash) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return + } + + // Hashes differ, so we need to provide more detail. + // Get all keys within the parent node's range locally + var keysInRange []string + for key := range s.filterPairsByRange(localPairs, req.ParentNode.StartKey, req.ParentNode.EndKey) { + keysInRange = append(keysInRange, key) + } + sort.Strings(keysInRange) + + const diffLeafThreshold = 10 // If a range has <= 10 keys, we consider it a leaf-level diff + + if len(keysInRange) <= diffLeafThreshold { + // This is a leaf-level diff, return the actual keys in the range + resp.Keys = keysInRange + } else { + // types.Group keys into sub-ranges and return their types.MerkleNode representations + // For simplicity, let's split the range into two halves. + mid := len(keysInRange) / 2 + + leftKeys := keysInRange[:mid] + rightKeys := keysInRange[mid:] + + if len(leftKeys) > 0 { + leftRangePairs := s.filterPairsByRange(localPairs, leftKeys[0], leftKeys[len(leftKeys)-1]) + leftNode, err := s.buildMerkleTreeFromPairs(leftRangePairs) + if err != nil { + s.logger.WithError(err).Error("Failed to build left child node for diff") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + if leftNode != nil { + resp.Children = append(resp.Children, *leftNode) + } + } + + if len(rightKeys) > 0 { + rightRangePairs := s.filterPairsByRange(localPairs, rightKeys[0], rightKeys[len(rightKeys)-1]) + rightNode, err := s.buildMerkleTreeFromPairs(rightRangePairs) + if err != nil { + s.logger.WithError(err).Error("Failed to build right child node for diff") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + if rightNode != nil { + resp.Children = append(resp.Children, *rightNode) + } + } + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} +func (s *Server) getKVRangeHandler(w http.ResponseWriter, r *http.Request) { + var req types.KVRangeRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + var pairs []struct { + Path string `json:"path"` + StoredValue types.StoredValue `json:"stored_value"` + } + + err := s.db.View(func(txn *badger.Txn) error { + opts := badger.DefaultIteratorOptions + opts.PrefetchValues = true + it := txn.NewIterator(opts) + defer it.Close() + + count := 0 + // Start iteration from the requested StartKey + for it.Seek([]byte(req.StartKey)); it.Valid(); it.Next() { + item := it.Item() + key := string(item.Key()) + + if strings.HasPrefix(key, "_ts:") { + continue // Skip index keys + } + + // Stop if we exceed the EndKey (if provided) + if req.EndKey != "" && key > req.EndKey { + break + } + + // Stop if we hit the limit (if provided) + if req.Limit > 0 && count >= req.Limit { + break + } + + var storedValue types.StoredValue + err := item.Value(func(val []byte) error { + return json.Unmarshal(val, &storedValue) + }) + if err != nil { + s.logger.WithError(err).WithField("key", key).Warn("Failed to unmarshal stored value in KV range, skipping") + continue + } + + pairs = append(pairs, struct { + Path string `json:"path"` + StoredValue types.StoredValue `json:"stored_value"` + }{Path: key, StoredValue: storedValue}) + count++ + } + return nil + }) + + if err != nil { + s.logger.WithError(err).Error("Failed to query KV range") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(types.KVRangeResponse{Pairs: pairs}) +} +func (s *Server) createUserHandler(w http.ResponseWriter, r *http.Request) { + var req types.CreateUserRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + if req.Nickname == "" { + http.Error(w, "Nickname is required", http.StatusBadRequest) + return + } + + // Generate UUID for the user + userUUID := uuid.New().String() + now := time.Now().Unix() + + user := types.User{ + UUID: userUUID, + NicknameHash: utils.HashUserNickname(req.Nickname), + Groups: []string{}, + CreatedAt: now, + UpdatedAt: now, + } + + // Store user in BadgerDB + userData, err := json.Marshal(user) + if err != nil { + s.logger.WithError(err).Error("Failed to marshal user data") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + err = s.db.Update(func(txn *badger.Txn) error { + return txn.Set([]byte(auth.UserStorageKey(userUUID)), userData) + }) + + if err != nil { + s.logger.WithError(err).Error("Failed to store user") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + s.logger.WithField("user_uuid", userUUID).Info("types.User created successfully") + + response := types.CreateUserResponse{UUID: userUUID} + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} +func (s *Server) getUserHandler(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + userUUID := vars["uuid"] + + if userUUID == "" { + http.Error(w, "types.User UUID is required", http.StatusBadRequest) + return + } + + var user types.User + err := s.db.View(func(txn *badger.Txn) error { + item, err := txn.Get([]byte(auth.UserStorageKey(userUUID))) + if err != nil { + return err + } + + return item.Value(func(val []byte) error { + return json.Unmarshal(val, &user) + }) + }) + + if err == badger.ErrKeyNotFound { + http.Error(w, "types.User not found", http.StatusNotFound) + return + } + + if err != nil { + s.logger.WithError(err).Error("Failed to get user") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + response := types.GetUserResponse{ + UUID: user.UUID, + NicknameHash: user.NicknameHash, + Groups: user.Groups, + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} +func (s *Server) updateUserHandler(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + userUUID := vars["uuid"] + + if userUUID == "" { + http.Error(w, "types.User UUID is required", http.StatusBadRequest) + return + } + + var req types.UpdateUserRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + err := s.db.Update(func(txn *badger.Txn) error { + // Get existing user + item, err := txn.Get([]byte(auth.UserStorageKey(userUUID))) + if err != nil { + return err + } + + var user types.User + err = item.Value(func(val []byte) error { + return json.Unmarshal(val, &user) + }) + if err != nil { + return err + } + + // Update fields if provided + now := time.Now().Unix() + user.UpdatedAt = now + + if req.Nickname != "" { + user.NicknameHash = utils.HashUserNickname(req.Nickname) + } + + if req.Groups != nil { + user.Groups = req.Groups + } + + // Store updated user + userData, err := json.Marshal(user) + if err != nil { + return err + } + + return txn.Set([]byte(auth.UserStorageKey(userUUID)), userData) + }) + + if err == badger.ErrKeyNotFound { + http.Error(w, "types.User not found", http.StatusNotFound) + return + } + + if err != nil { + s.logger.WithError(err).Error("Failed to update user") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + s.logger.WithField("user_uuid", userUUID).Info("types.User updated successfully") + w.WriteHeader(http.StatusOK) +} +func (s *Server) deleteUserHandler(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + userUUID := vars["uuid"] + + if userUUID == "" { + http.Error(w, "types.User UUID is required", http.StatusBadRequest) + return + } + + err := s.db.Update(func(txn *badger.Txn) error { + // Check if user exists first + _, err := txn.Get([]byte(auth.UserStorageKey(userUUID))) + if err != nil { + return err + } + + // Delete the user + return txn.Delete([]byte(auth.UserStorageKey(userUUID))) + }) + + if err == badger.ErrKeyNotFound { + http.Error(w, "types.User not found", http.StatusNotFound) + return + } + + if err != nil { + s.logger.WithError(err).Error("Failed to delete user") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + s.logger.WithField("user_uuid", userUUID).Info("types.User deleted successfully") + w.WriteHeader(http.StatusOK) +} +func (s *Server) createGroupHandler(w http.ResponseWriter, r *http.Request) { + var req types.CreateGroupRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + if req.Groupname == "" { + http.Error(w, "Groupname is required", http.StatusBadRequest) + return + } + + // Generate UUID for the group + groupUUID := uuid.New().String() + now := time.Now().Unix() + + group := types.Group{ + UUID: groupUUID, + NameHash: utils.HashGroupName(req.Groupname), + Members: req.Members, + CreatedAt: now, + UpdatedAt: now, + } + + if group.Members == nil { + group.Members = []string{} + } + + // Store group in BadgerDB + groupData, err := json.Marshal(group) + if err != nil { + s.logger.WithError(err).Error("Failed to marshal group data") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + err = s.db.Update(func(txn *badger.Txn) error { + return txn.Set([]byte(auth.GroupStorageKey(groupUUID)), groupData) + }) + + if err != nil { + s.logger.WithError(err).Error("Failed to store group") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + s.logger.WithField("group_uuid", groupUUID).Info("types.Group created successfully") + + response := types.CreateGroupResponse{UUID: groupUUID} + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} +func (s *Server) getGroupHandler(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + groupUUID := vars["uuid"] + + if groupUUID == "" { + http.Error(w, "types.Group UUID is required", http.StatusBadRequest) + return + } + + var group types.Group + err := s.db.View(func(txn *badger.Txn) error { + item, err := txn.Get([]byte(auth.GroupStorageKey(groupUUID))) + if err != nil { + return err + } + + return item.Value(func(val []byte) error { + return json.Unmarshal(val, &group) + }) + }) + + if err == badger.ErrKeyNotFound { + http.Error(w, "types.Group not found", http.StatusNotFound) + return + } + + if err != nil { + s.logger.WithError(err).Error("Failed to get group") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + response := types.GetGroupResponse{ + UUID: group.UUID, + NameHash: group.NameHash, + Members: group.Members, + CreatedAt: group.CreatedAt, + UpdatedAt: group.UpdatedAt, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} +func (s *Server) updateGroupHandler(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + groupUUID := vars["uuid"] + + if groupUUID == "" { + http.Error(w, "types.Group UUID is required", http.StatusBadRequest) + return + } + + var req types.UpdateGroupRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + err := s.db.Update(func(txn *badger.Txn) error { + // Get existing group + item, err := txn.Get([]byte(auth.GroupStorageKey(groupUUID))) + if err != nil { + return err + } + + var group types.Group + err = item.Value(func(val []byte) error { + return json.Unmarshal(val, &group) + }) + if err != nil { + return err + } + + // Update fields + now := time.Now().Unix() + group.UpdatedAt = now + group.Members = req.Members + + if group.Members == nil { + group.Members = []string{} + } + + // Store updated group + groupData, err := json.Marshal(group) + if err != nil { + return err + } + + return txn.Set([]byte(auth.GroupStorageKey(groupUUID)), groupData) + }) + + if err == badger.ErrKeyNotFound { + http.Error(w, "types.Group not found", http.StatusNotFound) + return + } + + if err != nil { + s.logger.WithError(err).Error("Failed to update group") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + s.logger.WithField("group_uuid", groupUUID).Info("types.Group updated successfully") + w.WriteHeader(http.StatusOK) +} +func (s *Server) deleteGroupHandler(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + groupUUID := vars["uuid"] + + if groupUUID == "" { + http.Error(w, "types.Group UUID is required", http.StatusBadRequest) + return + } + + err := s.db.Update(func(txn *badger.Txn) error { + // Check if group exists first + _, err := txn.Get([]byte(auth.GroupStorageKey(groupUUID))) + if err != nil { + return err + } + + // Delete the group + return txn.Delete([]byte(auth.GroupStorageKey(groupUUID))) + }) + + if err == badger.ErrKeyNotFound { + http.Error(w, "types.Group not found", http.StatusNotFound) + return + } + + if err != nil { + s.logger.WithError(err).Error("Failed to delete group") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + s.logger.WithField("group_uuid", groupUUID).Info("types.Group deleted successfully") + w.WriteHeader(http.StatusOK) +} +func (s *Server) createTokenHandler(w http.ResponseWriter, r *http.Request) { + var req types.CreateTokenRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + if req.UserUUID == "" { + http.Error(w, "types.User UUID is required", http.StatusBadRequest) + return + } + + if len(req.Scopes) == 0 { + http.Error(w, "At least one scope is required", http.StatusBadRequest) + return + } + + // Verify user exists + err := s.db.View(func(txn *badger.Txn) error { + _, err := txn.Get([]byte(auth.UserStorageKey(req.UserUUID))) + return err + }) + + if err == badger.ErrKeyNotFound { + http.Error(w, "types.User not found", http.StatusNotFound) + return + } + + if err != nil { + s.logger.WithError(err).Error("Failed to verify user") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + // Generate JWT token + tokenString, expiresAt, err := generateJWT(req.UserUUID, req.Scopes, 1) // 1 hour default + if err != nil { + s.logger.WithError(err).Error("Failed to generate JWT token") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + // Store token in BadgerDB + err = s.storeAPIToken(tokenString, req.UserUUID, req.Scopes, expiresAt) + if err != nil { + s.logger.WithError(err).Error("Failed to store API token") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + s.logger.WithFields(logrus.Fields{ + "user_uuid": req.UserUUID, + "scopes": req.Scopes, + "expires_at": expiresAt, + }).Info("API token created successfully") + + response := types.CreateTokenResponse{ + Token: tokenString, + ExpiresAt: expiresAt, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} +func (s *Server) getRevisionHistoryHandler(w http.ResponseWriter, r *http.Request) { + // Check if revision history is enabled + if !s.config.RevisionHistoryEnabled { + http.Error(w, "Revision history is disabled", http.StatusServiceUnavailable) + return + } + + vars := mux.Vars(r) + key := vars["key"] + + if key == "" { + http.Error(w, "Key is required", http.StatusBadRequest) + return + } + + revisions, err := s.getRevisionHistory(key) + if err != nil { + s.logger.WithError(err).WithField("key", key).Error("Failed to get revision history") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + if len(revisions) == 0 { + http.Error(w, "No revisions found", http.StatusNotFound) + return + } + + response := map[string]interface{}{ + "revisions": revisions, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} +func (s *Server) getSpecificRevisionHandler(w http.ResponseWriter, r *http.Request) { + // Check if revision history is enabled + if !s.config.RevisionHistoryEnabled { + http.Error(w, "Revision history is disabled", http.StatusServiceUnavailable) + return + } + + vars := mux.Vars(r) + key := vars["key"] + revisionStr := vars["revision"] + + if key == "" { + http.Error(w, "Key is required", http.StatusBadRequest) + return + } + + if revisionStr == "" { + http.Error(w, "Revision is required", http.StatusBadRequest) + return + } + + revision, err := strconv.Atoi(revisionStr) + if err != nil { + http.Error(w, "Invalid revision number", http.StatusBadRequest) + return + } + + storedValue, err := s.getSpecificRevision(key, revision) + if err == badger.ErrKeyNotFound { + http.Error(w, "Revision not found", http.StatusNotFound) + return + } + + if err != nil { + s.logger.WithError(err).WithFields(logrus.Fields{ + "key": key, + "revision": revision, + }).Error("Failed to get specific revision") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(storedValue) +} + +// calculateHash computes SHA256 hash of data +func calculateHash(data []byte) []byte { + h := sha256.New() + h.Write(data) + return h.Sum(nil) +} + +// getAllKVPairsForMerkleTree retrieves all key-value pairs for Merkle tree operations +func (s *Server) getAllKVPairsForMerkleTree() (map[string]*types.StoredValue, error) { + pairs := make(map[string]*types.StoredValue) + err := s.db.View(func(txn *badger.Txn) error { + opts := badger.DefaultIteratorOptions + opts.PrefetchValues = true // We need the values for hashing + it := txn.NewIterator(opts) + defer it.Close() + + // Iterate over all actual data keys (not _ts: indexes) + for it.Rewind(); it.Valid(); it.Next() { + item := it.Item() + key := string(item.Key()) + + if strings.HasPrefix(key, "_ts:") { + continue // Skip index keys + } + + var storedValue types.StoredValue + err := item.Value(func(val []byte) error { + return json.Unmarshal(val, &storedValue) + }) + if err != nil { + s.logger.WithError(err).WithField("key", key).Warn("Failed to unmarshal stored value for Merkle tree, skipping") + continue + } + pairs[key] = &storedValue + } + return nil + }) + if err != nil { + return nil, err + } + return pairs, nil +} + +// buildMerkleTreeFromPairs constructs a Merkle tree from key-value pairs +func (s *Server) buildMerkleTreeFromPairs(pairs map[string]*types.StoredValue) (*types.MerkleNode, error) { + if len(pairs) == 0 { + return &types.MerkleNode{Hash: calculateHash([]byte("empty_tree")), StartKey: "", EndKey: ""}, nil + } + + // Sort keys to ensure consistent tree structure + keys := make([]string, 0, len(pairs)) + for k := range pairs { + keys = append(keys, k) + } + sort.Strings(keys) + + // Create leaf nodes + leafNodes := make([]*types.MerkleNode, len(keys)) + for i, key := range keys { + storedValue := pairs[key] + hash := s.calculateLeafHash(key, storedValue) + leafNodes[i] = &types.MerkleNode{Hash: hash, StartKey: key, EndKey: key} + } + + // Recursively build parent nodes + return s.buildMerkleTreeRecursive(leafNodes) +} + +// filterPairsByRange filters key-value pairs by key range +func (s *Server) filterPairsByRange(allPairs map[string]*types.StoredValue, startKey, endKey string) map[string]*types.StoredValue { + filtered := make(map[string]*types.StoredValue) + for key, value := range allPairs { + if (startKey == "" || key >= startKey) && (endKey == "" || key <= endKey) { + filtered[key] = value + } + } + return filtered +} + +// calculateLeafHash generates a hash for a leaf node +func (s *Server) calculateLeafHash(path string, storedValue *types.StoredValue) []byte { + // Concatenate path, UUID, timestamp, and the raw data bytes for hashing + // Ensure a consistent order of fields for hashing + dataToHash := bytes.Buffer{} + dataToHash.WriteString(path) + dataToHash.WriteByte(':') + dataToHash.WriteString(storedValue.UUID) + dataToHash.WriteByte(':') + dataToHash.WriteString(strconv.FormatInt(storedValue.Timestamp, 10)) + dataToHash.WriteByte(':') + dataToHash.Write(storedValue.Data) // Use raw bytes of json.RawMessage + + return calculateHash(dataToHash.Bytes()) +} + +// buildMerkleTreeRecursive builds Merkle tree recursively from nodes +func (s *Server) buildMerkleTreeRecursive(nodes []*types.MerkleNode) (*types.MerkleNode, error) { + if len(nodes) == 0 { + return nil, nil + } + if len(nodes) == 1 { + return nodes[0], nil + } + + var nextLevel []*types.MerkleNode + for i := 0; i < len(nodes); i += 2 { + left := nodes[i] + var right *types.MerkleNode + if i+1 < len(nodes) { + right = nodes[i+1] + } + + var combinedHash []byte + var endKey string + + if right != nil { + combinedHash = calculateHash(append(left.Hash, right.Hash...)) + endKey = right.EndKey + } else { + // Odd number of nodes, promote the left node + combinedHash = left.Hash + endKey = left.EndKey + } + + parentNode := &types.MerkleNode{ + Hash: combinedHash, + StartKey: left.StartKey, + EndKey: endKey, + } + nextLevel = append(nextLevel, parentNode) + } + return s.buildMerkleTreeRecursive(nextLevel) +} + +// generateJWT creates a new JWT token for a user with specified scopes +func generateJWT(userUUID string, scopes []string, expirationHours int) (string, int64, error) { + if expirationHours <= 0 { + expirationHours = 1 // Default to 1 hour + } + + now := time.Now() + expiresAt := now.Add(time.Duration(expirationHours) * time.Hour) + + claims := JWTClaims{ + UserUUID: userUUID, + Scopes: scopes, + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(expiresAt), + Issuer: "kvs-server", + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString(jwtSigningKey) + if err != nil { + return "", 0, err + } + + return tokenString, expiresAt.Unix(), nil +} +func (s *Server) storeAPIToken(tokenString string, userUUID string, scopes []string, expiresAt int64) error { + tokenHash := utils.HashToken(tokenString) + + apiToken := types.APIToken{ + TokenHash: tokenHash, + UserUUID: userUUID, + Scopes: scopes, + IssuedAt: time.Now().Unix(), + ExpiresAt: expiresAt, + } + + tokenData, err := json.Marshal(apiToken) + if err != nil { + return err + } + + return s.db.Update(func(txn *badger.Txn) error { + entry := badger.NewEntry([]byte(auth.TokenStorageKey(tokenHash)), tokenData) + + // Set TTL to the token expiration time + ttl := time.Until(time.Unix(expiresAt, 0)) + if ttl > 0 { + entry = entry.WithTTL(ttl) + } + + return txn.SetEntry(entry) + }) +} + +// getRevisionHistory retrieves revision history for a key +func (s *Server) getRevisionHistory(key string) ([]map[string]interface{}, error) { + return s.revisionService.GetRevisionHistory(key) +} + +// getSpecificRevision retrieves a specific revision of a key +func (s *Server) getSpecificRevision(key string, revision int) (*types.StoredValue, error) { + return s.revisionService.GetSpecificRevision(key, revision) +} diff --git a/server/lifecycle.go b/server/lifecycle.go index 507f296..aec4718 100644 --- a/server/lifecycle.go +++ b/server/lifecycle.go @@ -74,21 +74,6 @@ func (s *Server) startBackgroundTasks() { // bootstrap joins cluster using seed nodes via bootstrap service func (s *Server) bootstrap() { - if len(s.config.SeedNodes) == 0 { - s.logger.Info("No seed nodes configured, running as standalone") - return - } - - s.logger.Info("Starting bootstrap process") - s.setMode("syncing") - // Use bootstrap service to join cluster - if err := s.bootstrapService.JoinCluster(); err != nil { - s.logger.WithError(err).Error("Failed to join cluster") - s.setMode("normal") - return - } - - s.setMode("normal") - s.logger.Info("Successfully joined cluster") -} \ No newline at end of file + s.bootstrapService.Bootstrap() +} diff --git a/server/routes.go b/server/routes.go index d0cf5d8..630b7d0 100644 --- a/server/routes.go +++ b/server/routes.go @@ -51,4 +51,4 @@ func (s *Server) setupRoutes() *mux.Router { router.HandleFunc("/api/backup/status", s.getBackupStatusHandler).Methods("GET") return router -} \ No newline at end of file +} diff --git a/server/server.go b/server/server.go index ebd2fa4..e1632ee 100644 --- a/server/server.go +++ b/server/server.go @@ -7,46 +7,47 @@ import ( "os" "path/filepath" "sync" + "time" - "github.com/dgraph-io/badger/v3" + "github.com/dgraph-io/badger/v4" "github.com/robfig/cron/v3" "github.com/sirupsen/logrus" - "github.com/kalzu/kvs/auth" - "github.com/kalzu/kvs/cluster" - "github.com/kalzu/kvs/storage" - "github.com/kalzu/kvs/types" + "kvs/auth" + "kvs/cluster" + "kvs/storage" + "kvs/types" ) // Server represents the KVS node type Server struct { - config *types.Config - db *badger.DB - mode string // "normal", "read-only", "syncing" - modeMu sync.RWMutex - logger *logrus.Logger - httpServer *http.Server - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup - + config *types.Config + db *badger.DB + mode string // "normal", "read-only", "syncing" + modeMu sync.RWMutex + logger *logrus.Logger + httpServer *http.Server + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + // Cluster services gossipService *cluster.GossipService syncService *cluster.SyncService merkleService *cluster.MerkleService bootstrapService *cluster.BootstrapService - + // Storage services storageService *storage.StorageService revisionService *storage.RevisionService - + // Phase 2: Backup system - cronScheduler *cron.Cron // Cron scheduler for backups - backupStatus types.BackupStatus // Current backup status - backupMu sync.RWMutex // Protects backup status - + cronScheduler *cron.Cron // Cron scheduler for backups + backupStatus types.BackupStatus // Current backup status + backupMu sync.RWMutex // Protects backup status + // Authentication service - authService *auth.AuthService + authService *auth.AuthService } // NewServer initializes and returns a new Server instance @@ -109,7 +110,7 @@ func NewServer(config *types.Config) (*Server, error) { return nil, fmt.Errorf("failed to initialize storage service: %v", err) } server.storageService = storageService - + // Initialize revision service server.revisionService = storage.NewRevisionService(storageService) @@ -168,9 +169,9 @@ func (s *Server) getJoinedTimestamp() int64 { func (s *Server) getBackupStatus() types.BackupStatus { s.backupMu.RLock() defer s.backupMu.RUnlock() - + status := s.backupStatus - + // Calculate next backup time if scheduler is running if s.cronScheduler != nil && len(s.cronScheduler.Entries()) > 0 { nextRun := s.cronScheduler.Entries()[0].Next @@ -178,6 +179,6 @@ func (s *Server) getBackupStatus() types.BackupStatus { status.NextBackupTime = nextRun.Unix() } } - + return status -} \ No newline at end of file +}