287 lines
6.5 KiB
Go
287 lines
6.5 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base32"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
)
|
|
|
|
type User struct {
|
|
ID string `json:"id"`
|
|
TOTPSecret string `json:"totp_secret"`
|
|
}
|
|
|
|
type UserStore struct {
|
|
mu sync.RWMutex
|
|
filePath string
|
|
crypto *Crypto
|
|
cache map[string]*encryptedUserEntry
|
|
}
|
|
|
|
type encryptedUserEntry struct {
|
|
UserIDHash string `json:"hash"`
|
|
Data []byte `json:"data"`
|
|
}
|
|
|
|
type encryptedStore struct {
|
|
Users []encryptedUserEntry `json:"users"`
|
|
}
|
|
|
|
func NewUserStore(dataDir string, crypto *Crypto) *UserStore {
|
|
// Create data directory if it doesn't exist
|
|
if err := os.MkdirAll(dataDir, 0700); err != nil {
|
|
logger.Error("Failed to create data directory: %v", err)
|
|
}
|
|
|
|
filePath := filepath.Join(dataDir, "users.enc")
|
|
logger.Info("Initialized user store at: %s", filePath)
|
|
|
|
store := &UserStore{
|
|
filePath: filePath,
|
|
crypto: crypto,
|
|
cache: make(map[string]*encryptedUserEntry),
|
|
}
|
|
|
|
store.loadCache()
|
|
return store
|
|
}
|
|
|
|
func (s *UserStore) hashUserID(userID string) string {
|
|
hash := sha256.Sum256([]byte(userID))
|
|
return hex.EncodeToString(hash[:])
|
|
}
|
|
|
|
// acquireFileLock attempts to acquire an exclusive lock on the store file
|
|
// Returns the file descriptor and an error if locking fails
|
|
func (s *UserStore) acquireFileLock(forWrite bool) (*os.File, error) {
|
|
lockPath := s.filePath + ".lock"
|
|
|
|
// Create lock file if it doesn't exist
|
|
lockFile, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0600)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Try to acquire lock with timeout
|
|
lockType := syscall.LOCK_SH // Shared lock for reads
|
|
if forWrite {
|
|
lockType = syscall.LOCK_EX // Exclusive lock for writes
|
|
}
|
|
|
|
// Use non-blocking lock with retry
|
|
maxRetries := 10
|
|
for i := 0; i < maxRetries; i++ {
|
|
err = syscall.Flock(int(lockFile.Fd()), lockType|syscall.LOCK_NB)
|
|
if err == nil {
|
|
return lockFile, nil
|
|
}
|
|
if err != syscall.EWOULDBLOCK {
|
|
lockFile.Close()
|
|
return nil, err
|
|
}
|
|
// Wait and retry
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|
|
|
|
lockFile.Close()
|
|
return nil, syscall.EWOULDBLOCK
|
|
}
|
|
|
|
// releaseFileLock releases the file lock
|
|
func (s *UserStore) releaseFileLock(lockFile *os.File) {
|
|
if lockFile != nil {
|
|
syscall.Flock(int(lockFile.Fd()), syscall.LOCK_UN)
|
|
lockFile.Close()
|
|
}
|
|
}
|
|
|
|
func (s *UserStore) Reload() error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
// Clear existing cache
|
|
s.cache = make(map[string]*encryptedUserEntry)
|
|
|
|
// Reload from disk
|
|
return s.loadCacheInternal()
|
|
}
|
|
|
|
func (s *UserStore) loadCache() error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
return s.loadCacheInternal()
|
|
}
|
|
|
|
func (s *UserStore) loadCacheInternal() error {
|
|
// Acquire shared lock for reading (allows multiple readers, blocks writers)
|
|
lockFile, err := s.acquireFileLock(false)
|
|
if err != nil {
|
|
logger.Warn("Failed to acquire read lock on user store: %v", err)
|
|
// Continue without lock - degraded mode
|
|
} else {
|
|
defer s.releaseFileLock(lockFile)
|
|
}
|
|
|
|
// Read encrypted store file
|
|
encryptedData, err := os.ReadFile(s.filePath)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
logger.Info("No existing user store found, starting fresh")
|
|
return nil
|
|
}
|
|
logger.Error("Failed to read store file: %v", err)
|
|
return err
|
|
}
|
|
|
|
// Decrypt with server key
|
|
decryptedData, err := s.crypto.DecryptWithServerKey(encryptedData)
|
|
if err != nil {
|
|
logger.Error("Failed to decrypt store: %v", err)
|
|
return err
|
|
}
|
|
|
|
var store encryptedStore
|
|
if err := json.Unmarshal(decryptedData, &store); err != nil {
|
|
logger.Error("Failed to unmarshal store: %v", err)
|
|
return err
|
|
}
|
|
|
|
// Load into cache
|
|
for i := range store.Users {
|
|
s.cache[store.Users[i].UserIDHash] = &store.Users[i]
|
|
}
|
|
|
|
logger.Info("Loaded %d encrypted user entries into cache", len(s.cache))
|
|
return nil
|
|
}
|
|
|
|
func (s *UserStore) save() error {
|
|
// Acquire exclusive lock for writing (blocks all readers and writers)
|
|
lockFile, err := s.acquireFileLock(true)
|
|
if err != nil {
|
|
logger.Error("Failed to acquire write lock on user store: %v", err)
|
|
return err
|
|
}
|
|
defer s.releaseFileLock(lockFile)
|
|
|
|
// Build store structure from cache
|
|
store := encryptedStore{
|
|
Users: make([]encryptedUserEntry, 0, len(s.cache)),
|
|
}
|
|
|
|
for _, entry := range s.cache {
|
|
store.Users = append(store.Users, *entry)
|
|
}
|
|
|
|
// Marshal to JSON
|
|
storeData, err := json.Marshal(store)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Encrypt with server key
|
|
encryptedData, err := s.crypto.EncryptWithServerKey(storeData)
|
|
if err != nil {
|
|
logger.Error("Failed to encrypt store: %v", err)
|
|
return err
|
|
}
|
|
|
|
// Write to temp file first for atomic operation
|
|
tempPath := s.filePath + ".tmp"
|
|
logger.Info("Saving user store with %d entries", len(s.cache))
|
|
if err := os.WriteFile(tempPath, encryptedData, 0600); err != nil {
|
|
logger.Error("Failed to write temp store file: %v", err)
|
|
return err
|
|
}
|
|
|
|
// Atomic rename
|
|
if err := os.Rename(tempPath, s.filePath); err != nil {
|
|
logger.Error("Failed to rename store file: %v", err)
|
|
os.Remove(tempPath)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *UserStore) GetUser(userID string) (*User, error) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
userHash := s.hashUserID(userID)
|
|
entry, exists := s.cache[userHash]
|
|
if !exists {
|
|
logger.Warn("User not found in cache")
|
|
return nil, nil
|
|
}
|
|
|
|
// Decrypt with user key (derived from user ID)
|
|
userData, err := s.crypto.DecryptWithUserKey(entry.Data, userID)
|
|
if err != nil {
|
|
logger.Error("Failed to decrypt user data: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
var user User
|
|
if err := json.Unmarshal(userData, &user); err != nil {
|
|
logger.Error("Failed to unmarshal user data: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
logger.Info("Successfully loaded user: %s", user.ID)
|
|
return &user, nil
|
|
}
|
|
|
|
func (s *UserStore) AddUser(userID, totpSecret string) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
user := &User{
|
|
ID: userID,
|
|
TOTPSecret: totpSecret,
|
|
}
|
|
|
|
// Marshal user data
|
|
userData, err := json.Marshal(user)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Encrypt with user key (derived from user ID)
|
|
userEncrypted, err := s.crypto.EncryptWithUserKey(userData, userID)
|
|
if err != nil {
|
|
logger.Error("Failed to encrypt with user key: %v", err)
|
|
return err
|
|
}
|
|
|
|
// Add to cache
|
|
userHash := s.hashUserID(userID)
|
|
s.cache[userHash] = &encryptedUserEntry{
|
|
UserIDHash: userHash,
|
|
Data: userEncrypted,
|
|
}
|
|
|
|
// Save entire store (encrypted with server key)
|
|
if err := s.save(); err != nil {
|
|
return err
|
|
}
|
|
|
|
logger.Info("Successfully saved user: %s", userID)
|
|
return nil
|
|
}
|
|
|
|
func generateSecret() (string, error) {
|
|
b := make([]byte, 20)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base32.StdEncoding.EncodeToString(b), nil
|
|
}
|