Files
ping_service/manager/main.go
2026-01-08 12:11:26 +02:00

649 lines
19 KiB
Go
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/json"
"encoding/pem"
"flag"
"fmt"
"io"
"log"
"math/big"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/pquerna/otp/totp"
"golang.org/x/crypto/acme/autocert"
)
var (
store *UserStore
sessions = struct {
sync.RWMutex
m map[string]*Session
}{m: make(map[string]*Session)}
logger *Logger
// Rate limiters
authRateLimiter *RateLimiter // Aggressive limit for auth endpoints
apiRateLimiter *RateLimiter // Moderate limit for API endpoints
)
type Session struct {
UserID string
ExpiresAt time.Time
}
func main() {
// --- FLAGS ---
addUser := flag.String("add-user", "", "Add a new user (provide user ID)")
port := flag.String("port", os.Getenv("MANAGER_PORT"), "Port to run the server on (use 443 for Let's Encrypt)")
domain := flag.String("domain", os.Getenv("DYFI_DOMAIN"), "Your dy.fi domain (e.g. example.dy.fi)")
dyfiUser := flag.String("dyfi-user", os.Getenv("DYFI_USER"), "dy.fi username (email)")
dyfiPass := flag.String("dyfi-pass", os.Getenv("DYFI_PASS"), "dy.fi password")
email := flag.String("email", os.Getenv("ACME_EMAIL"), "Email for Let's Encrypt notifications")
logFile := flag.String("log", os.Getenv("LOG_FILE"), "Path to log file for fail2ban")
enableGateway := flag.Bool("enable-gateway", false, "Enable gateway/proxy mode for external workers")
flag.Parse()
logger = NewLogger(*logFile)
// --- ENCRYPTION INITIALIZATION ---
serverKey := os.Getenv("SERVER_KEY")
if serverKey == "" {
logger.Warn("SERVER_KEY not set, generating new key")
var err error
serverKey, err = GenerateServerKey()
if err != nil {
logger.Error("Failed to generate server key: %v", err)
log.Fatal(err)
}
fmt.Printf("\n⚠ IMPORTANT: Save this SERVER_KEY to your environment:\n")
fmt.Printf("export SERVER_KEY=%s\n\n", serverKey)
}
crypto, err := NewCrypto(serverKey)
if err != nil {
logger.Error("Failed to initialize crypto: %v", err)
log.Fatal(err)
}
store = NewUserStore("users_data", crypto)
// Initialize worker store and health poller
workerStore = NewWorkerStore("workers_data.json")
healthPoller = NewHealthPoller(workerStore, 60*time.Second)
healthPoller.Start()
logger.Info("Worker health poller started (60s interval)")
// Initialize gateway components (if enabled)
if *enableGateway {
apiKeyStore = NewAPIKeyStore("apikeys_data", crypto)
proxyManager = NewProxyManager(workerStore)
logger.Info("Gateway mode enabled - API key auth and proxy available")
} else {
logger.Info("Gateway mode disabled (use --enable-gateway to enable)")
}
// Initialize rate limiters
// Auth endpoints: 10 requests per minute (aggressive)
authRateLimiter = NewRateLimiter(10, 1*time.Minute)
// API endpoints: 100 requests per minute (moderate)
apiRateLimiter = NewRateLimiter(100, 1*time.Minute)
logger.Info("Rate limiters initialized (auth: 10/min, api: 100/min)")
// --- BACKGROUND TASKS ---
// Reload user store from disk periodically
go func() {
ticker := time.NewTicker(1 * time.Minute)
for range ticker.C {
if err := store.Reload(); err != nil {
logger.Error("Failed to reload user store: %v", err)
}
}
}()
// Cleanup expired sessions
go func() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
cleanupSessions()
}
}()
// dy.fi Dynamic DNS Updater
if *domain != "" && *dyfiUser != "" {
startDyfiUpdater(*domain, *dyfiUser, *dyfiPass, *port)
}
// --- CLI COMMANDS ---
if *addUser != "" {
handleNewUser(*addUser)
return
}
// --- TEMPLATE LOADING ---
tmpl, err := LoadTemplate()
if err != nil {
log.Fatal(err)
}
appTmpl, err := LoadAppTemplate()
if err != nil {
log.Fatal(err)
}
// --- ROUTES ---
// Routes must be defined BEFORE the server starts
// Public health endpoint (no auth required) for monitoring and dy.fi failover
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status":"healthy"}`))
})
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if session := getValidSession(r, crypto); session != nil {
http.Redirect(w, r, "/app", http.StatusSeeOther)
return
}
tmpl.Execute(w, map[string]interface{}{"Step2": false})
})
http.HandleFunc("/app", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
// Redirect to dashboard
http.Redirect(w, r, "/dashboard", http.StatusSeeOther)
})
http.HandleFunc("/dashboard", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
handleDashboard(w, r)
})
http.HandleFunc("/rest-client", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
appTmpl.Execute(w, map[string]interface{}{"UserID": session.UserID})
})
http.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("auth_session")
if err == nil {
sessions.Lock()
delete(sessions.m, cookie.Value)
sessions.Unlock()
}
http.SetCookie(w, &http.Cookie{
Name: "auth_session",
Value: "",
Path: "/",
MaxAge: -1,
})
http.Redirect(w, r, "/", http.StatusSeeOther)
})
// API: Worker management endpoints
http.HandleFunc("/api/workers/list", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
handleAPIWorkersList(w, r)
})
http.HandleFunc("/api/workers/register", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
handleAPIWorkersRegister(w, r)
})
http.HandleFunc("/api/workers/remove", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
handleAPIWorkersRemove(w, r)
})
http.HandleFunc("/api/workers/get", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
handleAPIWorkersGet(w, r)
})
http.HandleFunc("/api/request", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
var req struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
Body string `json:"body"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
result := makeHTTPRequest(req.Method, req.URL, req.Headers, req.Body)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
})
// Gateway endpoints (API key auth) - only if gateway is enabled
if *enableGateway {
http.HandleFunc("/api/gateway/target", APIKeyAuthMiddleware(apiKeyStore, handleGatewayTarget))
http.HandleFunc("/api/gateway/result", APIKeyAuthMiddleware(apiKeyStore, handleGatewayResult))
http.HandleFunc("/api/gateway/stats", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
handleGatewayStats(w, r)
})
// API key management endpoints (TOTP auth - admin only)
http.HandleFunc("/api/apikeys/generate", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
handleAPIKeyGenerate(w, r)
})
http.HandleFunc("/api/apikeys/list", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
handleAPIKeyList(w, r)
})
http.HandleFunc("/api/apikeys/revoke", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
handleAPIKeyRevoke(w, r)
})
logger.Info("Gateway routes registered")
}
http.HandleFunc("/verify-user", RateLimitMiddleware(authRateLimiter, func(w http.ResponseWriter, r *http.Request) {
userID := strings.TrimSpace(r.FormValue("userid"))
// Input validation
if !ValidateInput(userID, 100) {
logger.Warn("AUTH_FAILURE: Invalid user ID format from IP %s", getIP(r))
tmpl.Execute(w, map[string]interface{}{"Step2": false, "Error": "Invalid input"})
return
}
user, err := store.GetUser(userID)
if err != nil || user == nil {
// FAIL2BAN TRIGGER
logger.Warn("AUTH_FAILURE: User not found: %s from IP %s", userID, getIP(r))
tmpl.Execute(w, map[string]interface{}{"Step2": false, "Error": "User not found"})
return
}
sessionID := fmt.Sprintf("%d", time.Now().UnixNano())
sessions.Lock()
sessions.m[sessionID] = &Session{
UserID: userID,
ExpiresAt: time.Now().Add(5 * time.Minute),
}
sessions.Unlock()
http.SetCookie(w, &http.Cookie{
Name: "temp_session",
Value: sessionID,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
tmpl.Execute(w, map[string]interface{}{"Step2": true})
}))
http.HandleFunc("/verify-totp", RateLimitMiddleware(authRateLimiter, func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("temp_session")
if err != nil {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
sessions.RLock()
session, ok := sessions.m[cookie.Value]
sessions.RUnlock()
if !ok || time.Now().After(session.ExpiresAt) {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
// Get the user from the store and the TOTP code from the form
user, _ := store.GetUser(session.UserID)
totpCode := strings.TrimSpace(r.FormValue("totp"))
// Input validation for TOTP code
if !ValidateInput(totpCode, 10) {
logger.Warn("AUTH_FAILURE: Invalid TOTP format for user %s from IP %s", session.UserID, getIP(r))
tmpl.Execute(w, map[string]interface{}{"Step2": true, "Error": "Invalid input"})
return
}
// Validate the TOTP code
if !totp.Validate(totpCode, user.TOTPSecret) {
// --- FAIL2BAN TRIGGER ---
logger.Warn("AUTH_FAILURE: Invalid TOTP for user %s from IP %s", session.UserID, getIP(r))
tmpl.Execute(w, map[string]interface{}{"Step2": true, "Error": "Invalid TOTP code"})
return
}
sessions.Lock()
delete(sessions.m, cookie.Value)
// Create a new long-lived authenticated session (1 hour)
authSessionID := fmt.Sprintf("%d", time.Now().UnixNano())
sessions.m[authSessionID] = &Session{
UserID: session.UserID,
ExpiresAt: time.Now().Add(1 * time.Hour),
}
sessions.Unlock()
encryptedSession, _ := crypto.EncryptWithServerKey([]byte(authSessionID))
http.SetCookie(w, &http.Cookie{
Name: "auth_session",
Value: base64.StdEncoding.EncodeToString(encryptedSession),
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
MaxAge: 3600,
})
// Redirect to the main application
http.Redirect(w, r, "/app", http.StatusSeeOther)
}))
// --- SERVER STARTUP ---
if *domain != "" {
// Let's Encrypt / ACME Setup
certManager := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(*domain),
Cache: autocert.DirCache("certs_cache"),
Email: *email,
}
// Let's Encrypt requires port 80 for the HTTP-01 challenge
go func() {
logger.Info("Starting HTTP-01 challenge listener on port 80")
// This handler automatically redirects HTTP to HTTPS while solving challenges
log.Fatal(http.ListenAndServe(":80", certManager.HTTPHandler(nil)))
}()
// Create base handler with security headers and size limits
baseHandler := SecurityHeadersMiddleware(
MaxBytesMiddleware(10*1024*1024, http.DefaultServeMux), // 10MB max request size
)
// Configure TLS with strong cipher suites
tlsConfig := certManager.TLSConfig()
tlsConfig.MinVersion = tls.VersionTLS12
tlsConfig.PreferServerCipherSuites = true
tlsConfig.CipherSuites = []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
}
server := &http.Server{
Addr: ":" + *port,
Handler: baseHandler,
TLSConfig: tlsConfig,
ReadTimeout: 15 * time.Second, // Time to read request headers + body
WriteTimeout: 30 * time.Second, // Time to write response
IdleTimeout: 120 * time.Second, // Time to keep connection alive
// Protect against slowloris attacks
ReadHeaderTimeout: 5 * time.Second,
MaxHeaderBytes: 1 << 20, // 1MB max header size
}
logger.Info("Secure Server starting with Let's Encrypt on https://%s", *domain)
logger.Info("Security: Rate limiting enabled, headers hardened, timeouts configured")
log.Fatal(server.ListenAndServeTLS("", "")) // Certs provided by autocert
} else {
// Fallback to Self-Signed Certs
certFile, keyFile, err := setupCerts()
if err != nil {
logger.Error("TLS Setup Error: %v", err)
log.Fatal(err)
}
// Create base handler with security headers and size limits
baseHandler := SecurityHeadersMiddleware(
MaxBytesMiddleware(10*1024*1024, http.DefaultServeMux), // 10MB max request size
)
server := &http.Server{
Addr: ":" + *port,
Handler: baseHandler,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
},
ReadTimeout: 15 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
ReadHeaderTimeout: 5 * time.Second,
MaxHeaderBytes: 1 << 20, // 1MB
}
logger.Info("Secure Server starting with self-signed certs on https://localhost:%s", *port)
logger.Info("Security: Rate limiting enabled, headers hardened, timeouts configured")
log.Fatal(server.ListenAndServeTLS(certFile, keyFile))
}
}
// setupCerts creates self-signed certs if they don't exist
func setupCerts() (string, string, error) {
certFile, keyFile := "cert.pem", "key.pem"
if _, err := os.Stat(certFile); err == nil {
return certFile, keyFile, nil
}
logger.Info("Generating self-signed certificates for HTTPS...")
priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{Organization: []string{"TwoStepAuth Dev"}},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
derBytes, _ := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
certOut, _ := os.Create(certFile)
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
certOut.Close()
keyOut, _ := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
privBytes, _ := x509.MarshalECPrivateKey(priv)
pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes})
keyOut.Close()
return certFile, keyFile, nil
}
func handleNewUser(userID string) {
secret, _ := generateSecret()
store.AddUser(userID, secret)
otpauthURL := fmt.Sprintf("otpauth://totp/TwoStepAuth:%s?secret=%s&issuer=TwoStepAuth", userID, secret)
fmt.Printf("\nUser: %s\nSecret: %s\n", userID, secret)
PrintQRCode(otpauthURL)
}
func getValidSession(r *http.Request, crypto *Crypto) *Session {
cookie, err := r.Cookie("auth_session")
if err != nil {
return nil
}
enc, _ := base64.StdEncoding.DecodeString(cookie.Value)
sid, err := crypto.DecryptWithServerKey(enc)
if err != nil {
return nil
}
sessions.RLock()
session, ok := sessions.m[string(sid)]
sessions.RUnlock()
if !ok || time.Now().After(session.ExpiresAt) {
return nil
}
return session
}
func cleanupSessions() {
sessions.Lock()
defer sessions.Unlock()
now := time.Now()
for id, s := range sessions.m {
if now.After(s.ExpiresAt) {
delete(sessions.m, id)
}
}
}
func getIP(r *http.Request) string {
// Check for X-Forwarded-For if you are behind a proxy (Nginx/Cloudflare)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
return strings.Split(xff, ",")[0]
}
// Otherwise use RemoteAddr (strip the port)
ip, _, _ := net.SplitHostPort(r.RemoteAddr)
return ip
}
func makeHTTPRequest(method, url string, headers map[string]string, body string) map[string]interface{} {
client := &http.Client{Timeout: 30 * time.Second}
var reqBody io.Reader
if body != "" {
reqBody = strings.NewReader(body)
}
req, err := http.NewRequest(method, url, reqBody)
if err != nil {
return map[string]interface{}{
"error": err.Error(),
}
}
for key, value := range headers {
req.Header.Set(key, value)
}
start := time.Now()
resp, err := client.Do(req)
duration := time.Since(start).Milliseconds()
if err != nil {
return map[string]interface{}{
"error": err.Error(),
"duration": duration,
}
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return map[string]interface{}{
"error": err.Error(),
"duration": duration,
}
}
respHeaders := make(map[string]string)
for key, values := range resp.Header {
respHeaders[key] = strings.Join(values, ", ")
}
return map[string]interface{}{
"status": resp.StatusCode,
"headers": respHeaders,
"body": string(bodyBytes),
"duration": duration,
}
}