#!/usr/bin/python
#
# (C)opyright Xelerance 2007 - 2009, Paul Wouters <paul@xelerance.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#

import commands
import re
import base64
import os
import sys
import getopt
import datetime
import encodings.punycode
import codecs

global outputfp
outputfp = 0

errors = ""

try:
	import dns.resolver
except:
	print "dnskey-pull requires the python-dns package from http://www.dnspython.org/"
	print "Fedora: yum install python-dns"
	print "Debian: apt-get install python-dnspython   (NOT python-dns!)"
	sys.exit()

def usage():
	print "dnskey-pull: Pull DNSKEY records of Key Signing Keys for resolver priming"
	print "usage: dnskey-pull [-a] [-t] [-o output] [-s <ns>] zone [ zone ...]"
	print "       dnskey-pull [-o output] url [ url ...]"
	print ""
	print "examples:"
	print "          dnskey-pull -a -s f.root-servers.net . arpa" 
	print "          dnskey-pull -a -s ns-pri.ripe.net. e164.arpa ." 
	print "          dnskey-pull -t xelerance.com >> /etc/named.conf" 
	print "          dnskey-pull xelerance.com"
	print "          dnskey-pull https://registro.br/ksk/index.html"
        
	print "          -a = use AXFR to find all NS/DNSKEY records in a zone"
	print "          -t = Add the 'trusted-keys' wrapper around the keys found"
	print "If no nameserver is specified, the system resolver is used for AXFR."
	print "The system resolver is always used for DNSKEY queries." 

# calculate key tag
def calcKeyTag(dnskey, flags, protocol, algorithm ):
	if not dnskey:
		return ""
	ac=0
	i=0
	dnskey += chr((flags >> 8) & 0xff) + chr(flags & 0xff) + chr(protocol & 0xff) + chr(algorithm & 0xff)
	for c in dnskey:
		ac += (ord(c) << (((i&1)^1)*8)) & 0xffff
		i  += 1
	return ((ac + (ac >> 16)) & 0xffff)


#known to fail, does not use brackets in DNSKEY
#page = commands.getoutput("curl https://ens.museum/dnssec/Kmuseum.+005+39226.key")

#page = commands.getoutput("curl https://registro.br/ksk/index.html")
#page = commands.getoutput("curl http://www.iis.se/docs/ksk.txt")
#page = commands.getoutput("curl http://dnssec.nic.pr/keys/pr.key")
def grabFromURL(url):
	if not url:
		return
	dnskeys = []
	dnskey = ""
	trustedkeys = []
	trustedkey = ""
	page = commands.getoutput("curl %s"%url)
	page = re.sub("\t"," ",page)
	pages = page.split("\n")

	# Try and find trusted-keys first, since its easy and our destination format anyway.
	# If that fails, hunt for DNSKEY's.
	for index,line in enumerate(pages):
		if "trusted-keys {" in line:
			#print "START_OF_TRUSTEDKEY on line %s:%s"%(index,line)
			offset = 0
			while offset < 40:
				if "}" in pages[index+offset]:
					trustedkeys.append("".join(pages[index:index+offset+1]))
					break
				else:
					offset += 1
			continue

		elif "DNSKEY" in line:
			#print "START_OF_DNSKEY on line %s:%s"%(index,line)
			# We got the possible start of a dnskey record
			if "(" in line:
				offset = 0
				while offset < 40:
					if ")" in pages[index+offset]:
						dnskeys.append("".join(pages[index:index+offset+1]))
						break
					else:
						offset += 1
				continue
		
		else:
			# junk as far as we're concerned
			#print "JUNKED:%s"%line
			continue

	if trustedkeys:
		for pkey in trustedkeys:
			pkey = re.sub(" +"," ",pkey)
			if not "{\n" in pkey:
				pkey = re.sub("{","{\n",pkey)
				pkey = re.sub("}","\n}",pkey)
			return pkey
	elif dnskeys:
		ret = 'trusted-keys {\n'
		for pkey in dnskeys:
			pkey = re.sub(" +"," ",pkey)
			try:
				if pkey.split()[2] != "IN":
					# no ttl in record 
					(zone,intext, dnskeytext,optflags,protocol,algs,keyblob) = pkey.split(" ",6)
				else:
					# ttl in record
					(zone,ttl,intext, dnskeytext,optflags,protocol,algs,keyblob) = pkey.split(" ",7)
				keyblob = re.sub(" ","", re.sub("\\(", "", re.sub("\\).*$","",keyblob)))
				if int(optflags)%2 !=0:
					ret +=  '"%s" %s %s %s "%s";\n'%(zone,optflags,protocol,algs,keyblob)
				else:
					#ignore ZSK
					return
			except:
				return
		ret += '};\n'
		return ret
	else:
		# Failed to find any key on supplied url
		return

