#!/usr/bin/python
"""
vrfdmn - a milter service for postfix
Copyright (C) 2014  R.N.S.

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.

"""
from __future__ import print_function

import sys
import os
import pwd
import grp
import signal
import Milter
import threading
import traceback
import argparse

try:
    # noinspection PyUnresolvedReferences
    import ldap
    # noinspection PyUnresolvedReferences
    import ldap.sasl
    # noinspection PyUnresolvedReferences
    from ldap.ldapobject import ReconnectLDAPObject
except ImportError:
    no_module_ldap = True
else:
    no_module_ldap = False

try:
    # noinspection PyUnresolvedReferences,PyPep8Naming
    import MySQLdb as mdb
except ImportError:
    no_module_sql = True
else:
    no_module_sql = False

try:
    # noinspection PyUnresolvedReferences
    import memcache
except ImportError:
    no_module_memcache = True
else:
    no_module_memcache = False

try:
    # noinspection PyUnresolvedReferences
    import setproctitle
    setproctitle.setproctitle("vrfydmn")
except ImportError:
    pass


from syslog import *
from syslog import syslog as syslog
from getopt import getopt
from email.header import decode_header
from email.utils import parseaddr
from threading import Thread
from Queue import Queue


NAME = "vrfydmn"

# Defaults
BINDADDR = "[::1]"
PORT = 30072
MILTERUSER = "milter"
MILTERGROUP = "milter"
LDAP_TIMEOUT = 60
MAX_SQL_CONNECT_RETRIES = 3
VERSION = "0.9.1"

__version__ = VERSION
__author__ = "Christian Roessner <c@roessner.co>"
__copyright__ = "Copyright (C) 2011  R.N.S."


# noinspection PyUnresolvedReferences
class Cfg(object):
    """Helper class for some configuration parameters"""

    action = Milter.REJECT
    hold = False
    workerQueue = Queue()

    ldap_con = None
    sql_con = None
    memcached_con = None


