Added rate limits and some other basic security settings.
This commit is contained in:
107
main.go
107
main.go
@@ -9,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
@@ -21,13 +22,24 @@ const (
|
||||
dbFile = "auth.db"
|
||||
jwtCookie = "auth_token"
|
||||
sessionDuration = 24 * time.Hour
|
||||
maxLoginAttempts = 5
|
||||
rateLimitWindow = 15 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
db *sql.DB
|
||||
jwtSecret string
|
||||
|
||||
// Rate limiting
|
||||
loginAttempts = make(map[string]*attemptTracker)
|
||||
attemptsMux sync.RWMutex
|
||||
)
|
||||
|
||||
type attemptTracker struct {
|
||||
count int
|
||||
firstAttempt time.Time
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
@@ -36,11 +48,10 @@ func main() {
|
||||
generate := flag.Bool("generate", false, "Generate a new TOTP seed and add it to the database")
|
||||
flag.Parse()
|
||||
|
||||
// Load JWT secret from environment or use default (warn if using default)
|
||||
// Load JWT secret from environment
|
||||
jwtSecret = os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
jwtSecret = "change_this_secret_to_something_secure"
|
||||
log.Println("WARNING: Using default JWT secret. Set JWT_SECRET environment variable for production!")
|
||||
log.Fatal("JWT_SECRET environment variable must be set!")
|
||||
}
|
||||
|
||||
var err error
|
||||
@@ -71,6 +82,9 @@ func main() {
|
||||
generateSeed()
|
||||
}
|
||||
|
||||
// Start cleanup goroutine for rate limiting
|
||||
go cleanupRateLimits()
|
||||
|
||||
http.HandleFunc("/verify", verifyHandler)
|
||||
http.HandleFunc("/login", loginHandler)
|
||||
http.HandleFunc("/health", healthHandler)
|
||||
@@ -98,6 +112,68 @@ func generateSeed() {
|
||||
fmt.Println("Use this to set up your authenticator app.")
|
||||
}
|
||||
|
||||
func cleanupRateLimits() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
attemptsMux.Lock()
|
||||
now := time.Now()
|
||||
for ip, tracker := range loginAttempts {
|
||||
if now.Sub(tracker.firstAttempt) > rateLimitWindow {
|
||||
delete(loginAttempts, ip)
|
||||
}
|
||||
}
|
||||
attemptsMux.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func checkRateLimit(ip string) bool {
|
||||
attemptsMux.Lock()
|
||||
defer attemptsMux.Unlock()
|
||||
|
||||
tracker, exists := loginAttempts[ip]
|
||||
if !exists {
|
||||
return true
|
||||
}
|
||||
|
||||
// Reset if window expired
|
||||
if time.Since(tracker.firstAttempt) > rateLimitWindow {
|
||||
delete(loginAttempts, ip)
|
||||
return true
|
||||
}
|
||||
|
||||
return tracker.count < maxLoginAttempts
|
||||
}
|
||||
|
||||
func recordFailedAttempt(ip string) {
|
||||
attemptsMux.Lock()
|
||||
defer attemptsMux.Unlock()
|
||||
|
||||
tracker, exists := loginAttempts[ip]
|
||||
if !exists {
|
||||
loginAttempts[ip] = &attemptTracker{
|
||||
count: 1,
|
||||
firstAttempt: time.Now(),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Reset if window expired
|
||||
if time.Since(tracker.firstAttempt) > rateLimitWindow {
|
||||
tracker.count = 1
|
||||
tracker.firstAttempt = time.Now()
|
||||
return
|
||||
}
|
||||
|
||||
tracker.count++
|
||||
}
|
||||
|
||||
func clearRateLimit(ip string) {
|
||||
attemptsMux.Lock()
|
||||
defer attemptsMux.Unlock()
|
||||
delete(loginAttempts, ip)
|
||||
}
|
||||
|
||||
func healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
@@ -219,6 +295,18 @@ func loginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Rate limiting check
|
||||
clientIP := r.RemoteAddr
|
||||
if !checkRateLimit(clientIP) {
|
||||
data := loginData{
|
||||
Next: r.URL.Query().Get("next"),
|
||||
Error: "Too many failed attempts. Try again in 15 minutes.",
|
||||
}
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
loginTmpl.Execute(w, data)
|
||||
return
|
||||
}
|
||||
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
@@ -231,6 +319,11 @@ func loginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
next = "/"
|
||||
}
|
||||
|
||||
// Validate redirect target to prevent open redirects
|
||||
if !strings.HasPrefix(next, "/") {
|
||||
next = "/"
|
||||
}
|
||||
|
||||
if validateOTP(otpCode) {
|
||||
tokenStr, err := generateJWT()
|
||||
if err != nil {
|
||||
@@ -239,13 +332,16 @@ func loginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear rate limit on successful login
|
||||
clearRateLimit(clientIP)
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: jwtCookie,
|
||||
Value: tokenStr,
|
||||
Expires: time.Now().Add(sessionDuration),
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
// Secure: true, // Uncomment for HTTPS
|
||||
Secure: true,
|
||||
Path: "/",
|
||||
})
|
||||
|
||||
@@ -253,7 +349,8 @@ func loginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Invalid OTP
|
||||
// Invalid OTP - record failed attempt
|
||||
recordFailedAttempt(clientIP)
|
||||
data := loginData{Next: next, Error: "Invalid OTP"}
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
loginTmpl.Execute(w, data)
|
||||
|
||||
Reference in New Issue
Block a user