Files
ping_service/manager/main.go

425 lines
11 KiB
Go
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/json"
"encoding/pem"
"flag"
"fmt"
"io"
"log"
"math/big"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/pquerna/otp/totp"
"golang.org/x/crypto/acme/autocert"
)
var (
store *UserStore
sessions = struct {
sync.RWMutex
m map[string]*Session
}{m: make(map[string]*Session)}
logger *Logger
)
type Session struct {
UserID string
ExpiresAt time.Time
}
func main() {
// --- FLAGS ---
addUser := flag.String("add-user", "", "Add a new user (provide user ID)")
port := flag.String("port", os.Getenv("MANAGER_PORT"), "Port to run the server on (use 443 for Let's Encrypt)")
domain := flag.String("domain", os.Getenv("DYFI_DOMAIN"), "Your dy.fi domain (e.g. example.dy.fi)")
dyfiUser := flag.String("dyfi-user", os.Getenv("DYFI_USER"), "dy.fi username (email)")
dyfiPass := flag.String("dyfi-pass", os.Getenv("DYFI_PASS"), "dy.fi password")
email := flag.String("email", os.Getenv("ACME_EMAIL"), "Email for Let's Encrypt notifications")
flag.Parse()
logger = NewLogger()
// --- ENCRYPTION INITIALIZATION ---
serverKey := os.Getenv("SERVER_KEY")
if serverKey == "" {
logger.Warn("SERVER_KEY not set, generating new key")
var err error
serverKey, err = GenerateServerKey()
if err != nil {
logger.Error("Failed to generate server key: %v", err)
log.Fatal(err)
}
fmt.Printf("\n⚠ IMPORTANT: Save this SERVER_KEY to your environment:\n")
fmt.Printf("export SERVER_KEY=%s\n\n", serverKey)
}
crypto, err := NewCrypto(serverKey)
if err != nil {
logger.Error("Failed to initialize crypto: %v", err)
log.Fatal(err)
}
store = NewUserStore("users_data", crypto)
// --- BACKGROUND TASKS ---
// Reload user store from disk periodically
go func() {
ticker := time.NewTicker(1 * time.Minute)
for range ticker.C {
if err := store.Reload(); err != nil {
logger.Error("Failed to reload user store: %v", err)
}
}
}()
// Cleanup expired sessions
go func() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
cleanupSessions()
}
}()
// dy.fi Dynamic DNS Updater
if *domain != "" && *dyfiUser != "" {
startDyfiUpdater(*domain, *dyfiUser, *dyfiPass)
}
// --- CLI COMMANDS ---
if *addUser != "" {
handleNewUser(*addUser)
return
}
// --- TEMPLATE LOADING ---
tmpl, err := LoadTemplate()
if err != nil {
log.Fatal(err)
}
appTmpl, err := LoadAppTemplate()
if err != nil {
log.Fatal(err)
}
// --- ROUTES ---
// Routes must be defined BEFORE the server starts
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if session := getValidSession(r, crypto); session != nil {
http.Redirect(w, r, "/app", http.StatusSeeOther)
return
}
tmpl.Execute(w, map[string]interface{}{"Step2": false})
})
http.HandleFunc("/app", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
appTmpl.Execute(w, map[string]interface{}{"UserID": session.UserID})
})
http.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("auth_session")
if err == nil {
sessions.Lock()
delete(sessions.m, cookie.Value)
sessions.Unlock()
}
http.SetCookie(w, &http.Cookie{
Name: "auth_session",
Value: "",
Path: "/",
MaxAge: -1,
})
http.Redirect(w, r, "/", http.StatusSeeOther)
})
http.HandleFunc("/api/request", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
var req struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
Body string `json:"body"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
result := makeHTTPRequest(req.Method, req.URL, req.Headers, req.Body)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
})
http.HandleFunc("/verify-user", func(w http.ResponseWriter, r *http.Request) {
userID := strings.TrimSpace(r.FormValue("userid"))
user, err := store.GetUser(userID)
if err != nil || user == nil {
tmpl.Execute(w, map[string]interface{}{"Step2": false, "Error": "User not found"})
return
}
sessionID := fmt.Sprintf("%d", time.Now().UnixNano())
sessions.Lock()
sessions.m[sessionID] = &Session{
UserID: userID,
ExpiresAt: time.Now().Add(5 * time.Minute),
}
sessions.Unlock()
http.SetCookie(w, &http.Cookie{
Name: "temp_session",
Value: sessionID,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
tmpl.Execute(w, map[string]interface{}{"Step2": true})
})
http.HandleFunc("/verify-totp", func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("temp_session")
if err != nil {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
sessions.RLock()
session, ok := sessions.m[cookie.Value]
sessions.RUnlock()
if !ok || time.Now().After(session.ExpiresAt) {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
user, _ := store.GetUser(session.UserID)
totpCode := strings.TrimSpace(r.FormValue("totp"))
if !totp.Validate(totpCode, user.TOTPSecret) {
tmpl.Execute(w, map[string]interface{}{"Step2": true, "Error": "Invalid TOTP code"})
return
}
sessions.Lock()
delete(sessions.m, cookie.Value)
authSessionID := fmt.Sprintf("%d", time.Now().UnixNano())
sessions.m[authSessionID] = &Session{
UserID: session.UserID,
ExpiresAt: time.Now().Add(1 * time.Hour),
}
sessions.Unlock()
encryptedSession, _ := crypto.EncryptWithServerKey([]byte(authSessionID))
http.SetCookie(w, &http.Cookie{
Name: "auth_session",
Value: base64.StdEncoding.EncodeToString(encryptedSession),
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
MaxAge: 3600,
})
http.Redirect(w, r, "/app", http.StatusSeeOther)
})
// --- SERVER STARTUP ---
if *domain != "" {
// Let's Encrypt / ACME Setup
certManager := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(*domain),
Cache: autocert.DirCache("certs_cache"),
Email: *email,
}
// Let's Encrypt requires port 80 for the HTTP-01 challenge
go func() {
logger.Info("Starting HTTP-01 challenge listener on port 80")
// This handler automatically redirects HTTP to HTTPS while solving challenges
log.Fatal(http.ListenAndServe(":80", certManager.HTTPHandler(nil)))
}()
server := &http.Server{
Addr: ":" + *port,
TLSConfig: certManager.TLSConfig(),
}
logger.Info("Secure Server starting with Let's Encrypt on https://%s", *domain)
log.Fatal(server.ListenAndServeTLS("", "")) // Certs provided by autocert
} else {
// Fallback to Self-Signed Certs
certFile, keyFile, err := setupCerts()
if err != nil {
logger.Error("TLS Setup Error: %v", err)
log.Fatal(err)
}
server := &http.Server{
Addr: ":" + *port,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
}
logger.Info("Secure Server starting with self-signed certs on https://localhost:%s", *port)
log.Fatal(server.ListenAndServeTLS(certFile, keyFile))
}
}
// setupCerts creates self-signed certs if they don't exist
func setupCerts() (string, string, error) {
certFile, keyFile := "cert.pem", "key.pem"
if _, err := os.Stat(certFile); err == nil {
return certFile, keyFile, nil
}
logger.Info("Generating self-signed certificates for HTTPS...")
priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{Organization: []string{"TwoStepAuth Dev"}},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
derBytes, _ := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
certOut, _ := os.Create(certFile)
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
certOut.Close()
keyOut, _ := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
privBytes, _ := x509.MarshalECPrivateKey(priv)
pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes})
keyOut.Close()
return certFile, keyFile, nil
}
func handleNewUser(userID string) {
secret, _ := generateSecret()
store.AddUser(userID, secret)
otpauthURL := fmt.Sprintf("otpauth://totp/TwoStepAuth:%s?secret=%s&issuer=TwoStepAuth", userID, secret)
fmt.Printf("\nUser: %s\nSecret: %s\n", userID, secret)
PrintQRCode(otpauthURL)
}
func getValidSession(r *http.Request, crypto *Crypto) *Session {
cookie, err := r.Cookie("auth_session")
if err != nil {
return nil
}
enc, _ := base64.StdEncoding.DecodeString(cookie.Value)
sid, err := crypto.DecryptWithServerKey(enc)
if err != nil {
return nil
}
sessions.RLock()
session, ok := sessions.m[string(sid)]
sessions.RUnlock()
if !ok || time.Now().After(session.ExpiresAt) {
return nil
}
return session
}
func cleanupSessions() {
sessions.Lock()
defer sessions.Unlock()
now := time.Now()
for id, s := range sessions.m {
if now.After(s.ExpiresAt) {
delete(sessions.m, id)
}
}
}
func makeHTTPRequest(method, url string, headers map[string]string, body string) map[string]interface{} {
client := &http.Client{Timeout: 30 * time.Second}
var reqBody io.Reader
if body != "" {
reqBody = strings.NewReader(body)
}
req, err := http.NewRequest(method, url, reqBody)
if err != nil {
return map[string]interface{}{
"error": err.Error(),
}
}
for key, value := range headers {
req.Header.Set(key, value)
}
start := time.Now()
resp, err := client.Do(req)
duration := time.Since(start).Milliseconds()
if err != nil {
return map[string]interface{}{
"error": err.Error(),
"duration": duration,
}
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return map[string]interface{}{
"error": err.Error(),
"duration": duration,
}
}
respHeaders := make(map[string]string)
for key, values := range resp.Header {
respHeaders[key] = strings.Join(values, ", ")
}
return map[string]interface{}{
"status": resp.StatusCode,
"headers": respHeaders,
"body": string(bodyBytes),
"duration": duration,
}
}