From 381e906b6a462f1a0589c72d519854a6ae782ce6 Mon Sep 17 00:00:00 2001
From: "Stuart D. Gathman" <stuart@gathman.org>
Date: Thu, 1 Dec 2016 23:59:31 -0500
Subject: [PATCH] Implement setsymlist decorator and test framework

---
 Milter/__init__.py | 37 +++++++++++++++++++++++++++++++++----
 Milter/test.py     | 45 +++++++++++++++++++++++++++++++++++++++++----
 sample.py          |  8 +++++++-
 testsample.py      |  5 +++++
 4 files changed, 86 insertions(+), 9 deletions(-)

diff --git a/Milter/__init__.py b/Milter/__init__.py
index 6e9da64..85c5ab0 100755
--- a/Milter/__init__.py
+++ b/Milter/__init__.py
@@ -48,6 +48,12 @@ OPTIONAL_CALLBACKS = {
   'header':(P_NR_HDR,P_NOHDRS)
 }
 
+MACRO_CALLBACKS = {
+  'connect': M_CONNECT,
+  'hello': M_HELO, 'envfrom': M_ENVFROM, 'envrcpt': M_ENVRCPT,
+  'data': M_DATA, 'eom': M_EOM, 'eoh': M_EOH
+}
+
 ## @private
 R = re.compile(r'%+')
 
@@ -141,6 +147,7 @@ def nocallback(func):
   except KeyError:
     raise ValueError(
       '@nocallback applied to non-optional method: '+func.__name__)
