382 lines
9.0 KiB
Go
382 lines
9.0 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// generateTestServerKey creates a test server key for crypto operations
|
|
func generateTestServerKey() string {
|
|
key := make([]byte, 32)
|
|
rand.Read(key)
|
|
return base64.StdEncoding.EncodeToString(key)
|
|
}
|
|
|
|
// TestFileLockingBasic verifies file locking works
|
|
func TestFileLockingBasic(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "store_lock_test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
// Create test crypto instance
|
|
crypto, err := NewCrypto(generateTestServerKey())
|
|
if err != nil {
|
|
t.Fatalf("Failed to create crypto: %v", err)
|
|
}
|
|
|
|
store := NewUserStore(tempDir, crypto)
|
|
|
|
// Acquire read lock
|
|
lockFile, err := store.acquireFileLock(false)
|
|
if err != nil {
|
|
t.Fatalf("Failed to acquire read lock: %v", err)
|
|
}
|
|
|
|
if lockFile == nil {
|
|
t.Error("Lock file should not be nil")
|
|
}
|
|
|
|
// Release lock
|
|
store.releaseFileLock(lockFile)
|
|
}
|
|
|
|
// TestFileLockingExclusiveBlocksReaders verifies exclusive lock blocks readers
|
|
func TestFileLockingExclusiveBlocksReaders(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "store_exclusive_test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
crypto, err := NewCrypto(generateTestServerKey())
|
|
if err != nil {
|
|
t.Fatalf("Failed to create crypto: %v", err)
|
|
}
|
|
|
|
store := NewUserStore(tempDir, crypto)
|
|
|
|
// Acquire exclusive lock
|
|
writeLock, err := store.acquireFileLock(true)
|
|
if err != nil {
|
|
t.Fatalf("Failed to acquire write lock: %v", err)
|
|
}
|
|
defer store.releaseFileLock(writeLock)
|
|
|
|
// Try to acquire read lock (should fail/timeout quickly)
|
|
done := make(chan bool)
|
|
go func() {
|
|
readLock, err := store.acquireFileLock(false)
|
|
if err == nil {
|
|
store.releaseFileLock(readLock)
|
|
t.Error("Read lock should have been blocked by write lock")
|
|
}
|
|
done <- true
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
// Expected - read lock was blocked
|
|
case <-time.After(2 * time.Second):
|
|
t.Error("Read lock acquisition took too long")
|
|
}
|
|
}
|
|
|
|
// TestFileLockingMultipleReaders verifies multiple readers can coexist
|
|
func TestFileLockingMultipleReaders(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "store_multi_read_test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
crypto, err := NewCrypto(generateTestServerKey())
|
|
if err != nil {
|
|
t.Fatalf("Failed to create crypto: %v", err)
|
|
}
|
|
|
|
store := NewUserStore(tempDir, crypto)
|
|
|
|
// Acquire first read lock
|
|
lock1, err := store.acquireFileLock(false)
|
|
if err != nil {
|
|
t.Fatalf("Failed to acquire first read lock: %v", err)
|
|
}
|
|
defer store.releaseFileLock(lock1)
|
|
|
|
// Acquire second read lock (should succeed)
|
|
lock2, err := store.acquireFileLock(false)
|
|
if err != nil {
|
|
t.Fatalf("Failed to acquire second read lock: %v", err)
|
|
}
|
|
defer store.releaseFileLock(lock2)
|
|
|
|
// Both locks acquired successfully
|
|
}
|
|
|
|
// TestUserStoreAddAndGet verifies basic user storage and retrieval
|
|
func TestUserStoreAddAndGet(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "store_user_test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
crypto, err := NewCrypto(generateTestServerKey())
|
|
if err != nil {
|
|
t.Fatalf("Failed to create crypto: %v", err)
|
|
}
|
|
|
|
store := NewUserStore(tempDir, crypto)
|
|
|
|
testUser := "testuser"
|
|
testSecret := "ABCDEFGHIJKLMNOP"
|
|
|
|
// Add user
|
|
if err := store.AddUser(testUser, testSecret); err != nil {
|
|
t.Fatalf("Failed to add user: %v", err)
|
|
}
|
|
|
|
// Retrieve user
|
|
user, err := store.GetUser(testUser)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get user: %v", err)
|
|
}
|
|
|
|
if user == nil {
|
|
t.Fatal("User should not be nil")
|
|
}
|
|
|
|
if user.ID != testUser {
|
|
t.Errorf("User ID mismatch: expected %s, got %s", testUser, user.ID)
|
|
}
|
|
|
|
if user.TOTPSecret != testSecret {
|
|
t.Errorf("TOTP secret mismatch: expected %s, got %s", testSecret, user.TOTPSecret)
|
|
}
|
|
}
|
|
|
|
// TestUserStoreReload verifies reload doesn't lose data
|
|
func TestUserStoreReload(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "store_reload_test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
crypto, err := NewCrypto(generateTestServerKey())
|
|
if err != nil {
|
|
t.Fatalf("Failed to create crypto: %v", err)
|
|
}
|
|
|
|
store := NewUserStore(tempDir, crypto)
|
|
|
|
// Add user
|
|
if err := store.AddUser("user1", "SECRET1"); err != nil {
|
|
t.Fatalf("Failed to add user: %v", err)
|
|
}
|
|
|
|
// Reload
|
|
if err := store.Reload(); err != nil {
|
|
t.Fatalf("Failed to reload: %v", err)
|
|
}
|
|
|
|
// Verify user still exists
|
|
user, err := store.GetUser("user1")
|
|
if err != nil {
|
|
t.Fatalf("Failed to get user after reload: %v", err)
|
|
}
|
|
|
|
if user == nil {
|
|
t.Error("User should still exist after reload")
|
|
}
|
|
}
|
|
|
|
// TestUserStoreConcurrentAccess verifies thread-safe access
|
|
func TestUserStoreConcurrentAccess(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "store_concurrent_test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
crypto, err := NewCrypto(generateTestServerKey())
|
|
if err != nil {
|
|
t.Fatalf("Failed to create crypto: %v", err)
|
|
}
|
|
|
|
store := NewUserStore(tempDir, crypto)
|
|
|
|
// Add initial user
|
|
if err := store.AddUser("initial", "SECRET"); err != nil {
|
|
t.Fatalf("Failed to add initial user: %v", err)
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
errors := make(chan error, 20)
|
|
|
|
// Concurrent readers
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
for j := 0; j < 10; j++ {
|
|
_, err := store.GetUser("initial")
|
|
if err != nil {
|
|
errors <- err
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Concurrent writers
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
userID := "user" + string(rune(id))
|
|
if err := store.AddUser(userID, "SECRET"+string(rune(id))); err != nil {
|
|
errors <- err
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errors)
|
|
|
|
if len(errors) > 0 {
|
|
for err := range errors {
|
|
t.Errorf("Concurrent access error: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestUserStorePersistence verifies data survives store recreation
|
|
func TestUserStorePersistence(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "store_persist_test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
crypto, err := NewCrypto(generateTestServerKey())
|
|
if err != nil {
|
|
t.Fatalf("Failed to create crypto: %v", err)
|
|
}
|
|
|
|
// Create first store and add user
|
|
store1 := NewUserStore(tempDir, crypto)
|
|
if err := store1.AddUser("persistent", "SECRETDATA"); err != nil {
|
|
t.Fatalf("Failed to add user: %v", err)
|
|
}
|
|
|
|
// Create second store (simulating restart)
|
|
store2 := NewUserStore(tempDir, crypto)
|
|
|
|
// Retrieve user
|
|
user, err := store2.GetUser("persistent")
|
|
if err != nil {
|
|
t.Fatalf("Failed to get user from new store: %v", err)
|
|
}
|
|
|
|
if user == nil {
|
|
t.Error("User should persist across store instances")
|
|
}
|
|
|
|
if user.TOTPSecret != "SECRETDATA" {
|
|
t.Error("User data should match original")
|
|
}
|
|
}
|
|
|
|
// TestUserStoreFileExists verifies store file is created
|
|
func TestUserStoreFileExists(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "store_file_test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
crypto, err := NewCrypto(generateTestServerKey())
|
|
if err != nil {
|
|
t.Fatalf("Failed to create crypto: %v", err)
|
|
}
|
|
|
|
store := NewUserStore(tempDir, crypto)
|
|
|
|
// Add user (triggers save)
|
|
if err := store.AddUser("filetest", "SECRET"); err != nil {
|
|
t.Fatalf("Failed to add user: %v", err)
|
|
}
|
|
|
|
// Verify file exists
|
|
expectedFile := filepath.Join(tempDir, "users.enc")
|
|
if _, err := os.Stat(expectedFile); os.IsNotExist(err) {
|
|
t.Error("Store file should have been created")
|
|
}
|
|
}
|
|
|
|
// TestGenerateSecret verifies TOTP secret generation
|
|
func TestGenerateSecret(t *testing.T) {
|
|
secret, err := generateSecret()
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate secret: %v", err)
|
|
}
|
|
|
|
if len(secret) == 0 {
|
|
t.Error("Generated secret should not be empty")
|
|
}
|
|
|
|
// Base32 encoded 20 bytes should be 32 characters
|
|
expectedLength := 32
|
|
if len(secret) != expectedLength {
|
|
t.Errorf("Expected secret length %d, got %d", expectedLength, len(secret))
|
|
}
|
|
|
|
// Verify two generated secrets are different
|
|
secret2, err := generateSecret()
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate second secret: %v", err)
|
|
}
|
|
|
|
if secret == secret2 {
|
|
t.Error("Generated secrets should be unique")
|
|
}
|
|
}
|
|
|
|
// TestUserHashingConsistency verifies user ID hashing is consistent
|
|
func TestUserHashingConsistency(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "store_hash_test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
crypto, err := NewCrypto(generateTestServerKey())
|
|
if err != nil {
|
|
t.Fatalf("Failed to create crypto: %v", err)
|
|
}
|
|
|
|
store := NewUserStore(tempDir, crypto)
|
|
|
|
userID := "testuser"
|
|
hash1 := store.hashUserID(userID)
|
|
hash2 := store.hashUserID(userID)
|
|
|
|
if hash1 != hash2 {
|
|
t.Error("Same user ID should produce same hash")
|
|
}
|
|
|
|
// Different user should produce different hash
|
|
hash3 := store.hashUserID("differentuser")
|
|
if hash1 == hash3 {
|
|
t.Error("Different users should produce different hashes")
|
|
}
|
|
}
|