Skip to content
Snippets Groups Projects
Commit 381e906b authored by Stuart D. Gathman's avatar Stuart D. Gathman
Browse files

Implement setsymlist decorator and test framework

parent 20727847
No related branches found
No related tags found
No related merge requests found
...@@ -48,6 +48,12 @@ OPTIONAL_CALLBACKS = { ...@@ -48,6 +48,12 @@ OPTIONAL_CALLBACKS = {
'header':(P_NR_HDR,P_NOHDRS) 'header':(P_NR_HDR,P_NOHDRS)
} }
MACRO_CALLBACKS = {
'connect': M_CONNECT,
'hello': M_HELO, 'envfrom': M_ENVFROM, 'envrcpt': M_ENVRCPT,
'data': M_DATA, 'eom': M_EOM, 'eoh': M_EOH
}
## @private ## @private
R = re.compile(r'%+') R = re.compile(r'%+')
...@@ -141,6 +147,7 @@ def nocallback(func): ...@@ -141,6 +147,7 @@ def nocallback(func):
except KeyError: except KeyError:
raise ValueError( raise ValueError(
'@nocallback applied to non-optional method: '+func.__name__) '@nocallback applied to non-optional method: '+func.__name__)
@wraps(func)
def wrapper(self,*args): def wrapper(self,*args):
if func(self,*args) != CONTINUE: if func(self,*args) != CONTINUE:
raise RuntimeError('%s return code must be CONTINUE with @nocallback' raise RuntimeError('%s return code must be CONTINUE with @nocallback'
...@@ -173,6 +180,19 @@ def noreply(func): ...@@ -173,6 +180,19 @@ def noreply(func):
wrapper.milter_protocol = nr_mask wrapper.milter_protocol = nr_mask
return wrapper return wrapper
## Function decorator to set macros used in a callback.
# By default, the MTA sends all macros defined for a callback.
# If some or all of these are unused, the bandwidth can be saved
# by listing the ones that are used.
# @since 1.0.2
def symlist(func,*syms):
if func.__name__ not in MACRO_CALLBACKS:
raise ValueError('@symlist applied to non-symlist method: '+func.__name__)
if len(syms) > 5:
raise ValueError('@symlist limited to 5 macros by MTA: '+func.__name__)
func._symlist = syms
return func
## Disabled action exception. ## Disabled action exception.
# set_flags() can tell the MTA that this application will not use certain # set_flags() can tell the MTA that this application will not use certain
# features (such as CHGFROM). This can also be negotiated for each # features (such as CHGFROM). This can also be negotiated for each
...@@ -393,6 +413,11 @@ class Base(object): ...@@ -393,6 +413,11 @@ class Base(object):
def negotiate(self,opts): def negotiate(self,opts):
try: try:
self._actions,p,f1,f2 = opts self._actions,p,f1,f2 = opts
for func,stage in MACRO_CALLBACKS.items():
func = getattr(self,func)
syms = getattr(func,'_symlist',None)
if syms is not None:
self.setsymlist(stage,syms)
opts[1] = self._protocol = p & ~self.protocol_mask() opts[1] = self._protocol = p & ~self.protocol_mask()
opts[2] = 0 opts[2] = 0
opts[3] = 0 opts[3] = 0
...@@ -443,23 +468,27 @@ class Base(object): ...@@ -443,23 +468,27 @@ class Base(object):
# set. The protocol stages are M_CONNECT, M_HELO, M_ENVFROM, M_ENVRCPT, # set. The protocol stages are M_CONNECT, M_HELO, M_ENVFROM, M_ENVRCPT,
# M_DATA, M_EOM, M_EOH. # M_DATA, M_EOM, M_EOH.
# #
# May only be called from negotiate callback. # May only be called from negotiate callback. Hence, this is an advanced
# feature. Use the @@symlist function decorator to conviently set
# the macros used by a callback.
# @since 0.9.8, previous version was misspelled! # @since 0.9.8, previous version was misspelled!
# @param stage the protocol stage to set to macro list for, # @param stage the protocol stage to set to macro list for,
# one of the M_* constants defined in Milter # one of the M_* constants defined in Milter
# @param macros space separated and/or lists of strings # @param macros space separated and/or lists of strings
def setsymlist(self,stage,*macros): def setsymlist(self,stage,*macros):
if not self._actions & SETSYMLIST: raise DisabledAction("SETSYMLIST") if not self._actions & SETSYMLIST: raise DisabledAction("SETSYMLIST")
if len(macros) > 5:
raise ValueError('setsymlist limited to 5 macros by MTA')
a = [] a = []
for m in macros: for m in macros:
try: try:
m = m.encode('utf8') m = m.encode('utf8')
except: pass except: pass
try: try:
m = m.split(' ') m = m.split(b' ')
except: pass
a += m a += m
return self._ctx.setsymlist(stage,' '.join(a)) except: pass
return self._ctx.setsymlist(stage,b' '.join(a))
# Milter methods which can only be called from eom callback. # Milter methods which can only be called from eom callback.
......
...@@ -40,6 +40,9 @@ class TestBase(object): ...@@ -40,6 +40,9 @@ class TestBase(object):
self._reply = None self._reply = None
## The rfc822 message object for the current email being fed to the %milter. ## The rfc822 message object for the current email being fed to the %milter.
self._msg = None self._msg = None
## The protocol stage for macros returned
self._stage = None
## The macros returned by protocol stage
self._symlist = [ None, None, None, None, None, None, None ] self._symlist = [ None, None, None, None, None, None, None ]
def log(self,*msg): def log(self,*msg):
...@@ -54,8 +57,12 @@ class TestBase(object): ...@@ -54,8 +57,12 @@ class TestBase(object):
self._macros[name] = val self._macros[name] = val
def getsymval(self,name): def getsymval(self,name):
# FIXME: track stage, and use _symlist stage = self._stage
return self._macros.get(name,'') if stage >= 0:
syms = self._symlist[stage]
if syms is not None and name not in syms:
return None
return self._macros.get(name,None)
def replacebody(self,chunk): def replacebody(self,chunk):
if self._body: if self._body:
...@@ -113,7 +120,10 @@ class TestBase(object): ...@@ -113,7 +120,10 @@ class TestBase(object):
self._reply = (rcode,xcode) + msg self._reply = (rcode,xcode) + msg
def setsymlist(self,stage,macros): def setsymlist(self,stage,macros):
if not self._actions & SETSYMLIST: raise DisabledAction("SETSYMLIST") if not self._actions & SETSYMLIST:
raise DisabledAction("SETSYMLIST")
if self._stage != -1:
raise RuntimeError("setsymlist may only be called from negotiate")
# not used yet, but just for grins we save the data # not used yet, but just for grins we save the data
a = [] a = []
for m in macros: for m in macros:
...@@ -121,9 +131,13 @@ class TestBase(object): ...@@ -121,9 +131,13 @@ class TestBase(object):
m = m.encode('utf8') m = m.encode('utf8')
except: pass except: pass
try: try:
m = m.split(' ') m = m.split(b' ')
except: pass except: pass
a += m a += m
if len(a) > 5:
raise ValueError('setsymlist limited to 5 macros by MTA')
if self._symlist[stage] is not None:
raise ValueError('setsymlist already called for stage:'+stage)
self._symlist[stage] = set(a) self._symlist[stage] = set(a)
## Feed a file like object to the %milter. Calls envfrom, envrcpt for ## Feed a file like object to the %milter. Calls envfrom, envrcpt for
...@@ -144,16 +158,32 @@ class TestBase(object): ...@@ -144,16 +158,32 @@ class TestBase(object):
self._reply = None self._reply = None
self._sender = '<%s>'%sender self._sender = '<%s>'%sender
msg = mime.message_from_file(fp) msg = mime.message_from_file(fp)
# envfrom
self._stage = Milter.M_ENVFROM
rc = self.envfrom(self._sender) rc = self.envfrom(self._sender)
self._stage = None
if rc != Milter.CONTINUE: return rc if rc != Milter.CONTINUE: return rc
# envrcpt
for rcpt in (rcpt,) + rcpts: for rcpt in (rcpt,) + rcpts:
self._stage = Milter.M_ENVRCPT
rc = self.envrcpt('<%s>'%rcpt) rc = self.envrcpt('<%s>'%rcpt)
self._stage = None
if rc != Milter.CONTINUE: return rc if rc != Milter.CONTINUE: return rc
# data
self._stage = Milter.M_DATA
rc = self.data()
self._stage = None
if rc != Milter.CONTINUE: return rc
# header
for h,val in msg.items(): for h,val in msg.items():
rc = self.header(h,val) rc = self.header(h,val)
if rc != Milter.CONTINUE: return rc if rc != Milter.CONTINUE: return rc
# eoh
self._stage = Milter.M_EOH
rc = self.eoh() rc = self.eoh()
self._stage = None
if rc != Milter.CONTINUE: return rc if rc != Milter.CONTINUE: return rc
# body
header,body = msg.as_bytes().split(b'\n\n',1) header,body = msg.as_bytes().split(b'\n\n',1)
bfp = BytesIO(body) bfp = BytesIO(body)
while 1: while 1:
...@@ -163,7 +193,9 @@ class TestBase(object): ...@@ -163,7 +193,9 @@ class TestBase(object):
if rc != Milter.CONTINUE: return rc if rc != Milter.CONTINUE: return rc
self._msg = msg self._msg = msg
self._body = BytesIO() self._body = BytesIO()
self._stage = Milter.M_EOM
rc = self.eom() rc = self.eom()
self._stage = None
if self._bodyreplaced: if self._bodyreplaced:
body = self._body.getvalue() body = self._body.getvalue()
self._body = BytesIO() self._body = BytesIO()
...@@ -189,12 +221,17 @@ class TestBase(object): ...@@ -189,12 +221,17 @@ class TestBase(object):
self._body = None self._body = None
self._bodyreplaced = False self._bodyreplaced = False
opts = [ Milter.CURR_ACTS,~0,0,0 ] opts = [ Milter.CURR_ACTS,~0,0,0 ]
self._stage = -1
rc = self.negotiate(opts) rc = self.negotiate(opts)
self._stage = Milter.M_CONNECT
rc = super(TestBase,self).connect(host,1,(ip,1234)) rc = super(TestBase,self).connect(host,1,(ip,1234))
if rc != Milter.CONTINUE: if rc != Milter.CONTINUE:
self._stage = None
self.close() self.close()
return rc return rc
self._stage = Milter.M_HELO
rc = self.hello(helo) rc = self.hello(helo)
self._stage = None
if rc != Milter.CONTINUE: if rc != Milter.CONTINUE:
self.close() self.close()
return rc return rc
...@@ -33,18 +33,24 @@ class sampleMilter(Milter.Milter): ...@@ -33,18 +33,24 @@ class sampleMilter(Milter.Milter):
self.fp = None self.fp = None
self.bodysize = 0 self.bodysize = 0
self.id = Milter.uniqueID() self.id = Milter.uniqueID()
self.user = None
# multiple messages can be received on a single connection # multiple messages can be received on a single connection
# envfrom (MAIL FROM in the SMTP protocol) seems to mark the start # envfrom (MAIL FROM in the SMTP protocol) seems to mark the start
# of each message. # of each message.
@Milter.symlist('{auth_authen}')
@Milter.noreply @Milter.noreply
def envfrom(self,f,*str): def envfrom(self,f,*str):
"start of MAIL transaction" "start of MAIL transaction"
self.log("mail from",f,str)
self.fp = BytesIO() self.fp = BytesIO()
self.tempname = None self.tempname = None
self.mailfrom = f self.mailfrom = f
self.bodysize = 0 self.bodysize = 0
self.user = self.getsymval('{auth_authen}')
if self.user:
self.log("user",self.user,"sent mail from",f,str)
else:
self.log("mail from",f,str)
return Milter.CONTINUE return Milter.CONTINUE
def envrcpt(self,to,*str): def envrcpt(self,to,*str):
......
...@@ -13,9 +13,14 @@ class BMSMilterTestCase(unittest.TestCase): ...@@ -13,9 +13,14 @@ class BMSMilterTestCase(unittest.TestCase):
def testDefang(self,fname='virus1'): def testDefang(self,fname='virus1'):
milter = TestMilter() milter = TestMilter()
milter.setsymval('{auth_authen}','batman')
milter.setsymval('{auth_type}','batcomputer')
milter.setsymval('j','mailhost')
rc = milter.connect() rc = milter.connect()
self.failUnless(rc == Milter.CONTINUE) self.failUnless(rc == Milter.CONTINUE)
rc = milter.feedMsg(fname) rc = milter.feedMsg(fname)
self.failUnless(milter.user == 'batman',"getsymval failed")
self.failUnless(milter.auth_type != 'batcomputer',"setsymlist failed")
self.failUnless(rc == Milter.ACCEPT) self.failUnless(rc == Milter.ACCEPT)
self.failUnless(milter._bodyreplaced,"Message body not replaced") self.failUnless(milter._bodyreplaced,"Message body not replaced")
fp = milter._body fp = milter._body
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment