Improved input service. New Manager web app. Directory and small readme for output service.

This commit is contained in:
Kalzu Rekku
2026-01-06 14:27:26 +02:00
parent ec9fec5ce3
commit f7056082f6
11 changed files with 1695 additions and 293 deletions

View File

@@ -3,34 +3,41 @@ package main
import (
"bufio"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"math/rand"
"net"
"net/http"
"net/netip"
"os"
"os/signal"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
)
const (
repoDir = "cloud-provider-ip-addresses"
port = 8080
stateDir = "progress_state"
saveInterval = 30 * time.Second
repoDir = "cloud-provider-ip-addresses"
port = 8080
stateDir = "progress_state"
saveInterval = 30 * time.Second
cleanupInterval = 5 * time.Minute
generatorTTL = 24 * time.Hour
maxImportSize = 10 * 1024 * 1024 // 10MB
)
// GeneratorState represents the serializable state of a generator
type GeneratorState struct {
CurrentFile int `json:"current_file"`
CurrentCIDRs []string `json:"current_cidrs"`
ActiveGenStates []HostGenState `json:"active_gen_states"`
CIDRFiles []string `json:"cidr_files"`
RemainingCIDRs []string `json:"remaining_cidrs"`
CurrentGen *HostGenState `json:"current_gen,omitempty"`
TotalCIDRs int `json:"total_cidrs"`
}
type HostGenState struct {
@@ -41,67 +48,72 @@ type HostGenState struct {
// IPGenerator generates IPs from CIDR ranges lazily
type IPGenerator struct {
mu sync.Mutex
cidrFiles []string
currentFile int
currentCIDRs []string
activeGens []*hostGenerator
rng *rand.Rand
totalCIDRsCount int
consumer string
dirty bool
mu sync.Mutex
rng *rand.Rand
totalCIDRsCount int
remainingCIDRs []string
currentGen *hostGenerator
consumer string
dirty atomic.Bool
}
type hostGenerator struct {
cidr string
network *net.IPNet
current net.IP
prefix netip.Prefix
current netip.Addr
last netip.Addr
done bool
}
func addrToUint32(a netip.Addr) uint32 {
b := a.As4()
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
}
func uint32ToAddr(u uint32) netip.Addr {
return netip.AddrFrom4([4]byte{byte(u >> 24), byte(u >> 16), byte(u >> 8), byte(u)})
}
func newHostGenerator(cidr string) (*hostGenerator, error) {
_, network, err := net.ParseCIDR(cidr)
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return nil, err
}
// Only IPv4
if network.IP.To4() == nil {
return nil, fmt.Errorf("not IPv4")
prefix = prefix.Masked()
if !prefix.IsValid() || !prefix.Addr().Is4() {
return nil, fmt.Errorf("invalid IPv4 prefix")
}
// Check if multicast
if network.IP.IsMulticast() {
if prefix.Addr().IsMulticast() {
return nil, fmt.Errorf("multicast network")
}
ones, bits := network.Mask.Size()
hg := &hostGenerator{
cidr: cidr,
network: network,
current: make(net.IP, len(network.IP)),
}
copy(hg.current, network.IP)
ip := prefix.Addr()
maskLen := prefix.Bits()
// For /32, just use the single address
if ones == bits {
return hg, nil
var first, last netip.Addr
lastUint := addrToUint32(ip) | ((1 << (32 - uint(maskLen))) - 1)
last = uint32ToAddr(lastUint)
if maskLen == 32 {
first = ip
last = ip
} else if maskLen == 31 {
first = ip
// last already ip + 1
} else {
first = ip.Next()
last = last.Prev()
}
// For other networks, skip network address (start at .1)
hg.increment()
return hg, nil
}
func (hg *hostGenerator) increment() {
for i := len(hg.current) - 1; i >= 0; i-- {
hg.current[i]++
if hg.current[i] != 0 {
break
}
if !prefix.Contains(first) || !prefix.Contains(last) {
return nil, fmt.Errorf("invalid range")
}
return &hostGenerator{
prefix: prefix,
current: first,
last: last,
done: false,
}, nil
}
func (hg *hostGenerator) next() (string, bool) {
@@ -109,58 +121,31 @@ func (hg *hostGenerator) next() (string, bool) {
return "", false
}
ones, bits := hg.network.Mask.Size()
// Handle /32 specially
if ones == bits {
if !hg.current.Equal(hg.network.IP) {
hg.done = true
return "", false
}
ip := hg.current.String()
hg.done = true
return ip, true
}
// Check if we're still in the network
if !hg.network.Contains(hg.current) {
if !hg.prefix.Contains(hg.current) || addrToUint32(hg.current) > addrToUint32(hg.last) {
hg.done = true
return "", false
}
// Check if this is the broadcast address (last IP in range)
broadcast := make(net.IP, len(hg.network.IP))
copy(broadcast, hg.network.IP)
for i := range broadcast {
broadcast[i] |= ^hg.network.Mask[i]
}
if hg.current.Equal(broadcast) {
hg.done = true
return "", false
}
// Skip multicast addresses
if hg.current.IsMulticast() {
hg.increment()
hg.current = hg.current.Next()
return hg.next()
}
ip := hg.current.String()
hg.increment()
hg.current = hg.current.Next()
return ip, true
}
func (hg *hostGenerator) getState() HostGenState {
return HostGenState{
CIDR: hg.cidr,
CIDR: hg.prefix.String(),
Current: hg.current.String(),
Done: hg.done,
}
}
func newIPGenerator(consumer string) (*IPGenerator, error) {
func newIPGenerator(s *Server, consumer string) (*IPGenerator, error) {
gen := &IPGenerator{
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
consumer: consumer,
@@ -173,207 +158,127 @@ func newIPGenerator(consumer string) (*IPGenerator, error) {
}
// No saved state, initialize fresh
// Find all IP files
err := filepath.Walk(repoDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && strings.HasSuffix(path, ".txt") && strings.Contains(strings.ToLower(path), "ips") {
gen.cidrFiles = append(gen.cidrFiles, path)
}
return nil
gen.remainingCIDRs = append([]string{}, s.allCIDRs...)
gen.rng.Shuffle(len(gen.remainingCIDRs), func(i, j int) {
gen.remainingCIDRs[i], gen.remainingCIDRs[j] = gen.remainingCIDRs[j], gen.remainingCIDRs[i]
})
gen.totalCIDRsCount = len(gen.remainingCIDRs)
gen.dirty.Store(true)
if err != nil {
return nil, fmt.Errorf("failed to scan repo directory: %w", err)
}
if len(gen.cidrFiles) == 0 {
return nil, fmt.Errorf("no IP files found in %s", repoDir)
}
// Load first batch of CIDRs
if err := gen.loadNextFile(); err != nil {
return nil, err
}
log.Printf("🆕 New generator for %s: %d IP files, %d CIDRs", consumer, len(gen.cidrFiles), gen.totalCIDRsCount)
log.Printf("📁 Found %d IP files", len(gen.cidrFiles))
log.Printf("📊 Total CIDRs discovered: %d", gen.totalCIDRsCount)
log.Printf("🆕 New generator for %s: %d total CIDRs", consumer, gen.totalCIDRsCount)
return gen, nil
}
func (g *IPGenerator) loadNextFile() error {
if g.currentFile >= len(g.cidrFiles) {
// Wrap around and reshuffle
g.currentFile = 0
g.rng.Shuffle(len(g.cidrFiles), func(i, j int) {
g.cidrFiles[i], g.cidrFiles[j] = g.cidrFiles[j], g.cidrFiles[i]
})
}
filepath := g.cidrFiles[g.currentFile]
g.currentFile++
file, err := os.Open(filepath)
if err != nil {
return fmt.Errorf("failed to open %s: %w", filepath, err)
}
defer file.Close()
g.currentCIDRs = g.currentCIDRs[:0] // Clear but keep capacity
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
fields := strings.Fields(line)
for _, field := range fields {
if field != "" {
// Basic validation
if strings.Contains(field, "/") || net.ParseIP(field) != nil {
g.currentCIDRs = append(g.currentCIDRs, field)
g.totalCIDRsCount++
}
}
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("error reading %s: %w", filepath, err)
}
// Shuffle CIDRs from this file
g.rng.Shuffle(len(g.currentCIDRs), func(i, j int) {
g.currentCIDRs[i], g.currentCIDRs[j] = g.currentCIDRs[j], g.currentCIDRs[i]
})
// Initialize generators for this batch
g.activeGens = make([]*hostGenerator, 0, len(g.currentCIDRs))
for _, cidr := range g.currentCIDRs {
// Ensure it has CIDR notation
if !strings.Contains(cidr, "/") {
cidr = cidr + "/32"
}
gen, err := newHostGenerator(cidr)
if err != nil {
// Skip invalid CIDRs silently
continue
}
g.activeGens = append(g.activeGens, gen)
}
g.dirty = true
return nil
}
func (g *IPGenerator) Next() (string, error) {
g.mu.Lock()
defer g.mu.Unlock()
for {
// If no active generators, load next file
if len(g.activeGens) == 0 {
if err := g.loadNextFile(); err != nil {
return "", fmt.Errorf("failed to load next file: %w", err)
}
if len(g.activeGens) == 0 {
if g.currentGen == nil || g.currentGen.done {
if len(g.remainingCIDRs) == 0 {
return "", fmt.Errorf("no more IPs available")
}
cidr := g.remainingCIDRs[0]
g.remainingCIDRs = g.remainingCIDRs[1:]
if !strings.Contains(cidr, "/") {
cidr += "/32"
}
var err error
g.currentGen, err = newHostGenerator(cidr)
if err != nil {
g.dirty.Store(true)
continue
}
}
// Pick a random generator
idx := g.rng.Intn(len(g.activeGens))
gen := g.activeGens[idx]
ip, ok := gen.next()
ip, ok := g.currentGen.next()
if !ok {
// This generator is exhausted, remove it
g.activeGens = append(g.activeGens[:idx], g.activeGens[idx+1:]...)
g.dirty = true
g.currentGen = nil
g.dirty.Store(true)
continue
}
g.dirty = true
g.dirty.Store(true)
return ip, nil
}
}
func (g *IPGenerator) buildState() GeneratorState {
// Assumes mu is held
state := GeneratorState{
RemainingCIDRs: append([]string{}, g.remainingCIDRs...),
TotalCIDRs: g.totalCIDRsCount,
}
if g.currentGen != nil && !g.currentGen.done {
state.CurrentGen = &HostGenState{
CIDR: g.currentGen.prefix.String(),
Current: g.currentGen.current.String(),
Done: false,
}
}
return state
}
func (g *IPGenerator) getState() GeneratorState {
g.mu.Lock()
defer g.mu.Unlock()
activeStates := make([]HostGenState, len(g.activeGens))
for i, gen := range g.activeGens {
activeStates[i] = gen.getState()
}
return GeneratorState{
CurrentFile: g.currentFile,
CurrentCIDRs: g.currentCIDRs,
ActiveGenStates: activeStates,
CIDRFiles: g.cidrFiles,
}
return g.buildState()
}
func (g *IPGenerator) saveState() error {
if !g.dirty {
g.mu.Lock()
if !g.dirty.Load() {
g.mu.Unlock()
return nil
}
state := g.getState()
state := g.buildState()
g.dirty.Store(false)
g.mu.Unlock()
// Ensure state directory exists
if err := os.MkdirAll(stateDir, 0755); err != nil {
return fmt.Errorf("failed to create state directory: %w", err)
}
// Use consumer as filename (sanitize for filesystem)
filename := strings.ReplaceAll(g.consumer, ":", "_")
filename = strings.ReplaceAll(filename, "/", "_")
filepath := filepath.Join(stateDir, filename+".json")
// Use hash of consumer as filename
hash := sha256.Sum256([]byte(g.consumer))
filename := hex.EncodeToString(hash[:])
filePath := filepath.Join(stateDir, filename+".json")
// Write to temp file first, then rename for atomic write
tempPath := filepath + ".tmp"
// Write to temp file first, then rename
tempPath := filePath + ".tmp"
file, err := os.Create(tempPath)
if err != nil {
return fmt.Errorf("failed to create temp state file: %w", err)
}
defer file.Close()
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
if err := encoder.Encode(state); err != nil {
file.Close()
os.Remove(tempPath)
return fmt.Errorf("failed to encode state: %w", err)
}
if err := file.Close(); err != nil {
os.Remove(tempPath)
return fmt.Errorf("failed to close temp state file: %w", err)
}
if err := os.Rename(tempPath, filepath); err != nil {
if err := os.Rename(tempPath, filePath); err != nil {
os.Remove(tempPath)
return fmt.Errorf("failed to rename state file: %w", err)
}
g.dirty = false
return nil
}
func (g *IPGenerator) loadState() error {
// Use consumer as filename (sanitize for filesystem)
filename := strings.ReplaceAll(g.consumer, ":", "_")
filename = strings.ReplaceAll(filename, "/", "_")
filepath := filepath.Join(stateDir, filename+".json")
// Use hash of consumer as filename
hash := sha256.Sum256([]byte(g.consumer))
filename := hex.EncodeToString(hash[:])
filePath := filepath.Join(stateDir, filename+".json")
file, err := os.Open(filepath)
file, err := os.Open(filePath)
if err != nil {
return err
}
@@ -385,27 +290,20 @@ func (g *IPGenerator) loadState() error {
}
// Restore state
g.cidrFiles = state.CIDRFiles
g.currentFile = state.CurrentFile
g.currentCIDRs = state.CurrentCIDRs
g.totalCIDRsCount = len(state.CurrentCIDRs)
g.remainingCIDRs = state.RemainingCIDRs
g.totalCIDRsCount = state.TotalCIDRs
// Rebuild active generators from state
g.activeGens = make([]*hostGenerator, 0, len(state.ActiveGenStates))
for _, genState := range state.ActiveGenStates {
gen, err := newHostGenerator(genState.CIDR)
if state.CurrentGen != nil {
gen, err := newHostGenerator(state.CurrentGen.CIDR)
if err != nil {
continue
return err
}
// Restore current IP position
gen.current = net.ParseIP(genState.Current)
if gen.current == nil {
continue
gen.current, err = netip.ParseAddr(state.CurrentGen.Current)
if err != nil {
return err
}
gen.done = genState.Done
g.activeGens = append(g.activeGens, gen)
gen.done = state.CurrentGen.Done
g.currentGen = gen
}
return nil
@@ -414,21 +312,86 @@ func (g *IPGenerator) loadState() error {
// Server holds per-consumer generators
type Server struct {
generators map[string]*IPGenerator
lastAccess map[string]time.Time
allCIDRs []string
mu sync.RWMutex
stopSaver chan struct{}
stopCleanup chan struct{}
wg sync.WaitGroup
}
func newServer() *Server {
s := &Server{
generators: make(map[string]*IPGenerator),
stopSaver: make(chan struct{}),
generators: make(map[string]*IPGenerator),
lastAccess: make(map[string]time.Time),
stopSaver: make(chan struct{}),
stopCleanup: make(chan struct{}),
}
if err := s.loadAllCIDRs(); err != nil {
log.Fatalf("❌ Failed to load CIDRs: %v", err)
}
s.startPeriodicSaver()
s.startCleanup()
return s
}
func (s *Server) loadAllCIDRs() error {
// Find all IP files
var fileList []string
err := filepath.Walk(repoDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && strings.HasSuffix(path, ".txt") && strings.Contains(strings.ToLower(path), "ips") {
fileList = append(fileList, path)
}
return nil
})
if err != nil {
return fmt.Errorf("failed to scan repo directory: %w", err)
}
if len(fileList) == 0 {
return fmt.Errorf("no IP files found in %s", repoDir)
}
// Load all CIDRs
for _, path := range fileList {
file, err := os.Open(path)
if err != nil {
log.Printf("⚠️ Failed to open %s: %v", path, err)
continue
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
fields := strings.Fields(line)
for _, field := range fields {
if field != "" {
if strings.Contains(field, "/") || netip.MustParseAddr(field).IsValid() {
s.allCIDRs = append(s.allCIDRs, field)
}
}
}
}
file.Close()
if err := scanner.Err(); err != nil {
log.Printf("⚠️ Error reading %s: %v", path, err)
}
}
log.Printf("📁 Found %d IP files", len(fileList))
log.Printf("📊 Total CIDRs discovered: %d", len(s.allCIDRs))
return nil
}
func (s *Server) startPeriodicSaver() {
s.wg.Add(1)
go func() {
@@ -441,7 +404,6 @@ func (s *Server) startPeriodicSaver() {
case <-ticker.C:
s.saveAllStates()
case <-s.stopSaver:
// Final save before shutdown
s.saveAllStates()
return
}
@@ -449,23 +411,63 @@ func (s *Server) startPeriodicSaver() {
}()
}
func (s *Server) startCleanup() {
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.cleanupOldGenerators()
case <-s.stopCleanup:
return
}
}
}()
}
func (s *Server) cleanupOldGenerators() {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for consumer, t := range s.lastAccess {
if now.Sub(t) > generatorTTL {
delete(s.generators, consumer)
delete(s.lastAccess, consumer)
// Remove state file
hash := sha256.Sum256([]byte(consumer))
fn := hex.EncodeToString(hash[:]) + ".json"
p := filepath.Join(stateDir, fn)
if err := os.Remove(p); err != nil && !os.IsNotExist(err) {
log.Printf("⚠️ Failed to remove state file %s: %v", p, err)
}
}
}
}
func (s *Server) saveAllStates() {
s.mu.RLock()
generators := make([]*IPGenerator, 0, len(s.generators))
gens := make([]*IPGenerator, 0, len(s.generators))
for _, gen := range s.generators {
generators = append(generators, gen)
gens = append(gens, gen)
}
s.mu.RUnlock()
for _, gen := range generators {
for _, gen := range gens {
if err := gen.saveState(); err != nil {
log.Printf("⚠️ Failed to save state for %s: %v", gen.consumer, err)
log.Printf("⚠️ Failed to save state for %s: %v", gen.consumer, err)
}
}
}
func (s *Server) shutdown() {
close(s.stopSaver)
close(s.stopCleanup)
s.wg.Wait()
log.Println("💾 All states saved")
}
@@ -476,26 +478,30 @@ func (s *Server) getGenerator(consumer string) (*IPGenerator, error) {
s.mu.RUnlock()
if exists {
s.mu.Lock()
s.lastAccess[consumer] = time.Now()
s.mu.Unlock()
return gen, nil
}
// Create new generator for this consumer
s.mu.Lock()
defer s.mu.Unlock()
// Double-check after acquiring write lock
// Double-check
if gen, exists := s.generators[consumer]; exists {
s.lastAccess[consumer] = time.Now()
return gen, nil
}
newGen, err := newIPGenerator(consumer)
newGen, err := newIPGenerator(s, consumer)
if err != nil {
return nil, err
}
s.generators[consumer] = newGen
s.lastAccess[consumer] = time.Now()
log.Printf("🆕 New consumer: %s", consumer)
return newGen, nil
}
@@ -505,38 +511,42 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
return
}
consumer := r.RemoteAddr
if host, _, err := net.SplitHostPort(consumer); err == nil {
consumer = host
addrPort, err := netip.ParseAddrPort(r.RemoteAddr)
consumerStr := r.RemoteAddr
if err == nil {
consumerStr = addrPort.Addr().String()
} else {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil {
consumerStr = host
}
}
gen, err := s.getGenerator(consumer)
gen, err := s.getGenerator(consumerStr)
if err != nil {
log.Printf("❌ Failed to get generator for %s: %v", consumer, err)
log.Printf("❌ Failed to get generator for %s: %v", consumerStr, err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
ip, err := gen.Next()
if err != nil {
log.Printf("❌ Failed to get IP for %s: %v", consumer, err)
log.Printf("❌ Failed to get IP for %s: %v", consumerStr, err)
http.Error(w, "No more IPs available", http.StatusServiceUnavailable)
return
}
log.Printf("📤 Serving IP to %s: %s", consumer, ip)
log.Printf("📤 Serving IP to %s: %s", consumerStr, ip)
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintf(w, "%s\n", ip)
}
type ConsumerStatus struct {
Consumer string `json:"consumer"`
CurrentFile int `json:"current_file"`
TotalFiles int `json:"total_files"`
ActiveCIDRs int `json:"active_cidrs"`
TotalCIDRs int `json:"total_cidrs_discovered"`
CurrentFilePath string `json:"current_file_path,omitempty"`
Consumer string `json:"consumer"`
RemainingCIDRs int `json:"remaining_cidrs"`
HasActiveGen bool `json:"has_active_gen"`
TotalCIDRs int `json:"total_cidrs"`
}
type StatusResponse struct {
@@ -565,14 +575,10 @@ func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
for consumer, gen := range s.generators {
gen.mu.Lock()
status := ConsumerStatus{
Consumer: consumer,
CurrentFile: gen.currentFile,
TotalFiles: len(gen.cidrFiles),
ActiveCIDRs: len(gen.activeGens),
TotalCIDRs: gen.totalCIDRsCount,
}
if gen.currentFile > 0 && gen.currentFile <= len(gen.cidrFiles) {
status.CurrentFilePath = gen.cidrFiles[gen.currentFile-1]
Consumer: consumer,
RemainingCIDRs: len(gen.remainingCIDRs),
HasActiveGen: gen.currentGen != nil,
TotalCIDRs: gen.totalCIDRsCount,
}
gen.mu.Unlock()
@@ -588,8 +594,8 @@ func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
}
type ExportResponse struct {
ExportedAt time.Time `json:"exported_at"`
States map[string]GeneratorState `json:"states"`
ExportedAt time.Time `json:"exported_at"`
States map[string]GeneratorState `json:"states"`
}
func (s *Server) handleExport(w http.ResponseWriter, r *http.Request) {
@@ -606,7 +612,7 @@ func (s *Server) handleExport(w http.ResponseWriter, r *http.Request) {
response := ExportResponse{
ExportedAt: time.Now(),
States: make(map[string]GeneratorState),
States: make(map[string]GeneratorState, len(s.generators)),
}
for consumer, gen := range s.generators {
@@ -614,9 +620,9 @@ func (s *Server) handleExport(w http.ResponseWriter, r *http.Request) {
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=state-export-%s.json",
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=state-export-%s.json",
time.Now().Format("2006-01-02-150405")))
encoder := json.NewEncoder(w)
encoder.SetIndent("", " ")
if err := encoder.Encode(response); err != nil {
@@ -630,9 +636,16 @@ func (s *Server) handleImport(w http.ResponseWriter, r *http.Request) {
return
}
reader := http.MaxBytesReader(w, r.Body, maxImportSize)
defer r.Body.Close()
var exportData ExportResponse
if err := json.NewDecoder(r.Body).Decode(&exportData); err != nil {
http.Error(w, fmt.Sprintf("Failed to decode import data: %v", err), http.StatusBadRequest)
if err := json.NewDecoder(reader).Decode(&exportData); err != nil {
if err == io.EOF || strings.Contains(err.Error(), "EOF") {
http.Error(w, "Invalid or empty request body", http.StatusBadRequest)
} else {
http.Error(w, fmt.Sprintf("Failed to decode import data: %v", err), http.StatusBadRequest)
}
return
}
@@ -646,14 +659,14 @@ func (s *Server) handleImport(w http.ResponseWriter, r *http.Request) {
failed := 0
for consumer, state := range exportData.States {
// Sanitize consumer name for filename
filename := strings.ReplaceAll(consumer, ":", "_")
filename = strings.ReplaceAll(filename, "/", "_")
filepath := filepath.Join(stateDir, filename+".json")
// Use hash for filename
hash := sha256.Sum256([]byte(consumer))
filename := hex.EncodeToString(hash[:])
filePath := filepath.Join(stateDir, filename+".json")
file, err := os.Create(filepath)
file, err := os.Create(filePath)
if err != nil {
log.Printf("⚠️ Failed to create state file for %s: %v", consumer, err)
log.Printf("⚠️ Failed to create state file for %s: %v", consumer, err)
failed++
continue
}
@@ -662,8 +675,9 @@ func (s *Server) handleImport(w http.ResponseWriter, r *http.Request) {
encoder.SetIndent("", " ")
if err := encoder.Encode(state); err != nil {
file.Close()
log.Printf("⚠️ Failed to encode state for %s: %v", consumer, err)
log.Printf("⚠️ Failed to encode state for %s: %v", consumer, err)
failed++
os.Remove(filePath)
continue
}
@@ -680,7 +694,7 @@ func (s *Server) handleImport(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
log.Printf("📥 Imported %d consumer states (%d failed)", imported, failed)
}
@@ -717,7 +731,7 @@ func main() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Stop periodic saver and save final state
// Stop savers and save final state
server.shutdown()
if err := httpServer.Shutdown(ctx); err != nil {
@@ -739,4 +753,4 @@ func main() {
}
log.Println("✅ Server stopped cleanly")
}
}