package main import ( "database/sql" "flag" "fmt" "html/template" "log" "net/http" "os" "strings" "time" "github.com/golang-jwt/jwt/v5" "github.com/pquerna/otp/totp" _ "github.com/mattn/go-sqlite3" ) const ( dbFile = "auth.db" jwtCookie = "auth_token" sessionDuration = 24 * time.Hour ) var ( db *sql.DB jwtSecret string ) type Claims struct { jwt.RegisteredClaims } func main() { generate := flag.Bool("generate", false, "Generate a new TOTP seed and add it to the database") flag.Parse() // Load JWT secret from environment or use default (warn if using default) jwtSecret = os.Getenv("JWT_SECRET") if jwtSecret == "" { jwtSecret = "change_this_secret_to_something_secure" log.Println("WARNING: Using default JWT secret. Set JWT_SECRET environment variable for production!") } var err error db, err = sql.Open("sqlite3", dbFile) if err != nil { log.Fatal(err) } defer db.Close() _, err = db.Exec(`CREATE TABLE IF NOT EXISTS seeds (id INTEGER PRIMARY KEY, secret TEXT UNIQUE)`) if err != nil { log.Fatal(err) } if *generate { generateSeed() os.Exit(0) } // Check if there are any seeds, if not, generate one var count int err = db.QueryRow(`SELECT COUNT(*) FROM seeds`).Scan(&count) if err != nil { log.Fatal(err) } if count == 0 { log.Println("No seeds found, generating one...") generateSeed() } http.HandleFunc("/verify", verifyHandler) http.HandleFunc("/login", loginHandler) http.HandleFunc("/health", healthHandler) log.Println("Starting auth server on :3000") log.Fatal(http.ListenAndServe(":3000", nil)) } func generateSeed() { key, err := totp.Generate(totp.GenerateOpts{ Issuer: "ForwardAuthApp", AccountName: "user", }) if err != nil { log.Fatal(err) } secret := key.Secret() _, err = db.Exec(`INSERT INTO seeds (secret) VALUES (?)`, secret) if err != nil { log.Fatal(err) } fmt.Printf("New TOTP seed generated:\nSecret: %s\nOTPAuth URL: %s\n", secret, key.URL()) fmt.Println("Use this to set up your authenticator app.") } func healthHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) } func verifyHandler(w http.ResponseWriter, r *http.Request) { tokenStr, err := getCookie(r, jwtCookie) if err == nil { _, err = validateJWT(tokenStr) if err == nil { w.WriteHeader(http.StatusNoContent) // 204 return } } // Not authenticated, redirect to login next := r.Header.Get("X-Original-URI") if next == "" { next = "/" } http.Redirect(w, r, "/login?next="+next, http.StatusFound) // 302 } var loginTmpl = template.Must(template.New("login").Parse(` Login

Enter OTP

{{if .Error}}
{{.Error}}
{{end}}
`)) type loginData struct { Next string Error string } func loginHandler(w http.ResponseWriter, r *http.Request) { if r.Method == "GET" { next := r.URL.Query().Get("next") data := loginData{Next: next} loginTmpl.Execute(w, data) return } if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } err := r.ParseForm() if err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } otpCode := strings.TrimSpace(r.FormValue("otp")) next := r.FormValue("next") if next == "" { next = "/" } if validateOTP(otpCode) { tokenStr, err := generateJWT() if err != nil { log.Printf("Error generating JWT: %v", err) http.Error(w, "Internal error", http.StatusInternalServerError) return } http.SetCookie(w, &http.Cookie{ Name: jwtCookie, Value: tokenStr, Expires: time.Now().Add(sessionDuration), HttpOnly: true, SameSite: http.SameSiteLaxMode, // Secure: true, // Uncomment for HTTPS Path: "/", }) http.Redirect(w, r, next, http.StatusFound) return } // Invalid OTP data := loginData{Next: next, Error: "Invalid OTP"} w.WriteHeader(http.StatusUnauthorized) loginTmpl.Execute(w, data) } func validateOTP(otpCode string) bool { rows, err := db.Query(`SELECT secret FROM seeds`) if err != nil { log.Printf("Error querying seeds: %v", err) return false } defer rows.Close() for rows.Next() { var secret string err = rows.Scan(&secret) if err != nil { log.Printf("Error scanning seed: %v", err) continue } if totp.Validate(otpCode, secret) { return true } } if err = rows.Err(); err != nil { log.Printf("Error iterating rows: %v", err) } return false } func generateJWT() (string, error) { claims := Claims{ RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(sessionDuration)), IssuedAt: jwt.NewNumericDate(time.Now()), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString([]byte(jwtSecret)) } func validateJWT(tokenStr string) (*Claims, error) { claims := &Claims{} token, err := jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (interface{}, error) { return []byte(jwtSecret), nil }) if err != nil || !token.Valid { return nil, err } return claims, nil } func getCookie(r *http.Request, name string) (string, error) { cookie, err := r.Cookie(name) if err != nil { return "", err } return cookie.Value, nil }