Monday, August 21, 2017

recursive dns resolution

recursive dns resolution, [gist link], requirements:

  • dns
  • clientsubnetoption


# -*- coding: utf8 -*-
"""@author: boyxuper@date: 2017/8/21 16:07"""import random
import socket
from collections import defaultdict
from itertools import chain

import clientsubnetoption
import dns
import dns.name
import dns.message
import dns.rdatatype
import dns.resolver

_logger = lambda *args: None

def logger_full(fmt, *args):
    if not args:
        print fmt
        return
    print fmt % args

socket.inet_pton = lambda _, p: socket.inet_aton(p)
socket.inet_ntop = lambda _, n: socket.inet_ntoa(n)


def retry(n=3, exc_list=()):
    def decorator(fn):
        def wrapper(*args, **kwargs):
            tried = 1            while tried <= n:
                try:
                    return fn(*args, **kwargs)
                except exc_list as err:
                    print 'RETRY %s/%s, [%s]' % (tried, n, err)
                    tried += 1            else:
                raise err
        return wrapper
    return decorator


@retry(n=5, exc_list=(dns.exception.Timeout, dns.resolver.NoAnswer, ))
def query_dns0(domain, ns_ip, client_ip=None, **kwargs):
    kwargs.setdefault('rdtype', dns.rdatatype.A)
    message = dns.message.make_query(domain, **kwargs)
    if client_ip:
        message.use_edns(options=[clientsubnetoption.ClientSubnetOption(client_ip)])

    if callable(ns_ip):
        ns_ip = ns_ip()
    return dns.query.udp(message, ns_ip, timeout=TIMEOUT), ns_ip


# only authority can be cachedKNOWN_AUTHORITIES = defaultdict(list, **{
    '.': ['192.5.5.241', '199.7.83.42', '192.58.128.30', '192.36.148.17'],
    'net.': [
        '192.48.79.30', '192.35.51.30', '192.52.178.30', '192.5.6.30', 
        '192.26.92.30', '192.41.162.30', '192.12.94.30', '192.54.112.30', 
        '192.31.80.30', '192.43.172.30', '192.33.14.30', '192.55.83.30', '192.42.93.30'],
    'com.': ['192.33.14.30', '192.12.94.30', '192.55.83.30', '192.52.178.30',
             '192.26.92.30', '192.48.79.30', '192.5.6.30', '192.43.172.30',
             '192.42.93.30', '192.31.80.30', '192.54.112.30', '192.35.51.30', '192.41.162.30'],
    'cn.': [
        '203.119.29.1', '203.119.27.1', '203.119.28.1',
        '203.119.26.1', '202.112.0.44', '203.119.25.1'],
})
NS_IPS = defaultdict(set)
TIMEOUT = 2

def locate_nearest_authority(dns_name):
    """    :type dns_name: dns.name.Name    """    for depth in range(len(dns_name), 0, -1):
        _, sub = dns_name.split(depth)
        if sub.to_text() in KNOWN_AUTHORITIES:
            return sub, KNOWN_AUTHORITIES[sub.to_text()]

    assert False, 'impossible'

def is_ip(s):
    try:
        socket.inet_aton(s)
    except:
        return False    else:
        return True

def _authority_iterator(authority_ips, client_ip, logger=_logger):
    pos = random.randint(0, len(authority_ips))

    # for short ips    while True:
        non_ips = []
        for name in chain(authority_ips[pos:], authority_ips[:pos]):
            if is_ip(name):
                yield name
            else:
                non_ips.append(name)

        for name in non_ips:
            if name in NS_IPS:
                server_ips = list(NS_IPS[name])
            else:
                _, server_ips, _ = resolve_A(name, client_ip=client_ip, logger=logger)
                NS_IPS[name] = set(server_ips)

            yield server_ips[pos % len(server_ips)]


@retry(n=5, exc_list=(dns.exception.Timeout, dns.resolver.NoAnswer, ))
def query_dns(dns_name, client_ip=None, logger=_logger, **kwargs):
    sub, authority_ips = locate_nearest_authority(dns_name)
    iterator = _authority_iterator(authority_ips, client_ip, logger=_logger)

    logger('querying %s @NS"%s": %r', dns_name, sub, authority_ips)
    response, ns_ip = query_dns0(dns_name, ns_ip=iterator.next, client_ip=client_ip, **kwargs)
    resp_code = response.rcode()
    if resp_code != dns.rcode.NOERROR:
        if resp_code == dns.rcode.NXDOMAIN:
            raise Exception('%s does not exist on %s.' % (sub, ns_ip))
        else:
            raise Exception('Error %s' % dns.rcode.to_text(resp_code))

    return response


def resolve_A(domain, client_ip, logger=_logger):
    """     :return: answers        response.authority.__len__() == 1    response.authority.name == {Name}a.shifen.com.    response.authority.items:        0 = {NS} ns2.a.shifen.com.        1 = {NS} ns3.a.shifen.com.        2 = {NS} ns4.a.shifen.com.        3 = {NS} ns1.a.shifen.com.        4 = {NS} ns5.a.shifen.com.    """    if not domain.endswith('.'): domain += '.'
    dns_name = dns.name.from_text(domain)

    while True:
        response = query_dns(dns_name, client_ip=client_ip, logger=logger)
        # logger(response)
        instant_ips = defaultdict(list)
        final_answer = None
        # authority contains NS should be processed last, so it can leverage the instant IPs        for answer in chain(response.additional, response.answer, response.authority):
            answer_name = answer.name.to_text()

            for item in answer.items:
                item_text = item.to_text()
                if item.rdtype == dns.rdatatype.A:
                    if answer_name == domain: final_answer = answer
                    instant_ips[answer_name].append(item_text)
                elif item.rdtype == dns.rdatatype.NS:
                    if domain not in instant_ips:
                        logger('got NS: %s -> %s', answer_name, item_text)

                    if item_text in instant_ips:
                        authority_ips = instant_ips[item_text]
                        NS_IPS[item_text].update(authority_ips)
                    else:
                        authority_ips = [item_text]

                    KNOWN_AUTHORITIES[answer_name].extend(authority_ips)
                elif item.rdtype == dns.rdatatype.CNAME:
                    logger('CNAME: %s -> %s', domain, item_text)
                    domain, dns_name = item_text, dns.name.from_text(item_text)

        if domain in instant_ips:
            return domain, instant_ips[domain], final_answer.ttl


if __name__ == '__main__':
    log = _logger
    log = logger_full
    print resolve_A('www.baidu.com', client_ip='8.8.8.8', logger=log)
    print resolve_A('www.taobao.com', client_ip='8.8.8.8', logger=log)
    print resolve_A('www.google.com', client_ip='8.8.8.8', logger=log)
    print resolve_A('dl.tiku.zhan.com', client_ip='8.8.8.8', logger=log)
    print resolve_A('ns3.dnsv4.com', client_ip='8.8.8.8', logger=log)
    print resolve_A('www.facebook.com', client_ip='8.8.8.8', logger=log)
    print resolve_A('api.taoqian123.com', client_ip='8.8.8.8', logger=log)
    print resolve_A('hotsoon.snssdk.com', client_ip='66.103.188.117', logger=log)