212 lines
5.3 KiB
Go
212 lines
5.3 KiB
Go
package main
|
|
|
|
import (
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// RateLimiter implements per-IP rate limiting
|
|
type RateLimiter struct {
|
|
mu sync.RWMutex
|
|
visitors map[string]*visitor
|
|
limit int // max requests
|
|
window time.Duration // time window
|
|
}
|
|
|
|
type visitor struct {
|
|
requests []time.Time
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
|
rl := &RateLimiter{
|
|
visitors: make(map[string]*visitor),
|
|
limit: limit,
|
|
window: window,
|
|
}
|
|
|
|
// Cleanup old visitors every 5 minutes
|
|
go func() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
rl.cleanup()
|
|
}
|
|
}()
|
|
|
|
return rl
|
|
}
|
|
|
|
func (rl *RateLimiter) getVisitor(ip string) *visitor {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
v, exists := rl.visitors[ip]
|
|
if !exists {
|
|
v = &visitor{
|
|
requests: make([]time.Time, 0),
|
|
}
|
|
rl.visitors[ip] = v
|
|
}
|
|
return v
|
|
}
|
|
|
|
func (rl *RateLimiter) Allow(ip string) bool {
|
|
v := rl.getVisitor(ip)
|
|
v.mu.Lock()
|
|
defer v.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
cutoff := now.Add(-rl.window)
|
|
|
|
// Remove old requests outside the time window
|
|
validRequests := make([]time.Time, 0)
|
|
for _, req := range v.requests {
|
|
if req.After(cutoff) {
|
|
validRequests = append(validRequests, req)
|
|
}
|
|
}
|
|
v.requests = validRequests
|
|
|
|
// Check if limit exceeded
|
|
if len(v.requests) >= rl.limit {
|
|
return false
|
|
}
|
|
|
|
// Add current request
|
|
v.requests = append(v.requests, now)
|
|
return true
|
|
}
|
|
|
|
func (rl *RateLimiter) cleanup() {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
cutoff := now.Add(-rl.window * 2) // Keep data for 2x window
|
|
|
|
for ip, v := range rl.visitors {
|
|
v.mu.Lock()
|
|
if len(v.requests) == 0 || (len(v.requests) > 0 && v.requests[len(v.requests)-1].Before(cutoff)) {
|
|
delete(rl.visitors, ip)
|
|
}
|
|
v.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// RateLimitMiddleware wraps handlers with rate limiting
|
|
func RateLimitMiddleware(rl *RateLimiter, next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
ip := getIP(r)
|
|
|
|
if !rl.Allow(ip) {
|
|
logger.Warn("RATE_LIMIT_EXCEEDED: Too many requests from IP %s", ip)
|
|
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
next(w, r)
|
|
}
|
|
}
|
|
|
|
// SecurityHeadersMiddleware adds security headers to all responses
|
|
func SecurityHeadersMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// HSTS: Force HTTPS for 1 year, include subdomains
|
|
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
|
|
|
|
// Prevent clickjacking
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
|
|
// Prevent MIME sniffing
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
|
|
// XSS Protection (legacy browsers)
|
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
|
|
// Content Security Policy
|
|
// This is restrictive - adjust if you need to load external resources
|
|
csp := "default-src 'self'; " +
|
|
"script-src 'self' 'unsafe-inline'; " + // unsafe-inline needed for embedded scripts in templates
|
|
"style-src 'self' 'unsafe-inline'; " + // unsafe-inline needed for embedded styles
|
|
"img-src 'self' data:; " +
|
|
"font-src 'self'; " +
|
|
"connect-src 'self'; " +
|
|
"frame-ancestors 'none'; " +
|
|
"base-uri 'self'; " +
|
|
"form-action 'self'"
|
|
w.Header().Set("Content-Security-Policy", csp)
|
|
|
|
// Referrer Policy
|
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
|
|
// Permissions Policy (formerly Feature-Policy)
|
|
w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=(), payment=()")
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// MaxBytesMiddleware limits request body size
|
|
func MaxBytesMiddleware(maxBytes int64, next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// ValidateInput performs basic input validation and sanitization
|
|
func ValidateInput(input string, maxLength int) bool {
|
|
if len(input) > maxLength {
|
|
return false
|
|
}
|
|
|
|
// Check for null bytes (security risk)
|
|
for _, c := range input {
|
|
if c == 0 {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// APIKeyAuthMiddleware validates API key from Authorization header
|
|
func APIKeyAuthMiddleware(store *APIKeyStore, next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
authHeader := r.Header.Get("Authorization")
|
|
|
|
// Expected format: "Bearer <api-key>"
|
|
if authHeader == "" {
|
|
logger.Warn("API_KEY_MISSING: Request from IP %s", getIP(r))
|
|
http.Error(w, "Missing Authorization header", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Parse Bearer token
|
|
var apiKey string
|
|
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
|
|
apiKey = authHeader[7:]
|
|
} else {
|
|
logger.Warn("API_KEY_INVALID_FORMAT: Request from IP %s", getIP(r))
|
|
http.Error(w, "Invalid Authorization header format. Use: Bearer <api-key>", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Validate API key
|
|
key, valid := store.Validate(apiKey)
|
|
if !valid {
|
|
logger.Warn("API_KEY_INVALID: Failed auth from IP %s", getIP(r))
|
|
http.Error(w, "Invalid or disabled API key", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Record usage
|
|
store.RecordUsage(apiKey)
|
|
|
|
logger.Info("API_KEY_AUTH: %s (type: %s) from IP %s", key.Name, key.WorkerType, getIP(r))
|
|
next(w, r)
|
|
}
|
|
}
|