package main import ( "database/sql" "flag" "fmt" "html/template" "log" "net/http" "os" "strings" "sync" "time" "github.com/golang-jwt/jwt/v5" "github.com/pquerna/otp/totp" _ "github.com/mattn/go-sqlite3" ) const ( dbFile = "auth.db" jwtCookie = "auth_token" sessionDuration = 24 * time.Hour maxLoginAttempts = 5 rateLimitWindow = 15 * time.Minute ) var ( db *sql.DB jwtSecret string // Rate limiting loginAttempts = make(map[string]*attemptTracker) attemptsMux sync.RWMutex ) type attemptTracker struct { count int firstAttempt time.Time } type Claims struct { jwt.RegisteredClaims } func main() { generate := flag.Bool("generate", false, "Generate a new TOTP seed and add it to the database") flag.Parse() // Load JWT secret from environment jwtSecret = os.Getenv("JWT_SECRET") if jwtSecret == "" { log.Fatal("JWT_SECRET environment variable must be set!") } var err error db, err = sql.Open("sqlite3", dbFile) if err != nil { log.Fatal(err) } defer db.Close() _, err = db.Exec(`CREATE TABLE IF NOT EXISTS seeds (id INTEGER PRIMARY KEY, secret TEXT UNIQUE)`) if err != nil { log.Fatal(err) } if *generate { generateSeed() os.Exit(0) } // Check if there are any seeds, if not, generate one var count int err = db.QueryRow(`SELECT COUNT(*) FROM seeds`).Scan(&count) if err != nil { log.Fatal(err) } if count == 0 { log.Println("No seeds found, generating one...") generateSeed() } // Start cleanup goroutine for rate limiting go cleanupRateLimits() http.HandleFunc("/verify", verifyHandler) http.HandleFunc("/login", loginHandler) http.HandleFunc("/health", healthHandler) log.Println("Starting auth server on :3000") log.Fatal(http.ListenAndServe(":3000", nil)) } func generateSeed() { key, err := totp.Generate(totp.GenerateOpts{ Issuer: "ForwardAuthApp", AccountName: "user", }) if err != nil { log.Fatal(err) } secret := key.Secret() _, err = db.Exec(`INSERT INTO seeds (secret) VALUES (?)`, secret) if err != nil { log.Fatal(err) } fmt.Printf("New TOTP seed generated:\nSecret: %s\nOTPAuth URL: %s\n", secret, key.URL()) fmt.Println("Use this to set up your authenticator app.") } func cleanupRateLimits() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { attemptsMux.Lock() now := time.Now() for ip, tracker := range loginAttempts { if now.Sub(tracker.firstAttempt) > rateLimitWindow { delete(loginAttempts, ip) } } attemptsMux.Unlock() } } func checkRateLimit(ip string) bool { attemptsMux.Lock() defer attemptsMux.Unlock() tracker, exists := loginAttempts[ip] if !exists { return true } // Reset if window expired if time.Since(tracker.firstAttempt) > rateLimitWindow { delete(loginAttempts, ip) return true } return tracker.count < maxLoginAttempts } func recordFailedAttempt(ip string) { attemptsMux.Lock() defer attemptsMux.Unlock() tracker, exists := loginAttempts[ip] if !exists { loginAttempts[ip] = &attemptTracker{ count: 1, firstAttempt: time.Now(), } return } // Reset if window expired if time.Since(tracker.firstAttempt) > rateLimitWindow { tracker.count = 1 tracker.firstAttempt = time.Now() return } tracker.count++ } func clearRateLimit(ip string) { attemptsMux.Lock() defer attemptsMux.Unlock() delete(loginAttempts, ip) } func healthHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) } func verifyHandler(w http.ResponseWriter, r *http.Request) { tokenStr, err := getCookie(r, jwtCookie) if err == nil { _, err = validateJWT(tokenStr) if err == nil { w.WriteHeader(http.StatusNoContent) // 204 return } } // Not authenticated, redirect to login next := r.Header.Get("X-Original-URI") if next == "" { next = "/" } http.Redirect(w, r, "/login?next="+next, http.StatusFound) // 302 } var loginTmpl = template.Must(template.New("login").Parse(`