From c5becf88ce28a210e413a66e2dfc2763cadfd297 Mon Sep 17 00:00:00 2001
From: Art Lukyanchyk <artiom.lukyanchyk@hs-hannover.de>
Date: Fri, 29 Mar 2019 17:19:04 +0100
Subject: [PATCH] Pika 1.0 support. Plus better task autodiscover and other
 tweaks.

---
 pikatasks/__init__.py      | 29 ++++++++------------
 pikatasks/django_compat.py | 55 ++++++++++++++++++++------------------
 pikatasks/utils.py         |  3 +--
 pikatasks/worker.py        | 45 +++++++++++++++----------------
 4 files changed, 62 insertions(+), 70 deletions(-)

diff --git a/pikatasks/__init__.py b/pikatasks/__init__.py
index 1f9f1f2..675b278 100644
--- a/pikatasks/__init__.py
+++ b/pikatasks/__init__.py
@@ -10,6 +10,9 @@ import datetime
 import types
 
 
+assert pika.__version__ >= "1.0.0"
+
+
 class RPCError(Exception):
     """ Something RPC-related went wrong. """
     pass
@@ -30,9 +33,6 @@ class RPCMessageQueueError(RPCError):
     pass
 
 
-all_tasks = set()  # each registered task will show up here
-
-
 def serialize_to_bytes(data):
     # pika documentation is very fuzzy about it (and types in general)
     return serialization.serialize(data).encode("utf-8")
@@ -101,12 +101,7 @@ def rpc(_task_name, **kwargs):
             channel = conn.channel()
             channel.confirm_delivery()
             # start waiting for RPC response (required before sending a request with reply-to)