def main(argv=None):
	if argv is None:
		argv = sys.argv
	try:
		opts, args = getopt.getopt(argv[1:], "tavho:s:", ["trusted","axfr", "version","help","output:","server:"])
	except getopt.error, msg:
		#print >>sys.stderr, err.msg
		print >>sys.stderr, "ERROR parsing options"
		usage()
		sys.exit(2)

	# parse options
	trusted = 0
	axfr = 0
	ns = ""
	zone = ""
	global outputfp
	for o, a in opts:
		if o in ("-v", "--version"):
			print "dnskey-pull version 1.20 "
			print "Author:\n Paul Wouters <paul@xelerance.com>"
			print "Source : http://www.xelerance.com/software/dnssec-conf/"
			sys.exit()
		if o in ("-h", "--help"):
			usage()
			sys.exit()

		if o in ("-t","--trusted"):
			trusted = 1
		if o in ("-a","--axfr"):
			axfr = 1
		if o in ("-o","--output"):
			if not a:
				print "error: no output file specified for -o"
				usage()
				sys.exit()
			else:
				try:
					outputfp = open(a,"w")
					outputfp.write("//; %s\n//; %s\n"%(" ".join(argv), re.sub(":[^:]*$","", str(datetime.datetime.today()))))
				except IOError:
					print "error writing file %s"%a
					sys.exit()
				
		if o in ("-s","--server"):
			if not a:
				print "error: no server specified for -s"
				usage()
				sys.exit()
			else:
				ns = a;

	if not args:
		usage()
		sys.exit()

	obtainedKeys = ""
	for zone in args:
		urlkeys = ""
		# first check for url
		if zone[0:8] == "https://" or zone[0:7] == "http://" or zone[0:6] == "ftp://":
			urlkeys += grabFromURL(zone)
		else:
			# fix dots, eg ".arpa" and "arpa" to "arpa."
			if zone != ".":
				if zone[-1] != ".":
					zone = zone + "."
				if zone[0] == ".":
					zone = zone[1:]
				#print "Processing %s"%zone
			if axfr:
				obtainedKeys += getkeys(zone,ns)
			else:
				obtainedKeys += getkey(zone,ns)
	
	# if we got nothing, return with failure
	if not obtainedKeys:
		sys.exit("error: obtaining DNSKEY's failed for %s"%zone)

	if trusted:
		if outputfp:
			outputfp.write("trusted-keys {\n")
		else:
			print "trusted-keys {"

	if obtainedKeys:
		if outputfp:
			outputfp.write(obtainedKeys+"\n")
		else:
			print obtainedKeys

	elif urlkeys:
		if outputfp:
			outputfp.write(urlkeys+"\n")
		else:
			print urlkeys
	
	if trusted:
		if outputfp:
			outputfp.write("};\n")
		else:
			print "};"

	if outputfp:
		outputfp.close()

def NSonly(x): return "IN\tNS" in x

# return the key obtained
def getkey(zone,ns):
	# we're not using axfr, so we can only attempt DNSKEY's for the APEX
	# initialise resolver object - this does not use specified NS, because
	# that one might not be a public open resolver.
	# TODO: but the user knows best, so do it anyway
	res = dns.resolver.Resolver()
	res.use_edns(1, 0, 1400)
	answers = ""
	keyout = ""
	try:
		answers = res.query(zone,'DNSKEY')
	except:
		#print "NO DNSKEY for %s"%zone
		pass

	if answers:
		if zone.lower()[0:4] == "xn--":
			errors = ""
			puny = "//; %s\n"% encodings.punycode.punycode_decode(zone[4:],errors)
			keyout += puny.encode("utf8")
		for rdata in answers.rrset:
			# check for KSK bit
			if (int(rdata.flags)%2) !=0:
				keytag = calcKeyTag(rdata.key,rdata.flags,rdata.protocol,rdata.algorithm)
				msg =  '"%s" %s %s %s "%s"; // key id = %s\n'%(zone, rdata.flags,
					rdata.protocol, rdata.algorithm,
					base64.b64encode(rdata.key), keytag )
				keyout += msg
		return keyout
	else:
		return ""

def getkeys(zone,ns):
	global outputfp
	# no axfr support in python-dns yet
	cmd = "host -t axfr %s %s"%(zone,ns)
	rawdomainList = filter(NSonly, commands.getoutput(cmd).splitlines() )
	domainList = []
	for zone in rawdomainList:
		zone = zone.split("\t",1)[0]
		if not zone in domainList:
			domainList.append(zone)
	
	#print "Attempting to find DNSKEY's for %s zones"% len(domainList)

	# initialise resolver object - this does not use specified NS, because
	# that one might not be a public open resolver.
	# TODO: but the user knows best, so do it anyway
	res = dns.resolver.Resolver()
	res.use_edns(1, 0, 1400)
	for zone in domainList:
		answers = 0
		try:
			answers = res.query(zone,'DNSKEY')
		except:
			#print "NO DNSKEY for %s"%zone
			pass
		if answers:
			if zone.lower()[0:4] == "xn--":
				errors = ""
				puny = "//; %s"% encodings.punycode.punycode_decode(zone[4:],errors)
				if outputfp:
					outputfp.write(puny)
				else:
					print puny
			for rdata in answers.rrset:
				if (int(rdata.flags)%2) !=0:
					msg =  '"%s" %s %s %s "%s";'%(zone, rdata.flags,
						rdata.protocol, rdata.algorithm,
						base64.b64encode(rdata.key) )
					if outputfp:
						outputfp.write(msg+"\n")
					else:
						print msg
				if outputfp:
					outputfp.write("\n")


if __name__ == "__main__":
	sys.exit(main())

