package main import ( "net/http" "sync" "time" ) // RateLimiter implements per-IP rate limiting type RateLimiter struct { mu sync.RWMutex visitors map[string]*visitor limit int // max requests window time.Duration // time window } type visitor struct { requests []time.Time mu sync.Mutex } func NewRateLimiter(limit int, window time.Duration) *RateLimiter { rl := &RateLimiter{ visitors: make(map[string]*visitor), limit: limit, window: window, } // Cleanup old visitors every 5 minutes go func() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { rl.cleanup() } }() return rl } func (rl *RateLimiter) getVisitor(ip string) *visitor { rl.mu.Lock() defer rl.mu.Unlock() v, exists := rl.visitors[ip] if !exists { v = &visitor{ requests: make([]time.Time, 0), } rl.visitors[ip] = v } return v } func (rl *RateLimiter) Allow(ip string) bool { v := rl.getVisitor(ip) v.mu.Lock() defer v.mu.Unlock() now := time.Now() cutoff := now.Add(-rl.window) // Remove old requests outside the time window validRequests := make([]time.Time, 0) for _, req := range v.requests { if req.After(cutoff) { validRequests = append(validRequests, req) } } v.requests = validRequests // Check if limit exceeded if len(v.requests) >= rl.limit { return false } // Add current request v.requests = append(v.requests, now) return true } func (rl *RateLimiter) cleanup() { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() cutoff := now.Add(-rl.window * 2) // Keep data for 2x window for ip, v := range rl.visitors { v.mu.Lock() if len(v.requests) == 0 || (len(v.requests) > 0 && v.requests[len(v.requests)-1].Before(cutoff)) { delete(rl.visitors, ip) } v.mu.Unlock() } } // RateLimitMiddleware wraps handlers with rate limiting func RateLimitMiddleware(rl *RateLimiter, next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ip := getIP(r) if !rl.Allow(ip) { logger.Warn("RATE_LIMIT_EXCEEDED: Too many requests from IP %s", ip) http.Error(w, "Too Many Requests", http.StatusTooManyRequests) return } next(w, r) } } // SecurityHeadersMiddleware adds security headers to all responses func SecurityHeadersMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // HSTS: Force HTTPS for 1 year, include subdomains w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") // Prevent clickjacking w.Header().Set("X-Frame-Options", "DENY") // Prevent MIME sniffing w.Header().Set("X-Content-Type-Options", "nosniff") // XSS Protection (legacy browsers) w.Header().Set("X-XSS-Protection", "1; mode=block") // Content Security Policy // This is restrictive - adjust if you need to load external resources csp := "default-src 'self'; " + "script-src 'self' 'unsafe-inline'; " + // unsafe-inline needed for embedded scripts in templates "style-src 'self' 'unsafe-inline'; " + // unsafe-inline needed for embedded styles "img-src 'self' data:; " + "font-src 'self'; " + "connect-src 'self'; " + "frame-ancestors 'none'; " + "base-uri 'self'; " + "form-action 'self'" w.Header().Set("Content-Security-Policy", csp) // Referrer Policy w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // Permissions Policy (formerly Feature-Policy) w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=(), payment=()") next.ServeHTTP(w, r) }) } // MaxBytesMiddleware limits request body size func MaxBytesMiddleware(maxBytes int64, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxBytes) next.ServeHTTP(w, r) }) } // ValidateInput performs basic input validation and sanitization func ValidateInput(input string, maxLength int) bool { if len(input) > maxLength { return false } // Check for null bytes (security risk) for _, c := range input { if c == 0 { return false } } return true } // APIKeyAuthMiddleware validates API key from Authorization header func APIKeyAuthMiddleware(store *APIKeyStore, next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") // Expected format: "Bearer " if authHeader == "" { logger.Warn("API_KEY_MISSING: Request from IP %s", getIP(r)) http.Error(w, "Missing Authorization header", http.StatusUnauthorized) return } // Parse Bearer token var apiKey string if len(authHeader) > 7 && authHeader[:7] == "Bearer " { apiKey = authHeader[7:] } else { logger.Warn("API_KEY_INVALID_FORMAT: Request from IP %s", getIP(r)) http.Error(w, "Invalid Authorization header format. Use: Bearer ", http.StatusUnauthorized) return } // Validate API key key, valid := store.Validate(apiKey) if !valid { logger.Warn("API_KEY_INVALID: Failed auth from IP %s", getIP(r)) http.Error(w, "Invalid or disabled API key", http.StatusUnauthorized) return } // Record usage store.RecordUsage(apiKey) logger.Info("API_KEY_AUTH: %s (type: %s) from IP %s", key.Name, key.WorkerType, getIP(r)) next(w, r) } }