From b2e0b2ebc6c6a6000b2a0f003d79a6f6d24ab8dc Mon Sep 17 00:00:00 2001
From: Stuart Gathman <stuart@gathman.org>
Date: Tue, 28 Aug 2012 06:02:14 +0000
Subject: [PATCH] Fix CNAME following bug.

---
 Milter/dns.py | 28 ++++++++++++++++++++++++----
 1 file changed, 24 insertions(+), 4 deletions(-)

diff --git a/Milter/dns.py b/Milter/dns.py
index fe9b54a..f8991c7 100644
--- a/Milter/dns.py
+++ b/Milter/dns.py
@@ -70,15 +70,23 @@ class Session(object):
     pre: qtype in ['A', 'AAAA', 'MX', 'PTR', 'TXT', 'SPF']
     post: isinstance(__return__, types.ListType)
     """
+    if name.endswith('.'): name = name[:-1]
+    if not reduce(lambda x,y:x and 0 < len(y) < 64, name.split('.'),True):
+        return []   # invalid DNS name (too long or empty)
     result = self.cache.get( (name, qtype) )
     cname = None
+    if result: return result
+    cnamek = (name,'CNAME')
+    cname = self.cache.get( cnamek )
 
-    if not result:
+    if cname:
+        cname = cname[0]
+    else:
         safe2cache = Session.SAFE2CACHE
         for k, v in DNSLookup(name, qtype):
-            if k == (name, 'CNAME'):
+            if k == cnamek:
                 cname = v
-            if (qtype,k[1]) in safe2cache:
+            if k[1] == 'CNAME' or (qtype,k[1]) in safe2cache:
                 self.cache.setdefault(k, []).append(v)
         result = self.cache.get( (name, qtype), [])
     if not result and cname:
@@ -89,10 +97,22 @@ class Session(object):
             raise DNSError('Length of CNAME chain exceeds %d' % MAX_CNAME)
         cnames[name] = cname
         if cname in cnames:
-            raise DNSError, 'CNAME loop'
+            raise DNSError('CNAME loop')
         result = self.dns(cname, qtype, cnames=cnames)
+        if result:
+            self.cache[(name,qtype)] = result
     return result
 
+  def dns_txt(self, domainname, enc='ascii'):
+    "Get a list of TXT records for a domain name."
+    if domainname:
+        try:
+            return [''.join(s.decode(enc) for s in a)
+                for a in self.dns(domainname, 'TXT')]
+        except UnicodeEncodeError:
+            raise DNSError('Non-ascii character in SPF TXT record.')
+    return []
+
 DNS.DiscoverNameServers()
 
 if __name__ == '__main__':
-- 
GitLab