Improved input service. New Manager web app. Directory and small readme for output service.

This commit is contained in:
Kalzu Rekku
2026-01-06 14:27:26 +02:00
parent ec9fec5ce3
commit f7056082f6
11 changed files with 1695 additions and 293 deletions

View File

@@ -3,34 +3,41 @@ package main
import ( import (
"bufio" "bufio"
"context" "context"
"crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"log" "log"
"math/rand" "math/rand"
"net" "net"
"net/http" "net/http"
"net/netip"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"syscall" "syscall"
"time" "time"
) )
const ( const (
repoDir = "cloud-provider-ip-addresses" repoDir = "cloud-provider-ip-addresses"
port = 8080 port = 8080
stateDir = "progress_state" stateDir = "progress_state"
saveInterval = 30 * time.Second saveInterval = 30 * time.Second
cleanupInterval = 5 * time.Minute
generatorTTL = 24 * time.Hour
maxImportSize = 10 * 1024 * 1024 // 10MB
) )
// GeneratorState represents the serializable state of a generator // GeneratorState represents the serializable state of a generator
type GeneratorState struct { type GeneratorState struct {
CurrentFile int `json:"current_file"` RemainingCIDRs []string `json:"remaining_cidrs"`
CurrentCIDRs []string `json:"current_cidrs"` CurrentGen *HostGenState `json:"current_gen,omitempty"`
ActiveGenStates []HostGenState `json:"active_gen_states"` TotalCIDRs int `json:"total_cidrs"`
CIDRFiles []string `json:"cidr_files"`
} }
type HostGenState struct { type HostGenState struct {
@@ -41,67 +48,72 @@ type HostGenState struct {
// IPGenerator generates IPs from CIDR ranges lazily // IPGenerator generates IPs from CIDR ranges lazily
type IPGenerator struct { type IPGenerator struct {
mu sync.Mutex mu sync.Mutex
cidrFiles []string rng *rand.Rand
currentFile int totalCIDRsCount int
currentCIDRs []string remainingCIDRs []string
activeGens []*hostGenerator currentGen *hostGenerator
rng *rand.Rand consumer string
totalCIDRsCount int dirty atomic.Bool
consumer string
dirty bool
} }
type hostGenerator struct { type hostGenerator struct {
cidr string prefix netip.Prefix
network *net.IPNet current netip.Addr
current net.IP last netip.Addr
done bool done bool
} }
func addrToUint32(a netip.Addr) uint32 {
b := a.As4()
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
}
func uint32ToAddr(u uint32) netip.Addr {
return netip.AddrFrom4([4]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
}
func newHostGenerator(cidr string) (*hostGenerator, error) { func newHostGenerator(cidr string) (*hostGenerator, error) {
_, network, err := net.ParseCIDR(cidr) prefix, err := netip.ParsePrefix(cidr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
prefix = prefix.Masked()
// Only IPv4 if !prefix.IsValid() || !prefix.Addr().Is4() {
if network.IP.To4() == nil { return nil, fmt.Errorf("invalid IPv4 prefix")
return nil, fmt.Errorf("not IPv4")
} }
if prefix.Addr().IsMulticast() {
// Check if multicast
if network.IP.IsMulticast() {
return nil, fmt.Errorf("multicast network") return nil, fmt.Errorf("multicast network")
} }
ones, bits := network.Mask.Size() ip := prefix.Addr()
maskLen := prefix.Bits()
hg := &hostGenerator{ var first, last netip.Addr
cidr: cidr, lastUint := addrToUint32(ip) | ((1 << (32 - uint(maskLen))) - 1)
network: network, last = uint32ToAddr(lastUint)
current: make(net.IP, len(network.IP)),
}
copy(hg.current, network.IP)
// For /32, just use the single address if maskLen == 32 {
if ones == bits { first = ip
return hg, nil last = ip
} else if maskLen == 31 {
first = ip
// last already ip + 1
} else {
first = ip.Next()
last = last.Prev()
} }
// For other networks, skip network address (start at .1) if !prefix.Contains(first) || !prefix.Contains(last) {
hg.increment() return nil, fmt.Errorf("invalid range")
return hg, nil
}
func (hg *hostGenerator) increment() {
for i := len(hg.current) - 1; i >= 0; i-- {
hg.current[i]++
if hg.current[i] != 0 {
break
}
} }
return &hostGenerator{
prefix: prefix,
current: first,
last: last,
done: false,
}, nil
} }
func (hg *hostGenerator) next() (string, bool) { func (hg *hostGenerator) next() (string, bool) {
@@ -109,58 +121,31 @@ func (hg *hostGenerator) next() (string, bool) {
return "", false return "", false
} }
ones, bits := hg.network.Mask.Size() if !hg.prefix.Contains(hg.current) || addrToUint32(hg.current) > addrToUint32(hg.last) {
// Handle /32 specially
if ones == bits {
if !hg.current.Equal(hg.network.IP) {
hg.done = true
return "", false
}
ip := hg.current.String()
hg.done = true
return ip, true
}
// Check if we're still in the network
if !hg.network.Contains(hg.current) {
hg.done = true hg.done = true
return "", false return "", false
} }
// Check if this is the broadcast address (last IP in range)
broadcast := make(net.IP, len(hg.network.IP))
copy(broadcast, hg.network.IP)
for i := range broadcast {
broadcast[i] |= ^hg.network.Mask[i]
}
if hg.current.Equal(broadcast) {
hg.done = true
return "", false
}
// Skip multicast addresses
if hg.current.IsMulticast() { if hg.current.IsMulticast() {
hg.increment() hg.current = hg.current.Next()
return hg.next() return hg.next()
} }
ip := hg.current.String() ip := hg.current.String()
hg.increment() hg.current = hg.current.Next()
return ip, true return ip, true
} }
func (hg *hostGenerator) getState() HostGenState { func (hg *hostGenerator) getState() HostGenState {
return HostGenState{ return HostGenState{
CIDR: hg.cidr, CIDR: hg.prefix.String(),
Current: hg.current.String(), Current: hg.current.String(),
Done: hg.done, Done: hg.done,
} }
} }
func newIPGenerator(consumer string) (*IPGenerator, error) { func newIPGenerator(s *Server, consumer string) (*IPGenerator, error) {
gen := &IPGenerator{ gen := &IPGenerator{
rng: rand.New(rand.NewSource(time.Now().UnixNano())), rng: rand.New(rand.NewSource(time.Now().UnixNano())),
consumer: consumer, consumer: consumer,
@@ -173,207 +158,127 @@ func newIPGenerator(consumer string) (*IPGenerator, error) {
} }
// No saved state, initialize fresh // No saved state, initialize fresh
// Find all IP files gen.remainingCIDRs = append([]string{}, s.allCIDRs...)
err := filepath.Walk(repoDir, func(path string, info os.FileInfo, err error) error { gen.rng.Shuffle(len(gen.remainingCIDRs), func(i, j int) {
if err != nil { gen.remainingCIDRs[i], gen.remainingCIDRs[j] = gen.remainingCIDRs[j], gen.remainingCIDRs[i]
return err
}
if !info.IsDir() && strings.HasSuffix(path, ".txt") && strings.Contains(strings.ToLower(path), "ips") {
gen.cidrFiles = append(gen.cidrFiles, path)
}
return nil
}) })
gen.totalCIDRsCount = len(gen.remainingCIDRs)
gen.dirty.Store(true)
if err != nil { log.Printf("🆕 New generator for %s: %d total CIDRs", consumer, gen.totalCIDRsCount)
return nil, fmt.Errorf("failed to scan repo directory: %w", err)
}
if len(gen.cidrFiles) == 0 {
return nil, fmt.Errorf("no IP files found in %s", repoDir)
}
// Load first batch of CIDRs
if err := gen.loadNextFile(); err != nil {
return nil, err
}
log.Printf("🆕 New generator for %s: %d IP files, %d CIDRs", consumer, len(gen.cidrFiles), gen.totalCIDRsCount)
log.Printf("📁 Found %d IP files", len(gen.cidrFiles))
log.Printf("📊 Total CIDRs discovered: %d", gen.totalCIDRsCount)
return gen, nil return gen, nil
} }
func (g *IPGenerator) loadNextFile() error {
if g.currentFile >= len(g.cidrFiles) {
// Wrap around and reshuffle
g.currentFile = 0
g.rng.Shuffle(len(g.cidrFiles), func(i, j int) {
g.cidrFiles[i], g.cidrFiles[j] = g.cidrFiles[j], g.cidrFiles[i]
})
}
filepath := g.cidrFiles[g.currentFile]
g.currentFile++
file, err := os.Open(filepath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", filepath, err)
}
defer file.Close()
g.currentCIDRs = g.currentCIDRs[:0] // Clear but keep capacity
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
fields := strings.Fields(line)
for _, field := range fields {
if field != "" {
// Basic validation
if strings.Contains(field, "/") || net.ParseIP(field) != nil {
g.currentCIDRs = append(g.currentCIDRs, field)
g.totalCIDRsCount++
}
}
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("error reading %s: %w", filepath, err)
}
// Shuffle CIDRs from this file
g.rng.Shuffle(len(g.currentCIDRs), func(i, j int) {
g.currentCIDRs[i], g.currentCIDRs[j] = g.currentCIDRs[j], g.currentCIDRs[i]
})
// Initialize generators for this batch
g.activeGens = make([]*hostGenerator, 0, len(g.currentCIDRs))
for _, cidr := range g.currentCIDRs {
// Ensure it has CIDR notation
if !strings.Contains(cidr, "/") {
cidr = cidr + "/32"
}
gen, err := newHostGenerator(cidr)
if err != nil {
// Skip invalid CIDRs silently
continue
}
g.activeGens = append(g.activeGens, gen)
}
g.dirty = true
return nil
}
func (g *IPGenerator) Next() (string, error) { func (g *IPGenerator) Next() (string, error) {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
for { for {
// If no active generators, load next file if g.currentGen == nil || g.currentGen.done {
if len(g.activeGens) == 0 { if len(g.remainingCIDRs) == 0 {
if err := g.loadNextFile(); err != nil {
return "", fmt.Errorf("failed to load next file: %w", err)
}
if len(g.activeGens) == 0 {
return "", fmt.Errorf("no more IPs available") return "", fmt.Errorf("no more IPs available")
} }
cidr := g.remainingCIDRs[0]
g.remainingCIDRs = g.remainingCIDRs[1:]
if !strings.Contains(cidr, "/") {
cidr += "/32"
}
var err error
g.currentGen, err = newHostGenerator(cidr)
if err != nil {
g.dirty.Store(true)
continue
}
} }
// Pick a random generator ip, ok := g.currentGen.next()
idx := g.rng.Intn(len(g.activeGens))
gen := g.activeGens[idx]
ip, ok := gen.next()
if !ok { if !ok {
// This generator is exhausted, remove it g.currentGen = nil
g.activeGens = append(g.activeGens[:idx], g.activeGens[idx+1:]...) g.dirty.Store(true)
g.dirty = true
continue continue
} }
g.dirty = true g.dirty.Store(true)
return ip, nil return ip, nil
} }
} }
func (g *IPGenerator) buildState() GeneratorState {
// Assumes mu is held
state := GeneratorState{
RemainingCIDRs: append([]string{}, g.remainingCIDRs...),
TotalCIDRs: g.totalCIDRsCount,
}
if g.currentGen != nil && !g.currentGen.done {
state.CurrentGen = &HostGenState{
CIDR: g.currentGen.prefix.String(),
Current: g.currentGen.current.String(),
Done: false,
}
}
return state
}
func (g *IPGenerator) getState() GeneratorState { func (g *IPGenerator) getState() GeneratorState {
g.mu.Lock() g.mu.Lock()
defer g.mu.Unlock() defer g.mu.Unlock()
return g.buildState()
activeStates := make([]HostGenState, len(g.activeGens))
for i, gen := range g.activeGens {
activeStates[i] = gen.getState()
}
return GeneratorState{
CurrentFile: g.currentFile,
CurrentCIDRs: g.currentCIDRs,
ActiveGenStates: activeStates,
CIDRFiles: g.cidrFiles,
}
} }
func (g *IPGenerator) saveState() error { func (g *IPGenerator) saveState() error {
if !g.dirty { g.mu.Lock()
if !g.dirty.Load() {
g.mu.Unlock()
return nil return nil
} }
state := g.buildState()
state := g.getState() g.dirty.Store(false)
g.mu.Unlock()
// Ensure state directory exists // Ensure state directory exists
if err := os.MkdirAll(stateDir, 0755); err != nil { if err := os.MkdirAll(stateDir, 0755); err != nil {
return fmt.Errorf("failed to create state directory: %w", err) return fmt.Errorf("failed to create state directory: %w", err)
} }
// Use consumer as filename (sanitize for filesystem) // Use hash of consumer as filename
filename := strings.ReplaceAll(g.consumer, ":", "_") hash := sha256.Sum256([]byte(g.consumer))
filename = strings.ReplaceAll(filename, "/", "_") filename := hex.EncodeToString(hash[:])
filepath := filepath.Join(stateDir, filename+".json") filePath := filepath.Join(stateDir, filename+".json")
// Write to temp file first, then rename for atomic write // Write to temp file first, then rename
tempPath := filepath + ".tmp" tempPath := filePath + ".tmp"
file, err := os.Create(tempPath) file, err := os.Create(tempPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to create temp state file: %w", err) return fmt.Errorf("failed to create temp state file: %w", err)
} }
defer file.Close()
encoder := json.NewEncoder(file) encoder := json.NewEncoder(file)
encoder.SetIndent("", " ") encoder.SetIndent("", " ")
if err := encoder.Encode(state); err != nil { if err := encoder.Encode(state); err != nil {
file.Close()
os.Remove(tempPath) os.Remove(tempPath)
return fmt.Errorf("failed to encode state: %w", err) return fmt.Errorf("failed to encode state: %w", err)
} }
if err := file.Close(); err != nil { if err := os.Rename(tempPath, filePath); err != nil {
os.Remove(tempPath)
return fmt.Errorf("failed to close temp state file: %w", err)
}
if err := os.Rename(tempPath, filepath); err != nil {
os.Remove(tempPath) os.Remove(tempPath)
return fmt.Errorf("failed to rename state file: %w", err) return fmt.Errorf("failed to rename state file: %w", err)
} }
g.dirty = false
return nil return nil
} }
func (g *IPGenerator) loadState() error { func (g *IPGenerator) loadState() error {
// Use consumer as filename (sanitize for filesystem) // Use hash of consumer as filename
filename := strings.ReplaceAll(g.consumer, ":", "_") hash := sha256.Sum256([]byte(g.consumer))
filename = strings.ReplaceAll(filename, "/", "_") filename := hex.EncodeToString(hash[:])
filepath := filepath.Join(stateDir, filename+".json") filePath := filepath.Join(stateDir, filename+".json")
file, err := os.Open(filepath) file, err := os.Open(filePath)
if err != nil { if err != nil {
return err return err
} }
@@ -385,27 +290,20 @@ func (g *IPGenerator) loadState() error {
} }
// Restore state // Restore state
g.cidrFiles = state.CIDRFiles g.remainingCIDRs = state.RemainingCIDRs
g.currentFile = state.CurrentFile g.totalCIDRsCount = state.TotalCIDRs
g.currentCIDRs = state.CurrentCIDRs
g.totalCIDRsCount = len(state.CurrentCIDRs)
// Rebuild active generators from state if state.CurrentGen != nil {
g.activeGens = make([]*hostGenerator, 0, len(state.ActiveGenStates)) gen, err := newHostGenerator(state.CurrentGen.CIDR)
for _, genState := range state.ActiveGenStates {
gen, err := newHostGenerator(genState.CIDR)
if err != nil { if err != nil {
continue return err
} }
gen.current, err = netip.ParseAddr(state.CurrentGen.Current)
// Restore current IP position if err != nil {
gen.current = net.ParseIP(genState.Current) return err
if gen.current == nil {
continue
} }
gen.done = genState.Done gen.done = state.CurrentGen.Done
g.currentGen = gen
g.activeGens = append(g.activeGens, gen)
} }
return nil return nil
@@ -414,21 +312,86 @@ func (g *IPGenerator) loadState() error {
// Server holds per-consumer generators // Server holds per-consumer generators
type Server struct { type Server struct {
generators map[string]*IPGenerator generators map[string]*IPGenerator
lastAccess map[string]time.Time
allCIDRs []string
mu sync.RWMutex mu sync.RWMutex
stopSaver chan struct{} stopSaver chan struct{}
stopCleanup chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
} }
func newServer() *Server { func newServer() *Server {
s := &Server{ s := &Server{
generators: make(map[string]*IPGenerator), generators: make(map[string]*IPGenerator),
stopSaver: make(chan struct{}), lastAccess: make(map[string]time.Time),
stopSaver: make(chan struct{}),
stopCleanup: make(chan struct{}),
}
if err := s.loadAllCIDRs(); err != nil {
log.Fatalf("❌ Failed to load CIDRs: %v", err)
} }
s.startPeriodicSaver() s.startPeriodicSaver()
s.startCleanup()
return s return s
} }
func (s *Server) loadAllCIDRs() error {
// Find all IP files
var fileList []string
err := filepath.Walk(repoDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && strings.HasSuffix(path, ".txt") && strings.Contains(strings.ToLower(path), "ips") {
fileList = append(fileList, path)
}
return nil
})
if err != nil {
return fmt.Errorf("failed to scan repo directory: %w", err)
}
if len(fileList) == 0 {
return fmt.Errorf("no IP files found in %s", repoDir)
}
// Load all CIDRs
for _, path := range fileList {
file, err := os.Open(path)
if err != nil {
log.Printf("⚠️ Failed to open %s: %v", path, err)
continue
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
fields := strings.Fields(line)
for _, field := range fields {
if field != "" {
if strings.Contains(field, "/") || netip.MustParseAddr(field).IsValid() {
s.allCIDRs = append(s.allCIDRs, field)
}
}
}
}
file.Close()
if err := scanner.Err(); err != nil {
log.Printf("⚠️ Error reading %s: %v", path, err)
}
}
log.Printf("📁 Found %d IP files", len(fileList))
log.Printf("📊 Total CIDRs discovered: %d", len(s.allCIDRs))
return nil
}
func (s *Server) startPeriodicSaver() { func (s *Server) startPeriodicSaver() {
s.wg.Add(1) s.wg.Add(1)
go func() { go func() {
@@ -441,7 +404,6 @@ func (s *Server) startPeriodicSaver() {
case <-ticker.C: case <-ticker.C:
s.saveAllStates() s.saveAllStates()
case <-s.stopSaver: case <-s.stopSaver:
// Final save before shutdown
s.saveAllStates() s.saveAllStates()
return return
} }
@@ -449,23 +411,63 @@ func (s *Server) startPeriodicSaver() {
}() }()
} }
func (s *Server) startCleanup() {
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.cleanupOldGenerators()
case <-s.stopCleanup:
return
}
}
}()
}
func (s *Server) cleanupOldGenerators() {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for consumer, t := range s.lastAccess {
if now.Sub(t) > generatorTTL {
delete(s.generators, consumer)
delete(s.lastAccess, consumer)
// Remove state file
hash := sha256.Sum256([]byte(consumer))
fn := hex.EncodeToString(hash[:]) + ".json"
p := filepath.Join(stateDir, fn)
if err := os.Remove(p); err != nil && !os.IsNotExist(err) {
log.Printf("⚠️ Failed to remove state file %s: %v", p, err)
}
}
}
}
func (s *Server) saveAllStates() { func (s *Server) saveAllStates() {
s.mu.RLock() s.mu.RLock()
generators := make([]*IPGenerator, 0, len(s.generators)) gens := make([]*IPGenerator, 0, len(s.generators))
for _, gen := range s.generators { for _, gen := range s.generators {
generators = append(generators, gen) gens = append(gens, gen)
} }
s.mu.RUnlock() s.mu.RUnlock()
for _, gen := range generators { for _, gen := range gens {
if err := gen.saveState(); err != nil { if err := gen.saveState(); err != nil {
log.Printf("⚠️ Failed to save state for %s: %v", gen.consumer, err) log.Printf("⚠️ Failed to save state for %s: %v", gen.consumer, err)
} }
} }
} }
func (s *Server) shutdown() { func (s *Server) shutdown() {
close(s.stopSaver) close(s.stopSaver)
close(s.stopCleanup)
s.wg.Wait() s.wg.Wait()
log.Println("💾 All states saved") log.Println("💾 All states saved")
} }
@@ -476,24 +478,28 @@ func (s *Server) getGenerator(consumer string) (*IPGenerator, error) {
s.mu.RUnlock() s.mu.RUnlock()
if exists { if exists {
s.mu.Lock()
s.lastAccess[consumer] = time.Now()
s.mu.Unlock()
return gen, nil return gen, nil
} }
// Create new generator for this consumer
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// Double-check after acquiring write lock // Double-check
if gen, exists := s.generators[consumer]; exists { if gen, exists := s.generators[consumer]; exists {
s.lastAccess[consumer] = time.Now()
return gen, nil return gen, nil
} }
newGen, err := newIPGenerator(consumer) newGen, err := newIPGenerator(s, consumer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.generators[consumer] = newGen s.generators[consumer] = newGen
s.lastAccess[consumer] = time.Now()
log.Printf("🆕 New consumer: %s", consumer) log.Printf("🆕 New consumer: %s", consumer)
return newGen, nil return newGen, nil
@@ -505,38 +511,42 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
return return
} }
consumer := r.RemoteAddr addrPort, err := netip.ParseAddrPort(r.RemoteAddr)
if host, _, err := net.SplitHostPort(consumer); err == nil { consumerStr := r.RemoteAddr
consumer = host if err == nil {
consumerStr = addrPort.Addr().String()
} else {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil {
consumerStr = host
}
} }
gen, err := s.getGenerator(consumer) gen, err := s.getGenerator(consumerStr)
if err != nil { if err != nil {
log.Printf("❌ Failed to get generator for %s: %v", consumer, err) log.Printf("❌ Failed to get generator for %s: %v", consumerStr, err)
http.Error(w, "Internal server error", http.StatusInternalServerError) http.Error(w, "Internal server error", http.StatusInternalServerError)
return return
} }
ip, err := gen.Next() ip, err := gen.Next()
if err != nil { if err != nil {
log.Printf("❌ Failed to get IP for %s: %v", consumer, err) log.Printf("❌ Failed to get IP for %s: %v", consumerStr, err)
http.Error(w, "No more IPs available", http.StatusServiceUnavailable) http.Error(w, "No more IPs available", http.StatusServiceUnavailable)
return return
} }
log.Printf("📤 Serving IP to %s: %s", consumer, ip) log.Printf("📤 Serving IP to %s: %s", consumerStr, ip)
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
fmt.Fprintf(w, "%s\n", ip) fmt.Fprintf(w, "%s\n", ip)
} }
type ConsumerStatus struct { type ConsumerStatus struct {
Consumer string `json:"consumer"` Consumer string `json:"consumer"`
CurrentFile int `json:"current_file"` RemainingCIDRs int `json:"remaining_cidrs"`
TotalFiles int `json:"total_files"` HasActiveGen bool `json:"has_active_gen"`
ActiveCIDRs int `json:"active_cidrs"` TotalCIDRs int `json:"total_cidrs"`
TotalCIDRs int `json:"total_cidrs_discovered"`
CurrentFilePath string `json:"current_file_path,omitempty"`
} }
type StatusResponse struct { type StatusResponse struct {
@@ -565,14 +575,10 @@ func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
for consumer, gen := range s.generators { for consumer, gen := range s.generators {
gen.mu.Lock() gen.mu.Lock()
status := ConsumerStatus{ status := ConsumerStatus{
Consumer: consumer, Consumer: consumer,
CurrentFile: gen.currentFile, RemainingCIDRs: len(gen.remainingCIDRs),
TotalFiles: len(gen.cidrFiles), HasActiveGen: gen.currentGen != nil,
ActiveCIDRs: len(gen.activeGens), TotalCIDRs: gen.totalCIDRsCount,
TotalCIDRs: gen.totalCIDRsCount,
}
if gen.currentFile > 0 && gen.currentFile <= len(gen.cidrFiles) {
status.CurrentFilePath = gen.cidrFiles[gen.currentFile-1]
} }
gen.mu.Unlock() gen.mu.Unlock()
@@ -588,8 +594,8 @@ func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
} }
type ExportResponse struct { type ExportResponse struct {
ExportedAt time.Time `json:"exported_at"` ExportedAt time.Time `json:"exported_at"`
States map[string]GeneratorState `json:"states"` States map[string]GeneratorState `json:"states"`
} }
func (s *Server) handleExport(w http.ResponseWriter, r *http.Request) { func (s *Server) handleExport(w http.ResponseWriter, r *http.Request) {
@@ -606,7 +612,7 @@ func (s *Server) handleExport(w http.ResponseWriter, r *http.Request) {
response := ExportResponse{ response := ExportResponse{
ExportedAt: time.Now(), ExportedAt: time.Now(),
States: make(map[string]GeneratorState), States: make(map[string]GeneratorState, len(s.generators)),
} }
for consumer, gen := range s.generators { for consumer, gen := range s.generators {
@@ -630,9 +636,16 @@ func (s *Server) handleImport(w http.ResponseWriter, r *http.Request) {
return return
} }
reader := http.MaxBytesReader(w, r.Body, maxImportSize)
defer r.Body.Close()
var exportData ExportResponse var exportData ExportResponse
if err := json.NewDecoder(r.Body).Decode(&exportData); err != nil { if err := json.NewDecoder(reader).Decode(&exportData); err != nil {
http.Error(w, fmt.Sprintf("Failed to decode import data: %v", err), http.StatusBadRequest) if err == io.EOF || strings.Contains(err.Error(), "EOF") {
http.Error(w, "Invalid or empty request body", http.StatusBadRequest)
} else {
http.Error(w, fmt.Sprintf("Failed to decode import data: %v", err), http.StatusBadRequest)
}
return return
} }
@@ -646,14 +659,14 @@ func (s *Server) handleImport(w http.ResponseWriter, r *http.Request) {
failed := 0 failed := 0
for consumer, state := range exportData.States { for consumer, state := range exportData.States {
// Sanitize consumer name for filename // Use hash for filename
filename := strings.ReplaceAll(consumer, ":", "_") hash := sha256.Sum256([]byte(consumer))
filename = strings.ReplaceAll(filename, "/", "_") filename := hex.EncodeToString(hash[:])
filepath := filepath.Join(stateDir, filename+".json") filePath := filepath.Join(stateDir, filename+".json")
file, err := os.Create(filepath) file, err := os.Create(filePath)
if err != nil { if err != nil {
log.Printf("⚠️ Failed to create state file for %s: %v", consumer, err) log.Printf("⚠️ Failed to create state file for %s: %v", consumer, err)
failed++ failed++
continue continue
} }
@@ -662,8 +675,9 @@ func (s *Server) handleImport(w http.ResponseWriter, r *http.Request) {
encoder.SetIndent("", " ") encoder.SetIndent("", " ")
if err := encoder.Encode(state); err != nil { if err := encoder.Encode(state); err != nil {
file.Close() file.Close()
log.Printf("⚠️ Failed to encode state for %s: %v", consumer, err) log.Printf("⚠️ Failed to encode state for %s: %v", consumer, err)
failed++ failed++
os.Remove(filePath)
continue continue
} }
@@ -717,7 +731,7 @@ func main() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
// Stop periodic saver and save final state // Stop savers and save final state
server.shutdown() server.shutdown()
if err := httpServer.Shutdown(ctx); err != nil { if err := httpServer.Shutdown(ctx); err != nil {

67
manager/cert.go Normal file
View File

@@ -0,0 +1,67 @@
package main
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"os"
"time"
)
func CheckAndGenerateCerts() (string, string, error) {
certFile := "cert.pem"
keyFile := "key.pem"
// If files already exist, just use them
if _, err := os.Stat(certFile); err == nil {
return certFile, keyFile, nil
}
logger.Info("Generating self-signed TLS certificates...")
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return "", "", err
}
notBefore := time.Now()
notAfter := notBefore.Add(365 * 24 * time.Hour)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"TwoStepAuth Dev"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return "", "", err
}
certOut, _ := os.Create(certFile)
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
certOut.Close()
keyOut, _ := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
privBytes, _ := x509.MarshalECPrivateKey(priv)
pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes})
keyOut.Close()
return certFile, keyFile, nil
}

109
manager/crypto.go Normal file
View File

@@ -0,0 +1,109 @@
package main
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"io"
"golang.org/x/crypto/pbkdf2"
)
type Crypto struct {
serverKey []byte
}
func NewCrypto(serverKeyStr string) (*Crypto, error) {
// Decode the base64 server key
serverKey, err := base64.StdEncoding.DecodeString(serverKeyStr)
if err != nil {
return nil, err
}
if len(serverKey) != 32 {
return nil, errors.New("invalid server key length")
}
logger.Info("Crypto initialized with server key")
return &Crypto{serverKey: serverKey}, nil
}
func GenerateServerKey() (string, error) {
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(key), nil
}
func (c *Crypto) deriveUserKey(userID string) []byte {
// Derive a 32-byte key from user ID using PBKDF2
// Using server key as salt for additional security
return pbkdf2.Key([]byte(userID), c.serverKey, 100000, 32, sha256.New)
}
func (c *Crypto) encrypt(data []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
ciphertext := gcm.Seal(nonce, nonce, data, nil)
return ciphertext, nil
}
func (c *Crypto) decrypt(data []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, errors.New("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plaintext, nil
}
func (c *Crypto) EncryptWithServerKey(data []byte) ([]byte, error) {
return c.encrypt(data, c.serverKey)
}
func (c *Crypto) DecryptWithServerKey(data []byte) ([]byte, error) {
return c.decrypt(data, c.serverKey)
}
func (c *Crypto) EncryptWithUserKey(data []byte, userID string) ([]byte, error) {
userKey := c.deriveUserKey(userID)
return c.encrypt(data, userKey)
}
func (c *Crypto) DecryptWithUserKey(data []byte, userID string) ([]byte, error) {
userKey := c.deriveUserKey(userID)
return c.decrypt(data, userKey)
}

42
manager/dyfi.go Normal file
View File

@@ -0,0 +1,42 @@
package main
import (
"fmt"
"net/http"
"time"
)
func startDyfiUpdater(hostname, username, password string) {
if hostname == "" || username == "" || password == "" {
return
}
logger.Info("Starting dy.fi updater for %s", hostname)
update := func() {
url := fmt.Sprintf("https://www.dy.fi/nic/update?hostname=%s", hostname)
req, _ := http.NewRequest("GET", url, nil)
req.SetBasicAuth(username, password)
req.Header.Set("User-Agent", "Go-TwoStepAuth-Client/1.0")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
logger.Error("dy.fi update failed: %v", err)
return
}
defer resp.Body.Close()
logger.Info("dy.fi update status: %s", resp.Status)
}
// Update immediately on start
update()
// Update every 7 days (dy.fi requires update at least every 30 days)
go func() {
ticker := time.NewTicker(7 * 24 * time.Hour)
for range ticker.C {
update()
}
}()
}

15
manager/go.mod Normal file
View File

@@ -0,0 +1,15 @@
module manager
go 1.25.0
require (
github.com/pquerna/otp v1.5.0
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
golang.org/x/crypto v0.46.0
)
require (
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/text v0.32.0 // indirect
)

17
manager/gr.go Normal file
View File

@@ -0,0 +1,17 @@
package main
import (
"fmt"
"github.com/skip2/go-qrcode"
)
func PrintQRCode(content string) {
qr, err := qrcode.New(content, qrcode.Medium)
if err != nil {
logger.Error("Failed to generate QR code: %v", err)
return
}
// Generate QR code as string (ASCII art)
fmt.Println(qr.ToSmallString(false))
}

33
manager/logger.go Normal file
View File

@@ -0,0 +1,33 @@
package main
import (
"fmt"
"log"
"os"
)
type Logger struct {
infoLog *log.Logger
warnLog *log.Logger
errorLog *log.Logger
}
func NewLogger() *Logger {
return &Logger{
infoLog: log.New(os.Stdout, "INFO ", log.Ldate|log.Ltime|log.Lshortfile),
warnLog: log.New(os.Stdout, "WARN ", log.Ldate|log.Ltime|log.Lshortfile),
errorLog: log.New(os.Stderr, "ERROR ", log.Ldate|log.Ltime|log.Lshortfile),
}
}
func (l *Logger) Info(format string, v ...interface{}) {
l.infoLog.Output(2, fmt.Sprintf(format, v...))
}
func (l *Logger) Warn(format string, v ...interface{}) {
l.warnLog.Output(2, fmt.Sprintf(format, v...))
}
func (l *Logger) Error(format string, v ...interface{}) {
l.errorLog.Output(2, fmt.Sprintf(format, v...))
}

424
manager/main.go Normal file
View File

@@ -0,0 +1,424 @@
package main
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/json"
"encoding/pem"
"flag"
"fmt"
"io"
"log"
"math/big"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/pquerna/otp/totp"
"golang.org/x/crypto/acme/autocert"
)
var (
store *UserStore
sessions = struct {
sync.RWMutex
m map[string]*Session
}{m: make(map[string]*Session)}
logger *Logger
)
type Session struct {
UserID string
ExpiresAt time.Time
}
func main() {
// --- FLAGS ---
addUser := flag.String("add-user", "", "Add a new user (provide user ID)")
port := flag.String("port", os.Getenv("MANAGER_PORT"), "Port to run the server on (use 443 for Let's Encrypt)")
domain := flag.String("domain", os.Getenv("DYFI_DOMAIN"), "Your dy.fi domain (e.g. example.dy.fi)")
dyfiUser := flag.String("dyfi-user", os.Getenv("DYFI_USER"), "dy.fi username (email)")
dyfiPass := flag.String("dyfi-pass", os.Getenv("DYFI_PASS"), "dy.fi password")
email := flag.String("email", os.Getenv("ACME_EMAIL"), "Email for Let's Encrypt notifications")
flag.Parse()
logger = NewLogger()
// --- ENCRYPTION INITIALIZATION ---
serverKey := os.Getenv("SERVER_KEY")
if serverKey == "" {
logger.Warn("SERVER_KEY not set, generating new key")
var err error
serverKey, err = GenerateServerKey()
if err != nil {
logger.Error("Failed to generate server key: %v", err)
log.Fatal(err)
}
fmt.Printf("\n⚠ IMPORTANT: Save this SERVER_KEY to your environment:\n")
fmt.Printf("export SERVER_KEY=%s\n\n", serverKey)
}
crypto, err := NewCrypto(serverKey)
if err != nil {
logger.Error("Failed to initialize crypto: %v", err)
log.Fatal(err)
}
store = NewUserStore("users_data", crypto)
// --- BACKGROUND TASKS ---
// Reload user store from disk periodically
go func() {
ticker := time.NewTicker(1 * time.Minute)
for range ticker.C {
if err := store.Reload(); err != nil {
logger.Error("Failed to reload user store: %v", err)
}
}
}()
// Cleanup expired sessions
go func() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
cleanupSessions()
}
}()
// dy.fi Dynamic DNS Updater
if *domain != "" && *dyfiUser != "" {
startDyfiUpdater(*domain, *dyfiUser, *dyfiPass)
}
// --- CLI COMMANDS ---
if *addUser != "" {
handleNewUser(*addUser)
return
}
// --- TEMPLATE LOADING ---
tmpl, err := LoadTemplate()
if err != nil {
log.Fatal(err)
}
appTmpl, err := LoadAppTemplate()
if err != nil {
log.Fatal(err)
}
// --- ROUTES ---
// Routes must be defined BEFORE the server starts
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if session := getValidSession(r, crypto); session != nil {
http.Redirect(w, r, "/app", http.StatusSeeOther)
return
}
tmpl.Execute(w, map[string]interface{}{"Step2": false})
})
http.HandleFunc("/app", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
appTmpl.Execute(w, map[string]interface{}{"UserID": session.UserID})
})
http.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("auth_session")
if err == nil {
sessions.Lock()
delete(sessions.m, cookie.Value)
sessions.Unlock()
}
http.SetCookie(w, &http.Cookie{
Name: "auth_session",
Value: "",
Path: "/",
MaxAge: -1,
})
http.Redirect(w, r, "/", http.StatusSeeOther)
})
http.HandleFunc("/api/request", func(w http.ResponseWriter, r *http.Request) {
session := getValidSession(r, crypto)
if session == nil {
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"})
return
}
var req struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
Body string `json:"body"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
result := makeHTTPRequest(req.Method, req.URL, req.Headers, req.Body)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(result)
})
http.HandleFunc("/verify-user", func(w http.ResponseWriter, r *http.Request) {
userID := strings.TrimSpace(r.FormValue("userid"))
user, err := store.GetUser(userID)
if err != nil || user == nil {
tmpl.Execute(w, map[string]interface{}{"Step2": false, "Error": "User not found"})
return
}
sessionID := fmt.Sprintf("%d", time.Now().UnixNano())
sessions.Lock()
sessions.m[sessionID] = &Session{
UserID: userID,
ExpiresAt: time.Now().Add(5 * time.Minute),
}
sessions.Unlock()
http.SetCookie(w, &http.Cookie{
Name: "temp_session",
Value: sessionID,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
tmpl.Execute(w, map[string]interface{}{"Step2": true})
})
http.HandleFunc("/verify-totp", func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("temp_session")
if err != nil {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
sessions.RLock()
session, ok := sessions.m[cookie.Value]
sessions.RUnlock()
if !ok || time.Now().After(session.ExpiresAt) {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
user, _ := store.GetUser(session.UserID)
totpCode := strings.TrimSpace(r.FormValue("totp"))
if !totp.Validate(totpCode, user.TOTPSecret) {
tmpl.Execute(w, map[string]interface{}{"Step2": true, "Error": "Invalid TOTP code"})
return
}
sessions.Lock()
delete(sessions.m, cookie.Value)
authSessionID := fmt.Sprintf("%d", time.Now().UnixNano())
sessions.m[authSessionID] = &Session{
UserID: session.UserID,
ExpiresAt: time.Now().Add(1 * time.Hour),
}
sessions.Unlock()
encryptedSession, _ := crypto.EncryptWithServerKey([]byte(authSessionID))
http.SetCookie(w, &http.Cookie{
Name: "auth_session",
Value: base64.StdEncoding.EncodeToString(encryptedSession),
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
MaxAge: 3600,
})
http.Redirect(w, r, "/app", http.StatusSeeOther)
})
// --- SERVER STARTUP ---
if *domain != "" {
// Let's Encrypt / ACME Setup
certManager := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(*domain),
Cache: autocert.DirCache("certs_cache"),
Email: *email,
}
// Let's Encrypt requires port 80 for the HTTP-01 challenge
go func() {
logger.Info("Starting HTTP-01 challenge listener on port 80")
// This handler automatically redirects HTTP to HTTPS while solving challenges
log.Fatal(http.ListenAndServe(":80", certManager.HTTPHandler(nil)))
}()
server := &http.Server{
Addr: ":" + *port,
TLSConfig: certManager.TLSConfig(),
}
logger.Info("Secure Server starting with Let's Encrypt on https://%s", *domain)
log.Fatal(server.ListenAndServeTLS("", "")) // Certs provided by autocert
} else {
// Fallback to Self-Signed Certs
certFile, keyFile, err := setupCerts()
if err != nil {
logger.Error("TLS Setup Error: %v", err)
log.Fatal(err)
}
server := &http.Server{
Addr: ":" + *port,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
}
logger.Info("Secure Server starting with self-signed certs on https://localhost:%s", *port)
log.Fatal(server.ListenAndServeTLS(certFile, keyFile))
}
}
// setupCerts creates self-signed certs if they don't exist
func setupCerts() (string, string, error) {
certFile, keyFile := "cert.pem", "key.pem"
if _, err := os.Stat(certFile); err == nil {
return certFile, keyFile, nil
}
logger.Info("Generating self-signed certificates for HTTPS...")
priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{Organization: []string{"TwoStepAuth Dev"}},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
derBytes, _ := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
certOut, _ := os.Create(certFile)
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
certOut.Close()
keyOut, _ := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
privBytes, _ := x509.MarshalECPrivateKey(priv)
pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes})
keyOut.Close()
return certFile, keyFile, nil
}
func handleNewUser(userID string) {
secret, _ := generateSecret()
store.AddUser(userID, secret)
otpauthURL := fmt.Sprintf("otpauth://totp/TwoStepAuth:%s?secret=%s&issuer=TwoStepAuth", userID, secret)
fmt.Printf("\nUser: %s\nSecret: %s\n", userID, secret)
PrintQRCode(otpauthURL)
}
func getValidSession(r *http.Request, crypto *Crypto) *Session {
cookie, err := r.Cookie("auth_session")
if err != nil {
return nil
}
enc, _ := base64.StdEncoding.DecodeString(cookie.Value)
sid, err := crypto.DecryptWithServerKey(enc)
if err != nil {
return nil
}
sessions.RLock()
session, ok := sessions.m[string(sid)]
sessions.RUnlock()
if !ok || time.Now().After(session.ExpiresAt) {
return nil
}
return session
}
func cleanupSessions() {
sessions.Lock()
defer sessions.Unlock()
now := time.Now()
for id, s := range sessions.m {
if now.After(s.ExpiresAt) {
delete(sessions.m, id)
}
}
}
func makeHTTPRequest(method, url string, headers map[string]string, body string) map[string]interface{} {
client := &http.Client{Timeout: 30 * time.Second}
var reqBody io.Reader
if body != "" {
reqBody = strings.NewReader(body)
}
req, err := http.NewRequest(method, url, reqBody)
if err != nil {
return map[string]interface{}{
"error": err.Error(),
}
}
for key, value := range headers {
req.Header.Set(key, value)
}
start := time.Now()
resp, err := client.Do(req)
duration := time.Since(start).Milliseconds()
if err != nil {
return map[string]interface{}{
"error": err.Error(),
"duration": duration,
}
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return map[string]interface{}{
"error": err.Error(),
"duration": duration,
}
}
respHeaders := make(map[string]string)
for key, values := range resp.Header {
respHeaders[key] = strings.Join(values, ", ")
}
return map[string]interface{}{
"status": resp.StatusCode,
"headers": respHeaders,
"body": string(bodyBytes),
"duration": duration,
}
}

215
manager/store.go Normal file
View File

@@ -0,0 +1,215 @@
package main
import (
"crypto/rand"
"crypto/sha256"
"encoding/base32"
"encoding/hex"
"encoding/json"
"os"
"path/filepath"
"sync"
)
type User struct {
ID string `json:"id"`
TOTPSecret string `json:"totp_secret"`
}
type UserStore struct {
mu sync.RWMutex
filePath string
crypto *Crypto
cache map[string]*encryptedUserEntry
}
type encryptedUserEntry struct {
UserIDHash string `json:"hash"`
Data []byte `json:"data"`
}
type encryptedStore struct {
Users []encryptedUserEntry `json:"users"`
}
func NewUserStore(dataDir string, crypto *Crypto) *UserStore {
// Create data directory if it doesn't exist
if err := os.MkdirAll(dataDir, 0700); err != nil {
logger.Error("Failed to create data directory: %v", err)
}
filePath := filepath.Join(dataDir, "users.enc")
logger.Info("Initialized user store at: %s", filePath)
store := &UserStore{
filePath: filePath,
crypto: crypto,
cache: make(map[string]*encryptedUserEntry),
}
store.loadCache()
return store
}
func (s *UserStore) hashUserID(userID string) string {
hash := sha256.Sum256([]byte(userID))
return hex.EncodeToString(hash[:])
}
func (s *UserStore) Reload() error {
s.mu.Lock()
defer s.mu.Unlock()
// Clear existing cache
s.cache = make(map[string]*encryptedUserEntry)
// Reload from disk
return s.loadCacheInternal()
}
func (s *UserStore) loadCache() error {
s.mu.Lock()
defer s.mu.Unlock()
return s.loadCacheInternal()
}
func (s *UserStore) loadCacheInternal() error {
// Read encrypted store file
encryptedData, err := os.ReadFile(s.filePath)
if err != nil {
if os.IsNotExist(err) {
logger.Info("No existing user store found, starting fresh")
return nil
}
logger.Error("Failed to read store file: %v", err)
return err
}
// Decrypt with server key
decryptedData, err := s.crypto.DecryptWithServerKey(encryptedData)
if err != nil {
logger.Error("Failed to decrypt store: %v", err)
return err
}
var store encryptedStore
if err := json.Unmarshal(decryptedData, &store); err != nil {
logger.Error("Failed to unmarshal store: %v", err)
return err
}
// Load into cache
for i := range store.Users {
s.cache[store.Users[i].UserIDHash] = &store.Users[i]
}
logger.Info("Loaded %d encrypted user entries into cache", len(s.cache))
return nil
}
func (s *UserStore) save() error {
// Build store structure from cache
store := encryptedStore{
Users: make([]encryptedUserEntry, 0, len(s.cache)),
}
for _, entry := range s.cache {
store.Users = append(store.Users, *entry)
}
// Marshal to JSON
storeData, err := json.Marshal(store)
if err != nil {
return err
}
// Encrypt with server key
encryptedData, err := s.crypto.EncryptWithServerKey(storeData)
if err != nil {
logger.Error("Failed to encrypt store: %v", err)
return err
}
// Write to file
logger.Info("Saving user store with %d entries", len(s.cache))
if err := os.WriteFile(s.filePath, encryptedData, 0600); err != nil {
logger.Error("Failed to write store file: %v", err)
return err
}
return nil
}
func (s *UserStore) GetUser(userID string) (*User, error) {
s.mu.RLock()
defer s.mu.RUnlock()
userHash := s.hashUserID(userID)
entry, exists := s.cache[userHash]
if !exists {
logger.Warn("User not found in cache")
return nil, nil
}
// Decrypt with user key (derived from user ID)
userData, err := s.crypto.DecryptWithUserKey(entry.Data, userID)
if err != nil {
logger.Error("Failed to decrypt user data: %v", err)
return nil, err
}
var user User
if err := json.Unmarshal(userData, &user); err != nil {
logger.Error("Failed to unmarshal user data: %v", err)
return nil, err
}
logger.Info("Successfully loaded user: %s", user.ID)
return &user, nil
}
func (s *UserStore) AddUser(userID, totpSecret string) error {
s.mu.Lock()
defer s.mu.Unlock()
user := &User{
ID: userID,
TOTPSecret: totpSecret,
}
// Marshal user data
userData, err := json.Marshal(user)
if err != nil {
return err
}
// Encrypt with user key (derived from user ID)
userEncrypted, err := s.crypto.EncryptWithUserKey(userData, userID)
if err != nil {
logger.Error("Failed to encrypt with user key: %v", err)
return err
}
// Add to cache
userHash := s.hashUserID(userID)
s.cache[userHash] = &encryptedUserEntry{
UserIDHash: userHash,
Data: userEncrypted,
}
// Save entire store (encrypted with server key)
if err := s.save(); err != nil {
return err
}
logger.Info("Successfully saved user: %s", userID)
return nil
}
func generateSecret() (string, error) {
b := make([]byte, 20)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base32.StdEncoding.EncodeToString(b), nil
}

459
manager/template.go Normal file
View File

@@ -0,0 +1,459 @@
package main
import (
"html/template"
"os"
)
func LoadTemplate() (*template.Template, error) {
// Try to load from file first
if _, err := os.Stat("template.html"); err == nil {
logger.Info("Loading template from template.html")
return template.ParseFiles("template.html")
}
// Fall back to embedded template
logger.Info("Using embedded template")
return template.New("page").Parse(embeddedTemplate)
}
func LoadAppTemplate() (*template.Template, error) {
// Try to load from file first
if _, err := os.Stat("app.html"); err == nil {
logger.Info("Loading app template from app.html")
return template.ParseFiles("app.html")
}
// Fall back to embedded template
logger.Info("Using embedded app template")
return template.New("app").Parse(embeddedAppTemplate)
}
const embeddedTemplate = `<!DOCTYPE html>
<html>
<head>
<title>Two-Step Authentication</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Arial, sans-serif;
background: linear-gradient(135deg, #0f0f1e 0%, #1a1a2e 50%, #16213e 100%);
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
color: #b8c5d6;
}
.container {
text-align: center;
width: 100%;
max-width: 500px;
padding: 20px;
}
h1 {
font-size: 28px;
margin-bottom: 40px;
color: #4a9eff;
font-weight: 300;
letter-spacing: 1px;
}
.form-group {
display: flex;
gap: 15px;
align-items: center;
justify-content: center;
}
input {
flex: 1;
max-width: 300px;
padding: 18px 24px;
font-size: 18px;
border: 2px solid #2c3e50;
background: #1e2835;
color: #e0e6ed;
border-radius: 8px;
outline: none;
transition: all 0.3s;
}
input:focus {
border-color: #4a9eff;
background: #252f3f;
box-shadow: 0 0 20px rgba(74, 158, 255, 0.2);
}
input::placeholder { color: #5a6c7d; }
button {
padding: 18px 32px;
font-size: 18px;
background: linear-gradient(135deg, #1e40af 0%, #3b82f6 100%);
color: white;
border: none;
border-radius: 8px;
cursor: pointer;
transition: all 0.3s;
font-weight: 500;
}
button:hover {
background: linear-gradient(135deg, #2563eb 0%, #60a5fa 100%);
transform: translateY(-2px);
box-shadow: 0 4px 20px rgba(37, 99, 235, 0.3);
}
button:active { transform: translateY(0); }
.error {
color: #ef4444;
margin-top: 20px;
font-size: 16px;
background: rgba(239, 68, 68, 0.1);
padding: 12px 20px;
border-radius: 6px;
border: 1px solid rgba(239, 68, 68, 0.3);
}
.success {
margin-top: 30px;
padding: 20px;
background: rgba(34, 197, 94, 0.1);
border: 2px solid #22c55e;
border-radius: 8px;
font-size: 18px;
}
</style>
</head>
<body>
<div class="container">
{{if .Step2}}
<h1>Enter TOTP Code</h1>
<form method="POST" action="/verify-totp">
<div class="form-group">
<input type="text" name="totp" placeholder="000000" autofocus required pattern="[0-9]{6}" maxlength="6">
<button type="submit">Verify</button>
</div>
</form>
{{else}}
<h1>Enter User ID</h1>
<form method="POST" action="/verify-user">
<div class="form-group">
<input type="text" name="userid" placeholder="User ID" autofocus required>
<button type="submit">Continue</button>
</div>
</form>
{{end}}
{{if .Error}}<div class="error">{{.Error}}</div>{{end}}
</div>
</body>
</html>`
const embeddedAppTemplate = `<!DOCTYPE html>
<html>
<head>
<title>REST API Client</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Arial, sans-serif;
background: linear-gradient(135deg, #0f0f1e 0%, #1a1a2e 50%, #16213e 100%);
min-height: 100vh;
color: #b8c5d6;
padding: 20px;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 30px;
padding-bottom: 20px;
border-bottom: 2px solid #2c3e50;
}
h1 {
font-size: 28px;
color: #4a9eff;
font-weight: 300;
}
.user-info {
display: flex;
gap: 15px;
align-items: center;
}
.username {
color: #b8c5d6;
font-size: 16px;
}
.logout-btn {
padding: 10px 20px;
background: #dc2626;
color: white;
border: none;
border-radius: 6px;
cursor: pointer;
font-size: 14px;
text-decoration: none;
display: inline-block;
}
.logout-btn:hover {
background: #ef4444;
}
.request-form {
background: #1e2835;
padding: 25px;
border-radius: 10px;
margin-bottom: 20px;
border: 2px solid #2c3e50;
}
.form-row {
display: flex;
gap: 10px;
margin-bottom: 15px;
}
.form-group {
flex: 1;
display: flex;
flex-direction: column;
gap: 8px;
}
label {
color: #8b9bb0;
font-size: 14px;
font-weight: 500;
}
select, input, textarea {
padding: 12px;
background: #252f3f;
border: 2px solid #2c3e50;
color: #e0e6ed;
border-radius: 6px;
font-size: 14px;
font-family: 'Courier New', monospace;
}
select:focus, input:focus, textarea:focus {
outline: none;
border-color: #4a9eff;
}
textarea {
resize: vertical;
min-height: 100px;
}
.headers-input {
font-size: 13px;
}
button {
padding: 12px 30px;
background: linear-gradient(135deg, #1e40af 0%, #3b82f6 100%);
color: white;
border: none;
border-radius: 6px;
cursor: pointer;
font-size: 16px;
font-weight: 500;
width: 100%;
}
button:hover {
background: linear-gradient(135deg, #2563eb 0%, #60a5fa 100%);
}
button:disabled {
opacity: 0.5;
cursor: not-allowed;
}
.response-section {
background: #1e2835;
padding: 25px;
border-radius: 10px;
border: 2px solid #2c3e50;
display: none;
}
.response-section.visible {
display: block;
}
.response-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 15px;
padding-bottom: 15px;
border-bottom: 1px solid #2c3e50;
}
.status {
font-size: 18px;
font-weight: 600;
}
.status.success { color: #22c55e; }
.status.error { color: #ef4444; }
.duration {
color: #8b9bb0;
font-size: 14px;
}
.response-body, .response-headers {
background: #252f3f;
padding: 15px;
border-radius: 6px;
margin-top: 15px;
overflow-x: auto;
}
.response-body pre, .response-headers pre {
margin: 0;
color: #e0e6ed;
font-family: 'Courier New', monospace;
font-size: 13px;
line-height: 1.5;
}
.section-title {
color: #4a9eff;
font-size: 14px;
font-weight: 600;
margin-bottom: 10px;
}
.error-message {
background: rgba(239, 68, 68, 0.1);
border: 1px solid #ef4444;
color: #ef4444;
padding: 15px;
border-radius: 6px;
margin-top: 15px;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>REST API Client</h1>
<div class="user-info">
<span class="username">👤 {{.UserID}}</span>
<a href="/logout" class="logout-btn">Logout</a>
</div>
</div>
<div class="request-form">
<div class="form-row">
<div class="form-group" style="flex: 0 0 120px;">
<label>Method</label>
<select id="method">
<option>GET</option>
<option>POST</option>
<option>PUT</option>
<option>PATCH</option>
<option>DELETE</option>
</select>
</div>
<div class="form-group">
<label>URL</label>
<input type="text" id="url" placeholder="https://api.example.com/endpoint">
</div>
</div>
<div class="form-group">
<label>Headers (JSON format)</label>
<textarea id="headers" class="headers-input" placeholder='{"Content-Type": "application/json", "Authorization": "Bearer token"}'></textarea>
</div>
<div class="form-group">
<label>Request Body</label>
<textarea id="body" placeholder='{"key": "value"}'></textarea>
</div>
<button onclick="sendRequest()" id="sendBtn">Send Request</button>
</div>
<div class="response-section" id="responseSection">
<div class="response-header">
<span class="status" id="status"></span>
<span class="duration" id="duration"></span>
</div>
<div id="errorMessage" class="error-message" style="display: none;"></div>
<div id="responseHeaders">
<div class="section-title">Response Headers</div>
<div class="response-headers">
<pre id="headersContent"></pre>
</div>
</div>
<div class="response-body">
<div class="section-title">Response Body</div>
<pre id="bodyContent"></pre>
</div>
</div>
</div>
<script>
async function sendRequest() {
const method = document.getElementById('method').value;
const url = document.getElementById('url').value;
const headersText = document.getElementById('headers').value;
const body = document.getElementById('body').value;
const sendBtn = document.getElementById('sendBtn');
const responseSection = document.getElementById('responseSection');
const errorMessage = document.getElementById('errorMessage');
if (!url) {
alert('Please enter a URL');
return;
}
let headers = {};
if (headersText.trim()) {
try {
headers = JSON.parse(headersText);
} catch (e) {
alert('Invalid JSON in headers');
return;
}
}
sendBtn.disabled = true;
sendBtn.textContent = 'Sending...';
responseSection.classList.remove('visible');
errorMessage.style.display = 'none';
try {
const response = await fetch('/api/request', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ method, url, headers, body })
});
const result = await response.json();
responseSection.classList.add('visible');
if (result.error) {
document.getElementById('status').textContent = 'Error';
document.getElementById('status').className = 'status error';
errorMessage.textContent = result.error;
errorMessage.style.display = 'block';
document.getElementById('responseHeaders').style.display = 'none';
document.getElementById('bodyContent').textContent = '';
} else {
document.getElementById('status').textContent = 'Status: ' + result.status;
document.getElementById('status').className = result.status < 400 ? 'status success' : 'status error';
document.getElementById('responseHeaders').style.display = 'block';
const formattedHeaders = Object.entries(result.headers)
.map(([key, value]) => key + ': ' + value)
.join('\n');
document.getElementById('headersContent').textContent = formattedHeaders;
try {
const parsed = JSON.parse(result.body);
document.getElementById('bodyContent').textContent = JSON.stringify(parsed, null, 2);
} catch {
document.getElementById('bodyContent').textContent = result.body;
}
}
document.getElementById('duration').textContent = result.duration + 'ms';
} catch (error) {
responseSection.classList.add('visible');
document.getElementById('status').textContent = 'Request Failed';
document.getElementById('status').className = 'status error';
errorMessage.textContent = error.message;
errorMessage.style.display = 'block';
document.getElementById('responseHeaders').style.display = 'none';
} finally {
sendBtn.disabled = false;
sendBtn.textContent = 'Send Request';
}
}
// Allow Enter key in textareas
document.getElementById('url').addEventListener('keypress', function(e) {
if (e.key === 'Enter') sendRequest();
});
</script>
</body>
</html>`

7
output_service/README.md Normal file
View File

@@ -0,0 +1,7 @@
# output service
Service to receive output from ping_service instances.
Builds database of mappable nodes.
Updates input services address lists with all working endpoints and working hops from the traces.
Have reporting api endpoints for the manager to monitor the progress.