From 8ee0203de8ea7a91581c1ae0979ac2d911c30bc9 Mon Sep 17 00:00:00 2001 From: Kalzu Rekku Date: Sat, 3 May 2025 11:38:00 +0300 Subject: [PATCH] Working to add TXT records --- MiniDiscovery.py | 480 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 335 insertions(+), 145 deletions(-) diff --git a/MiniDiscovery.py b/MiniDiscovery.py index ae6b54c..5472356 100644 --- a/MiniDiscovery.py +++ b/MiniDiscovery.py @@ -77,12 +77,13 @@ class ServiceInstance(BaseModel): tags: List[str] = Field( default_factory=list, description="Optional list of tags for filtering" ) - health: str = Field( - default="passing", description="Health status ('passing', 'failing', 'unknown')" - ) metadata: Dict[str, str] = Field( default_factory=dict, description="Optional key-value metadata" ) + health: str = Field( + default="passing", description="Health status ('passing', 'failing', 'unknown')" + ) + last_updated: float = Field(default_factory=time.time) class TokenInfo(BaseModel): @@ -867,171 +868,357 @@ class MiniDiscoveryResolver(common.ResolverBase): query_type: int, timeout: Optional[Tuple[int]] = None, ) -> Deferred: + """ + Main lookup entry point. Decodes the name, checks the suffix, + and dispatches to specific handlers based on query_type. + """ d = Deferred() - name_str_debug = repr(name) + name_str_debug = repr(name) # For logging errors + try: # 1. Decode the incoming query name safely try: name_str = name.decode("utf-8").lower() - except UnicodeDecodeError as decode_err: - print( - f"DNS lookup error: Cannot decode query name {name_str_debug} as UTF-8: {decode_err}" - ) - d.errback(Failure(decode_err)) - return d - - # 2. Check suffix - if not name_str.endswith(DNS_QUERY_SUFFIX): + except UnicodeDecodeError: + # Log? Respond with format error? For now, treat as non-existent. + print(f"DNS: Cannot decode query name {name_str_debug} as UTF-8.") + # This should result in NXDOMAIN if we don't callback anything + # d.callback(([], [], [])) # Alternatively, explicitly return empty + return d # Let Twisted handle NXDOMAIN? Best to callback empty. + # Returning empty is safer for our specific resolver. d.callback(([], [], [])) return d - # --- SRV Service Type Lookup Logic --- - is_srv_type_query = False - service_tag_to_find = None - if name_str.startswith("_") and ( - "_tcp." in name_str or "_udp." in name_str - ): - parts = name_str.split(".") - # Example: _ssh._tcp.laiska.local - if ( - len(parts) >= 4 - and parts[0].startswith("_") - and parts[1] in ["_tcp", "_udp"] - ): - # Assume format _service._proto.domain.suffix... - service_tag_to_find = parts[0][1:] # Extract 'ssh' from '_ssh' - # Ensure the query type is actually SRV - if query_type == dns.SRV: - is_srv_type_query = True - else: - # Requesting A/other records for SRV-style name doesn't make sense here - d.callback(([], [], [])) - return d - - # --- Instance Fetching --- - instances_to_process = [] - if is_srv_type_query: - print( - f"DNS: Performing SRV type lookup for tag '{service_tag_to_find}'" - ) - all_services = registry.get_all_services() - for service_name, service_instances in all_services.items(): - for instance in service_instances: - # Check health AND tag presence - if ( - instance.health == "passing" - and service_tag_to_find in instance.tags - ): - instances_to_process.append(instance) - if not instances_to_process: - print( - f"DNS: No passing instances found with tag '{service_tag_to_find}'" - ) - d.callback(([], [], [])) # No matching instances found + # 2. Check for the expected suffix + if not name_str.endswith(DNS_QUERY_SUFFIX): + # Not a query for our domain, return empty. + d.callback(([], [], [])) return d + + # 3. Extract the base name (part before the suffix) + base_name = name_str[: -len(DNS_QUERY_SUFFIX)] + if not base_name: # Query was just the suffix itself + d.callback(([], [], [])) + return d + + # 4. Dispatch based on query type + handler = None + if query_type == dns.A: + handler = self._handle_a_query + elif query_type == dns.SRV: + handler = self._handle_srv_query + elif query_type == dns.TXT: + handler = self._handle_txt_query + # Add elif for AAAA if needed in the future + # elif query_type == dns.AAAA: + # handler = self._handle_aaaa_query # Implement this if needed else: - service_name = name_str[: -len(DNS_QUERY_SUFFIX)].split(".")[-1] - if not service_name: - d.callback(([], [], [])) - return d - instances_to_process = registry.get_service( - service_name, only_passing=True - ) - if not instances_to_process: - d.callback(([], [], [])) - return d + # Unsupported query type for our resolver + d.callback(([], [], [])) + return d - # --- Build DNS Records --- - answers = [] - authority = [] - additional = [] - - # Check query_type against the records we can generate - if ( - query_type == dns.A and not is_srv_type_query - ): # Only generate A for direct name query - for instance in instances_to_process: - instance_addr_debug = instance.address - try: - ip_address_string = instance.address - answers.append( - dns.RRHeader( - name=name, - type=dns.A, - cls=cls, - ttl=DNS_DEFAULT_TTL, - # Pass the string, Twisted handles conversion - payload=dns.Record_A( - address=ip_address_string, ttl=DNS_DEFAULT_TTL - ), - ) - ) - except ( - Exception - ) as record_e: # Catch potential errors during record creation itself - print( - f"Warning: Error creating A record for IP '{instance.address}' (Service ID {instance.id}): {record_e}. Skipping." - ) - - elif ( - query_type == dns.SRV - ): # Generate SRV for direct name query OR SRV type query - for instance in instances_to_process: - try: - ip_address_string = instance.address - # SRV target should still be bytes - target_name = f"{instance.id}{DNS_QUERY_SUFFIX}".encode("utf-8") - answers.append( - dns.RRHeader( - name=name, - type=dns.SRV, - cls=cls, - ttl=DNS_DEFAULT_TTL, - payload=dns.Record_SRV( - priority=0, - weight=10, - port=instance.port, - target=target_name, - ttl=DNS_DEFAULT_TTL, - ), - ) - ) - additional.append( - dns.RRHeader( - name=target_name, - type=dns.A, - cls=cls, - ttl=DNS_DEFAULT_TTL, - payload=dns.Record_A( - address=ip_address_string, ttl=DNS_DEFAULT_TTL - ), - ) - ) - except ( - Exception - ) as record_e: # Catch potential errors during record creation itself - print( - f"Warning: Error creating SRV/additional record for IP '{instance.address}' (Service ID {instance.id}): {record_e}. Skipping." - ) - - # If we successfully built records (or correctly skipped bad IPs) + # 5. Execute the handler and set callback/errback + # The handlers currently don't return Deferreds, so we call directly + answers, authority, additional = handler(name, base_name, cls) d.callback((answers, authority, additional)) - # --- Catch ANY OTHER unexpected error during the lookup process --- except Exception as e: - # Log the errors + # Catch-all for unexpected errors during dispatch or handler execution print( - f"!!! Unhandled exception during DNS lookup for query name {name_str_debug} !!!" + f"!!! Unhandled exception during DNS lookup for {name_str_debug} " + f"(Type: {query_type}) !!!" ) print(f"Exception Type: {type(e).__name__}") print(f"Exception Args: {e.args}") print("--- Traceback ---") traceback.print_exc() print("--- End Traceback ---") - d.errback(Failure(e)) # Signal SERVFAIL + # Signal DNS server failure (SERVFAIL) + d.errback(Failure(e)) return d + # --- Helper Methods for Specific Record Types --- + + def _parse_srv_query(self, base_name: str) -> Tuple[Optional[str], Optional[str]]: + """ + Parses SRV-style queries: _tag._proto.service or _tag._proto + Returns (tag, service_name) or (None, None) if not SRV-style. + """ + parts = base_name.split(".") + if ( + len(parts) >= 2 + and parts[0].startswith("_") + and parts[1] in ["_tcp", "_udp"] + ): + tag = parts[0][1:] # Remove leading '_' + # Service name is the part *after* _tag._proto, if it exists + service_name = parts[2] if len(parts) > 2 else None + # We currently ignore the rest of the parts (like datacenter in consul) + return tag, service_name + return None, None # Not an SRV-style query name + + def _get_instances_for_query( + self, base_name: str, is_srv_query: bool = False + ) -> List[ServiceInstance]: + """Fetches relevant, passing service instances based on the query name.""" + instances = [] + tag_filter = None + service_name_filter = None + + if is_srv_query: + tag_filter, service_name_filter = self._parse_srv_query(base_name) + if tag_filter is None: # Not a valid _tag._proto... query + return [] # Return empty, SRV query handler expects specific format + + else: # A or TXT query: service name is the last part + service_name_filter = base_name.split(".")[-1] + + if service_name_filter: + # Query targets a specific service (with potential tag filter for SRV) + service_instances = self.registry.get_service( + service_name_filter, only_passing=True + ) + if tag_filter: # SRV query with tag and service + instances = [ + inst for inst in service_instances if tag_filter in inst.tags + ] + else: # A/TXT query, or SRV query without tag (using service name) + instances = service_instances # Already filtered for passing + + elif tag_filter: # SRV query for a tag across all services + all_services = self.registry.get_all_services() + for name, service_instances in all_services.items(): + for inst in service_instances: + if inst.health == "passing" and tag_filter in inst.tags: + instances.append(inst) + else: + # This case shouldn't be reached if initial checks are correct + # (e.g., A/TXT query needs a service name part) + print(f"Warning: Could not determine filter for query '{base_name}'") + + print( + f"DNS Lookup: base_name='{base_name}', is_srv={is_srv_query}, " + f"tag='{tag_filter}', service='{service_name_filter}'. Found {len(instances)} instances." + ) + return instances + + def _handle_a_query( + self, name: bytes, base_name: str, cls: int + ) -> Tuple[List, List, List]: + """Handles A record lookups.""" + answers = [] + instances = self._get_instances_for_query(base_name, is_srv_query=False) + + for instance in instances: + try: + # Twisted's Record_A expects the IP address string + payload = dns.Record_A(address=instance.address, ttl=DNS_DEFAULT_TTL) + rr = dns.RRHeader( + name=name, # Respond with the original query name + type=dns.A, + cls=cls, + ttl=DNS_DEFAULT_TTL, + payload=payload, + ) + answers.append(rr) + except Exception as e: + print( + f"Warning: Error creating A record for instance {instance.id} " + f"(IP: {instance.address}): {e}. Skipping." + ) + + return answers, [], [] # No authority or additional records for basic A + + def _handle_srv_query( + self, name: bytes, base_name: str, cls: int + ) -> Tuple[List, List, List]: + """Handles SRV record lookups (service or tag based).""" + answers = [] + additional = [] + instances = self._get_instances_for_query(base_name, is_srv_query=True) + + # If _get_instances_for_query returned empty because parsing failed, + # we might want to try interpreting the name differently, e.g., + # as a direct SRV lookup for a service name like `service.domain.suffix`. + # For now, we strictly follow the _tag._proto logic defined above. + # If you want `srvlookup service.domain.suffix`, the logic in + # _get_instances_for_query needs adjustment or another branch here. + + for instance in instances: + try: + # SRV target points to a name that resolves to the instance's A record. + # Conventionally: ... + # Let's use: .node. for simplicity, + # or maybe .... + # Using just instance ID + suffix is simple and unique. + # Ensure the target ends with the suffix too! + target_name_str = f"{instance.id}{DNS_QUERY_SUFFIX}" + target_name_bytes = target_name_str.encode("utf-8") + + srv_payload = dns.Record_SRV( + priority=0, # Lower is more preferred + weight=10, # Relative weight for same priority + port=instance.port, + target=target_name_bytes, # Must be bytes + ttl=DNS_DEFAULT_TTL, # TTL for the SRV record itself + ) + srv_rr = dns.RRHeader( + name=name, # Respond with the original query name + type=dns.SRV, + cls=cls, + ttl=DNS_DEFAULT_TTL, + payload=srv_payload, + ) + answers.append(srv_rr) + + # Add corresponding A record for the target in the additional section + a_payload = dns.Record_A(address=instance.address, ttl=DNS_DEFAULT_TTL) + a_rr = dns.RRHeader( + name=target_name_bytes, # Name matches SRV target + type=dns.A, + cls=cls, + ttl=DNS_DEFAULT_TTL, # TTL for the additional A record + payload=a_payload, + ) + additional.append(a_rr) + + except Exception as e: + print( + f"Warning: Error creating SRV/A record for instance {instance.id} " + f"(Addr: {instance.address}:{instance.port}): {e}. Skipping." + ) + + return answers, [], additional + + def _handle_txt_query( + self, name: bytes, base_name: str, cls: int + ) -> Tuple[List, List, List]: + """Handles TXT record lookups, returning service metadata.""" + answers = [] + instances = self._get_instances_for_query(base_name, is_srv_query=False) + + for instance in instances: + # --- Initialize list for this instance --- + txt_data = [] + instance_id_str = str(instance.id) # Use consistently + + try: + print(f"DNS TXT: Processing instance {instance_id_str}") # Log start + + # --- Process Tags --- + if isinstance(instance.tags, list): + for tag in instance.tags: + try: + # Ensure tag is string before encoding + txt_data.append(str(tag).encode("utf-8")) + except Exception as tag_enc_err: + print( + f"ERROR encoding tag '{repr(tag)}' (type: {type(tag)}) for instance {instance_id_str}: {tag_enc_err}" + ) + else: + print( + f"WARNING: Instance {instance_id_str} tags are not a list: {type(instance.tags)}" + ) + + # --- Process Metadata --- + if isinstance(instance.metadata, dict): + for k, v in instance.metadata.items(): + try: + # Ensure key/value are strings before formatting/encoding + key_str = str(k) + val_str = str(v) + txt_data.append(f"{key_str}={val_str}".encode("utf-8")) + except Exception as meta_enc_err: + print( + f"ERROR encoding metadata item '{repr(k)}':'{repr(v)}' (types: {type(k)}/{type(v)}) for instance {instance_id_str}: {meta_enc_err}" + ) + else: + print( + f"WARNING: Instance {instance_id_str} metadata is not a dict: {type(instance.metadata)}" + ) + + # --- Process Instance ID --- + try: + txt_data.append(f"instance_id={instance_id_str}".encode("utf-8")) + except Exception as id_enc_err: + print( + f"ERROR encoding instance ID for {instance_id_str}: {id_enc_err}" + ) + + # --- **** THE CRITICAL DEBUGGING STEP **** --- + print( + f"DNS TXT DEBUG: Data for instance {instance_id_str} BEFORE Record_TXT:" + ) + valid_types = True + if not isinstance(txt_data, list): + print(f" FATAL: txt_data is NOT a list! Type: {type(txt_data)}") + valid_types = False + else: + for i, item in enumerate(txt_data): + item_type = type(item) + print(f" Item {i}: Type={item_type}, Value={repr(item)}") + if item_type is not bytes: + print(f" ^^^^^ ERROR: Item {i} is NOT bytes!") + valid_types = False + # --- **** END DEBUGGING STEP **** --- + + if not txt_data: + print( + f"DNS TXT: No valid TXT data generated for instance {instance_id_str}, skipping." + ) + continue + + # Only proceed if all items were bytes + if not valid_types: + print( + f"DNS TXT ERROR: txt_data for {instance_id_str} contained non-bytes elements. Skipping record creation." + ) + continue # Skip this instance if data is bad + + # --- Create Payload and RR Header --- + # This is where the error occurs if txt_data contains non-bytes + print( + f"DNS TXT: Attempting to create Record_TXT for instance {instance_id_str}..." + ) + payload = dns.Record_TXT(txt_data, ttl=DNS_DEFAULT_TTL) + print( + f"DNS TXT: Record_TXT created successfully for {instance_id_str}." + ) + + rr = dns.RRHeader( + name=name, + type=dns.TXT, + cls=cls, + ttl=DNS_DEFAULT_TTL, + payload=payload, + ) + answers.append(rr) + print( + f"DNS TXT: RRHeader created and added for instance {instance_id_str}." + ) + + # Catch errors specifically during the DNS object creation phase + except TypeError as te_dns: + print( + f"FATAL DNS TypeError creating TXT record for {instance_id_str}: {te_dns}" + ) + print( + " This likely means the list passed to Record_TXT contained non-bytes elements." + ) + traceback.print_exc() # Crucial to see where in Twisted it fails + except Exception as e_dns: + print( + f"ERROR creating TXT DNS objects for instance {instance_id_str}: {e_dns.__class__.__name__}: {e_dns}" + ) + traceback.print_exc() + + # Log the final result before returning + print( + f"DNS TXT: Finished processing query for '{base_name}'. Found {len(instances)} instances, generated {len(answers)} TXT records." + ) + return answers, [], [] + # --- Health Checker --- def check_service_health(instance: ServiceInstance) -> str: @@ -1073,7 +1260,10 @@ def run_dns_server(db_path: str, hmac_key: bytes, port: int): db_path, hmac_key ) # DNS process needs its own registry instance resolver = MiniDiscoveryResolver(registry) - factory = server.DNSServerFactory(clients=[resolver]) + factory = server.DNSServerFactory( + clients=[resolver], + # verbose=2 # Uncomment for very detailed Twisted DNS logging + ) protocol = dns.DNSDatagramProtocol(controller=factory) # Listen on UDP and TCP