+  @wraps(func)
   def wrapper(self,*args):
     if func(self,*args) != CONTINUE:
       raise RuntimeError('%s return code must be CONTINUE with @nocallback'
@@ -173,6 +180,19 @@ def noreply(func):
   wrapper.milter_protocol = nr_mask
   return wrapper
 
+## Function decorator to set macros used in a callback.
+# By default, the MTA sends all macros defined for a callback.
+# If some or all of these are unused, the bandwidth can be saved
+# by listing the ones that are used.
+# @since 1.0.2
+def symlist(func,*syms):
+  if func.__name__ not in MACRO_CALLBACKS:
+    raise ValueError('@symlist applied to non-symlist method: '+func.__name__)
+  if len(syms) > 5:
+    raise ValueError('@symlist limited to 5 macros by MTA: '+func.__name__)
+  func._symlist = syms
+  return func
+
 ## Disabled action exception.
 # set_flags() can tell the MTA that this application will not use certain
 # features (such as CHGFROM).  This can also be negotiated for each
@@ -393,6 +413,11 @@ class Base(object):
   def negotiate(self,opts):
     try:
       self._actions,p,f1,f2 = opts
+      for func,stage in MACRO_CALLBACKS.items():
+        func = getattr(self,func)
+        syms = getattr(func,'_symlist',None)
+        if syms is not None:
+          self.setsymlist(stage,syms)
       opts[1] = self._protocol = p & ~self.protocol_mask()
       opts[2] = 0
       opts[3] = 0
@@ -443,23 +468,27 @@ class Base(object):
   # set.  The protocol stages are M_CONNECT, M_HELO, M_ENVFROM, M_ENVRCPT,
   # M_DATA, M_EOM, M_EOH.
   #
-  # May only be called from negotiate callback.
+  # May only be called from negotiate callback.  Hence, this is an advanced
+  # feature.  Use the @@symlist function decorator to conviently set
+  # the macros used by a callback.
   # @since 0.9.8, previous version was misspelled!
   # @param stage the protocol stage to set to macro list for, 
   # one of the M_* constants defined in Milter
   # @param macros space separated and/or lists of strings
   def setsymlist(self,stage,*macros):
     if not self._actions & SETSYMLIST: raise DisabledAction("SETSYMLIST")
+    if len(macros) > 5:
+      raise ValueError('setsymlist limited to 5 macros by MTA')
     a = []
     for m in macros:
       try:
         m = m.encode('utf8')
       except: pass
       try:
-        m = m.split(' ')
+        m = m.split(b' ')
+        a += m
       except: pass
-      a += m
-    return self._ctx.setsymlist(stage,' '.join(a))
+    return self._ctx.setsymlist(stage,b' '.join(a))
 
   # Milter methods which can only be called from eom callback.
 
diff --git a/Milter/test.py b/Milter/test.py
index f0dc58b..c20be63 100644
--- a/Milter/test.py
+++ b/Milter/test.py
@@ -40,6 +40,9 @@ class TestBase(object):
     self._reply = None
     ## The rfc822 message object for the current email being fed to the %milter.
     self._msg = None
+    ## The protocol stage for macros returned
+    self._stage = None
+    ## The macros returned by protocol stage
     self._symlist = [ None, None, None, None, None, None, None ]
 
   def log(self,*msg):
@@ -54,8 +57,12 @@ class TestBase(object):
     self._macros[name] = val
     
   def getsymval(self,name):
-    # FIXME: track stage, and use _symlist
-    return self._macros.get(name,'')
+    stage = self._stage
+    if stage >= 0:
+      syms = self._symlist[stage]
+      if syms is not None and name not in syms:
+        return None
+    return self._macros.get(name,None)
 
   def replacebody(self,chunk):
     if self._body:
@@ -113,7 +120,10 @@ class TestBase(object):
     self._reply = (rcode,xcode) + msg
 
   def setsymlist(self,stage,macros):
-    if not self._actions & SETSYMLIST: raise DisabledAction("SETSYMLIST")
+    if not self._actions & SETSYMLIST:
+      raise DisabledAction("SETSYMLIST")
+    if self._stage != -1:
+      raise RuntimeError("setsymlist may only be called from negotiate")
     # not used yet, but just for grins we save the data
     a = []
     for m in macros:
@@ -121,9 +131,13 @@ class TestBase(object):
         m = m.encode('utf8')
       except: pass
       try:
-        m = m.split(' ')
+        m = m.split(b' ')
       except: pass
       a += m
+    if len(a) > 5:
+      raise ValueError('setsymlist limited to 5 macros by MTA')
+    if self._symlist[stage] is not None:
+      raise ValueError('setsymlist already called for stage:'+stage)
     self._symlist[stage] = set(a)
 
   ## Feed a file like object to the %milter.  Calls envfrom, envrcpt for
@@ -144,16 +158,32 @@ class TestBase(object):
     self._reply = None
     self._sender = '<%s>'%sender
     msg = mime.message_from_file(fp)
+    # envfrom
+    self._stage = Milter.M_ENVFROM
     rc = self.envfrom(self._sender)
+    self._stage = None
     if rc != Milter.CONTINUE: return rc
+    # envrcpt
     for rcpt in (rcpt,) + rcpts:
+      self._stage = Milter.M_ENVRCPT
       rc = self.envrcpt('<%s>'%rcpt)
+      self._stage = None
       if rc != Milter.CONTINUE: return rc
+    # data
+    self._stage = Milter.M_DATA
+    rc = self.data()
+    self._stage = None
+    if rc != Milter.CONTINUE: return rc
+    # header
     for h,val in msg.items():
       rc = self.header(h,val)
       if rc != Milter.CONTINUE: return rc
+    # eoh
+    self._stage = Milter.M_EOH
     rc = self.eoh()
+    self._stage = None
     if rc != Milter.CONTINUE: return rc
+    # body
     header,body = msg.as_bytes().split(b'\n\n',1)
     bfp = BytesIO(body)
     while 1:
@@ -163,7 +193,9 @@ class TestBase(object):
       if rc != Milter.CONTINUE: return rc
     self._msg = msg
     self._body = BytesIO()
+    self._stage = Milter.M_EOM
     rc = self.eom()
+    self._stage = None
     if self._bodyreplaced:
       body = self._body.getvalue()
     self._body = BytesIO()
@@ -189,12 +221,17 @@ class TestBase(object):
     self._body = None
     self._bodyreplaced = False
     opts = [ Milter.CURR_ACTS,~0,0,0 ]
+    self._stage = -1
     rc = self.negotiate(opts)
+    self._stage = Milter.M_CONNECT
     rc =  super(TestBase,self).connect(host,1,(ip,1234)) 
     if rc != Milter.CONTINUE:
+      self._stage = None
       self.close()
       return rc
+    self._stage = Milter.M_HELO
     rc = self.hello(helo)
+    self._stage = None
     if rc != Milter.CONTINUE:
       self.close()
     return rc
diff --git a/sample.py b/sample.py
index 6c48296..4829397 100644
--- a/sample.py
+++ b/sample.py
@@ -33,18 +33,24 @@ class sampleMilter(Milter.Milter):
     self.fp = None
     self.bodysize = 0
     self.id = Milter.uniqueID()
+    self.user = None
 
   # multiple messages can be received on a single connection
   # envfrom (MAIL FROM in the SMTP protocol) seems to mark the start
   # of each message.
+  @Milter.symlist('{auth_authen}')
   @Milter.noreply
   def envfrom(self,f,*str):
     "start of MAIL transaction"
-    self.log("mail from",f,str)
     self.fp = BytesIO()
     self.tempname = None
     self.mailfrom = f
     self.bodysize = 0
+    self.user = self.getsymval('{auth_authen}')
+    if self.user:
+      self.log("user",self.user,"sent mail from",f,str)
+    else:
+      self.log("mail from",f,str)
     return Milter.CONTINUE
 
   def envrcpt(self,to,*str):
diff --git a/testsample.py b/testsample.py
index 6259257..2df8c0e 100644
--- a/testsample.py
+++ b/testsample.py
@@ -13,9 +13,14 @@ class BMSMilterTestCase(unittest.TestCase):
 
   def testDefang(self,fname='virus1'):
     milter = TestMilter()
+    milter.setsymval('{auth_authen}','batman')
+    milter.setsymval('{auth_type}','batcomputer')
+    milter.setsymval('j','mailhost')
     rc = milter.connect()
     self.failUnless(rc == Milter.CONTINUE)
     rc = milter.feedMsg(fname)
+    self.failUnless(milter.user == 'batman',"getsymval failed")
+    self.failUnless(milter.auth_type != 'batcomputer',"setsymlist failed")
     self.failUnless(rc == Milter.ACCEPT)
     self.failUnless(milter._bodyreplaced,"Message body not replaced")
     fp = milter._body
-- 
GitLab