commit 58a6cf205090f0220f13f2bbe59d5920809322b0 Author: Kalzu Rekku Date: Tue Nov 4 21:13:32 2025 +0200 Initial commit. diff --git a/main.go b/main.go new file mode 100644 index 0000000..7ab4c43 --- /dev/null +++ b/main.go @@ -0,0 +1,317 @@ +package main + +import ( + "database/sql" + "flag" + "fmt" + "html/template" + "log" + "net/http" + "os" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/pquerna/otp/totp" + + _ "github.com/mattn/go-sqlite3" +) + +const ( + dbFile = "auth.db" + jwtCookie = "auth_token" + sessionDuration = 24 * time.Hour +) + +var ( + db *sql.DB + jwtSecret string +) + +type Claims struct { + jwt.RegisteredClaims +} + +func main() { + generate := flag.Bool("generate", false, "Generate a new TOTP seed and add it to the database") + flag.Parse() + + // Load JWT secret from environment or use default (warn if using default) + jwtSecret = os.Getenv("JWT_SECRET") + if jwtSecret == "" { + jwtSecret = "change_this_secret_to_something_secure" + log.Println("WARNING: Using default JWT secret. Set JWT_SECRET environment variable for production!") + } + + var err error + db, err = sql.Open("sqlite3", dbFile) + if err != nil { + log.Fatal(err) + } + defer db.Close() + + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS seeds (id INTEGER PRIMARY KEY, secret TEXT UNIQUE)`) + if err != nil { + log.Fatal(err) + } + + if *generate { + generateSeed() + os.Exit(0) + } + + // Check if there are any seeds, if not, generate one + var count int + err = db.QueryRow(`SELECT COUNT(*) FROM seeds`).Scan(&count) + if err != nil { + log.Fatal(err) + } + if count == 0 { + log.Println("No seeds found, generating one...") + generateSeed() + } + + http.HandleFunc("/verify", verifyHandler) + http.HandleFunc("/login", loginHandler) + http.HandleFunc("/health", healthHandler) + + log.Println("Starting auth server on :3000") + log.Fatal(http.ListenAndServe(":3000", nil)) +} + +func generateSeed() { + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: "ForwardAuthApp", + AccountName: "user", + }) + if err != nil { + log.Fatal(err) + } + + secret := key.Secret() + _, err = db.Exec(`INSERT INTO seeds (secret) VALUES (?)`, secret) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("New TOTP seed generated:\nSecret: %s\nOTPAuth URL: %s\n", secret, key.URL()) + fmt.Println("Use this to set up your authenticator app.") +} + +func healthHandler(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) +} + +func verifyHandler(w http.ResponseWriter, r *http.Request) { + tokenStr, err := getCookie(r, jwtCookie) + if err == nil { + _, err = validateJWT(tokenStr) + if err == nil { + w.WriteHeader(http.StatusNoContent) // 204 + return + } + } + + // Not authenticated, redirect to login + next := r.Header.Get("X-Original-URI") + if next == "" { + next = "/" + } + http.Redirect(w, r, "/login?next="+next, http.StatusFound) // 302 +} + +var loginTmpl = template.Must(template.New("login").Parse(` + + + + Login + + + +
+

Enter OTP

+ {{if .Error}} +
{{.Error}}
+ {{end}} +
+ + + + +
+
+ + +`)) + +type loginData struct { + Next string + Error string +} + +func loginHandler(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" { + next := r.URL.Query().Get("next") + data := loginData{Next: next} + loginTmpl.Execute(w, data) + return + } + + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + err := r.ParseForm() + if err != nil { + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + otpCode := strings.TrimSpace(r.FormValue("otp")) + next := r.FormValue("next") + if next == "" { + next = "/" + } + + if validateOTP(otpCode) { + tokenStr, err := generateJWT() + if err != nil { + log.Printf("Error generating JWT: %v", err) + http.Error(w, "Internal error", http.StatusInternalServerError) + return + } + + http.SetCookie(w, &http.Cookie{ + Name: jwtCookie, + Value: tokenStr, + Expires: time.Now().Add(sessionDuration), + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + // Secure: true, // Uncomment for HTTPS + Path: "/", + }) + + http.Redirect(w, r, next, http.StatusFound) + return + } + + // Invalid OTP + data := loginData{Next: next, Error: "Invalid OTP"} + w.WriteHeader(http.StatusUnauthorized) + loginTmpl.Execute(w, data) +} + +func validateOTP(otpCode string) bool { + rows, err := db.Query(`SELECT secret FROM seeds`) + if err != nil { + log.Printf("Error querying seeds: %v", err) + return false + } + defer rows.Close() + + for rows.Next() { + var secret string + err = rows.Scan(&secret) + if err != nil { + log.Printf("Error scanning seed: %v", err) + continue + } + if totp.Validate(otpCode, secret) { + return true + } + } + + if err = rows.Err(); err != nil { + log.Printf("Error iterating rows: %v", err) + } + + return false +} + +func generateJWT() (string, error) { + claims := Claims{ + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(sessionDuration)), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(jwtSecret)) +} + +func validateJWT(tokenStr string) (*Claims, error) { + claims := &Claims{} + token, err := jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (interface{}, error) { + return []byte(jwtSecret), nil + }) + if err != nil || !token.Valid { + return nil, err + } + return claims, nil +} + +func getCookie(r *http.Request, name string) (string, error) { + cookie, err := r.Cookie(name) + if err != nil { + return "", err + } + return cookie.Value, nil +} \ No newline at end of file