# noinspection PyIncorrectDocstring,PyUnresolvedReferences
class VrfyDmnMilter(Milter.Base):
    """
    Milter that comares the domain component of an email address read from the
    From:-header and compares it to a list of Postfix domains. If a match is
    found, mail is allowed to pass, else the mail is rejected.
    """
    
    def __init__(self):
        self.__id = Milter.uniqueID()
        self.__ipname = None
        self.__ip = None
        self.__port = None
        self.__reject = False
        self.__dryrun_reject = False
        self.__email = ""
        self.__mail_from = ""
        self.__original_from = ""
        self.__add_header = True
        self.__has_sender = False
        self.__header = None

    @Milter.noreply
    def connect(self, ipname, family, hostaddr):
        """connect callback """

        self.__ip = hostaddr[0]
        self.__ipname = ipname
        self.__port = hostaddr[1]

        if config.debug:
            print("id=%i connect from %s[%s]:%s" % (self.__id,
                                                    self.__ipname,
                                                    self.__ip,
                                                    self.__port))

        return Milter.CONTINUE

    @Milter.noreply
    def envfrom(self, mailfrom, *dummy):
        """Callback that is called when MAIL FROM: is recognized. This also
        is the most earliest time, where we can collect nearly all connection
        specific information.
        """

        self.__mail_from = parseaddr(mailfrom)[1]
        self.__header = dict()

        return Milter.CONTINUE

    @Milter.noreply
    def header(self, name, hval):
        """header callback gets called for each header
        """
        if config.debug:
            print("%s: %s" % (name, hval))

        if name.lower() == "from":
            self.__original_from = hval
            self.__header[name.lower()] = hval

        # Mailinglists...
        if name.lower() == "sender":
            self.__header[name.lower()] = hval

        if name.lower() == "reply-to":
            self.__add_header = False

        return Milter.CONTINUE

    def eoh(self):
        """eoh - end of header. Gets called after all headers have been
        proccessed"""

        # We must check for a Sender header field
        if "sender" in self.__header:
            hval = "sender"
        else:
            hval = "from"

        # Extract email from most right tuple
        decoded_from = decode_header(self.__header[hval])[-1]

        # Try to find the email address and to cut off garbage
        decoded_from_parts = decoded_from[0].split()
        email = ""
        for component in iter(decoded_from_parts):
            # NOTE: Asume the last occurence of an "@" represents the
            # email address.
            # NOTE: RFC5322 allows a mailbox-list for the From field!
            # Currently this fact is ignored!
            if "@" in component:
                email = component

        email = parseaddr(email)[1]

        self.__email = email

        # From: <> found, skip this mail
        if email == "":
            if config.debug:
                print("id=%i %s return_value=skip"
                      % (self.__id, self.getsymval('i')))
            syslog(LOG_INFO, "%s: return_value=skip" % self.getsymval('i'))

            return Milter.CONTINUE

        # Cut local part from email
        _from_domain = email.split("@")
        if len(_from_domain) == 1:
            if config.debug:
                print("id=%i %s unhandled=<%s> return_value=skip"
                      % (self.__id, self.getsymval('i'), email))
            syslog(LOG_INFO,
                   "%s: unhandled=<%s> return_value=skip" % (
                    self.getsymval('i'), email))

            return Milter.CONTINUE
        else:
            from_domain = _from_domain[1]

        # Remember, if a domain was found
        found = False

        if config.fix:
            self.__dryrun_reject = True
        elif config.opposite:
            self.__reject = False
        else:
            self.__reject = True

        if config.ldap:
            response = Queue()
            Cfg.workerQueue.put((response,
                                 Cfg.ldap_con.query_ldap,    # func
                                 (from_domain if not config.email else
                                  email,)))                  # *args
            result = response.get()
            if result is True or result is None:
                found = True
                if config.fix:
                    self.__dryrun_reject = False
                elif config.opposite:
                    self.__reject = True
                else:
                    self.__reject = False

        if not found and config.sql:
            response = Queue()
            Cfg.workerQueue.put((response,
                                 Cfg.sql_con.query_sql,     # func
                                 (from_domain if not config.email else
                                  email,)))                 # *args
            result = response.get()
            if result is True or result is None:
                found = True
                if config.fix:
                    self.__dryrun_reject = False
                elif config.opposite:
                    self.__reject = True
                else:
                    self.__reject = False

        if not found and config.file:
            if PfDomains.domains:
                # Honor sub domains
                for key_domain in iter(PfDomains.domains):
                    if key_domain in from_domain:
                        if config.fix:
                            self.__dryrun_reject = False
                        elif config.opposite:
                            self.__reject = True
                        else:
                            self.__reject = False
                        break

        result = "continue"
        if self.__reject:
            if Cfg.action == Milter.CONTINUE:
                if Cfg.hold:
                    result = "quarantine"
                else:
                    result = "continue"
            elif Cfg.action == Milter.REJECT:
                result = "reject"

        if config.debug:
            print("id=%i %s header_from=<%s> mail_from=<%s> return_value=%s"
                  % (self.__id,
                     self.getsymval('i'),
                     email,
                     self.__mail_from,
                     result))
        syslog(LOG_INFO, "%s: header_from=<%s> mail_from=<%s> return_value=%s"
               % (self.getsymval('i'), email, self.__mail_from, result))

        if self.__reject:
            if Cfg.action == Milter.REJECT:
                self.setreply("554", xcode="5.7.0", msg="Reject Queue-ID: %s - "
                              "RFC5322 from address: <%s>"
                              % (self.getsymval('i'), self.__email))

                return Milter.REJECT

        return Milter.CONTINUE

    def eom(self):
        """eom - end of message. If --fix was given at the command line, we
        replace the broken From:-header with the MAIL FROM value"""

        if self.__reject and Cfg.hold:
            self.quarantine("%s: header_from=<%s> mail_from=<%s>"
                            % (self.getsymval("i"),
                               self.__email,
                               self.__mail_from))

        if self.__dryrun_reject and \
                self.__email != self.__mail_from:

            self.chgheader("From", 0, "<%s>" % self.__mail_from)
            if config.debug:
                print("id=%i %s header_from=<%s> mail_from=<%s>"
                      % (self.__id,
                         self.getsymval('i'),
                         self.__email,
                         self.__mail_from))
            syslog(LOG_INFO, "%s: header_from=<%s> mail_from=<%s>"
                   % (self.getsymval('i'), self.__email, self.__mail_from))

            if self.__add_header:
                self.addheader("Reply-To", self.__original_from)
                decoded_from = decode_header(self.__original_from)
                new_from = " ".join([s for s, _ in decoded_from])
                if config.debug:
                    print("id=%i %s reply_to: %s"
                          % (self.__id, self.getsymval('i'), new_from))
                syslog(LOG_INFO, "%s: reply_to: %s"
                       % (self.getsymval('i'), new_from))

        return Milter.CONTINUE

    def close(self):
        """close callback"""

        if config.debug:
            print("id=%i disconnect from %s[%s]:%s" % (self.__id,
                                                       self.__ipname,
                                                       self.__ip,
                                                       self.__port))

        return Milter.CONTINUE


