package main import ( "crypto/rand" "crypto/sha256" "encoding/base32" "encoding/hex" "encoding/json" "os" "path/filepath" "sync" "syscall" "time" ) type User struct { ID string `json:"id"` TOTPSecret string `json:"totp_secret"` } type UserStore struct { mu sync.RWMutex filePath string crypto *Crypto cache map[string]*encryptedUserEntry } type encryptedUserEntry struct { UserIDHash string `json:"hash"` Data []byte `json:"data"` } type encryptedStore struct { Users []encryptedUserEntry `json:"users"` } func NewUserStore(dataDir string, crypto *Crypto) *UserStore { // Create data directory if it doesn't exist if err := os.MkdirAll(dataDir, 0700); err != nil { logger.Error("Failed to create data directory: %v", err) } filePath := filepath.Join(dataDir, "users.enc") logger.Info("Initialized user store at: %s", filePath) store := &UserStore{ filePath: filePath, crypto: crypto, cache: make(map[string]*encryptedUserEntry), } store.loadCache() return store } func (s *UserStore) hashUserID(userID string) string { hash := sha256.Sum256([]byte(userID)) return hex.EncodeToString(hash[:]) } // acquireFileLock attempts to acquire an exclusive lock on the store file // Returns the file descriptor and an error if locking fails func (s *UserStore) acquireFileLock(forWrite bool) (*os.File, error) { lockPath := s.filePath + ".lock" // Create lock file if it doesn't exist lockFile, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0600) if err != nil { return nil, err } // Try to acquire lock with timeout lockType := syscall.LOCK_SH // Shared lock for reads if forWrite { lockType = syscall.LOCK_EX // Exclusive lock for writes } // Use non-blocking lock with retry maxRetries := 10 for i := 0; i < maxRetries; i++ { err = syscall.Flock(int(lockFile.Fd()), lockType|syscall.LOCK_NB) if err == nil { return lockFile, nil } if err != syscall.EWOULDBLOCK { lockFile.Close() return nil, err } // Wait and retry time.Sleep(100 * time.Millisecond) } lockFile.Close() return nil, syscall.EWOULDBLOCK } // releaseFileLock releases the file lock func (s *UserStore) releaseFileLock(lockFile *os.File) { if lockFile != nil { syscall.Flock(int(lockFile.Fd()), syscall.LOCK_UN) lockFile.Close() } } func (s *UserStore) Reload() error { s.mu.Lock() defer s.mu.Unlock() // Clear existing cache s.cache = make(map[string]*encryptedUserEntry) // Reload from disk return s.loadCacheInternal() } func (s *UserStore) loadCache() error { s.mu.Lock() defer s.mu.Unlock() return s.loadCacheInternal() } func (s *UserStore) loadCacheInternal() error { // Acquire shared lock for reading (allows multiple readers, blocks writers) lockFile, err := s.acquireFileLock(false) if err != nil { logger.Warn("Failed to acquire read lock on user store: %v", err) // Continue without lock - degraded mode } else { defer s.releaseFileLock(lockFile) } // Read encrypted store file encryptedData, err := os.ReadFile(s.filePath) if err != nil { if os.IsNotExist(err) { logger.Info("No existing user store found, starting fresh") return nil } logger.Error("Failed to read store file: %v", err) return err } // Decrypt with server key decryptedData, err := s.crypto.DecryptWithServerKey(encryptedData) if err != nil { logger.Error("Failed to decrypt store: %v", err) return err } var store encryptedStore if err := json.Unmarshal(decryptedData, &store); err != nil { logger.Error("Failed to unmarshal store: %v", err) return err } // Load into cache for i := range store.Users { s.cache[store.Users[i].UserIDHash] = &store.Users[i] } logger.Info("Loaded %d encrypted user entries into cache", len(s.cache)) return nil } func (s *UserStore) save() error { // Acquire exclusive lock for writing (blocks all readers and writers) lockFile, err := s.acquireFileLock(true) if err != nil { logger.Error("Failed to acquire write lock on user store: %v", err) return err } defer s.releaseFileLock(lockFile) // Build store structure from cache store := encryptedStore{ Users: make([]encryptedUserEntry, 0, len(s.cache)), } for _, entry := range s.cache { store.Users = append(store.Users, *entry) } // Marshal to JSON storeData, err := json.Marshal(store) if err != nil { return err } // Encrypt with server key encryptedData, err := s.crypto.EncryptWithServerKey(storeData) if err != nil { logger.Error("Failed to encrypt store: %v", err) return err } // Write to temp file first for atomic operation tempPath := s.filePath + ".tmp" logger.Info("Saving user store with %d entries", len(s.cache)) if err := os.WriteFile(tempPath, encryptedData, 0600); err != nil { logger.Error("Failed to write temp store file: %v", err) return err } // Atomic rename if err := os.Rename(tempPath, s.filePath); err != nil { logger.Error("Failed to rename store file: %v", err) os.Remove(tempPath) return err } return nil } func (s *UserStore) GetUser(userID string) (*User, error) { s.mu.RLock() defer s.mu.RUnlock() userHash := s.hashUserID(userID) entry, exists := s.cache[userHash] if !exists { logger.Warn("User not found in cache") return nil, nil } // Decrypt with user key (derived from user ID) userData, err := s.crypto.DecryptWithUserKey(entry.Data, userID) if err != nil { logger.Error("Failed to decrypt user data: %v", err) return nil, err } var user User if err := json.Unmarshal(userData, &user); err != nil { logger.Error("Failed to unmarshal user data: %v", err) return nil, err } logger.Info("Successfully loaded user: %s", user.ID) return &user, nil } func (s *UserStore) AddUser(userID, totpSecret string) error { s.mu.Lock() defer s.mu.Unlock() user := &User{ ID: userID, TOTPSecret: totpSecret, } // Marshal user data userData, err := json.Marshal(user) if err != nil { return err } // Encrypt with user key (derived from user ID) userEncrypted, err := s.crypto.EncryptWithUserKey(userData, userID) if err != nil { logger.Error("Failed to encrypt with user key: %v", err) return err } // Add to cache userHash := s.hashUserID(userID) s.cache[userHash] = &encryptedUserEntry{ UserIDHash: userHash, Data: userEncrypted, } // Save entire store (encrypted with server key) if err := s.save(); err != nil { return err } logger.Info("Successfully saved user: %s", userID) return nil } func generateSecret() (string, error) { b := make([]byte, 20) if _, err := rand.Read(b); err != nil { return "", err } return base32.StdEncoding.EncodeToString(b), nil }