Working to add TXT records

This commit is contained in:
Kalzu Rekku 2025-05-03 11:38:00 +03:00
parent 26e9454ba9
commit 8ee0203de8

View File

@ -77,12 +77,13 @@ class ServiceInstance(BaseModel):
tags: List[str] = Field(
default_factory=list, description="Optional list of tags for filtering"
)
health: str = Field(
default="passing", description="Health status ('passing', 'failing', 'unknown')"
)
metadata: Dict[str, str] = Field(
default_factory=dict, description="Optional key-value metadata"
)
health: str = Field(
default="passing", description="Health status ('passing', 'failing', 'unknown')"
)
last_updated: float = Field(default_factory=time.time)
class TokenInfo(BaseModel):
@ -867,171 +868,357 @@ class MiniDiscoveryResolver(common.ResolverBase):
query_type: int,
timeout: Optional[Tuple[int]] = None,
) -> Deferred:
"""
Main lookup entry point. Decodes the name, checks the suffix,
and dispatches to specific handlers based on query_type.
"""
d = Deferred()
name_str_debug = repr(name)
name_str_debug = repr(name) # For logging errors
try:
# 1. Decode the incoming query name safely
try:
name_str = name.decode("utf-8").lower()
except UnicodeDecodeError as decode_err:
print(
f"DNS lookup error: Cannot decode query name {name_str_debug} as UTF-8: {decode_err}"
)
d.errback(Failure(decode_err))
except UnicodeDecodeError:
# Log? Respond with format error? For now, treat as non-existent.
print(f"DNS: Cannot decode query name {name_str_debug} as UTF-8.")
# This should result in NXDOMAIN if we don't callback anything
# d.callback(([], [], [])) # Alternatively, explicitly return empty
return d # Let Twisted handle NXDOMAIN? Best to callback empty.
# Returning empty is safer for our specific resolver.
d.callback(([], [], []))
return d
# 2. Check suffix
# 2. Check for the expected suffix
if not name_str.endswith(DNS_QUERY_SUFFIX):
# Not a query for our domain, return empty.
d.callback(([], [], []))
return d
# --- SRV Service Type Lookup Logic ---
is_srv_type_query = False
service_tag_to_find = None
if name_str.startswith("_") and (
"_tcp." in name_str or "_udp." in name_str
):
parts = name_str.split(".")
# Example: _ssh._tcp.laiska.local
if (
len(parts) >= 4
and parts[0].startswith("_")
and parts[1] in ["_tcp", "_udp"]
):
# Assume format _service._proto.domain.suffix...
service_tag_to_find = parts[0][1:] # Extract 'ssh' from '_ssh'
# Ensure the query type is actually SRV
if query_type == dns.SRV:
is_srv_type_query = True
# 3. Extract the base name (part before the suffix)
base_name = name_str[: -len(DNS_QUERY_SUFFIX)]
if not base_name: # Query was just the suffix itself
d.callback(([], [], []))
return d
# 4. Dispatch based on query type
handler = None
if query_type == dns.A:
handler = self._handle_a_query
elif query_type == dns.SRV:
handler = self._handle_srv_query
elif query_type == dns.TXT:
handler = self._handle_txt_query
# Add elif for AAAA if needed in the future
# elif query_type == dns.AAAA:
# handler = self._handle_aaaa_query # Implement this if needed
else:
# Requesting A/other records for SRV-style name doesn't make sense here
# Unsupported query type for our resolver
d.callback(([], [], []))
return d
# --- Instance Fetching ---
instances_to_process = []
if is_srv_type_query:
print(
f"DNS: Performing SRV type lookup for tag '{service_tag_to_find}'"
)
all_services = registry.get_all_services()
for service_name, service_instances in all_services.items():
for instance in service_instances:
# Check health AND tag presence
if (
instance.health == "passing"
and service_tag_to_find in instance.tags
):
instances_to_process.append(instance)
if not instances_to_process:
print(
f"DNS: No passing instances found with tag '{service_tag_to_find}'"
)
d.callback(([], [], [])) # No matching instances found
return d
else:
service_name = name_str[: -len(DNS_QUERY_SUFFIX)].split(".")[-1]
if not service_name:
d.callback(([], [], []))
return d
instances_to_process = registry.get_service(
service_name, only_passing=True
)
if not instances_to_process:
d.callback(([], [], []))
return d
# --- Build DNS Records ---
answers = []
authority = []
additional = []
# Check query_type against the records we can generate
if (
query_type == dns.A and not is_srv_type_query
): # Only generate A for direct name query
for instance in instances_to_process:
instance_addr_debug = instance.address
try:
ip_address_string = instance.address
answers.append(
dns.RRHeader(
name=name,
type=dns.A,
cls=cls,
ttl=DNS_DEFAULT_TTL,
# Pass the string, Twisted handles conversion
payload=dns.Record_A(
address=ip_address_string, ttl=DNS_DEFAULT_TTL
),
)
)
except (
Exception
) as record_e: # Catch potential errors during record creation itself
print(
f"Warning: Error creating A record for IP '{instance.address}' (Service ID {instance.id}): {record_e}. Skipping."
)
elif (
query_type == dns.SRV
): # Generate SRV for direct name query OR SRV type query
for instance in instances_to_process:
try:
ip_address_string = instance.address
# SRV target should still be bytes
target_name = f"{instance.id}{DNS_QUERY_SUFFIX}".encode("utf-8")
answers.append(
dns.RRHeader(
name=name,
type=dns.SRV,
cls=cls,
ttl=DNS_DEFAULT_TTL,
payload=dns.Record_SRV(
priority=0,
weight=10,
port=instance.port,
target=target_name,
ttl=DNS_DEFAULT_TTL,
),
)
)
additional.append(
dns.RRHeader(
name=target_name,
type=dns.A,
cls=cls,
ttl=DNS_DEFAULT_TTL,
payload=dns.Record_A(
address=ip_address_string, ttl=DNS_DEFAULT_TTL
),
)
)
except (
Exception
) as record_e: # Catch potential errors during record creation itself
print(
f"Warning: Error creating SRV/additional record for IP '{instance.address}' (Service ID {instance.id}): {record_e}. Skipping."
)
# If we successfully built records (or correctly skipped bad IPs)
# 5. Execute the handler and set callback/errback
# The handlers currently don't return Deferreds, so we call directly
answers, authority, additional = handler(name, base_name, cls)
d.callback((answers, authority, additional))
# --- Catch ANY OTHER unexpected error during the lookup process ---
except Exception as e:
# Log the errors
# Catch-all for unexpected errors during dispatch or handler execution
print(
f"!!! Unhandled exception during DNS lookup for query name {name_str_debug} !!!"
f"!!! Unhandled exception during DNS lookup for {name_str_debug} "
f"(Type: {query_type}) !!!"
)
print(f"Exception Type: {type(e).__name__}")
print(f"Exception Args: {e.args}")
print("--- Traceback ---")
traceback.print_exc()
print("--- End Traceback ---")
d.errback(Failure(e)) # Signal SERVFAIL
# Signal DNS server failure (SERVFAIL)
d.errback(Failure(e))
return d
# --- Helper Methods for Specific Record Types ---
def _parse_srv_query(self, base_name: str) -> Tuple[Optional[str], Optional[str]]:
"""
Parses SRV-style queries: _tag._proto.service or _tag._proto
Returns (tag, service_name) or (None, None) if not SRV-style.
"""
parts = base_name.split(".")
if (
len(parts) >= 2
and parts[0].startswith("_")
and parts[1] in ["_tcp", "_udp"]
):
tag = parts[0][1:] # Remove leading '_'
# Service name is the part *after* _tag._proto, if it exists
service_name = parts[2] if len(parts) > 2 else None
# We currently ignore the rest of the parts (like datacenter in consul)
return tag, service_name
return None, None # Not an SRV-style query name
def _get_instances_for_query(
self, base_name: str, is_srv_query: bool = False
) -> List[ServiceInstance]:
"""Fetches relevant, passing service instances based on the query name."""
instances = []
tag_filter = None
service_name_filter = None
if is_srv_query:
tag_filter, service_name_filter = self._parse_srv_query(base_name)
if tag_filter is None: # Not a valid _tag._proto... query
return [] # Return empty, SRV query handler expects specific format
else: # A or TXT query: service name is the last part
service_name_filter = base_name.split(".")[-1]
if service_name_filter:
# Query targets a specific service (with potential tag filter for SRV)
service_instances = self.registry.get_service(
service_name_filter, only_passing=True
)
if tag_filter: # SRV query with tag and service
instances = [
inst for inst in service_instances if tag_filter in inst.tags
]
else: # A/TXT query, or SRV query without tag (using service name)
instances = service_instances # Already filtered for passing
elif tag_filter: # SRV query for a tag across all services
all_services = self.registry.get_all_services()
for name, service_instances in all_services.items():
for inst in service_instances:
if inst.health == "passing" and tag_filter in inst.tags:
instances.append(inst)
else:
# This case shouldn't be reached if initial checks are correct
# (e.g., A/TXT query needs a service name part)
print(f"Warning: Could not determine filter for query '{base_name}'")
print(
f"DNS Lookup: base_name='{base_name}', is_srv={is_srv_query}, "
f"tag='{tag_filter}', service='{service_name_filter}'. Found {len(instances)} instances."
)
return instances
def _handle_a_query(
self, name: bytes, base_name: str, cls: int
) -> Tuple[List, List, List]:
"""Handles A record lookups."""
answers = []
instances = self._get_instances_for_query(base_name, is_srv_query=False)
for instance in instances:
try:
# Twisted's Record_A expects the IP address string
payload = dns.Record_A(address=instance.address, ttl=DNS_DEFAULT_TTL)
rr = dns.RRHeader(
name=name, # Respond with the original query name
type=dns.A,
cls=cls,
ttl=DNS_DEFAULT_TTL,
payload=payload,
)
answers.append(rr)
except Exception as e:
print(
f"Warning: Error creating A record for instance {instance.id} "
f"(IP: {instance.address}): {e}. Skipping."
)
return answers, [], [] # No authority or additional records for basic A
def _handle_srv_query(
self, name: bytes, base_name: str, cls: int
) -> Tuple[List, List, List]:
"""Handles SRV record lookups (service or tag based)."""
answers = []
additional = []
instances = self._get_instances_for_query(base_name, is_srv_query=True)
# If _get_instances_for_query returned empty because parsing failed,
# we might want to try interpreting the name differently, e.g.,
# as a direct SRV lookup for a service name like `service.domain.suffix`.
# For now, we strictly follow the _tag._proto logic defined above.
# If you want `srvlookup service.domain.suffix`, the logic in
# _get_instances_for_query needs adjustment or another branch here.
for instance in instances:
try:
# SRV target points to a name that resolves to the instance's A record.
# Conventionally: <instance_id>.<service_name>.<domain_suffix>.
# Let's use: <instance_id>.node.<domain_suffix> for simplicity,
# or maybe <instance_id>.<service_name>...
# Using just instance ID + suffix is simple and unique.
# Ensure the target ends with the suffix too!
target_name_str = f"{instance.id}{DNS_QUERY_SUFFIX}"
target_name_bytes = target_name_str.encode("utf-8")
srv_payload = dns.Record_SRV(
priority=0, # Lower is more preferred
weight=10, # Relative weight for same priority
port=instance.port,
target=target_name_bytes, # Must be bytes
ttl=DNS_DEFAULT_TTL, # TTL for the SRV record itself
)
srv_rr = dns.RRHeader(
name=name, # Respond with the original query name
type=dns.SRV,
cls=cls,
ttl=DNS_DEFAULT_TTL,
payload=srv_payload,
)
answers.append(srv_rr)
# Add corresponding A record for the target in the additional section
a_payload = dns.Record_A(address=instance.address, ttl=DNS_DEFAULT_TTL)
a_rr = dns.RRHeader(
name=target_name_bytes, # Name matches SRV target
type=dns.A,
cls=cls,
ttl=DNS_DEFAULT_TTL, # TTL for the additional A record
payload=a_payload,
)
additional.append(a_rr)
except Exception as e:
print(
f"Warning: Error creating SRV/A record for instance {instance.id} "
f"(Addr: {instance.address}:{instance.port}): {e}. Skipping."
)
return answers, [], additional
def _handle_txt_query(
self, name: bytes, base_name: str, cls: int
) -> Tuple[List, List, List]:
"""Handles TXT record lookups, returning service metadata."""
answers = []
instances = self._get_instances_for_query(base_name, is_srv_query=False)
for instance in instances:
# --- Initialize list for this instance ---
txt_data = []
instance_id_str = str(instance.id) # Use consistently
try:
print(f"DNS TXT: Processing instance {instance_id_str}") # Log start
# --- Process Tags ---
if isinstance(instance.tags, list):
for tag in instance.tags:
try:
# Ensure tag is string before encoding
txt_data.append(str(tag).encode("utf-8"))
except Exception as tag_enc_err:
print(
f"ERROR encoding tag '{repr(tag)}' (type: {type(tag)}) for instance {instance_id_str}: {tag_enc_err}"
)
else:
print(
f"WARNING: Instance {instance_id_str} tags are not a list: {type(instance.tags)}"
)
# --- Process Metadata ---
if isinstance(instance.metadata, dict):
for k, v in instance.metadata.items():
try:
# Ensure key/value are strings before formatting/encoding
key_str = str(k)
val_str = str(v)
txt_data.append(f"{key_str}={val_str}".encode("utf-8"))
except Exception as meta_enc_err:
print(
f"ERROR encoding metadata item '{repr(k)}':'{repr(v)}' (types: {type(k)}/{type(v)}) for instance {instance_id_str}: {meta_enc_err}"
)
else:
print(
f"WARNING: Instance {instance_id_str} metadata is not a dict: {type(instance.metadata)}"
)
# --- Process Instance ID ---
try:
txt_data.append(f"instance_id={instance_id_str}".encode("utf-8"))
except Exception as id_enc_err:
print(
f"ERROR encoding instance ID for {instance_id_str}: {id_enc_err}"
)
# --- **** THE CRITICAL DEBUGGING STEP **** ---
print(
f"DNS TXT DEBUG: Data for instance {instance_id_str} BEFORE Record_TXT:"
)
valid_types = True
if not isinstance(txt_data, list):
print(f" FATAL: txt_data is NOT a list! Type: {type(txt_data)}")
valid_types = False
else:
for i, item in enumerate(txt_data):
item_type = type(item)
print(f" Item {i}: Type={item_type}, Value={repr(item)}")
if item_type is not bytes:
print(f" ^^^^^ ERROR: Item {i} is NOT bytes!")
valid_types = False
# --- **** END DEBUGGING STEP **** ---
if not txt_data:
print(
f"DNS TXT: No valid TXT data generated for instance {instance_id_str}, skipping."
)
continue
# Only proceed if all items were bytes
if not valid_types:
print(
f"DNS TXT ERROR: txt_data for {instance_id_str} contained non-bytes elements. Skipping record creation."
)
continue # Skip this instance if data is bad
# --- Create Payload and RR Header ---
# This is where the error occurs if txt_data contains non-bytes
print(
f"DNS TXT: Attempting to create Record_TXT for instance {instance_id_str}..."
)
payload = dns.Record_TXT(txt_data, ttl=DNS_DEFAULT_TTL)
print(
f"DNS TXT: Record_TXT created successfully for {instance_id_str}."
)
rr = dns.RRHeader(
name=name,
type=dns.TXT,
cls=cls,
ttl=DNS_DEFAULT_TTL,
payload=payload,
)
answers.append(rr)
print(
f"DNS TXT: RRHeader created and added for instance {instance_id_str}."
)
# Catch errors specifically during the DNS object creation phase
except TypeError as te_dns:
print(
f"FATAL DNS TypeError creating TXT record for {instance_id_str}: {te_dns}"
)
print(
" This likely means the list passed to Record_TXT contained non-bytes elements."
)
traceback.print_exc() # Crucial to see where in Twisted it fails
except Exception as e_dns:
print(
f"ERROR creating TXT DNS objects for instance {instance_id_str}: {e_dns.__class__.__name__}: {e_dns}"
)
traceback.print_exc()
# Log the final result before returning
print(
f"DNS TXT: Finished processing query for '{base_name}'. Found {len(instances)} instances, generated {len(answers)} TXT records."
)
return answers, [], []
# --- Health Checker ---
def check_service_health(instance: ServiceInstance) -> str:
@ -1073,7 +1260,10 @@ def run_dns_server(db_path: str, hmac_key: bytes, port: int):
db_path, hmac_key
) # DNS process needs its own registry instance
resolver = MiniDiscoveryResolver(registry)
factory = server.DNSServerFactory(clients=[resolver])
factory = server.DNSServerFactory(
clients=[resolver],
# verbose=2 # Uncomment for very detailed Twisted DNS logging
)
protocol = dns.DNSDatagramProtocol(controller=factory)
# Listen on UDP and TCP