# noinspection PyMethodParameters,PyShadowingNames
class MetaPfDomains(type):
    """
    PfDomains is a central store for all postfix domains that the milter
    recognizes as trusted domains
    """

    _domains = list()

    __lock = threading.Lock()

    def _set_postfix_domains(meta, pf_file):
        generated_list = list()

        try:
            with open(pf_file) as fd:
                while True:
                    raw_line = fd.readline()
                    if raw_line == "":
                        break

                    line = raw_line.strip()

                    # Skip comments and empty lines
                    if line.startswith("#"):
                        continue
                    if line == "":
                        continue

                    key_domain = line.split()[0]
                    generated_list.append(key_domain)

                    # Make operation thread safe
                    with MetaPfDomains.__lock:
                        meta._domains = generated_list

        except OSError as e:
            # Unable to read Postfix domains!
            if config.debug:
                print('Unable to read %s: %s' % (pf_file, e), file=sys.stderr)
            syslog(LOG_ERR, 'Unable to read %s: %s' % (pf_file, e))

    def _get_postfix_domains(meta):
        return meta._domains

    domains = property(_get_postfix_domains, _set_postfix_domains)


class PfDomains(object):
    """We use a meta class, as the "domains" variable shall not be world
    readable and writeable. This is a class with classmethods and properties
    """

    __metaclass__ = MetaPfDomains


# noinspection PyShadowingNames
class Domains(object):
    """
    Base mixin class for all kinds of lists of domains. Each database driver
    must derive from this base class. The constructor requires at least a
    configuration file, which is directly parsed and evaluated. The result is
    a dictionary that is stored internally. Each driver itself knows how to
    deal with the values found.
    
    As the base class implements an iterator and a representation method, the
    values being processed by each driver are returned as a list of domains.
    
    We do not keep database connections open, as a list of domains normally
    does not float too often. Additionally the amount of domains will never be
    too large, so we can read all data at once and keep it in memory.
    
    """
    def __init__(self, cffile):
        self._filecontent = {}
        self._cf_read_err = False

        try:
            with open(cffile, "r") as fd:
                while True:
                    line = fd.readline()
                    if line == "":
                        break
                    if line.lstrip().startswith("#"):
                        continue
                    if line.strip() == "":
                        continue
                    if line.count('=') >= 1:
                        idx = line.find('=')
                        k = line[0:idx].strip().lower()
                        v = line[idx+1:].strip()
                    else:
                        k = line.strip().lower()
                        v = []
                    if k != "":
                        self._filecontent[k] = v

        except Exception as e:
            # Unable to read config file
            if config.debug:
                print('Unable to read %s: %s' % (cffile, e), file=sys.stderr)
            syslog(LOG_ERR, 'Unable to read %s: %s' % (cffile, e))
            self._cf_read_err = True


