折角勉強したんで bitstring を使ってみたよ
なんかはてな良く分からないな、こまったな。
id:blooper:20110203 で DNS のこと少し勉強して、id:blooper:20110509 で bitstring のこと少し勉強したんで、struct 使わずに bitstring で書き直してみた。
_uint8() みたいなしんどい関数用意しなくてもいいのはいいんだけど、結局書き方がベタな気がして、もっとフォーマットを定義してしまった方がすっきりするのかなぁ。でも面倒くさいなぁ。
まぁ、ちーとずつ bitstring 読んでくのに bitstring は便利だわ。
# coding=utf8 import random import select import socket import time import bitstring random.seed() # Google 使っちゃうよ DNSSERVER = '8.8.8.8' PORT = 53 FORMAT_HEADER = ', '.join(( 'uint:16=ID', 'uint:1=QR', 'uint:4=Opcode', 'uint:1=AA', 'uint:1=TC', 'uint:1=RD', 'uint:1=RA', 'uint:3=Z', 'uint:4=RCODE', 'uint:16=QDCOUNT', 'uint:16=ANCOUNT', 'uint:16=NSCOUNT', 'uint:16=ARCOUNT', )) FORMAT_QUESTION = ', '.join(( 'bytes:QNAME_LEN=QNAME', 'uint:16=QTYPE', 'uint:16=QCLASS', )) QTYPES = { 'A': 1, 'NS': 2, 'MD': 3, 'MF': 4, 'CNAME': 5, 'SOA': 6, 'MB': 7, 'MG': 8, 'MR': 9, 'NULL': 10, 'WKS': 11, 'PTR': 12, 'HINFO': 13, 'MINFO': 14, 'MX': 15, 'TXT': 16, } class Resolver(object): QUEUE_LIMIT = 30 # 溜めておく queue の上限数 queue = set() # 未処理のもの入れ results = {} # 結果入れ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) fd = sock.fileno() def set(self, fqdn): """登録して、溜まってたら処理する""" if fqdn not in self.results: qdata = self._query(fqdn) self.sock.sendto(qdata, (DNSSERVER, PORT)) self.results[fqdn] = None self.queue.add(fqdn) if len(self.queue) > self.QUEUE_LIMIT: self.sweep() def get(self, fqdn, timeout=10): """溜まってるの処理して質問に答える""" if self.queue: self.sweep(timeout) return self.results.get(fqdn) def sweep(self, timeout=10): """溜まってるの処理する""" start = time.time() while self.queue: if len(select.select([self.fd], [], [], 0.001)[0]) > 0: # 受けとれるデータがあったら処理して結果に入れる rdata = self.sock.recvfrom(8192)[0] # qr は response なら 1 qr = ord(rdata[3]) >> 7 # rcode は成功なら 0 rcode = ord(rdata[4]) & 31 # response で成功なら parse する if (qr, rcode) == (1, 0): rfqdn, response = DNSResponse(rdata).parse() self.queue.remove(rfqdn) self.results[rfqdn] = response if time.time() - start > timeout: # 時間切れは終了 break for q in self.queue: del self.results[q] self.queue.clear() def _query(self, fqdn): """fqdn の IP address を下さいという query を作る""" fqdn = fqdn.encode('utf-8') # d.hatena.ne.jp -> \x01d\x06hatena\x02ne\x02jp\x00 QNAME = b''.join([chr(len(x)) + x for x in fqdn.split(b'.')]) + b'\x00' query = bitstring.pack( FORMAT_HEADER + ', ' + FORMAT_QUESTION, ID=random.randint(0, 2 ** 16 - 1), QR=0, Opcode=0, AA=0, TC=0, RD=1, RA=0, Z=0, RCODE=0, QDCOUNT=1, ANCOUNT=0, NSCOUNT=0, ARCOUNT=0, QNAME_LEN=len(QNAME)*8, QNAME=QNAME, QTYPE=QTYPES['A'], QCLASS=1 ) return query.bytes class DNSResponse(object): def __init__(self, data): self.cbs = bitstring.ConstBitStream(bytes=data) def parse(self): """DNS response をパースして list of tuples で返す data -> [(name, type, class ttl, response_data)] """ result = [] self.cbs.pos = 12 * 8 # これ hard coding でいいの? # response には query が含まれてるのね qname = self._pick_name() qtype = self.cbs.read('uint:16') qclass = self.cbs.read('uint:16') while True: # response data を喰いつくす、全部読んだら終わり if len(self.cbs) <= self.cbs.pos: break rec_name = self._pick_name() rec_type = self.cbs.read('uint:16') rec_class = self.cbs.read('uint:16') rec_ttl = self.cbs.read('uint:32') rec_dlength = self.cbs.read('uint:16') if rec_type == 1: # A レコードだったら rec_data = self._parse_ipv4addr() elif rec_type == 5: # CNAME レコードだったら rec_data = self._pick_name() elif rec_type == 6: # SOA レコードだったら name = self._pick_name() rname = self._pick_name() serial = self.cbs.read('uint:32') refresh = self.cbs.read('uint:32') retry = self.cbs.read('uint:32') expire = self.cbs.read('uint:32') minimum = self.cbs.read('uint:32') rec_data = (name, rname, serial, refresh, retry, expire) else: # それ以外だったら parse しないで bytes として rec_data = self.cbs.read('bytes:%d' % (rec_dlength,)) result.append((rec_name, rec_type, rec_class, rec_ttl, rec_data)) return qname, result def _pick_name(self): """self.data の self.pos bytes 目から始まる名前を拾う""" name_list = [] while True: length = self.cbs.read('uint:8') if length == 0: break elif length >= 192: # length (2 octets) の頭が 11 だと compressed なので別処理 self.cbs.pos -= 8 name_pos = self.cbs.readlist('bits:2, uint:14')[1] * 8 pos, self.cbs.pos = self.cbs.pos, name_pos name = self._pick_name() self.cbs.pos = pos name_list.append(name) break else: name_list.append(self.cbs.read('bytes:%d' % (length,))) return b'.'.join(name_list) def _parse_ipv4addr(self): """self.data の self.pos bytes 目から 4 byte 読んで self.pos をずらす 返り値は IP address としての str """ return '.'.join([str(x) for x in self.cbs.readlist('4*uint:8')]) if __name__ == '__main__': import sys resolver = Resolver() for fqdn in sys.argv[1:]: resolver.set(fqdn) for fqdn in sys.argv[1:]: print(resolver.get(fqdn))