折角勉強したんで bitstring を使ってみたよ

なんかはてな良く分からないな、こまったな。

id:blooper:20110203 で DNS のこと少し勉強して、id:blooper:20110509 で bitstring のこと少し勉強したんで、struct 使わずに bitstring で書き直してみた。

_uint8() みたいなしんどい関数用意しなくてもいいのはいいんだけど、結局書き方がベタな気がして、もっとフォーマットを定義してしまった方がすっきりするのかなぁ。でも面倒くさいなぁ。

まぁ、ちーとずつ bitstring 読んでくのに bitstring は便利だわ。

# coding=utf8
import random
import select
import socket
import time

import bitstring

random.seed()

# Google 使っちゃうよ
DNSSERVER = '8.8.8.8'
PORT = 53

FORMAT_HEADER = ', '.join((
    'uint:16=ID',
    'uint:1=QR',
    'uint:4=Opcode',
    'uint:1=AA',
    'uint:1=TC',
    'uint:1=RD',
    'uint:1=RA',
    'uint:3=Z',
    'uint:4=RCODE',
    'uint:16=QDCOUNT',
    'uint:16=ANCOUNT',
    'uint:16=NSCOUNT',
    'uint:16=ARCOUNT',
    ))

FORMAT_QUESTION = ', '.join((
    'bytes:QNAME_LEN=QNAME',
    'uint:16=QTYPE',
    'uint:16=QCLASS',
    ))

QTYPES = {
    'A': 1,
    'NS': 2,
    'MD': 3,
    'MF': 4,
    'CNAME': 5,
    'SOA': 6,
    'MB': 7,
    'MG': 8,
    'MR': 9,
    'NULL': 10,
    'WKS': 11,
    'PTR': 12,
    'HINFO': 13,
    'MINFO': 14,
    'MX': 15,
    'TXT': 16,
    }


class Resolver(object):
  QUEUE_LIMIT = 30 # 溜めておく queue の上限数

  queue = set()  # 未処理のもの入れ
  results = {}   # 結果入れ
  sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  fd = sock.fileno()

  def set(self, fqdn):
    """登録して、溜まってたら処理する"""

    if fqdn not in self.results:
      qdata = self._query(fqdn)
      self.sock.sendto(qdata, (DNSSERVER, PORT))
      self.results[fqdn] = None
      self.queue.add(fqdn)
      if len(self.queue) > self.QUEUE_LIMIT:
        self.sweep()

  def get(self, fqdn, timeout=10):
    """溜まってるの処理して質問に答える"""

    if self.queue:
      self.sweep(timeout)
    return self.results.get(fqdn)

  def sweep(self, timeout=10):
    """溜まってるの処理する"""

    start = time.time()
    while self.queue:
      if len(select.select([self.fd], [], [], 0.001)[0]) > 0:
        # 受けとれるデータがあったら処理して結果に入れる
        rdata = self.sock.recvfrom(8192)[0]
        # qr は response なら 1
        qr = ord(rdata[3]) >> 7
        # rcode は成功なら 0
        rcode = ord(rdata[4]) & 31

        # response で成功なら parse する
        if (qr, rcode) == (1, 0):
          rfqdn, response = DNSResponse(rdata).parse()
          self.queue.remove(rfqdn)
          self.results[rfqdn] = response
      if time.time() - start > timeout:
        # 時間切れは終了
        break
    for q in self.queue:
      del self.results[q]
    self.queue.clear()

  def _query(self, fqdn):
    """fqdn の IP address を下さいという query を作る"""

    fqdn = fqdn.encode('utf-8')

    # d.hatena.ne.jp -> \x01d\x06hatena\x02ne\x02jp\x00
    QNAME = b''.join([chr(len(x)) + x for x in fqdn.split(b'.')]) + b'\x00'
    query = bitstring.pack(
        FORMAT_HEADER + ', ' + FORMAT_QUESTION,
        ID=random.randint(0, 2 ** 16 - 1),
        QR=0, Opcode=0, AA=0, TC=0, RD=1, RA=0, Z=0, RCODE=0,
        QDCOUNT=1,
        ANCOUNT=0,
        NSCOUNT=0,
        ARCOUNT=0,
        QNAME_LEN=len(QNAME)*8,
        QNAME=QNAME,
        QTYPE=QTYPES['A'],
        QCLASS=1
        )

    return query.bytes


class DNSResponse(object):
  def __init__(self, data):
    self.cbs = bitstring.ConstBitStream(bytes=data)

  def parse(self):
    """DNS response をパースして list of tuples で返す

    data -> [(name, type, class ttl, response_data)]
    """
    result = []
    self.cbs.pos = 12 * 8  # これ hard coding でいいの?

    # response には query が含まれてるのね
    qname = self._pick_name()
    qtype = self.cbs.read('uint:16')
    qclass = self.cbs.read('uint:16')

    while True:
      # response data を喰いつくす、全部読んだら終わり
      if len(self.cbs) <= self.cbs.pos:
        break

      rec_name = self._pick_name()
      rec_type = self.cbs.read('uint:16')
      rec_class = self.cbs.read('uint:16')
      rec_ttl = self.cbs.read('uint:32')
      rec_dlength = self.cbs.read('uint:16')

      if rec_type == 1:
        # A レコードだったら
        rec_data = self._parse_ipv4addr()
      elif rec_type == 5:
        # CNAME レコードだったら
        rec_data = self._pick_name()
      elif rec_type == 6:
        # SOA レコードだったら
        name = self._pick_name()
        rname = self._pick_name()
        serial = self.cbs.read('uint:32')
        refresh = self.cbs.read('uint:32')
        retry = self.cbs.read('uint:32')
        expire = self.cbs.read('uint:32')
        minimum = self.cbs.read('uint:32')
        rec_data = (name, rname, serial, refresh, retry, expire)
      else:
        # それ以外だったら parse しないで bytes として
        rec_data = self.cbs.read('bytes:%d' % (rec_dlength,))
      result.append((rec_name, rec_type, rec_class, rec_ttl, rec_data))

    return qname, result

  def _pick_name(self):
    """self.data の self.pos bytes 目から始まる名前を拾う"""

    name_list = []
    while True:
      length = self.cbs.read('uint:8')
      if length == 0:
        break
      elif length >= 192:
        # length (2 octets) の頭が 11 だと compressed なので別処理
        self.cbs.pos -= 8
        name_pos = self.cbs.readlist('bits:2, uint:14')[1] * 8 
        pos, self.cbs.pos = self.cbs.pos, name_pos
        name = self._pick_name()
        self.cbs.pos = pos
        name_list.append(name)
        break
      else:
        name_list.append(self.cbs.read('bytes:%d' % (length,)))

    return b'.'.join(name_list)

  def _parse_ipv4addr(self):
    """self.data の self.pos bytes 目から 4 byte 読んで self.pos をずらす

    返り値は IP address としての str
    """
    return '.'.join([str(x) for x in self.cbs.readlist('4*uint:8')])


if __name__ == '__main__':
  import sys

  resolver = Resolver()
  for fqdn in sys.argv[1:]:
    resolver.set(fqdn)
  for fqdn in sys.argv[1:]:
    print(resolver.get(fqdn))