# noinspection PyShadowingNames
class LDAPDomains(Domains):
    """
    The LDAPDomains driver connects to a list of LDAP servers. It supports
    simple and SASL authentication, as well as TLS connections.
    
    """
    def __init__(self, cffile):
        Domains.__init__(self, cffile)

        self.__con = None

        if self._cf_read_err:
            return

        self.__host = ["ldap://127.0.0.1/"]
        self.__base = ""
        self.__bindmethod = "simple"
        self.__binddn = None
        self.__bindpw = None
        self.__saslmech = None
        self.__authz_id = ""
        self.__filter = "(objectClass=*)"
        self.__result_attrs = []
        self.__scope = "sub"
        self.__usetls = "no"
        self.__cipher = "TLSv1"
        self.__reqcert = "never"
        self.__cert = None
        self.__key = None
        self.__cacert = None
        
        for k, v in self._filecontent.items():
            if k == "host":
                self.__host = v.split(',')
                for idx, server in enumerate(self.__host):
                    self.__host[idx] = server.strip()
            elif k == "base":
                self.__base = v
            elif k == "bindmethod":
                self.__bindmethod = v
            elif k == "binddn":
                self.__binddn = v
            elif k == "bindpw":
                self.__bindpw = v
            elif k == "saslmech":
                self.__saslmech = v
            elif k == "authzid":
                self.__authz_id = v
            elif k == "filter":
                self.__filter = v
            elif k == "result_attrs":
                self.__result_attrs = v.split(',')
                for idx, attr in enumerate(self.__result_attrs):
                    self.__result_attrs[idx] = attr.strip()
            elif k == "scope":
                self.__scope = v
            elif k == "usetls":
                self.__usetls = v
            elif k == "cipher":
                self.__cipher = v
            elif k == "reqcert":
                self.__reqcert = v
            elif k == "cert":
                self.__cert = v
            elif k == "key":
                self.__key = v
            elif k == "cacert":
                self.__cacert = v
            else:
                raise Exception("Unsupported parameter %s: %s" % (k, v))
        
        tls = False
        sasl = False

        # Do we connect with TLS?
        reqcert = None
        if self.__usetls.lower() in ("yes", "true", "1"):
            if self.__reqcert in ("never", "allow", "try", "demand"):
                if self.__reqcert == "never":
                    reqcert = ldap.OPT_X_TLS_NEVER
                elif self.__reqcert == "allow":
                    reqcert = ldap.OPT_X_TLS_ALLOW
                elif self.__reqcert == "try":
                    reqcert = ldap.OPT_X_TLS_TRY
                elif self.__reqcert == "demand":
                    reqcert = ldap.OPT_X_TLS_DEMAND
            else:
                raise Exception("Unsupported TLS reqcert Option %s" %
                                self.__reqcert)
            ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, reqcert)
            ldap.set_option(ldap.OPT_X_TLS_CIPHER_SUITE, self.__cipher)
            if self.__cacert is not None:
                ldap.set_option(ldap.OPT_X_TLS_CACERTFILE, self.__cacert)
            if self.__cert is not None:
                ldap.set_option(ldap.OPT_X_TLS_CERTFILE, self.__cert)
            if self.__key is not None:
                ldap.set_option(ldap.OPT_X_TLS_KEYFILE, self.__key)
            tls = True

        # Are we SASL binding to our servers?
        auth_tokens = None
        if self.__bindmethod == "sasl":
            if self.__saslmech.lower() == "digest-md5":
                auth_tokens = ldap.sasl.digest_md5(self.__binddn,
                                                   self.__bindpw)
            elif self.__saslmech.lower() == "cram-md5":
                auth_tokens = ldap.sasl.cram_md5(self.__binddn, self.__bindpw)
            elif self.__saslmech.lower() == "external":
                auth_tokens = ldap.sasl.external(self.__authz_id)
            elif self.__saslmech.lower() == "gssapi":
                auth_tokens = ldap.sasl.gssapi(self.__authz_id)
            else:
                raise Exception("Unsupported SASL mech %s" % self.__saslmech)
            sasl = True

        con = None
        for server in iter(self.__host):
            try:
                con = ReconnectLDAPObject(server, retry_max=1000000)
                if tls:
                    con.start_tls_s()
                if sasl:
                    con.sasl_interactive_bind_s("", auth_tokens)
                else:
                    con.simple_bind_s(self.__binddn, self.__bindpw)
            except:
                if config.debug:
                    traceback.print_exc()
                continue
            break

        self.__con = con
    
    # This method is synchronized with Queue()
    def query_ldap(self, from_domain):
        if self.__con is None:
            return

        if config.memcached:
            result = Cfg.memcached_con.get(from_domain)
            if result is not None:
                result = bool(result)
                if result:
                    text = "continue"
                else:
                    text = "reject"
                if config.debug:
                    print("cached_result: %s=%s" % (from_domain, text))
                syslog(LOG_INFO, "cached_result: %s=%s" % (from_domain, text))
                return result

        if self.__scope in ("sub", "subtree"):
            scope = ldap.SCOPE_SUBTREE
        elif self.__scope in ("one", "onelevel"):
            scope = ldap.SCOPE_ONELEVEL
        elif self.__scope in ("base", "exact"):
            scope = ldap.SCOPE_BASE
        else:
            raise Exception("Unsupported LDAP scope %s" % self.__scope)

        ldap_filter = ""
        if "%s" in self.__filter.lower():
            filter_parts = self.__filter.split("%s")
            ldap_filter = from_domain.join(filter_parts)

        if config.debug:
            print("====> LDAP-filter: %s" % ldap_filter)

        result = self.__con.search_st(self.__base,
                                      scope,
                                      ldap_filter,
                                      self.__result_attrs,
                                      timeout=LDAP_TIMEOUT)

        # Temporarily store the result
        tmp_lst = list()

        # result - list of tuples
        for row in iter(result):
            for res_attrs in row[1].values()[0]:
                tmp_lst.append(res_attrs)
                if config.debug:
                    print("----> res_attrs = %s" % res_attrs)
                    
        if tmp_lst:
            if config.memcached:
                Cfg.memcached_con.set(from_domain, True, time=3600)
            return True
        else:
            if config.memcached:
                Cfg.memcached_con.set(from_domain, False, time=60)
            return False