-            if pika.__version__ >= "1.0":
-                # multiple breaking changes in version 1.0.0b1
-                # if this is broken again, also fix the other basic_consume in this project
-                channel.basic_consume(queue="amq.rabbitmq.reply-to", on_message_callback=callback_result, auto_ack=True)
-            else:
-                channel.basic_consume(consumer_callback=callback_result, queue="amq.rabbitmq.reply-to", no_ack=True)
+            channel.basic_consume(queue="amq.rabbitmq.reply-to", on_message_callback=callback_result, auto_ack=True)
             # send a request
             channel.basic_publish(
                 exchange=settings.CLIENT_EXCHANGE_NAME,
@@ -138,14 +133,12 @@ def task(name):
     It doesn't replace the function with a wrapper. Instead, it adds additional properties to the function.
     Property .as_callback is a callable ready to be consumed by pika's functions like Channel.basic_consume
     """
-    if isinstance(name, str):
-        # used with an argument: @task("task_name")
-        func = None
+    if isinstance(name, str):  # used with an argument as @task(...)
+        func = None  # function is not yet known, need to return the decorator
         task_name = name
-    elif isinstance(name, types.FunctionType):
-        # used without an argument: @task
-        func = name
-        task_name = func.__name__
+    elif isinstance(name, types.FunctionType):  # used as @task
+        func = name  # function is passed as positional argument
+        task_name = func.__name__  # will use function name as task name
     else:
         raise TypeError("Cannot decorate this: {0}".format(repr(name)))
 
@@ -163,7 +156,7 @@ def task(name):
             channel.basic_ack(delivery_tag=method.delivery_tag)
             logger.debug("Received task {task_name}".format(**locals()))  # don't log the body, private data
             if django_compat.DJANGO:
-                django_compat.check_fix_db_connection()
+                django_compat.check_fix_db_connections()
             try:
                 task_kwargs = deserialize_from_bytes(body)
                 if not isinstance(task_kwargs, dict):
@@ -193,7 +186,7 @@ def task(name):
         func.as_callback = as_callback
         func.task_name = task_name
         func.task_queue = task_name
-        utils.all_tasks.add(func)
+        utils.known_tasks.add(func)
         return func
 
     if func:
diff --git a/pikatasks/django_compat.py b/pikatasks/django_compat.py
index a904041..4b53757 100644
--- a/pikatasks/django_compat.py
+++ b/pikatasks/django_compat.py
@@ -1,5 +1,5 @@
-import logging
 import importlib
+import itertools
 from . import utils
 from .utils import logger
 
@@ -12,6 +12,15 @@ except ImportError:
     DJANGO = None
 
 
+def requires_django(callable):
+    def wrapper(*args, **kwargs):
+        if not DJANGO:
+            raise ModuleNotFoundError("Cannot use django compat features without django itself.")
+        callable(*args, **kwargs)
+    return wrapper
+
+
+@requires_django
 def close_db_connections():
     """
     Closes all Django db connections.
@@ -28,24 +37,22 @@ def close_db_connections():
         logger.warning("Failed to close django db connections: {e.__class__.__qualname__}: {e}".format(e=e))
 
 
+@requires_django
 def check_worker_db_settings():
-    assert DJANGO
     t = int(django_conf.settings.CONN_MAX_AGE)
     if not t or t > 20 * 60:
         raise ValueError("When using django, CONN_MAX_AGE must be set to a sane value. The current value: {t} seconds.".format(t=t))
 
 
-def check_fix_db_connection():
+@requires_django
+def check_fix_db_connections():
     """
     I leave multiple options here to help solving possible future issues.
     This should fix OperationalError when nothing helps (starting with CONN_MAX_AGE).
-    Theis function has to be run *before* executing *each* task.
+    This function has to be run *before* executing *each* task.
     """
-    assert DJANGO
-
     # Option 1:
     django_db.close_old_connections()
-
     # # Option 2 (If Option 1 does not help):
     #     for name in django_db.connections:
     #         conn = django_db.connections[name]
@@ -53,36 +60,32 @@ def check_fix_db_connection():
     #             cursor = conn.cursor()  # test
     #         except django_db.OperationalError:  # probably closed
     #             conn.close()  # let django reopen it if needed
-
     # # Option 3 (If Option 2 does not help):
     #     django_db.connections.close_all()
-
     pass
 
 
+@requires_django
 def autodiscover_tasks(apps=None, modules=("tasks",)):
     """
     Imports modules with tasks from django apps.
     This function utilizes the fact that each task registers itself in utils.all_tasks
     :param apps: tuple of app names, leave None for everything in INSTALLED_APPS
-    :param modules: tuple of module names, if apps have their tasks in places other than "tasks.py"
-    :return: utils.all_tasks
+    :param modules: tuple of module names, override if apps have their tasks in places other than "tasks.py"
+    :return: known tasks ()
     """
-    assert DJANGO
     if apps is None:
         apps = django_conf.settings.INSTALLED_APPS
-    for app_name in apps:
-        for module_name in modules:
-            full_module_name = "{0}.{1}".format(app_name, module_name)
-            try:
-                importlib.import_module(full_module_name)
-                # just importing the module is perfectly enough, each task will register itself on import
-                logger.info("Autodiscover: imported \"{0}\"".format(full_module_name))
-            except ImportError as ie:
-                msg = "Autodiscover: module \"{0}\" does not exist: {1}".format(full_module_name, str(ie))
-                if 'tasks' in app_name:
-                    logger.warning(msg)
-                else:
-                    logger.debug(msg)
-    return utils.all_tasks
+    for app_name, module_name in itertools.product(apps, modules):
+        full_module_name = "{0}.{1}".format(app_name, module_name)
+        try:
+            importlib.import_module(full_module_name)
+            # just importing the module is perfectly enough, each task will register itself on import
+            logger.info("Discovered \"{0}\"".format(full_module_name))
+        except ImportError as e:
+            if e.name == full_module_name:  # tasks module does not exist, it's okay
+                logger.debug("App {0} does not have module {1}".format(app_name, module_name))
+            else:  # failed to import something nested, it's bad
+                raise e
+    return utils.known_tasks
 
diff --git a/pikatasks/utils.py b/pikatasks/utils.py
index b5565ba..9ead895 100644
--- a/pikatasks/utils.py
+++ b/pikatasks/utils.py
@@ -3,10 +3,9 @@ import logging
 import ssl
 from . import settings
 
-
 logger = logging.getLogger("pikatasks")
 
-all_tasks = set()  # each registered task will show up here
+known_tasks = set()  # each declared (and imported) task will add itself into this set
 
 
 def get_ssl_options(settings):
diff --git a/pikatasks/worker.py b/pikatasks/worker.py
index f595f1b..1ff26f6 100644
--- a/pikatasks/worker.py
+++ b/pikatasks/worker.py
@@ -11,8 +11,7 @@ from . import utils
 from . import django_compat
 
 
-MASTER_IPC_PERIOD = timedelta(seconds=0.2)  # how often master checks own signals and minion processes
-MINION_IPC_PERIOD = timedelta(seconds=0.2)  # how often minions check their signals
+IPC_PERIOD = timedelta(seconds=0.2)  # how often processes check their signals and other processes
 
 
 class _SignalHandler:
@@ -36,10 +35,10 @@ class _SignalHandler:
             signal.signal(s, signal_callback)
 
 
-def start(tasks=utils.all_tasks, number_of_processes=None):
+def start(tasks="all", number_of_processes=None):
     """
     Use this to launch a worker.
-    :param tasks: list of tasks to process
+    :param tasks: list of tasks to process (or "all" for all registered + auto-discovered)
     :param number_of_processes: number of worker processes
     :return: 
     """
@@ -60,7 +59,7 @@ def start(tasks=utils.all_tasks, number_of_processes=None):
         if django_compat.DJANGO:
             django_compat.close_db_connections()
         p = multiprocessing.Process(
-            target=_task_process,
+            target=_minion_process,
             kwargs=dict(
                 tasks=tasks,
                 parent_pid=os.getpid(),
@@ -76,10 +75,10 @@ def start(tasks=utils.all_tasks, number_of_processes=None):
         deadline_dt = datetime.now() + settings.WORKER_GRACEFUL_STOP_TIMEOUT
         while processes and datetime.now() < deadline_dt:
             for p in processes:
-                os.kill(p.pid, signal.SIGTERM)  # SIGTERM = ask nicely
-            time.sleep((MINION_IPC_PERIOD / 2).total_seconds())
+                os.kill(p.pid, signal.SIGTERM)  # SIGTERM = ask minions nicely to stop
+            time.sleep(IPC_PERIOD.total_seconds())
             remove_ended_processes(expect_exited_processes=True)
-            if datetime.now() > last_reminder_dt + timedelta(seconds=5):
+            if datetime.now() > last_reminder_dt + timedelta(seconds=5):  # log reminder every 5 seconds
                 last_reminder_dt = datetime.now()
                 logger.info("Stopping... Minions still running: {n}. Deadline in: {d}.".format(d=deadline_dt - datetime.now(), n=len(processes)))
 
@@ -113,8 +112,13 @@ def start(tasks=utils.all_tasks, number_of_processes=None):
             logger.error("Some queues are missing: {0}".format(queues_missing))
         return [t for t in tasks if t.task_queue in queues_found]
 
+    def get_all_tasks():
+        if django_compat.DJANGO:
+            django_compat.autodiscover_tasks()
+        return utils.known_tasks
+
     logger.info("Starting pikatasks worker...")
-    tasks = filter_tasks(tasks=tasks)
+    tasks = filter_tasks(tasks=get_all_tasks() if tasks == "all" else tasks)
     logger.info("Tasks: {0}".format(repr([t.task_name for t in tasks])))
     if not tasks:
         raise ValueError("Empty task list.")
@@ -124,7 +128,7 @@ def start(tasks=utils.all_tasks, number_of_processes=None):
         remove_ended_processes(expect_exited_processes=False)
         while len(processes) < (number_of_processes or settings.WORKER_TASK_PROCESSES):
             processes.append(create_minion(tasks))
-        time.sleep(MASTER_IPC_PERIOD.total_seconds())
+        time.sleep(IPC_PERIOD.total_seconds())
     # stopping
     logger.info("Stopping minions...")
     stop_minions()
@@ -136,7 +140,7 @@ def start(tasks=utils.all_tasks, number_of_processes=None):
     logger.info("Stopped pikatasks worker.")
 
 
-def _task_process(tasks, parent_pid):
+def _minion_process(tasks, parent_pid):
     """ This is a single process, that performs tasks. """
     logger = logging.getLogger("pikatasks.worker.minion.pid{0}".format(os.getpid()))
     signal_handler = _SignalHandler(logger=logger, this_process_name="minion")
@@ -144,10 +148,10 @@ def _task_process(tasks, parent_pid):
         raise RuntimeError("Got empty list of tasks")
     conn, channel = None, None
 
-    def control_beat():
+    def process_controller():
+        # this function performs IPC and decides when to finish this process
         # this function registers itself to be called again
         stop = False
-        # check whether the parent process is alive
         if os.getppid() != parent_pid:  # got adopted, new owner is probably init
             logger.error("Master (PID={0}) has disappeared :( Stopping.".format(parent_pid))
             stop = True
@@ -157,7 +161,7 @@ def _task_process(tasks, parent_pid):
         if stop:
             channel.stop_consuming()
             logger.debug("Stopping consuming messages from queues.")
-        conn.add_timeout(MINION_IPC_PERIOD.total_seconds(), control_beat)  # run this function again soon
+        conn.call_later(IPC_PERIOD.total_seconds(), process_controller)  # run this function again soon
 
     try:
         logger.debug("Opening a connection...")
@@ -168,27 +172,20 @@ def _task_process(tasks, parent_pid):
                     callback = getattr(task, "as_callback", None)
                     if not callback or not callable(callback):
                         raise ValueError("Not a valid task: {0}".format(task))
-                    if pika.__version__ >= "1.0":
-                        # multiple breaking changes in version 1.0.0b1
-                        # if this is broken again, also fix the other basic_consume in this project
-                        channel.basic_consume(queue=task.task_queue, on_message_callback=callback)
-                    else:
-                        channel.basic_consume(consumer_callback=callback, queue=task.task_queue)
+                    channel.basic_consume(queue=task.task_queue, on_message_callback=callback)
                     logger.debug("Registered task {t} on queue {q}".format(t=task.task_name, q=task.task_queue))
                 except Exception as e:
                     logger.error("Could not register task \"{t}\". {e.__class__.__qualname__}: {e}".format(t=task.task_name, e=e))
                     if isinstance(e, (AMQPChannelError, AMQPConnectionError,)):
                         raise e  # does not make sense to try registering other tasks
-            control_beat()  # initial
-            channel.start_consuming()
+            process_controller()  # initial
+            channel.start_consuming()  # until stop_consuming is called
     except Exception as e:
         logger.error("{e.__class__.__qualname__}: {e}".format(e=e))
     finally:
         if channel and channel.is_open:
-            logger.info("Closing the channel: {0}".format(channel))
             channel.close()
         if conn and conn.is_open:
-            logger.info("Closing the connection: {0}".format(conn))
             conn.close()
 
     logger.debug("Stopped.")
-- 
GitLab