445 lines
12 KiB
Go
445 lines
12 KiB
Go
package main
|
||
|
||
import (
|
||
"crypto/ecdsa"
|
||
"crypto/elliptic"
|
||
"crypto/rand"
|
||
"crypto/tls"
|
||
"crypto/x509"
|
||
"crypto/x509/pkix"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"encoding/pem"
|
||
"flag"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"math/big"
|
||
"net"
|
||
"net/http"
|
||
"os"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/pquerna/otp/totp"
|
||
"golang.org/x/crypto/acme/autocert"
|
||
)
|
||
|
||
var (
|
||
store *UserStore
|
||
sessions = struct {
|
||
sync.RWMutex
|
||
m map[string]*Session
|
||
}{m: make(map[string]*Session)}
|
||
logger *Logger
|
||
)
|
||
|
||
type Session struct {
|
||
UserID string
|
||
ExpiresAt time.Time
|
||
}
|
||
|
||
func main() {
|
||
// --- FLAGS ---
|
||
addUser := flag.String("add-user", "", "Add a new user (provide user ID)")
|
||
port := flag.String("port", os.Getenv("MANAGER_PORT"), "Port to run the server on (use 443 for Let's Encrypt)")
|
||
domain := flag.String("domain", os.Getenv("DYFI_DOMAIN"), "Your dy.fi domain (e.g. example.dy.fi)")
|
||
dyfiUser := flag.String("dyfi-user", os.Getenv("DYFI_USER"), "dy.fi username (email)")
|
||
dyfiPass := flag.String("dyfi-pass", os.Getenv("DYFI_PASS"), "dy.fi password")
|
||
email := flag.String("email", os.Getenv("ACME_EMAIL"), "Email for Let's Encrypt notifications")
|
||
logFile := flag.String("log", os.Getenv("LOG_FILE"), "Path to log file for fail2ban")
|
||
|
||
flag.Parse()
|
||
|
||
logger = NewLogger(*logFile)
|
||
|
||
// --- ENCRYPTION INITIALIZATION ---
|
||
serverKey := os.Getenv("SERVER_KEY")
|
||
if serverKey == "" {
|
||
logger.Warn("SERVER_KEY not set, generating new key")
|
||
var err error
|
||
serverKey, err = GenerateServerKey()
|
||
if err != nil {
|
||
logger.Error("Failed to generate server key: %v", err)
|
||
log.Fatal(err)
|
||
}
|
||
fmt.Printf("\n⚠️ IMPORTANT: Save this SERVER_KEY to your environment:\n")
|
||
fmt.Printf("export SERVER_KEY=%s\n\n", serverKey)
|
||
}
|
||
|
||
crypto, err := NewCrypto(serverKey)
|
||
if err != nil {
|
||
logger.Error("Failed to initialize crypto: %v", err)
|
||
log.Fatal(err)
|
||
}
|
||
|
||
store = NewUserStore("users_data", crypto)
|
||
|
||
// --- BACKGROUND TASKS ---
|
||
// Reload user store from disk periodically
|
||
go func() {
|
||
ticker := time.NewTicker(1 * time.Minute)
|
||
for range ticker.C {
|
||
if err := store.Reload(); err != nil {
|
||
logger.Error("Failed to reload user store: %v", err)
|
||
}
|
||
}
|
||
}()
|
||
|
||
// Cleanup expired sessions
|
||
go func() {
|
||
ticker := time.NewTicker(5 * time.Minute)
|
||
for range ticker.C {
|
||
cleanupSessions()
|
||
}
|
||
}()
|
||
|
||
// dy.fi Dynamic DNS Updater
|
||
if *domain != "" && *dyfiUser != "" {
|
||
startDyfiUpdater(*domain, *dyfiUser, *dyfiPass)
|
||
}
|
||
|
||
// --- CLI COMMANDS ---
|
||
if *addUser != "" {
|
||
handleNewUser(*addUser)
|
||
return
|
||
}
|
||
|
||
// --- TEMPLATE LOADING ---
|
||
tmpl, err := LoadTemplate()
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
appTmpl, err := LoadAppTemplate()
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
|
||
// --- ROUTES ---
|
||
// Routes must be defined BEFORE the server starts
|
||
|
||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||
if session := getValidSession(r, crypto); session != nil {
|
||
http.Redirect(w, r, "/app", http.StatusSeeOther)
|
||
return
|
||
}
|
||
tmpl.Execute(w, map[string]interface{}{"Step2": false})
|
||
})
|
||
|
||
http.HandleFunc("/app", func(w http.ResponseWriter, r *http.Request) {
|
||
session := getValidSession(r, crypto)
|
||
if session == nil {
|
||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||
return
|
||
}
|
||
appTmpl.Execute(w, map[string]interface{}{"UserID": session.UserID})
|
||
})
|
||
|
||
http.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) {
|
||
cookie, err := r.Cookie("auth_session")
|
||
if err == nil {
|
||
sessions.Lock()
|
||
delete(sessions.m, cookie.Value)
|
||
sessions.Unlock()
|
||
}
|
||
http.SetCookie(w, &http.Cookie{
|
||
Name: "auth_session",
|
||
Value: "",
|
||
Path: "/",
|
||
MaxAge: -1,
|
||
})
|
||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||
})
|
||
|
||
http.HandleFunc("/api/request", func(w http.ResponseWriter, r *http.Request) {
|
||
session := getValidSession(r, crypto)
|
||
if session == nil {
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
|
||
return
|
||
}
|
||
|
||
var req struct {
|
||
Method string `json:"method"`
|
||
URL string `json:"url"`
|
||
Headers map[string]string `json:"headers"`
|
||
Body string `json:"body"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
w.WriteHeader(http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
result := makeHTTPRequest(req.Method, req.URL, req.Headers, req.Body)
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(result)
|
||
})
|
||
|
||
http.HandleFunc("/verify-user", func(w http.ResponseWriter, r *http.Request) {
|
||
userID := strings.TrimSpace(r.FormValue("userid"))
|
||
user, err := store.GetUser(userID)
|
||
if err != nil || user == nil {
|
||
// FAIL2BAN TRIGGER
|
||
logger.Warn("AUTH_FAILURE: User not found: %s from IP %s", userID, getIP(r))
|
||
tmpl.Execute(w, map[string]interface{}{"Step2": false, "Error": "User not found"})
|
||
return
|
||
}
|
||
|
||
sessionID := fmt.Sprintf("%d", time.Now().UnixNano())
|
||
sessions.Lock()
|
||
sessions.m[sessionID] = &Session{
|
||
UserID: userID,
|
||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||
}
|
||
sessions.Unlock()
|
||
|
||
http.SetCookie(w, &http.Cookie{
|
||
Name: "temp_session",
|
||
Value: sessionID,
|
||
Path: "/",
|
||
HttpOnly: true,
|
||
Secure: true,
|
||
SameSite: http.SameSiteStrictMode,
|
||
})
|
||
tmpl.Execute(w, map[string]interface{}{"Step2": true})
|
||
})
|
||
|
||
http.HandleFunc("/verify-totp", func(w http.ResponseWriter, r *http.Request) {
|
||
cookie, err := r.Cookie("temp_session")
|
||
if err != nil {
|
||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
sessions.RLock()
|
||
session, ok := sessions.m[cookie.Value]
|
||
sessions.RUnlock()
|
||
|
||
if !ok || time.Now().After(session.ExpiresAt) {
|
||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||
return
|
||
}
|
||
|
||
// Get the user from the store and the TOTP code from the form
|
||
user, _ := store.GetUser(session.UserID)
|
||
totpCode := strings.TrimSpace(r.FormValue("totp"))
|
||
|
||
// Validate the TOTP code
|
||
if !totp.Validate(totpCode, user.TOTPSecret) {
|
||
// --- FAIL2BAN TRIGGER ---
|
||
logger.Warn("AUTH_FAILURE: Invalid TOTP for user %s from IP %s", session.UserID, getIP(r))
|
||
|
||
tmpl.Execute(w, map[string]interface{}{"Step2": true, "Error": "Invalid TOTP code"})
|
||
return
|
||
}
|
||
|
||
sessions.Lock()
|
||
delete(sessions.m, cookie.Value)
|
||
|
||
// Create a new long-lived authenticated session (1 hour)
|
||
authSessionID := fmt.Sprintf("%d", time.Now().UnixNano())
|
||
sessions.m[authSessionID] = &Session{
|
||
UserID: session.UserID,
|
||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||
}
|
||
sessions.Unlock()
|
||
|
||
encryptedSession, _ := crypto.EncryptWithServerKey([]byte(authSessionID))
|
||
|
||
http.SetCookie(w, &http.Cookie{
|
||
Name: "auth_session",
|
||
Value: base64.StdEncoding.EncodeToString(encryptedSession),
|
||
Path: "/",
|
||
HttpOnly: true,
|
||
Secure: true,
|
||
SameSite: http.SameSiteStrictMode,
|
||
MaxAge: 3600,
|
||
})
|
||
|
||
// Redirect to the main application
|
||
http.Redirect(w, r, "/app", http.StatusSeeOther)
|
||
})
|
||
|
||
// --- SERVER STARTUP ---
|
||
|
||
if *domain != "" {
|
||
// Let's Encrypt / ACME Setup
|
||
certManager := autocert.Manager{
|
||
Prompt: autocert.AcceptTOS,
|
||
HostPolicy: autocert.HostWhitelist(*domain),
|
||
Cache: autocert.DirCache("certs_cache"),
|
||
Email: *email,
|
||
}
|
||
|
||
// Let's Encrypt requires port 80 for the HTTP-01 challenge
|
||
go func() {
|
||
logger.Info("Starting HTTP-01 challenge listener on port 80")
|
||
// This handler automatically redirects HTTP to HTTPS while solving challenges
|
||
log.Fatal(http.ListenAndServe(":80", certManager.HTTPHandler(nil)))
|
||
}()
|
||
|
||
server := &http.Server{
|
||
Addr: ":" + *port,
|
||
TLSConfig: certManager.TLSConfig(),
|
||
}
|
||
|
||
logger.Info("Secure Server starting with Let's Encrypt on https://%s", *domain)
|
||
log.Fatal(server.ListenAndServeTLS("", "")) // Certs provided by autocert
|
||
} else {
|
||
// Fallback to Self-Signed Certs
|
||
certFile, keyFile, err := setupCerts()
|
||
if err != nil {
|
||
logger.Error("TLS Setup Error: %v", err)
|
||
log.Fatal(err)
|
||
}
|
||
|
||
server := &http.Server{
|
||
Addr: ":" + *port,
|
||
TLSConfig: &tls.Config{
|
||
MinVersion: tls.VersionTLS12,
|
||
},
|
||
}
|
||
|
||
logger.Info("Secure Server starting with self-signed certs on https://localhost:%s", *port)
|
||
log.Fatal(server.ListenAndServeTLS(certFile, keyFile))
|
||
}
|
||
}
|
||
|
||
// setupCerts creates self-signed certs if they don't exist
|
||
func setupCerts() (string, string, error) {
|
||
certFile, keyFile := "cert.pem", "key.pem"
|
||
if _, err := os.Stat(certFile); err == nil {
|
||
return certFile, keyFile, nil
|
||
}
|
||
|
||
logger.Info("Generating self-signed certificates for HTTPS...")
|
||
priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||
|
||
serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||
template := x509.Certificate{
|
||
SerialNumber: serialNumber,
|
||
Subject: pkix.Name{Organization: []string{"TwoStepAuth Dev"}},
|
||
NotBefore: time.Now(),
|
||
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||
DNSNames: []string{"localhost"},
|
||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||
}
|
||
|
||
derBytes, _ := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||
|
||
certOut, _ := os.Create(certFile)
|
||
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||
certOut.Close()
|
||
|
||
keyOut, _ := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||
privBytes, _ := x509.MarshalECPrivateKey(priv)
|
||
pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes})
|
||
keyOut.Close()
|
||
|
||
return certFile, keyFile, nil
|
||
}
|
||
|
||
func handleNewUser(userID string) {
|
||
secret, _ := generateSecret()
|
||
store.AddUser(userID, secret)
|
||
otpauthURL := fmt.Sprintf("otpauth://totp/TwoStepAuth:%s?secret=%s&issuer=TwoStepAuth", userID, secret)
|
||
fmt.Printf("\nUser: %s\nSecret: %s\n", userID, secret)
|
||
PrintQRCode(otpauthURL)
|
||
}
|
||
|
||
func getValidSession(r *http.Request, crypto *Crypto) *Session {
|
||
cookie, err := r.Cookie("auth_session")
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
enc, _ := base64.StdEncoding.DecodeString(cookie.Value)
|
||
sid, err := crypto.DecryptWithServerKey(enc)
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
sessions.RLock()
|
||
session, ok := sessions.m[string(sid)]
|
||
sessions.RUnlock()
|
||
if !ok || time.Now().After(session.ExpiresAt) {
|
||
return nil
|
||
}
|
||
return session
|
||
}
|
||
|
||
func cleanupSessions() {
|
||
sessions.Lock()
|
||
defer sessions.Unlock()
|
||
now := time.Now()
|
||
for id, s := range sessions.m {
|
||
if now.After(s.ExpiresAt) {
|
||
delete(sessions.m, id)
|
||
}
|
||
}
|
||
}
|
||
|
||
func getIP(r *http.Request) string {
|
||
// Check for X-Forwarded-For if you are behind a proxy (Nginx/Cloudflare)
|
||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||
return strings.Split(xff, ",")[0]
|
||
}
|
||
// Otherwise use RemoteAddr (strip the port)
|
||
ip, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||
return ip
|
||
}
|
||
|
||
func makeHTTPRequest(method, url string, headers map[string]string, body string) map[string]interface{} {
|
||
client := &http.Client{Timeout: 30 * time.Second}
|
||
|
||
var reqBody io.Reader
|
||
if body != "" {
|
||
reqBody = strings.NewReader(body)
|
||
}
|
||
|
||
req, err := http.NewRequest(method, url, reqBody)
|
||
if err != nil {
|
||
return map[string]interface{}{
|
||
"error": err.Error(),
|
||
}
|
||
}
|
||
|
||
for key, value := range headers {
|
||
req.Header.Set(key, value)
|
||
}
|
||
|
||
start := time.Now()
|
||
resp, err := client.Do(req)
|
||
duration := time.Since(start).Milliseconds()
|
||
|
||
if err != nil {
|
||
return map[string]interface{}{
|
||
"error": err.Error(),
|
||
"duration": duration,
|
||
}
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
bodyBytes, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return map[string]interface{}{
|
||
"error": err.Error(),
|
||
"duration": duration,
|
||
}
|
||
}
|
||
|
||
respHeaders := make(map[string]string)
|
||
for key, values := range resp.Header {
|
||
respHeaders[key] = strings.Join(values, ", ")
|
||
}
|
||
|
||
return map[string]interface{}{
|
||
"status": resp.StatusCode,
|
||
"headers": respHeaders,
|
||
"body": string(bodyBytes),
|
||
"duration": duration,
|
||
}
|
||
}
|