# noinspection PyShadowingNames
class SQLDomains(Domains):
    """SQL class. Currently only MySQL ist supported
    """
    def __init__(self, cffile):
        Domains.__init__(self, cffile)

        self.__con = None

        if self._cf_read_err:
            return

        self.__host = "localhost"
        self.__dbname = None
        self.__dbuser = None
        self.__dbpass = None
        self.__query = None

        for k, v in self._filecontent.items():
            if k == "host":
                self.__host = v
            elif k == "port":
                self.__port = int(v)
            elif k == "dbuser":
                self.__dbuser = v
            elif k == "dbpass":
                self.__dbpass = v
            elif k == "dbname":
                self.__dbname = v
            elif k == "query":
                self.__query = v

        # Initially connect, retry on errors; see below
        self.connect()

    def connect(self):
        con = None
        try:
            con = mdb.connect(self.__host,
                              self.__dbuser,
                              self.__dbpass,
                              self.__dbname)
        except:
            if config.debug:
                traceback.print_exc()

        self.__con = con

    def query_sql(self, from_domain):
        if self.__con is None:
            return None
        if self.__query is None:
            return None

        if config.memcached:
            result = Cfg.memcached_con.get(from_domain)
            if result is not None:
                result = bool(result)
                if result:
                    text = "continue"
                else:
                    text = "reject"
                if config.debug:
                    print("cached_result: %s=%s" % (from_domain, text))
                syslog(LOG_INFO, "cached_result: %s=%s" % (from_domain, text))
                return result

        for retries in xrange(MAX_SQL_CONNECT_RETRIES):
            try:
                cur = self.__con.cursor()
                cur.execute(self.__query % from_domain)
                break
            except Exception as e:
                if e[0] == 2006:
                    # Lost connection, try reconnect
                    syslog(LOG_ERR,
                           "SQL connection lost ({}/{}): code={} msg={}".format(
                            retries+1, MAX_SQL_CONNECT_RETRIES, e[0], e[1]))
                    self.connect()
                else:
                    syslog(LOG_ERR, "SQL error: %s" % str(e))
                    return None
        else:
            return None

        domains = cur.fetchall()
        if config.debug:
            print("SQL result for %s: %s" % (from_domain, str(domains)))

        # Domain was found on SQL server, count > 0
        if len(domains) > 0:
            if config.memcached:
                Cfg.memcached_con.set(from_domain, True, time=3600)
            return True
        # ... not found
        else:
            if config.memcached:
                Cfg.memcached_con.set(from_domain, False, time=3600)
            return False


# noinspection PyUnresolvedReferences
def runner():
    """Starts the milter loop"""

    Milter.factory = VrfyDmnMilter

    flags = Milter.CHGHDRS + Milter.ADDHDRS + Milter.QUARANTINE
    Milter.set_flags(flags)

    Milter.runmilter(NAME, config.socket, timeout=300)


# noinspection PyShadowingNames
def db_runner():
    """Implements a bi-directional queue that allows all milter threads to
    communicate with one single thread that does all DB operations.

    """
    db_running = False

    if config.memcached:
        Cfg.memcached_con = memcache.Client([config.memcached], debug=0)

    if config.ldap:
        Cfg.ldap_con = LDAPDomains(config.ldap)
        if Cfg.ldap_con is not None:
            db_running = True

    if config.sql:
        Cfg.sql_con = SQLDomains(config.sql)
        if Cfg.sql_con is not None:
            db_running = True

    if db_running:
        while True:
            req = Cfg.workerQueue.get()
            if not req:
                # Nothing more to process; skip
                continue
            # Queue, func, *args
            response, func, args = req
            try:
                result = func(*args)
            except:
                traceback.print_exc()
                result = None
            response.put(result)


# noinspection PyProtectedMember,PyUnresolvedReferences
if __name__ == "__main__":
    parser = argparse.ArgumentParser(epilog="vrfydmn - verify domain milter")

    parser.add_argument("--socket", "-s",
                        type=str,
                        default="inet6:{0}@{1}".format(PORT, BINDADDR),
                        help="IPv4, IPv6 or unix socket (default: %(default)s)")
    parser.add_argument("--syslog_name", "-n",
                        type=str,
                        default=NAME,
                        help="Syslog name (default: %(default)s)")
    parser.add_argument("--user", "-u",
                        type=str,
                        default=MILTERUSER,
                        help="Run milter as this user (default: %(default)s)")
    parser.add_argument("--group", "-g",
                        type=str,
                        default=MILTERGROUP,
                        help="Run milter with this group "
                             "(default: %(default)s)")
    parser.add_argument("--pid", "-p",
                        type=str,
                        default=None,
                        help="Path for an optional PID file")
    parser.add_argument("--debug", "-d",
                        default=False,
                        action="store_true",
                        help="Run in foreground with debugging turned on")
    parser.add_argument("--file", "-f",
                        type=str,
                        default=None,
                        help="Postfix domains map file")
    parser.add_argument("--ldap", "-l",
                        type=str,
                        default=None,
                        help="Config file for a LDAP connection")
    parser.add_argument("--sql", "-S",
                        type=str,
                        default=None,
                        help="Config file for a SQL connection (Currently "
                             "MySQL only)")
    parser.add_argument("--memcached", "-m",
                        type=str,
                        default=None,
                        help="Memcache socket")
    parser.add_argument("--fix", "-F",
                        default=False,
                        action="store_true",
                        help="Replace broken From:-header with envelope sender")
    parser.add_argument("--opposite", "-O",
                        default=False,
                        action="store_true",
                        help="Reject mail, if a sender uses our domain")
    parser.add_argument("--action", "-a",
                        default="reject",
                        choices=["accept", "reject", "quarantine"],
                        help="If test fails: accept, reject or quarantine. "
                             "The --fix option implies 'accept' "
                             "(default: %(default)s)")
    parser.add_argument("--email", "-e",
                        default=False,
                        action="store_true",
                        help="Use email as search key, not only the domain")

    config = parser.parse_args()

    if config.file:
        if not os.path.exists(config.file):
            print('No such file: %s' % config.file, file=sys.stderr)
            sys.exit(os.EX_USAGE)
        # Read list of domains
        PfDomains.domains = config.file

    if config.ldap:
        if no_module_ldap:
            config.ldap = None
            print("Missing python module ldap!", file=sys.stderr)

    if config.sql:
        if no_module_sql:
            config.sql = None
            print("Missing python module for SQL!", file=sys.stderr)

    if not (config.file or config.ldap or config.sql):
        print("You must specify at least one of --file, --ldap or --sql",
              file=sys.stderr)
        sys.exit(1)

    if config.memcached:
        if no_module_memcache:
            config.memcached = None
            print("Missing python module memcache!", file=sys.stderr)

    if config.action:
        if config.action == "accept":
            Cfg.action = Milter.CONTINUE
        elif config.action == "reject":
            Cfg.action = Milter.REJECT
        elif config.action == "quarantine":
            Cfg.action = Milter.CONTINUE
            Cfg.hold = True

    if config.fix and config.opposite:
        print("Do not use --fix and --opposite together", file=sys.stderr)
        sys.exit(os.EX_USAGE)

    if config.fix:
        Cfg.action = Milter.CONTINUE
        Cfg.hold = False

    # We want unbuffered stdout and stderr
    unbuf_stdout = os.fdopen(sys.stdout.fileno(), 'w', 0)
    unbuf_stderr = os.fdopen(sys.stderr.fileno(), 'w', 0)
    sys.stdout = unbuf_stdout
    sys.stderr = unbuf_stderr
    
    openlog(config.syslog_name, LOG_PID, LOG_MAIL)

    try:
        uid = pwd.getpwnam(config.user)[2]
        gid = grp.getgrnam(config.group)[2]
    except KeyError as e:
        print("User or group not known: {0}".format(e.message), file=sys.stderr)
        sys.exit(1)

    try:
        # Needs Python >=2.7
        os.initgroups(config.user, gid)
    except:
        pass

    try:
        os.setgid(gid)
    except OSError as e:
        print('Could not set effective group id: %s' % e, file=sys.stderr)
        sys.exit(1)
    try:
        os.setuid(uid)
    except OSError as e:
        print('Could not set effective user id: %s' % e, file=sys.stderr)
        sys.exit(1)

    if config.debug:
        print("Switched user to %s, group to %s" % (uid, gid))
        print("Staying in foreground...")
    else:
        try:
            pid = os.fork()
        except OSError as e:
            print("First fork failed: (%d) %s" % (e.errno, e.strerror),
                  file=sys.stderr)
            sys.exit(1)
        if pid == 0:
            os.setsid()
            try:
                pid = os.fork()
            except OSError as e:
                print("Second fork failed: (%d) %s" % (e.errno, e.strerror),
                      file=sys.stderr)
                sys.exit(1)
            if pid == 0:
                os.chdir("/")
                os.umask(0)
            else:
                # noinspection PyProtectedMember
                os._exit(0)
        else:
            # noinspection PyProtectedMember
            os._exit(0)
    
        # In daemon mode, we redirect stdin, stdout and stderr to /dev/null   
        sys.stdin = open(os.devnull, "r").fileno()
        sys.stdout = open(os.devnull, "w").fileno()
        sys.stderr = open(os.devnull, "w").fileno()
    
    try:
        if config.pid:
            with open(config.pid, "w") as fd:
                fd.write(str(os.getpid()))
    except IOError as e:
        if config.debug:
            print("Cannot create PID file: (%d) %s" % (e.errno, e.strerror),
                  file=sys.stderr)

    def finish(signum, frame):
        _ = frame
        syslog(LOG_NOTICE,
               "%s-%s milter shutdown. Caught signal %d"
               % (NAME, VERSION, signum))
    
    def reload_postfix_domains(signum, frame):
        _ = frame
        if config.debug:
            print("%s-%s milter reload Postfix domains. Caught signal %d"
                  % (NAME, VERSION, signum))
        syslog(LOG_NOTICE,
               "%s-%s milter reload Postfix domains. Caught signal %d"
               % (NAME, VERSION, signum))

        if config.file:
            PfDomains.domains = config.file

        # Go back to sleep
        signal.pause()

    def print_postfix_domains(signum, frame):
        _ = signum, frame
        if config.file:
            # noinspection PyTypeChecker
            all_domains = ", ".join(PfDomains.domains)

            # max syslog line length
            offset = 1536
            str_sgmts = []
            max_len = len(all_domains)
            d = "..."

            for i in xrange(0, max_len, offset):
                if i == 0:
                    str_sgmts.append(all_domains[i:i+offset].strip() + d)
                if i > 0 and i+offset < max_len:
                    str_sgmts.append(d + all_domains[i:i+offset].strip() + d)
                if i > 0 and i+offset >= max_len:
                    str_sgmts.append(d + all_domains[i:i+offset].strip())

            for part in iter(str_sgmts):
                if config.debug:
                    print("%s-%s milter Postfix domains: [%s]"
                          % (NAME, VERSION, part))
                syslog(LOG_NOTICE,
                       "%s-%s milter Postfix domains: [%s]"
                       % (NAME, VERSION, part))

        # Go back to sleep
        signal.pause()

    signal.signal(signal.SIGINT, finish)
    signal.signal(signal.SIGQUIT, finish)
    signal.signal(signal.SIGTERM, finish)

    signal.signal(signal.SIGHUP, reload_postfix_domains)
    signal.signal(signal.SIGUSR1, print_postfix_domains)
    signal.siginterrupt(signal.SIGHUP, False)
    signal.siginterrupt(signal.SIGUSR1, False)
    
    syslog(LOG_NOTICE, "%s-%s milter startup" % (NAME, VERSION))

    milter_t = Thread(target=runner)
    milter_t.daemon = True
    milter_t.start()

    # Worker thread for all kinds of databases
    db_runner_t = Thread(target=db_runner)
    db_runner_t.daemon = True
    db_runner_t.start()

    # Waiting for SIGNAL to terminate process
    signal.pause()

    try:
        if config.pid and os.path.exists(config.pid):
            os.unlink(config.pid)
    except IOError as e:
        if config.debug:
            print("Cannot remove PID file: (%d) %s" % (e.errno, e.strerror),
                  file=sys.stderr)
        sys.exit(1)

    sys.exit(0)

# vim: expandtab ts=4 sw=4
