diff --git a/input_service/http_input_service.go b/input_service/http_input_service.go index 16ec888..c848d7b 100644 --- a/input_service/http_input_service.go +++ b/input_service/http_input_service.go @@ -3,34 +3,41 @@ package main import ( "bufio" "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" + "io" "log" "math/rand" "net" "net/http" + "net/netip" "os" "os/signal" "path/filepath" "strings" "sync" + "sync/atomic" "syscall" "time" ) const ( - repoDir = "cloud-provider-ip-addresses" - port = 8080 - stateDir = "progress_state" - saveInterval = 30 * time.Second + repoDir = "cloud-provider-ip-addresses" + port = 8080 + stateDir = "progress_state" + saveInterval = 30 * time.Second + cleanupInterval = 5 * time.Minute + generatorTTL = 24 * time.Hour + maxImportSize = 10 * 1024 * 1024 // 10MB ) // GeneratorState represents the serializable state of a generator type GeneratorState struct { - CurrentFile int `json:"current_file"` - CurrentCIDRs []string `json:"current_cidrs"` - ActiveGenStates []HostGenState `json:"active_gen_states"` - CIDRFiles []string `json:"cidr_files"` + RemainingCIDRs []string `json:"remaining_cidrs"` + CurrentGen *HostGenState `json:"current_gen,omitempty"` + TotalCIDRs int `json:"total_cidrs"` } type HostGenState struct { @@ -41,67 +48,72 @@ type HostGenState struct { // IPGenerator generates IPs from CIDR ranges lazily type IPGenerator struct { - mu sync.Mutex - cidrFiles []string - currentFile int - currentCIDRs []string - activeGens []*hostGenerator - rng *rand.Rand - totalCIDRsCount int - consumer string - dirty bool + mu sync.Mutex + rng *rand.Rand + totalCIDRsCount int + remainingCIDRs []string + currentGen *hostGenerator + consumer string + dirty atomic.Bool } type hostGenerator struct { - cidr string - network *net.IPNet - current net.IP + prefix netip.Prefix + current netip.Addr + last netip.Addr done bool } +func addrToUint32(a netip.Addr) uint32 { + b := a.As4() + return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) +} + +func uint32ToAddr(u uint32) netip.Addr { + return netip.AddrFrom4([4]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)}) +} + func newHostGenerator(cidr string) (*hostGenerator, error) { - _, network, err := net.ParseCIDR(cidr) + prefix, err := netip.ParsePrefix(cidr) if err != nil { return nil, err } - - // Only IPv4 - if network.IP.To4() == nil { - return nil, fmt.Errorf("not IPv4") + prefix = prefix.Masked() + if !prefix.IsValid() || !prefix.Addr().Is4() { + return nil, fmt.Errorf("invalid IPv4 prefix") } - - // Check if multicast - if network.IP.IsMulticast() { + if prefix.Addr().IsMulticast() { return nil, fmt.Errorf("multicast network") } - ones, bits := network.Mask.Size() - - hg := &hostGenerator{ - cidr: cidr, - network: network, - current: make(net.IP, len(network.IP)), - } - copy(hg.current, network.IP) + ip := prefix.Addr() + maskLen := prefix.Bits() - // For /32, just use the single address - if ones == bits { - return hg, nil + var first, last netip.Addr + lastUint := addrToUint32(ip) | ((1 << (32 - uint(maskLen))) - 1) + last = uint32ToAddr(lastUint) + + if maskLen == 32 { + first = ip + last = ip + } else if maskLen == 31 { + first = ip + // last already ip + 1 + } else { + first = ip.Next() + last = last.Prev() } - // For other networks, skip network address (start at .1) - hg.increment() - - return hg, nil -} - -func (hg *hostGenerator) increment() { - for i := len(hg.current) - 1; i >= 0; i-- { - hg.current[i]++ - if hg.current[i] != 0 { - break - } + if !prefix.Contains(first) || !prefix.Contains(last) { + return nil, fmt.Errorf("invalid range") } + + return &hostGenerator{ + prefix: prefix, + current: first, + last: last, + done: false, + }, nil } func (hg *hostGenerator) next() (string, bool) { @@ -109,58 +121,31 @@ func (hg *hostGenerator) next() (string, bool) { return "", false } - ones, bits := hg.network.Mask.Size() - - // Handle /32 specially - if ones == bits { - if !hg.current.Equal(hg.network.IP) { - hg.done = true - return "", false - } - ip := hg.current.String() - hg.done = true - return ip, true - } - - // Check if we're still in the network - if !hg.network.Contains(hg.current) { + if !hg.prefix.Contains(hg.current) || addrToUint32(hg.current) > addrToUint32(hg.last) { hg.done = true return "", false } - // Check if this is the broadcast address (last IP in range) - broadcast := make(net.IP, len(hg.network.IP)) - copy(broadcast, hg.network.IP) - for i := range broadcast { - broadcast[i] |= ^hg.network.Mask[i] - } - - if hg.current.Equal(broadcast) { - hg.done = true - return "", false - } - - // Skip multicast addresses if hg.current.IsMulticast() { - hg.increment() + hg.current = hg.current.Next() return hg.next() } ip := hg.current.String() - hg.increment() - + hg.current = hg.current.Next() + return ip, true } func (hg *hostGenerator) getState() HostGenState { return HostGenState{ - CIDR: hg.cidr, + CIDR: hg.prefix.String(), Current: hg.current.String(), Done: hg.done, } } -func newIPGenerator(consumer string) (*IPGenerator, error) { +func newIPGenerator(s *Server, consumer string) (*IPGenerator, error) { gen := &IPGenerator{ rng: rand.New(rand.NewSource(time.Now().UnixNano())), consumer: consumer, @@ -173,207 +158,127 @@ func newIPGenerator(consumer string) (*IPGenerator, error) { } // No saved state, initialize fresh - // Find all IP files - err := filepath.Walk(repoDir, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - if !info.IsDir() && strings.HasSuffix(path, ".txt") && strings.Contains(strings.ToLower(path), "ips") { - gen.cidrFiles = append(gen.cidrFiles, path) - } - return nil + gen.remainingCIDRs = append([]string{}, s.allCIDRs...) + gen.rng.Shuffle(len(gen.remainingCIDRs), func(i, j int) { + gen.remainingCIDRs[i], gen.remainingCIDRs[j] = gen.remainingCIDRs[j], gen.remainingCIDRs[i] }) + gen.totalCIDRsCount = len(gen.remainingCIDRs) + gen.dirty.Store(true) - if err != nil { - return nil, fmt.Errorf("failed to scan repo directory: %w", err) - } - - if len(gen.cidrFiles) == 0 { - return nil, fmt.Errorf("no IP files found in %s", repoDir) - } - - // Load first batch of CIDRs - if err := gen.loadNextFile(); err != nil { - return nil, err - } - - log.Printf("šŸ†• New generator for %s: %d IP files, %d CIDRs", consumer, len(gen.cidrFiles), gen.totalCIDRsCount) - log.Printf("šŸ“ Found %d IP files", len(gen.cidrFiles)) - log.Printf("šŸ“Š Total CIDRs discovered: %d", gen.totalCIDRsCount) + log.Printf("šŸ†• New generator for %s: %d total CIDRs", consumer, gen.totalCIDRsCount) return gen, nil } -func (g *IPGenerator) loadNextFile() error { - if g.currentFile >= len(g.cidrFiles) { - // Wrap around and reshuffle - g.currentFile = 0 - g.rng.Shuffle(len(g.cidrFiles), func(i, j int) { - g.cidrFiles[i], g.cidrFiles[j] = g.cidrFiles[j], g.cidrFiles[i] - }) - } - - filepath := g.cidrFiles[g.currentFile] - g.currentFile++ - - file, err := os.Open(filepath) - if err != nil { - return fmt.Errorf("failed to open %s: %w", filepath, err) - } - defer file.Close() - - g.currentCIDRs = g.currentCIDRs[:0] // Clear but keep capacity - scanner := bufio.NewScanner(file) - - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - - fields := strings.Fields(line) - for _, field := range fields { - if field != "" { - // Basic validation - if strings.Contains(field, "/") || net.ParseIP(field) != nil { - g.currentCIDRs = append(g.currentCIDRs, field) - g.totalCIDRsCount++ - } - } - } - } - - if err := scanner.Err(); err != nil { - return fmt.Errorf("error reading %s: %w", filepath, err) - } - - // Shuffle CIDRs from this file - g.rng.Shuffle(len(g.currentCIDRs), func(i, j int) { - g.currentCIDRs[i], g.currentCIDRs[j] = g.currentCIDRs[j], g.currentCIDRs[i] - }) - - // Initialize generators for this batch - g.activeGens = make([]*hostGenerator, 0, len(g.currentCIDRs)) - for _, cidr := range g.currentCIDRs { - // Ensure it has CIDR notation - if !strings.Contains(cidr, "/") { - cidr = cidr + "/32" - } - - gen, err := newHostGenerator(cidr) - if err != nil { - // Skip invalid CIDRs silently - continue - } - g.activeGens = append(g.activeGens, gen) - } - - g.dirty = true - return nil -} - func (g *IPGenerator) Next() (string, error) { g.mu.Lock() defer g.mu.Unlock() for { - // If no active generators, load next file - if len(g.activeGens) == 0 { - if err := g.loadNextFile(); err != nil { - return "", fmt.Errorf("failed to load next file: %w", err) - } - if len(g.activeGens) == 0 { + if g.currentGen == nil || g.currentGen.done { + if len(g.remainingCIDRs) == 0 { return "", fmt.Errorf("no more IPs available") } + + cidr := g.remainingCIDRs[0] + g.remainingCIDRs = g.remainingCIDRs[1:] + + if !strings.Contains(cidr, "/") { + cidr += "/32" + } + + var err error + g.currentGen, err = newHostGenerator(cidr) + if err != nil { + g.dirty.Store(true) + continue + } } - // Pick a random generator - idx := g.rng.Intn(len(g.activeGens)) - gen := g.activeGens[idx] - - ip, ok := gen.next() + ip, ok := g.currentGen.next() if !ok { - // This generator is exhausted, remove it - g.activeGens = append(g.activeGens[:idx], g.activeGens[idx+1:]...) - g.dirty = true + g.currentGen = nil + g.dirty.Store(true) continue } - g.dirty = true + g.dirty.Store(true) return ip, nil } } +func (g *IPGenerator) buildState() GeneratorState { + // Assumes mu is held + state := GeneratorState{ + RemainingCIDRs: append([]string{}, g.remainingCIDRs...), + TotalCIDRs: g.totalCIDRsCount, + } + if g.currentGen != nil && !g.currentGen.done { + state.CurrentGen = &HostGenState{ + CIDR: g.currentGen.prefix.String(), + Current: g.currentGen.current.String(), + Done: false, + } + } + return state +} + func (g *IPGenerator) getState() GeneratorState { g.mu.Lock() defer g.mu.Unlock() - - activeStates := make([]HostGenState, len(g.activeGens)) - for i, gen := range g.activeGens { - activeStates[i] = gen.getState() - } - - return GeneratorState{ - CurrentFile: g.currentFile, - CurrentCIDRs: g.currentCIDRs, - ActiveGenStates: activeStates, - CIDRFiles: g.cidrFiles, - } + return g.buildState() } func (g *IPGenerator) saveState() error { - if !g.dirty { + g.mu.Lock() + if !g.dirty.Load() { + g.mu.Unlock() return nil } - - state := g.getState() + state := g.buildState() + g.dirty.Store(false) + g.mu.Unlock() // Ensure state directory exists if err := os.MkdirAll(stateDir, 0755); err != nil { return fmt.Errorf("failed to create state directory: %w", err) } - // Use consumer as filename (sanitize for filesystem) - filename := strings.ReplaceAll(g.consumer, ":", "_") - filename = strings.ReplaceAll(filename, "/", "_") - filepath := filepath.Join(stateDir, filename+".json") + // Use hash of consumer as filename + hash := sha256.Sum256([]byte(g.consumer)) + filename := hex.EncodeToString(hash[:]) + filePath := filepath.Join(stateDir, filename+".json") - // Write to temp file first, then rename for atomic write - tempPath := filepath + ".tmp" + // Write to temp file first, then rename + tempPath := filePath + ".tmp" file, err := os.Create(tempPath) if err != nil { return fmt.Errorf("failed to create temp state file: %w", err) } + defer file.Close() encoder := json.NewEncoder(file) encoder.SetIndent("", " ") if err := encoder.Encode(state); err != nil { - file.Close() os.Remove(tempPath) return fmt.Errorf("failed to encode state: %w", err) } - if err := file.Close(); err != nil { - os.Remove(tempPath) - return fmt.Errorf("failed to close temp state file: %w", err) - } - - if err := os.Rename(tempPath, filepath); err != nil { + if err := os.Rename(tempPath, filePath); err != nil { os.Remove(tempPath) return fmt.Errorf("failed to rename state file: %w", err) } - g.dirty = false return nil } func (g *IPGenerator) loadState() error { - // Use consumer as filename (sanitize for filesystem) - filename := strings.ReplaceAll(g.consumer, ":", "_") - filename = strings.ReplaceAll(filename, "/", "_") - filepath := filepath.Join(stateDir, filename+".json") + // Use hash of consumer as filename + hash := sha256.Sum256([]byte(g.consumer)) + filename := hex.EncodeToString(hash[:]) + filePath := filepath.Join(stateDir, filename+".json") - file, err := os.Open(filepath) + file, err := os.Open(filePath) if err != nil { return err } @@ -385,27 +290,20 @@ func (g *IPGenerator) loadState() error { } // Restore state - g.cidrFiles = state.CIDRFiles - g.currentFile = state.CurrentFile - g.currentCIDRs = state.CurrentCIDRs - g.totalCIDRsCount = len(state.CurrentCIDRs) + g.remainingCIDRs = state.RemainingCIDRs + g.totalCIDRsCount = state.TotalCIDRs - // Rebuild active generators from state - g.activeGens = make([]*hostGenerator, 0, len(state.ActiveGenStates)) - for _, genState := range state.ActiveGenStates { - gen, err := newHostGenerator(genState.CIDR) + if state.CurrentGen != nil { + gen, err := newHostGenerator(state.CurrentGen.CIDR) if err != nil { - continue + return err } - - // Restore current IP position - gen.current = net.ParseIP(genState.Current) - if gen.current == nil { - continue + gen.current, err = netip.ParseAddr(state.CurrentGen.Current) + if err != nil { + return err } - gen.done = genState.Done - - g.activeGens = append(g.activeGens, gen) + gen.done = state.CurrentGen.Done + g.currentGen = gen } return nil @@ -414,21 +312,86 @@ func (g *IPGenerator) loadState() error { // Server holds per-consumer generators type Server struct { generators map[string]*IPGenerator + lastAccess map[string]time.Time + allCIDRs []string mu sync.RWMutex stopSaver chan struct{} + stopCleanup chan struct{} wg sync.WaitGroup } func newServer() *Server { s := &Server{ - generators: make(map[string]*IPGenerator), - stopSaver: make(chan struct{}), + generators: make(map[string]*IPGenerator), + lastAccess: make(map[string]time.Time), + stopSaver: make(chan struct{}), + stopCleanup: make(chan struct{}), } - + + if err := s.loadAllCIDRs(); err != nil { + log.Fatalf("āŒ Failed to load CIDRs: %v", err) + } + s.startPeriodicSaver() + s.startCleanup() + return s } +func (s *Server) loadAllCIDRs() error { + // Find all IP files + var fileList []string + err := filepath.Walk(repoDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() && strings.HasSuffix(path, ".txt") && strings.Contains(strings.ToLower(path), "ips") { + fileList = append(fileList, path) + } + return nil + }) + if err != nil { + return fmt.Errorf("failed to scan repo directory: %w", err) + } + + if len(fileList) == 0 { + return fmt.Errorf("no IP files found in %s", repoDir) + } + + // Load all CIDRs + for _, path := range fileList { + file, err := os.Open(path) + if err != nil { + log.Printf("āš ļø Failed to open %s: %v", path, err) + continue + } + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + fields := strings.Fields(line) + for _, field := range fields { + if field != "" { + if strings.Contains(field, "/") || netip.MustParseAddr(field).IsValid() { + s.allCIDRs = append(s.allCIDRs, field) + } + } + } + } + file.Close() + if err := scanner.Err(); err != nil { + log.Printf("āš ļø Error reading %s: %v", path, err) + } + } + + log.Printf("šŸ“ Found %d IP files", len(fileList)) + log.Printf("šŸ“Š Total CIDRs discovered: %d", len(s.allCIDRs)) + + return nil +} + func (s *Server) startPeriodicSaver() { s.wg.Add(1) go func() { @@ -441,7 +404,6 @@ func (s *Server) startPeriodicSaver() { case <-ticker.C: s.saveAllStates() case <-s.stopSaver: - // Final save before shutdown s.saveAllStates() return } @@ -449,23 +411,63 @@ func (s *Server) startPeriodicSaver() { }() } +func (s *Server) startCleanup() { + s.wg.Add(1) + go func() { + defer s.wg.Done() + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.cleanupOldGenerators() + case <-s.stopCleanup: + return + } + } + }() +} + +func (s *Server) cleanupOldGenerators() { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + for consumer, t := range s.lastAccess { + if now.Sub(t) > generatorTTL { + delete(s.generators, consumer) + delete(s.lastAccess, consumer) + + // Remove state file + hash := sha256.Sum256([]byte(consumer)) + fn := hex.EncodeToString(hash[:]) + ".json" + p := filepath.Join(stateDir, fn) + if err := os.Remove(p); err != nil && !os.IsNotExist(err) { + log.Printf("āš ļø Failed to remove state file %s: %v", p, err) + } + } + } +} + func (s *Server) saveAllStates() { s.mu.RLock() - generators := make([]*IPGenerator, 0, len(s.generators)) + gens := make([]*IPGenerator, 0, len(s.generators)) for _, gen := range s.generators { - generators = append(generators, gen) + gens = append(gens, gen) } s.mu.RUnlock() - for _, gen := range generators { + for _, gen := range gens { if err := gen.saveState(); err != nil { - log.Printf("āš ļø Failed to save state for %s: %v", gen.consumer, err) + log.Printf("āš ļø Failed to save state for %s: %v", gen.consumer, err) } } } func (s *Server) shutdown() { close(s.stopSaver) + close(s.stopCleanup) s.wg.Wait() log.Println("šŸ’¾ All states saved") } @@ -476,26 +478,30 @@ func (s *Server) getGenerator(consumer string) (*IPGenerator, error) { s.mu.RUnlock() if exists { + s.mu.Lock() + s.lastAccess[consumer] = time.Now() + s.mu.Unlock() return gen, nil } - // Create new generator for this consumer s.mu.Lock() defer s.mu.Unlock() - // Double-check after acquiring write lock + // Double-check if gen, exists := s.generators[consumer]; exists { + s.lastAccess[consumer] = time.Now() return gen, nil } - newGen, err := newIPGenerator(consumer) + newGen, err := newIPGenerator(s, consumer) if err != nil { return nil, err } s.generators[consumer] = newGen + s.lastAccess[consumer] = time.Now() log.Printf("šŸ†• New consumer: %s", consumer) - + return newGen, nil } @@ -505,38 +511,42 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { return } - consumer := r.RemoteAddr - if host, _, err := net.SplitHostPort(consumer); err == nil { - consumer = host + addrPort, err := netip.ParseAddrPort(r.RemoteAddr) + consumerStr := r.RemoteAddr + if err == nil { + consumerStr = addrPort.Addr().String() + } else { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err == nil { + consumerStr = host + } } - gen, err := s.getGenerator(consumer) + gen, err := s.getGenerator(consumerStr) if err != nil { - log.Printf("āŒ Failed to get generator for %s: %v", consumer, err) + log.Printf("āŒ Failed to get generator for %s: %v", consumerStr, err) http.Error(w, "Internal server error", http.StatusInternalServerError) return } ip, err := gen.Next() if err != nil { - log.Printf("āŒ Failed to get IP for %s: %v", consumer, err) + log.Printf("āŒ Failed to get IP for %s: %v", consumerStr, err) http.Error(w, "No more IPs available", http.StatusServiceUnavailable) return } - log.Printf("šŸ“¤ Serving IP to %s: %s", consumer, ip) + log.Printf("šŸ“¤ Serving IP to %s: %s", consumerStr, ip) w.Header().Set("Content-Type", "text/plain") fmt.Fprintf(w, "%s\n", ip) } type ConsumerStatus struct { - Consumer string `json:"consumer"` - CurrentFile int `json:"current_file"` - TotalFiles int `json:"total_files"` - ActiveCIDRs int `json:"active_cidrs"` - TotalCIDRs int `json:"total_cidrs_discovered"` - CurrentFilePath string `json:"current_file_path,omitempty"` + Consumer string `json:"consumer"` + RemainingCIDRs int `json:"remaining_cidrs"` + HasActiveGen bool `json:"has_active_gen"` + TotalCIDRs int `json:"total_cidrs"` } type StatusResponse struct { @@ -565,14 +575,10 @@ func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { for consumer, gen := range s.generators { gen.mu.Lock() status := ConsumerStatus{ - Consumer: consumer, - CurrentFile: gen.currentFile, - TotalFiles: len(gen.cidrFiles), - ActiveCIDRs: len(gen.activeGens), - TotalCIDRs: gen.totalCIDRsCount, - } - if gen.currentFile > 0 && gen.currentFile <= len(gen.cidrFiles) { - status.CurrentFilePath = gen.cidrFiles[gen.currentFile-1] + Consumer: consumer, + RemainingCIDRs: len(gen.remainingCIDRs), + HasActiveGen: gen.currentGen != nil, + TotalCIDRs: gen.totalCIDRsCount, } gen.mu.Unlock() @@ -588,8 +594,8 @@ func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { } type ExportResponse struct { - ExportedAt time.Time `json:"exported_at"` - States map[string]GeneratorState `json:"states"` + ExportedAt time.Time `json:"exported_at"` + States map[string]GeneratorState `json:"states"` } func (s *Server) handleExport(w http.ResponseWriter, r *http.Request) { @@ -606,7 +612,7 @@ func (s *Server) handleExport(w http.ResponseWriter, r *http.Request) { response := ExportResponse{ ExportedAt: time.Now(), - States: make(map[string]GeneratorState), + States: make(map[string]GeneratorState, len(s.generators)), } for consumer, gen := range s.generators { @@ -614,9 +620,9 @@ func (s *Server) handleExport(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "application/json") - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=state-export-%s.json", + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=state-export-%s.json", time.Now().Format("2006-01-02-150405"))) - + encoder := json.NewEncoder(w) encoder.SetIndent("", " ") if err := encoder.Encode(response); err != nil { @@ -630,9 +636,16 @@ func (s *Server) handleImport(w http.ResponseWriter, r *http.Request) { return } + reader := http.MaxBytesReader(w, r.Body, maxImportSize) + defer r.Body.Close() + var exportData ExportResponse - if err := json.NewDecoder(r.Body).Decode(&exportData); err != nil { - http.Error(w, fmt.Sprintf("Failed to decode import data: %v", err), http.StatusBadRequest) + if err := json.NewDecoder(reader).Decode(&exportData); err != nil { + if err == io.EOF || strings.Contains(err.Error(), "EOF") { + http.Error(w, "Invalid or empty request body", http.StatusBadRequest) + } else { + http.Error(w, fmt.Sprintf("Failed to decode import data: %v", err), http.StatusBadRequest) + } return } @@ -646,14 +659,14 @@ func (s *Server) handleImport(w http.ResponseWriter, r *http.Request) { failed := 0 for consumer, state := range exportData.States { - // Sanitize consumer name for filename - filename := strings.ReplaceAll(consumer, ":", "_") - filename = strings.ReplaceAll(filename, "/", "_") - filepath := filepath.Join(stateDir, filename+".json") + // Use hash for filename + hash := sha256.Sum256([]byte(consumer)) + filename := hex.EncodeToString(hash[:]) + filePath := filepath.Join(stateDir, filename+".json") - file, err := os.Create(filepath) + file, err := os.Create(filePath) if err != nil { - log.Printf("āš ļø Failed to create state file for %s: %v", consumer, err) + log.Printf("āš ļø Failed to create state file for %s: %v", consumer, err) failed++ continue } @@ -662,8 +675,9 @@ func (s *Server) handleImport(w http.ResponseWriter, r *http.Request) { encoder.SetIndent("", " ") if err := encoder.Encode(state); err != nil { file.Close() - log.Printf("āš ļø Failed to encode state for %s: %v", consumer, err) + log.Printf("āš ļø Failed to encode state for %s: %v", consumer, err) failed++ + os.Remove(filePath) continue } @@ -680,7 +694,7 @@ func (s *Server) handleImport(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) - + log.Printf("šŸ“„ Imported %d consumer states (%d failed)", imported, failed) } @@ -717,7 +731,7 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - // Stop periodic saver and save final state + // Stop savers and save final state server.shutdown() if err := httpServer.Shutdown(ctx); err != nil { @@ -739,4 +753,4 @@ func main() { } log.Println("āœ… Server stopped cleanly") -} +} \ No newline at end of file diff --git a/manager/cert.go b/manager/cert.go new file mode 100644 index 0000000..4274eda --- /dev/null +++ b/manager/cert.go @@ -0,0 +1,67 @@ +package main + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "os" + "time" +) + +func CheckAndGenerateCerts() (string, string, error) { + certFile := "cert.pem" + keyFile := "key.pem" + + // If files already exist, just use them + if _, err := os.Stat(certFile); err == nil { + return certFile, keyFile, nil + } + + logger.Info("Generating self-signed TLS certificates...") + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return "", "", err + } + + notBefore := time.Now() + notAfter := notBefore.Add(365 * 24 * time.Hour) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit) + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"TwoStepAuth Dev"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return "", "", err + } + + certOut, _ := os.Create(certFile) + pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + certOut.Close() + + keyOut, _ := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + privBytes, _ := x509.MarshalECPrivateKey(priv) + pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes}) + keyOut.Close() + + return certFile, keyFile, nil +} diff --git a/manager/crypto.go b/manager/crypto.go new file mode 100644 index 0000000..80804d3 --- /dev/null +++ b/manager/crypto.go @@ -0,0 +1,109 @@ +package main + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "io" + + "golang.org/x/crypto/pbkdf2" +) + +type Crypto struct { + serverKey []byte +} + +func NewCrypto(serverKeyStr string) (*Crypto, error) { + // Decode the base64 server key + serverKey, err := base64.StdEncoding.DecodeString(serverKeyStr) + if err != nil { + return nil, err + } + + if len(serverKey) != 32 { + return nil, errors.New("invalid server key length") + } + + logger.Info("Crypto initialized with server key") + return &Crypto{serverKey: serverKey}, nil +} + +func GenerateServerKey() (string, error) { + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(key), nil +} + +func (c *Crypto) deriveUserKey(userID string) []byte { + // Derive a 32-byte key from user ID using PBKDF2 + // Using server key as salt for additional security + return pbkdf2.Key([]byte(userID), c.serverKey, 100000, 32, sha256.New) +} + +func (c *Crypto) encrypt(data []byte, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + ciphertext := gcm.Seal(nonce, nonce, data, nil) + return ciphertext, nil +} + +func (c *Crypto) decrypt(data []byte, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + nonceSize := gcm.NonceSize() + if len(data) < nonceSize { + return nil, errors.New("ciphertext too short") + } + + nonce, ciphertext := data[:nonceSize], data[nonceSize:] + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, err + } + + return plaintext, nil +} + +func (c *Crypto) EncryptWithServerKey(data []byte) ([]byte, error) { + return c.encrypt(data, c.serverKey) +} + +func (c *Crypto) DecryptWithServerKey(data []byte) ([]byte, error) { + return c.decrypt(data, c.serverKey) +} + +func (c *Crypto) EncryptWithUserKey(data []byte, userID string) ([]byte, error) { + userKey := c.deriveUserKey(userID) + return c.encrypt(data, userKey) +} + +func (c *Crypto) DecryptWithUserKey(data []byte, userID string) ([]byte, error) { + userKey := c.deriveUserKey(userID) + return c.decrypt(data, userKey) +} diff --git a/manager/dyfi.go b/manager/dyfi.go new file mode 100644 index 0000000..09d1acd --- /dev/null +++ b/manager/dyfi.go @@ -0,0 +1,42 @@ +package main + +import ( + "fmt" + "net/http" + "time" +) + +func startDyfiUpdater(hostname, username, password string) { + if hostname == "" || username == "" || password == "" { + return + } + + logger.Info("Starting dy.fi updater for %s", hostname) + + update := func() { + url := fmt.Sprintf("https://www.dy.fi/nic/update?hostname=%s", hostname) + req, _ := http.NewRequest("GET", url, nil) + req.SetBasicAuth(username, password) + req.Header.Set("User-Agent", "Go-TwoStepAuth-Client/1.0") + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + logger.Error("dy.fi update failed: %v", err) + return + } + defer resp.Body.Close() + logger.Info("dy.fi update status: %s", resp.Status) + } + + // Update immediately on start + update() + + // Update every 7 days (dy.fi requires update at least every 30 days) + go func() { + ticker := time.NewTicker(7 * 24 * time.Hour) + for range ticker.C { + update() + } + }() +} diff --git a/manager/go.mod b/manager/go.mod new file mode 100644 index 0000000..6febe7b --- /dev/null +++ b/manager/go.mod @@ -0,0 +1,15 @@ +module manager + +go 1.25.0 + +require ( + github.com/pquerna/otp v1.5.0 + github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e + golang.org/x/crypto v0.46.0 +) + +require ( + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect + golang.org/x/net v0.47.0 // indirect + golang.org/x/text v0.32.0 // indirect +) diff --git a/manager/gr.go b/manager/gr.go new file mode 100644 index 0000000..99bfee8 --- /dev/null +++ b/manager/gr.go @@ -0,0 +1,17 @@ +package main + +import ( + "fmt" + "github.com/skip2/go-qrcode" +) + +func PrintQRCode(content string) { + qr, err := qrcode.New(content, qrcode.Medium) + if err != nil { + logger.Error("Failed to generate QR code: %v", err) + return + } + + // Generate QR code as string (ASCII art) + fmt.Println(qr.ToSmallString(false)) +} diff --git a/manager/logger.go b/manager/logger.go new file mode 100644 index 0000000..6c10d89 --- /dev/null +++ b/manager/logger.go @@ -0,0 +1,33 @@ +package main + +import ( + "fmt" + "log" + "os" +) + +type Logger struct { + infoLog *log.Logger + warnLog *log.Logger + errorLog *log.Logger +} + +func NewLogger() *Logger { + return &Logger{ + infoLog: log.New(os.Stdout, "INFO ", log.Ldate|log.Ltime|log.Lshortfile), + warnLog: log.New(os.Stdout, "WARN ", log.Ldate|log.Ltime|log.Lshortfile), + errorLog: log.New(os.Stderr, "ERROR ", log.Ldate|log.Ltime|log.Lshortfile), + } +} + +func (l *Logger) Info(format string, v ...interface{}) { + l.infoLog.Output(2, fmt.Sprintf(format, v...)) +} + +func (l *Logger) Warn(format string, v ...interface{}) { + l.warnLog.Output(2, fmt.Sprintf(format, v...)) +} + +func (l *Logger) Error(format string, v ...interface{}) { + l.errorLog.Output(2, fmt.Sprintf(format, v...)) +} diff --git a/manager/main.go b/manager/main.go new file mode 100644 index 0000000..232d205 --- /dev/null +++ b/manager/main.go @@ -0,0 +1,424 @@ +package main + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/json" + "encoding/pem" + "flag" + "fmt" + "io" + "log" + "math/big" + "net" + "net/http" + "os" + "strings" + "sync" + "time" + + "github.com/pquerna/otp/totp" + "golang.org/x/crypto/acme/autocert" +) + +var ( + store *UserStore + sessions = struct { + sync.RWMutex + m map[string]*Session + }{m: make(map[string]*Session)} + logger *Logger +) + +type Session struct { + UserID string + ExpiresAt time.Time +} + +func main() { + // --- FLAGS --- + addUser := flag.String("add-user", "", "Add a new user (provide user ID)") + port := flag.String("port", os.Getenv("MANAGER_PORT"), "Port to run the server on (use 443 for Let's Encrypt)") + domain := flag.String("domain", os.Getenv("DYFI_DOMAIN"), "Your dy.fi domain (e.g. example.dy.fi)") + dyfiUser := flag.String("dyfi-user", os.Getenv("DYFI_USER"), "dy.fi username (email)") + dyfiPass := flag.String("dyfi-pass", os.Getenv("DYFI_PASS"), "dy.fi password") + email := flag.String("email", os.Getenv("ACME_EMAIL"), "Email for Let's Encrypt notifications") + + flag.Parse() + + logger = NewLogger() + + // --- ENCRYPTION INITIALIZATION --- + serverKey := os.Getenv("SERVER_KEY") + if serverKey == "" { + logger.Warn("SERVER_KEY not set, generating new key") + var err error + serverKey, err = GenerateServerKey() + if err != nil { + logger.Error("Failed to generate server key: %v", err) + log.Fatal(err) + } + fmt.Printf("\nāš ļø IMPORTANT: Save this SERVER_KEY to your environment:\n") + fmt.Printf("export SERVER_KEY=%s\n\n", serverKey) + } + + crypto, err := NewCrypto(serverKey) + if err != nil { + logger.Error("Failed to initialize crypto: %v", err) + log.Fatal(err) + } + + store = NewUserStore("users_data", crypto) + + // --- BACKGROUND TASKS --- + // Reload user store from disk periodically + go func() { + ticker := time.NewTicker(1 * time.Minute) + for range ticker.C { + if err := store.Reload(); err != nil { + logger.Error("Failed to reload user store: %v", err) + } + } + }() + + // Cleanup expired sessions + go func() { + ticker := time.NewTicker(5 * time.Minute) + for range ticker.C { + cleanupSessions() + } + }() + + // dy.fi Dynamic DNS Updater + if *domain != "" && *dyfiUser != "" { + startDyfiUpdater(*domain, *dyfiUser, *dyfiPass) + } + + // --- CLI COMMANDS --- + if *addUser != "" { + handleNewUser(*addUser) + return + } + + // --- TEMPLATE LOADING --- + tmpl, err := LoadTemplate() + if err != nil { + log.Fatal(err) + } + appTmpl, err := LoadAppTemplate() + if err != nil { + log.Fatal(err) + } + + // --- ROUTES --- + // Routes must be defined BEFORE the server starts + + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if session := getValidSession(r, crypto); session != nil { + http.Redirect(w, r, "/app", http.StatusSeeOther) + return + } + tmpl.Execute(w, map[string]interface{}{"Step2": false}) + }) + + http.HandleFunc("/app", func(w http.ResponseWriter, r *http.Request) { + session := getValidSession(r, crypto) + if session == nil { + http.Redirect(w, r, "/", http.StatusSeeOther) + return + } + appTmpl.Execute(w, map[string]interface{}{"UserID": session.UserID}) + }) + + http.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("auth_session") + if err == nil { + sessions.Lock() + delete(sessions.m, cookie.Value) + sessions.Unlock() + } + http.SetCookie(w, &http.Cookie{ + Name: "auth_session", + Value: "", + Path: "/", + MaxAge: -1, + }) + http.Redirect(w, r, "/", http.StatusSeeOther) + }) + + http.HandleFunc("/api/request", func(w http.ResponseWriter, r *http.Request) { + session := getValidSession(r, crypto) + if session == nil { + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(map[string]string{"error": "Unauthorized"}) + return + } + + var req struct { + Method string `json:"method"` + URL string `json:"url"` + Headers map[string]string `json:"headers"` + Body string `json:"body"` + } + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + result := makeHTTPRequest(req.Method, req.URL, req.Headers, req.Body) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(result) + }) + + http.HandleFunc("/verify-user", func(w http.ResponseWriter, r *http.Request) { + userID := strings.TrimSpace(r.FormValue("userid")) + user, err := store.GetUser(userID) + if err != nil || user == nil { + tmpl.Execute(w, map[string]interface{}{"Step2": false, "Error": "User not found"}) + return + } + + sessionID := fmt.Sprintf("%d", time.Now().UnixNano()) + sessions.Lock() + sessions.m[sessionID] = &Session{ + UserID: userID, + ExpiresAt: time.Now().Add(5 * time.Minute), + } + sessions.Unlock() + + http.SetCookie(w, &http.Cookie{ + Name: "temp_session", + Value: sessionID, + Path: "/", + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, + }) + tmpl.Execute(w, map[string]interface{}{"Step2": true}) + }) + + http.HandleFunc("/verify-totp", func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("temp_session") + if err != nil { + http.Redirect(w, r, "/", http.StatusSeeOther) + return + } + + sessions.RLock() + session, ok := sessions.m[cookie.Value] + sessions.RUnlock() + + if !ok || time.Now().After(session.ExpiresAt) { + http.Redirect(w, r, "/", http.StatusSeeOther) + return + } + + user, _ := store.GetUser(session.UserID) + totpCode := strings.TrimSpace(r.FormValue("totp")) + + if !totp.Validate(totpCode, user.TOTPSecret) { + tmpl.Execute(w, map[string]interface{}{"Step2": true, "Error": "Invalid TOTP code"}) + return + } + + sessions.Lock() + delete(sessions.m, cookie.Value) + + authSessionID := fmt.Sprintf("%d", time.Now().UnixNano()) + sessions.m[authSessionID] = &Session{ + UserID: session.UserID, + ExpiresAt: time.Now().Add(1 * time.Hour), + } + sessions.Unlock() + + encryptedSession, _ := crypto.EncryptWithServerKey([]byte(authSessionID)) + + http.SetCookie(w, &http.Cookie{ + Name: "auth_session", + Value: base64.StdEncoding.EncodeToString(encryptedSession), + Path: "/", + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, + MaxAge: 3600, + }) + + http.Redirect(w, r, "/app", http.StatusSeeOther) + }) + + // --- SERVER STARTUP --- + + if *domain != "" { + // Let's Encrypt / ACME Setup + certManager := autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(*domain), + Cache: autocert.DirCache("certs_cache"), + Email: *email, + } + + // Let's Encrypt requires port 80 for the HTTP-01 challenge + go func() { + logger.Info("Starting HTTP-01 challenge listener on port 80") + // This handler automatically redirects HTTP to HTTPS while solving challenges + log.Fatal(http.ListenAndServe(":80", certManager.HTTPHandler(nil))) + }() + + server := &http.Server{ + Addr: ":" + *port, + TLSConfig: certManager.TLSConfig(), + } + + logger.Info("Secure Server starting with Let's Encrypt on https://%s", *domain) + log.Fatal(server.ListenAndServeTLS("", "")) // Certs provided by autocert + } else { + // Fallback to Self-Signed Certs + certFile, keyFile, err := setupCerts() + if err != nil { + logger.Error("TLS Setup Error: %v", err) + log.Fatal(err) + } + + server := &http.Server{ + Addr: ":" + *port, + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + } + + logger.Info("Secure Server starting with self-signed certs on https://localhost:%s", *port) + log.Fatal(server.ListenAndServeTLS(certFile, keyFile)) + } +} + +// setupCerts creates self-signed certs if they don't exist +func setupCerts() (string, string, error) { + certFile, keyFile := "cert.pem", "key.pem" + if _, err := os.Stat(certFile); err == nil { + return certFile, keyFile, nil + } + + logger.Info("Generating self-signed certificates for HTTPS...") + priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + serialNumber, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{Organization: []string{"TwoStepAuth Dev"}}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + derBytes, _ := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + + certOut, _ := os.Create(certFile) + pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + certOut.Close() + + keyOut, _ := os.OpenFile(keyFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + privBytes, _ := x509.MarshalECPrivateKey(priv) + pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes}) + keyOut.Close() + + return certFile, keyFile, nil +} + +func handleNewUser(userID string) { + secret, _ := generateSecret() + store.AddUser(userID, secret) + otpauthURL := fmt.Sprintf("otpauth://totp/TwoStepAuth:%s?secret=%s&issuer=TwoStepAuth", userID, secret) + fmt.Printf("\nUser: %s\nSecret: %s\n", userID, secret) + PrintQRCode(otpauthURL) +} + +func getValidSession(r *http.Request, crypto *Crypto) *Session { + cookie, err := r.Cookie("auth_session") + if err != nil { + return nil + } + enc, _ := base64.StdEncoding.DecodeString(cookie.Value) + sid, err := crypto.DecryptWithServerKey(enc) + if err != nil { + return nil + } + sessions.RLock() + session, ok := sessions.m[string(sid)] + sessions.RUnlock() + if !ok || time.Now().After(session.ExpiresAt) { + return nil + } + return session +} + +func cleanupSessions() { + sessions.Lock() + defer sessions.Unlock() + now := time.Now() + for id, s := range sessions.m { + if now.After(s.ExpiresAt) { + delete(sessions.m, id) + } + } +} + +func makeHTTPRequest(method, url string, headers map[string]string, body string) map[string]interface{} { + client := &http.Client{Timeout: 30 * time.Second} + + var reqBody io.Reader + if body != "" { + reqBody = strings.NewReader(body) + } + + req, err := http.NewRequest(method, url, reqBody) + if err != nil { + return map[string]interface{}{ + "error": err.Error(), + } + } + + for key, value := range headers { + req.Header.Set(key, value) + } + + start := time.Now() + resp, err := client.Do(req) + duration := time.Since(start).Milliseconds() + + if err != nil { + return map[string]interface{}{ + "error": err.Error(), + "duration": duration, + } + } + defer resp.Body.Close() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return map[string]interface{}{ + "error": err.Error(), + "duration": duration, + } + } + + respHeaders := make(map[string]string) + for key, values := range resp.Header { + respHeaders[key] = strings.Join(values, ", ") + } + + return map[string]interface{}{ + "status": resp.StatusCode, + "headers": respHeaders, + "body": string(bodyBytes), + "duration": duration, + } +} diff --git a/manager/store.go b/manager/store.go new file mode 100644 index 0000000..052cd0a --- /dev/null +++ b/manager/store.go @@ -0,0 +1,215 @@ +package main + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base32" + "encoding/hex" + "encoding/json" + "os" + "path/filepath" + "sync" +) + +type User struct { + ID string `json:"id"` + TOTPSecret string `json:"totp_secret"` +} + +type UserStore struct { + mu sync.RWMutex + filePath string + crypto *Crypto + cache map[string]*encryptedUserEntry +} + +type encryptedUserEntry struct { + UserIDHash string `json:"hash"` + Data []byte `json:"data"` +} + +type encryptedStore struct { + Users []encryptedUserEntry `json:"users"` +} + +func NewUserStore(dataDir string, crypto *Crypto) *UserStore { + // Create data directory if it doesn't exist + if err := os.MkdirAll(dataDir, 0700); err != nil { + logger.Error("Failed to create data directory: %v", err) + } + + filePath := filepath.Join(dataDir, "users.enc") + logger.Info("Initialized user store at: %s", filePath) + + store := &UserStore{ + filePath: filePath, + crypto: crypto, + cache: make(map[string]*encryptedUserEntry), + } + + store.loadCache() + return store +} + +func (s *UserStore) hashUserID(userID string) string { + hash := sha256.Sum256([]byte(userID)) + return hex.EncodeToString(hash[:]) +} + +func (s *UserStore) Reload() error { + s.mu.Lock() + defer s.mu.Unlock() + + // Clear existing cache + s.cache = make(map[string]*encryptedUserEntry) + + // Reload from disk + return s.loadCacheInternal() +} + +func (s *UserStore) loadCache() error { + s.mu.Lock() + defer s.mu.Unlock() + return s.loadCacheInternal() +} + +func (s *UserStore) loadCacheInternal() error { + // Read encrypted store file + encryptedData, err := os.ReadFile(s.filePath) + if err != nil { + if os.IsNotExist(err) { + logger.Info("No existing user store found, starting fresh") + return nil + } + logger.Error("Failed to read store file: %v", err) + return err + } + + // Decrypt with server key + decryptedData, err := s.crypto.DecryptWithServerKey(encryptedData) + if err != nil { + logger.Error("Failed to decrypt store: %v", err) + return err + } + + var store encryptedStore + if err := json.Unmarshal(decryptedData, &store); err != nil { + logger.Error("Failed to unmarshal store: %v", err) + return err + } + + // Load into cache + for i := range store.Users { + s.cache[store.Users[i].UserIDHash] = &store.Users[i] + } + + logger.Info("Loaded %d encrypted user entries into cache", len(s.cache)) + return nil +} + +func (s *UserStore) save() error { + // Build store structure from cache + store := encryptedStore{ + Users: make([]encryptedUserEntry, 0, len(s.cache)), + } + + for _, entry := range s.cache { + store.Users = append(store.Users, *entry) + } + + // Marshal to JSON + storeData, err := json.Marshal(store) + if err != nil { + return err + } + + // Encrypt with server key + encryptedData, err := s.crypto.EncryptWithServerKey(storeData) + if err != nil { + logger.Error("Failed to encrypt store: %v", err) + return err + } + + // Write to file + logger.Info("Saving user store with %d entries", len(s.cache)) + if err := os.WriteFile(s.filePath, encryptedData, 0600); err != nil { + logger.Error("Failed to write store file: %v", err) + return err + } + + return nil +} + +func (s *UserStore) GetUser(userID string) (*User, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + userHash := s.hashUserID(userID) + entry, exists := s.cache[userHash] + if !exists { + logger.Warn("User not found in cache") + return nil, nil + } + + // Decrypt with user key (derived from user ID) + userData, err := s.crypto.DecryptWithUserKey(entry.Data, userID) + if err != nil { + logger.Error("Failed to decrypt user data: %v", err) + return nil, err + } + + var user User + if err := json.Unmarshal(userData, &user); err != nil { + logger.Error("Failed to unmarshal user data: %v", err) + return nil, err + } + + logger.Info("Successfully loaded user: %s", user.ID) + return &user, nil +} + +func (s *UserStore) AddUser(userID, totpSecret string) error { + s.mu.Lock() + defer s.mu.Unlock() + + user := &User{ + ID: userID, + TOTPSecret: totpSecret, + } + + // Marshal user data + userData, err := json.Marshal(user) + if err != nil { + return err + } + + // Encrypt with user key (derived from user ID) + userEncrypted, err := s.crypto.EncryptWithUserKey(userData, userID) + if err != nil { + logger.Error("Failed to encrypt with user key: %v", err) + return err + } + + // Add to cache + userHash := s.hashUserID(userID) + s.cache[userHash] = &encryptedUserEntry{ + UserIDHash: userHash, + Data: userEncrypted, + } + + // Save entire store (encrypted with server key) + if err := s.save(); err != nil { + return err + } + + logger.Info("Successfully saved user: %s", userID) + return nil +} + +func generateSecret() (string, error) { + b := make([]byte, 20) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base32.StdEncoding.EncodeToString(b), nil +} diff --git a/manager/template.go b/manager/template.go new file mode 100644 index 0000000..b554d44 --- /dev/null +++ b/manager/template.go @@ -0,0 +1,459 @@ +package main + +import ( + "html/template" + "os" +) + +func LoadTemplate() (*template.Template, error) { + // Try to load from file first + if _, err := os.Stat("template.html"); err == nil { + logger.Info("Loading template from template.html") + return template.ParseFiles("template.html") + } + + // Fall back to embedded template + logger.Info("Using embedded template") + return template.New("page").Parse(embeddedTemplate) +} + +func LoadAppTemplate() (*template.Template, error) { + // Try to load from file first + if _, err := os.Stat("app.html"); err == nil { + logger.Info("Loading app template from app.html") + return template.ParseFiles("app.html") + } + + // Fall back to embedded template + logger.Info("Using embedded app template") + return template.New("app").Parse(embeddedAppTemplate) +} + +const embeddedTemplate = ` + + + Two-Step Authentication + + + +
+ {{if .Step2}} +

Enter TOTP Code

+
+
+ + +
+
+ {{else}} +

Enter User ID

+
+
+ + +
+
+ {{end}} + {{if .Error}}
{{.Error}}
{{end}} +
+ +` + +const embeddedAppTemplate = ` + + + REST API Client + + + +
+
+

REST API Client

+ +
+ +
+
+
+ + +
+
+ + +
+
+
+ + +
+
+ + +
+ +
+ +
+
+ + +
+ +
+
Response Headers
+
+

+                
+
+
+
Response Body
+

+            
+
+
+ + + +` diff --git a/output_service/README.md b/output_service/README.md new file mode 100644 index 0000000..b10296c --- /dev/null +++ b/output_service/README.md @@ -0,0 +1,7 @@ +# output service + +Service to receive output from ping_service instances. +Builds database of mappable nodes. +Updates input services address lists with all working endpoints and working hops from the traces. + +Have reporting api endpoints for the manager to monitor the progress.