Working to add TXT records
This commit is contained in:
parent
26e9454ba9
commit
8ee0203de8
480
MiniDiscovery.py
480
MiniDiscovery.py
@ -77,12 +77,13 @@ class ServiceInstance(BaseModel):
|
||||
tags: List[str] = Field(
|
||||
default_factory=list, description="Optional list of tags for filtering"
|
||||
)
|
||||
health: str = Field(
|
||||
default="passing", description="Health status ('passing', 'failing', 'unknown')"
|
||||
)
|
||||
metadata: Dict[str, str] = Field(
|
||||
default_factory=dict, description="Optional key-value metadata"
|
||||
)
|
||||
health: str = Field(
|
||||
default="passing", description="Health status ('passing', 'failing', 'unknown')"
|
||||
)
|
||||
last_updated: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
class TokenInfo(BaseModel):
|
||||
@ -867,171 +868,357 @@ class MiniDiscoveryResolver(common.ResolverBase):
|
||||
query_type: int,
|
||||
timeout: Optional[Tuple[int]] = None,
|
||||
) -> Deferred:
|
||||
"""
|
||||
Main lookup entry point. Decodes the name, checks the suffix,
|
||||
and dispatches to specific handlers based on query_type.
|
||||
"""
|
||||
d = Deferred()
|
||||
name_str_debug = repr(name)
|
||||
name_str_debug = repr(name) # For logging errors
|
||||
|
||||
try:
|
||||
# 1. Decode the incoming query name safely
|
||||
try:
|
||||
name_str = name.decode("utf-8").lower()
|
||||
except UnicodeDecodeError as decode_err:
|
||||
print(
|
||||
f"DNS lookup error: Cannot decode query name {name_str_debug} as UTF-8: {decode_err}"
|
||||
)
|
||||
d.errback(Failure(decode_err))
|
||||
return d
|
||||
|
||||
# 2. Check suffix
|
||||
if not name_str.endswith(DNS_QUERY_SUFFIX):
|
||||
except UnicodeDecodeError:
|
||||
# Log? Respond with format error? For now, treat as non-existent.
|
||||
print(f"DNS: Cannot decode query name {name_str_debug} as UTF-8.")
|
||||
# This should result in NXDOMAIN if we don't callback anything
|
||||
# d.callback(([], [], [])) # Alternatively, explicitly return empty
|
||||
return d # Let Twisted handle NXDOMAIN? Best to callback empty.
|
||||
# Returning empty is safer for our specific resolver.
|
||||
d.callback(([], [], []))
|
||||
return d
|
||||
|
||||
# --- SRV Service Type Lookup Logic ---
|
||||
is_srv_type_query = False
|
||||
service_tag_to_find = None
|
||||
if name_str.startswith("_") and (
|
||||
"_tcp." in name_str or "_udp." in name_str
|
||||
):
|
||||
parts = name_str.split(".")
|
||||
# Example: _ssh._tcp.laiska.local
|
||||
if (
|
||||
len(parts) >= 4
|
||||
and parts[0].startswith("_")
|
||||
and parts[1] in ["_tcp", "_udp"]
|
||||
):
|
||||
# Assume format _service._proto.domain.suffix...
|
||||
service_tag_to_find = parts[0][1:] # Extract 'ssh' from '_ssh'
|
||||
# Ensure the query type is actually SRV
|
||||
if query_type == dns.SRV:
|
||||
is_srv_type_query = True
|
||||
else:
|
||||
# Requesting A/other records for SRV-style name doesn't make sense here
|
||||
d.callback(([], [], []))
|
||||
return d
|
||||
|
||||
# --- Instance Fetching ---
|
||||
instances_to_process = []
|
||||
if is_srv_type_query:
|
||||
print(
|
||||
f"DNS: Performing SRV type lookup for tag '{service_tag_to_find}'"
|
||||
)
|
||||
all_services = registry.get_all_services()
|
||||
for service_name, service_instances in all_services.items():
|
||||
for instance in service_instances:
|
||||
# Check health AND tag presence
|
||||
if (
|
||||
instance.health == "passing"
|
||||
and service_tag_to_find in instance.tags
|
||||
):
|
||||
instances_to_process.append(instance)
|
||||
if not instances_to_process:
|
||||
print(
|
||||
f"DNS: No passing instances found with tag '{service_tag_to_find}'"
|
||||
)
|
||||
d.callback(([], [], [])) # No matching instances found
|
||||
# 2. Check for the expected suffix
|
||||
if not name_str.endswith(DNS_QUERY_SUFFIX):
|
||||
# Not a query for our domain, return empty.
|
||||
d.callback(([], [], []))
|
||||
return d
|
||||
|
||||
# 3. Extract the base name (part before the suffix)
|
||||
base_name = name_str[: -len(DNS_QUERY_SUFFIX)]
|
||||
if not base_name: # Query was just the suffix itself
|
||||
d.callback(([], [], []))
|
||||
return d
|
||||
|
||||
# 4. Dispatch based on query type
|
||||
handler = None
|
||||
if query_type == dns.A:
|
||||
handler = self._handle_a_query
|
||||
elif query_type == dns.SRV:
|
||||
handler = self._handle_srv_query
|
||||
elif query_type == dns.TXT:
|
||||
handler = self._handle_txt_query
|
||||
# Add elif for AAAA if needed in the future
|
||||
# elif query_type == dns.AAAA:
|
||||
# handler = self._handle_aaaa_query # Implement this if needed
|
||||
else:
|
||||
service_name = name_str[: -len(DNS_QUERY_SUFFIX)].split(".")[-1]
|
||||
if not service_name:
|
||||
d.callback(([], [], []))
|
||||
return d
|
||||
instances_to_process = registry.get_service(
|
||||
service_name, only_passing=True
|
||||
)
|
||||
if not instances_to_process:
|
||||
d.callback(([], [], []))
|
||||
return d
|
||||
# Unsupported query type for our resolver
|
||||
d.callback(([], [], []))
|
||||
return d
|
||||
|
||||
# --- Build DNS Records ---
|
||||
answers = []
|
||||
authority = []
|
||||
additional = []
|
||||
|
||||
# Check query_type against the records we can generate
|
||||
if (
|
||||
query_type == dns.A and not is_srv_type_query
|
||||
): # Only generate A for direct name query
|
||||
for instance in instances_to_process:
|
||||
instance_addr_debug = instance.address
|
||||
try:
|
||||
ip_address_string = instance.address
|
||||
answers.append(
|
||||
dns.RRHeader(
|
||||
name=name,
|
||||
type=dns.A,
|
||||
cls=cls,
|
||||
ttl=DNS_DEFAULT_TTL,
|
||||
# Pass the string, Twisted handles conversion
|
||||
payload=dns.Record_A(
|
||||
address=ip_address_string, ttl=DNS_DEFAULT_TTL
|
||||
),
|
||||
)
|
||||
)
|
||||
except (
|
||||
Exception
|
||||
) as record_e: # Catch potential errors during record creation itself
|
||||
print(
|
||||
f"Warning: Error creating A record for IP '{instance.address}' (Service ID {instance.id}): {record_e}. Skipping."
|
||||
)
|
||||
|
||||
elif (
|
||||
query_type == dns.SRV
|
||||
): # Generate SRV for direct name query OR SRV type query
|
||||
for instance in instances_to_process:
|
||||
try:
|
||||
ip_address_string = instance.address
|
||||
# SRV target should still be bytes
|
||||
target_name = f"{instance.id}{DNS_QUERY_SUFFIX}".encode("utf-8")
|
||||
answers.append(
|
||||
dns.RRHeader(
|
||||
name=name,
|
||||
type=dns.SRV,
|
||||
cls=cls,
|
||||
ttl=DNS_DEFAULT_TTL,
|
||||
payload=dns.Record_SRV(
|
||||
priority=0,
|
||||
weight=10,
|
||||
port=instance.port,
|
||||
target=target_name,
|
||||
ttl=DNS_DEFAULT_TTL,
|
||||
),
|
||||
)
|
||||
)
|
||||
additional.append(
|
||||
dns.RRHeader(
|
||||
name=target_name,
|
||||
type=dns.A,
|
||||
cls=cls,
|
||||
ttl=DNS_DEFAULT_TTL,
|
||||
payload=dns.Record_A(
|
||||
address=ip_address_string, ttl=DNS_DEFAULT_TTL
|
||||
),
|
||||
)
|
||||
)
|
||||
except (
|
||||
Exception
|
||||
) as record_e: # Catch potential errors during record creation itself
|
||||
print(
|
||||
f"Warning: Error creating SRV/additional record for IP '{instance.address}' (Service ID {instance.id}): {record_e}. Skipping."
|
||||
)
|
||||
|
||||
# If we successfully built records (or correctly skipped bad IPs)
|
||||
# 5. Execute the handler and set callback/errback
|
||||
# The handlers currently don't return Deferreds, so we call directly
|
||||
answers, authority, additional = handler(name, base_name, cls)
|
||||
d.callback((answers, authority, additional))
|
||||
|
||||
# --- Catch ANY OTHER unexpected error during the lookup process ---
|
||||
except Exception as e:
|
||||
# Log the errors
|
||||
# Catch-all for unexpected errors during dispatch or handler execution
|
||||
print(
|
||||
f"!!! Unhandled exception during DNS lookup for query name {name_str_debug} !!!"
|
||||
f"!!! Unhandled exception during DNS lookup for {name_str_debug} "
|
||||
f"(Type: {query_type}) !!!"
|
||||
)
|
||||
print(f"Exception Type: {type(e).__name__}")
|
||||
print(f"Exception Args: {e.args}")
|
||||
print("--- Traceback ---")
|
||||
traceback.print_exc()
|
||||
print("--- End Traceback ---")
|
||||
d.errback(Failure(e)) # Signal SERVFAIL
|
||||
# Signal DNS server failure (SERVFAIL)
|
||||
d.errback(Failure(e))
|
||||
|
||||
return d
|
||||
|
||||
# --- Helper Methods for Specific Record Types ---
|
||||
|
||||
def _parse_srv_query(self, base_name: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Parses SRV-style queries: _tag._proto.service or _tag._proto
|
||||
Returns (tag, service_name) or (None, None) if not SRV-style.
|
||||
"""
|
||||
parts = base_name.split(".")
|
||||
if (
|
||||
len(parts) >= 2
|
||||
and parts[0].startswith("_")
|
||||
and parts[1] in ["_tcp", "_udp"]
|
||||
):
|
||||
tag = parts[0][1:] # Remove leading '_'
|
||||
# Service name is the part *after* _tag._proto, if it exists
|
||||
service_name = parts[2] if len(parts) > 2 else None
|
||||
# We currently ignore the rest of the parts (like datacenter in consul)
|
||||
return tag, service_name
|
||||
return None, None # Not an SRV-style query name
|
||||
|
||||
def _get_instances_for_query(
|
||||
self, base_name: str, is_srv_query: bool = False
|
||||
) -> List[ServiceInstance]:
|
||||
"""Fetches relevant, passing service instances based on the query name."""
|
||||
instances = []
|
||||
tag_filter = None
|
||||
service_name_filter = None
|
||||
|
||||
if is_srv_query:
|
||||
tag_filter, service_name_filter = self._parse_srv_query(base_name)
|
||||
if tag_filter is None: # Not a valid _tag._proto... query
|
||||
return [] # Return empty, SRV query handler expects specific format
|
||||
|
||||
else: # A or TXT query: service name is the last part
|
||||
service_name_filter = base_name.split(".")[-1]
|
||||
|
||||
if service_name_filter:
|
||||
# Query targets a specific service (with potential tag filter for SRV)
|
||||
service_instances = self.registry.get_service(
|
||||
service_name_filter, only_passing=True
|
||||
)
|
||||
if tag_filter: # SRV query with tag and service
|
||||
instances = [
|
||||
inst for inst in service_instances if tag_filter in inst.tags
|
||||
]
|
||||
else: # A/TXT query, or SRV query without tag (using service name)
|
||||
instances = service_instances # Already filtered for passing
|
||||
|
||||
elif tag_filter: # SRV query for a tag across all services
|
||||
all_services = self.registry.get_all_services()
|
||||
for name, service_instances in all_services.items():
|
||||
for inst in service_instances:
|
||||
if inst.health == "passing" and tag_filter in inst.tags:
|
||||
instances.append(inst)
|
||||
else:
|
||||
# This case shouldn't be reached if initial checks are correct
|
||||
# (e.g., A/TXT query needs a service name part)
|
||||
print(f"Warning: Could not determine filter for query '{base_name}'")
|
||||
|
||||
print(
|
||||
f"DNS Lookup: base_name='{base_name}', is_srv={is_srv_query}, "
|
||||
f"tag='{tag_filter}', service='{service_name_filter}'. Found {len(instances)} instances."
|
||||
)
|
||||
return instances
|
||||
|
||||
def _handle_a_query(
|
||||
self, name: bytes, base_name: str, cls: int
|
||||
) -> Tuple[List, List, List]:
|
||||
"""Handles A record lookups."""
|
||||
answers = []
|
||||
instances = self._get_instances_for_query(base_name, is_srv_query=False)
|
||||
|
||||
for instance in instances:
|
||||
try:
|
||||
# Twisted's Record_A expects the IP address string
|
||||
payload = dns.Record_A(address=instance.address, ttl=DNS_DEFAULT_TTL)
|
||||
rr = dns.RRHeader(
|
||||
name=name, # Respond with the original query name
|
||||
type=dns.A,
|
||||
cls=cls,
|
||||
ttl=DNS_DEFAULT_TTL,
|
||||
payload=payload,
|
||||
)
|
||||
answers.append(rr)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Warning: Error creating A record for instance {instance.id} "
|
||||
f"(IP: {instance.address}): {e}. Skipping."
|
||||
)
|
||||
|
||||
return answers, [], [] # No authority or additional records for basic A
|
||||
|
||||
def _handle_srv_query(
|
||||
self, name: bytes, base_name: str, cls: int
|
||||
) -> Tuple[List, List, List]:
|
||||
"""Handles SRV record lookups (service or tag based)."""
|
||||
answers = []
|
||||
additional = []
|
||||
instances = self._get_instances_for_query(base_name, is_srv_query=True)
|
||||
|
||||
# If _get_instances_for_query returned empty because parsing failed,
|
||||
# we might want to try interpreting the name differently, e.g.,
|
||||
# as a direct SRV lookup for a service name like `service.domain.suffix`.
|
||||
# For now, we strictly follow the _tag._proto logic defined above.
|
||||
# If you want `srvlookup service.domain.suffix`, the logic in
|
||||
# _get_instances_for_query needs adjustment or another branch here.
|
||||
|
||||
for instance in instances:
|
||||
try:
|
||||
# SRV target points to a name that resolves to the instance's A record.
|
||||
# Conventionally: <instance_id>.<service_name>.<domain_suffix>.
|
||||
# Let's use: <instance_id>.node.<domain_suffix> for simplicity,
|
||||
# or maybe <instance_id>.<service_name>...
|
||||
# Using just instance ID + suffix is simple and unique.
|
||||
# Ensure the target ends with the suffix too!
|
||||
target_name_str = f"{instance.id}{DNS_QUERY_SUFFIX}"
|
||||
target_name_bytes = target_name_str.encode("utf-8")
|
||||
|
||||
srv_payload = dns.Record_SRV(
|
||||
priority=0, # Lower is more preferred
|
||||
weight=10, # Relative weight for same priority
|
||||
port=instance.port,
|
||||
target=target_name_bytes, # Must be bytes
|
||||
ttl=DNS_DEFAULT_TTL, # TTL for the SRV record itself
|
||||
)
|
||||
srv_rr = dns.RRHeader(
|
||||
name=name, # Respond with the original query name
|
||||
type=dns.SRV,
|
||||
cls=cls,
|
||||
ttl=DNS_DEFAULT_TTL,
|
||||
payload=srv_payload,
|
||||
)
|
||||
answers.append(srv_rr)
|
||||
|
||||
# Add corresponding A record for the target in the additional section
|
||||
a_payload = dns.Record_A(address=instance.address, ttl=DNS_DEFAULT_TTL)
|
||||
a_rr = dns.RRHeader(
|
||||
name=target_name_bytes, # Name matches SRV target
|
||||
type=dns.A,
|
||||
cls=cls,
|
||||
ttl=DNS_DEFAULT_TTL, # TTL for the additional A record
|
||||
payload=a_payload,
|
||||
)
|
||||
additional.append(a_rr)
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Warning: Error creating SRV/A record for instance {instance.id} "
|
||||
f"(Addr: {instance.address}:{instance.port}): {e}. Skipping."
|
||||
)
|
||||
|
||||
return answers, [], additional
|
||||
|
||||
def _handle_txt_query(
|
||||
self, name: bytes, base_name: str, cls: int
|
||||
) -> Tuple[List, List, List]:
|
||||
"""Handles TXT record lookups, returning service metadata."""
|
||||
answers = []
|
||||
instances = self._get_instances_for_query(base_name, is_srv_query=False)
|
||||
|
||||
for instance in instances:
|
||||
# --- Initialize list for this instance ---
|
||||
txt_data = []
|
||||
instance_id_str = str(instance.id) # Use consistently
|
||||
|
||||
try:
|
||||
print(f"DNS TXT: Processing instance {instance_id_str}") # Log start
|
||||
|
||||
# --- Process Tags ---
|
||||
if isinstance(instance.tags, list):
|
||||
for tag in instance.tags:
|
||||
try:
|
||||
# Ensure tag is string before encoding
|
||||
txt_data.append(str(tag).encode("utf-8"))
|
||||
except Exception as tag_enc_err:
|
||||
print(
|
||||
f"ERROR encoding tag '{repr(tag)}' (type: {type(tag)}) for instance {instance_id_str}: {tag_enc_err}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"WARNING: Instance {instance_id_str} tags are not a list: {type(instance.tags)}"
|
||||
)
|
||||
|
||||
# --- Process Metadata ---
|
||||
if isinstance(instance.metadata, dict):
|
||||
for k, v in instance.metadata.items():
|
||||
try:
|
||||
# Ensure key/value are strings before formatting/encoding
|
||||
key_str = str(k)
|
||||
val_str = str(v)
|
||||
txt_data.append(f"{key_str}={val_str}".encode("utf-8"))
|
||||
except Exception as meta_enc_err:
|
||||
print(
|
||||
f"ERROR encoding metadata item '{repr(k)}':'{repr(v)}' (types: {type(k)}/{type(v)}) for instance {instance_id_str}: {meta_enc_err}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"WARNING: Instance {instance_id_str} metadata is not a dict: {type(instance.metadata)}"
|
||||
)
|
||||
|
||||
# --- Process Instance ID ---
|
||||
try:
|
||||
txt_data.append(f"instance_id={instance_id_str}".encode("utf-8"))
|
||||
except Exception as id_enc_err:
|
||||
print(
|
||||
f"ERROR encoding instance ID for {instance_id_str}: {id_enc_err}"
|
||||
)
|
||||
|
||||
# --- **** THE CRITICAL DEBUGGING STEP **** ---
|
||||
print(
|
||||
f"DNS TXT DEBUG: Data for instance {instance_id_str} BEFORE Record_TXT:"
|
||||
)
|
||||
valid_types = True
|
||||
if not isinstance(txt_data, list):
|
||||
print(f" FATAL: txt_data is NOT a list! Type: {type(txt_data)}")
|
||||
valid_types = False
|
||||
else:
|
||||
for i, item in enumerate(txt_data):
|
||||
item_type = type(item)
|
||||
print(f" Item {i}: Type={item_type}, Value={repr(item)}")
|
||||
if item_type is not bytes:
|
||||
print(f" ^^^^^ ERROR: Item {i} is NOT bytes!")
|
||||
valid_types = False
|
||||
# --- **** END DEBUGGING STEP **** ---
|
||||
|
||||
if not txt_data:
|
||||
print(
|
||||
f"DNS TXT: No valid TXT data generated for instance {instance_id_str}, skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
# Only proceed if all items were bytes
|
||||
if not valid_types:
|
||||
print(
|
||||
f"DNS TXT ERROR: txt_data for {instance_id_str} contained non-bytes elements. Skipping record creation."
|
||||
)
|
||||
continue # Skip this instance if data is bad
|
||||
|
||||
# --- Create Payload and RR Header ---
|
||||
# This is where the error occurs if txt_data contains non-bytes
|
||||
print(
|
||||
f"DNS TXT: Attempting to create Record_TXT for instance {instance_id_str}..."
|
||||
)
|
||||
payload = dns.Record_TXT(txt_data, ttl=DNS_DEFAULT_TTL)
|
||||
print(
|
||||
f"DNS TXT: Record_TXT created successfully for {instance_id_str}."
|
||||
)
|
||||
|
||||
rr = dns.RRHeader(
|
||||
name=name,
|
||||
type=dns.TXT,
|
||||
cls=cls,
|
||||
ttl=DNS_DEFAULT_TTL,
|
||||
payload=payload,
|
||||
)
|
||||
answers.append(rr)
|
||||
print(
|
||||
f"DNS TXT: RRHeader created and added for instance {instance_id_str}."
|
||||
)
|
||||
|
||||
# Catch errors specifically during the DNS object creation phase
|
||||
except TypeError as te_dns:
|
||||
print(
|
||||
f"FATAL DNS TypeError creating TXT record for {instance_id_str}: {te_dns}"
|
||||
)
|
||||
print(
|
||||
" This likely means the list passed to Record_TXT contained non-bytes elements."
|
||||
)
|
||||
traceback.print_exc() # Crucial to see where in Twisted it fails
|
||||
except Exception as e_dns:
|
||||
print(
|
||||
f"ERROR creating TXT DNS objects for instance {instance_id_str}: {e_dns.__class__.__name__}: {e_dns}"
|
||||
)
|
||||
traceback.print_exc()
|
||||
|
||||
# Log the final result before returning
|
||||
print(
|
||||
f"DNS TXT: Finished processing query for '{base_name}'. Found {len(instances)} instances, generated {len(answers)} TXT records."
|
||||
)
|
||||
return answers, [], []
|
||||
|
||||
|
||||
# --- Health Checker ---
|
||||
def check_service_health(instance: ServiceInstance) -> str:
|
||||
@ -1073,7 +1260,10 @@ def run_dns_server(db_path: str, hmac_key: bytes, port: int):
|
||||
db_path, hmac_key
|
||||
) # DNS process needs its own registry instance
|
||||
resolver = MiniDiscoveryResolver(registry)
|
||||
factory = server.DNSServerFactory(clients=[resolver])
|
||||
factory = server.DNSServerFactory(
|
||||
clients=[resolver],
|
||||
# verbose=2 # Uncomment for very detailed Twisted DNS logging
|
||||
)
|
||||
protocol = dns.DNSDatagramProtocol(controller=factory)
|
||||
|
||||
# Listen on UDP and TCP
|
||||
|
Loading…
x
Reference in New Issue
Block a user