From c5becf88ce28a210e413a66e2dfc2763cadfd297 Mon Sep 17 00:00:00 2001 From: Art Lukyanchyk <artiom.lukyanchyk@hs-hannover.de> Date: Fri, 29 Mar 2019 17:19:04 +0100 Subject: [PATCH] Pika 1.0 support. Plus better task autodiscover and other tweaks. --- pikatasks/__init__.py | 29 ++++++++------------ pikatasks/django_compat.py | 55 ++++++++++++++++++++------------------ pikatasks/utils.py | 3 +-- pikatasks/worker.py | 45 +++++++++++++++---------------- 4 files changed, 62 insertions(+), 70 deletions(-) diff --git a/pikatasks/__init__.py b/pikatasks/__init__.py index 1f9f1f2..675b278 100644 --- a/pikatasks/__init__.py +++ b/pikatasks/__init__.py @@ -10,6 +10,9 @@ import datetime import types +assert pika.__version__ >= "1.0.0" + + class RPCError(Exception): """ Something RPC-related went wrong. """ pass @@ -30,9 +33,6 @@ class RPCMessageQueueError(RPCError): pass -all_tasks = set() # each registered task will show up here - - def serialize_to_bytes(data): # pika documentation is very fuzzy about it (and types in general) return serialization.serialize(data).encode("utf-8") @@ -101,12 +101,7 @@ def rpc(_task_name, **kwargs): channel = conn.channel() channel.confirm_delivery() # start waiting for RPC response (required before sending a request with reply-to) - if pika.__version__ >= "1.0": - # multiple breaking changes in version 1.0.0b1 - # if this is broken again, also fix the other basic_consume in this project - channel.basic_consume(queue="amq.rabbitmq.reply-to", on_message_callback=callback_result, auto_ack=True) - else: - channel.basic_consume(consumer_callback=callback_result, queue="amq.rabbitmq.reply-to", no_ack=True) + channel.basic_consume(queue="amq.rabbitmq.reply-to", on_message_callback=callback_result, auto_ack=True) # send a request channel.basic_publish( exchange=settings.CLIENT_EXCHANGE_NAME, @@ -138,14 +133,12 @@ def task(name): It doesn't replace the function with a wrapper. Instead, it adds additional properties to the function. Property .as_callback is a callable ready to be consumed by pika's functions like Channel.basic_consume """ - if isinstance(name, str): - # used with an argument: @task("task_name") - func = None + if isinstance(name, str): # used with an argument as @task(...) + func = None # function is not yet known, need to return the decorator task_name = name - elif isinstance(name, types.FunctionType): - # used without an argument: @task - func = name - task_name = func.__name__ + elif isinstance(name, types.FunctionType): # used as @task + func = name # function is passed as positional argument + task_name = func.__name__ # will use function name as task name else: raise TypeError("Cannot decorate this: {0}".format(repr(name))) @@ -163,7 +156,7 @@ def task(name): channel.basic_ack(delivery_tag=method.delivery_tag) logger.debug("Received task {task_name}".format(**locals())) # don't log the body, private data if django_compat.DJANGO: - django_compat.check_fix_db_connection() + django_compat.check_fix_db_connections() try: task_kwargs = deserialize_from_bytes(body) if not isinstance(task_kwargs, dict): @@ -193,7 +186,7 @@ def task(name): func.as_callback = as_callback func.task_name = task_name func.task_queue = task_name - utils.all_tasks.add(func) + utils.known_tasks.add(func) return func if func: diff --git a/pikatasks/django_compat.py b/pikatasks/django_compat.py index a904041..4b53757 100644 --- a/pikatasks/django_compat.py +++ b/pikatasks/django_compat.py @@ -1,5 +1,5 @@ -import logging import importlib +import itertools from . import utils from .utils import logger @@ -12,6 +12,15 @@ except ImportError: DJANGO = None +def requires_django(callable): + def wrapper(*args, **kwargs): + if not DJANGO: + raise ModuleNotFoundError("Cannot use django compat features without django itself.") + callable(*args, **kwargs) + return wrapper + + +@requires_django def close_db_connections(): """ Closes all Django db connections. @@ -28,24 +37,22 @@ def close_db_connections(): logger.warning("Failed to close django db connections: {e.__class__.__qualname__}: {e}".format(e=e)) +@requires_django def check_worker_db_settings(): - assert DJANGO t = int(django_conf.settings.CONN_MAX_AGE) if not t or t > 20 * 60: raise ValueError("When using django, CONN_MAX_AGE must be set to a sane value. The current value: {t} seconds.".format(t=t)) -def check_fix_db_connection(): +@requires_django +def check_fix_db_connections(): """ I leave multiple options here to help solving possible future issues. This should fix OperationalError when nothing helps (starting with CONN_MAX_AGE). - Theis function has to be run *before* executing *each* task. + This function has to be run *before* executing *each* task. """ - assert DJANGO - # Option 1: django_db.close_old_connections() - # # Option 2 (If Option 1 does not help): # for name in django_db.connections: # conn = django_db.connections[name] @@ -53,36 +60,32 @@ def check_fix_db_connection(): # cursor = conn.cursor() # test # except django_db.OperationalError: # probably closed # conn.close() # let django reopen it if needed - # # Option 3 (If Option 2 does not help): # django_db.connections.close_all() - pass +@requires_django def autodiscover_tasks(apps=None, modules=("tasks",)): """ Imports modules with tasks from django apps. This function utilizes the fact that each task registers itself in utils.all_tasks :param apps: tuple of app names, leave None for everything in INSTALLED_APPS - :param modules: tuple of module names, if apps have their tasks in places other than "tasks.py" - :return: utils.all_tasks + :param modules: tuple of module names, override if apps have their tasks in places other than "tasks.py" + :return: known tasks () """ - assert DJANGO if apps is None: apps = django_conf.settings.INSTALLED_APPS - for app_name in apps: - for module_name in modules: - full_module_name = "{0}.{1}".format(app_name, module_name) - try: - importlib.import_module(full_module_name) - # just importing the module is perfectly enough, each task will register itself on import - logger.info("Autodiscover: imported \"{0}\"".format(full_module_name)) - except ImportError as ie: - msg = "Autodiscover: module \"{0}\" does not exist: {1}".format(full_module_name, str(ie)) - if 'tasks' in app_name: - logger.warning(msg) - else: - logger.debug(msg) - return utils.all_tasks + for app_name, module_name in itertools.product(apps, modules): + full_module_name = "{0}.{1}".format(app_name, module_name) + try: + importlib.import_module(full_module_name) + # just importing the module is perfectly enough, each task will register itself on import + logger.info("Discovered \"{0}\"".format(full_module_name)) + except ImportError as e: + if e.name == full_module_name: # tasks module does not exist, it's okay + logger.debug("App {0} does not have module {1}".format(app_name, module_name)) + else: # failed to import something nested, it's bad + raise e + return utils.known_tasks diff --git a/pikatasks/utils.py b/pikatasks/utils.py index b5565ba..9ead895 100644 --- a/pikatasks/utils.py +++ b/pikatasks/utils.py @@ -3,10 +3,9 @@ import logging import ssl from . import settings - logger = logging.getLogger("pikatasks") -all_tasks = set() # each registered task will show up here +known_tasks = set() # each declared (and imported) task will add itself into this set def get_ssl_options(settings): diff --git a/pikatasks/worker.py b/pikatasks/worker.py index f595f1b..1ff26f6 100644 --- a/pikatasks/worker.py +++ b/pikatasks/worker.py @@ -11,8 +11,7 @@ from . import utils from . import django_compat -MASTER_IPC_PERIOD = timedelta(seconds=0.2) # how often master checks own signals and minion processes -MINION_IPC_PERIOD = timedelta(seconds=0.2) # how often minions check their signals +IPC_PERIOD = timedelta(seconds=0.2) # how often processes check their signals and other processes class _SignalHandler: @@ -36,10 +35,10 @@ class _SignalHandler: signal.signal(s, signal_callback) -def start(tasks=utils.all_tasks, number_of_processes=None): +def start(tasks="all", number_of_processes=None): """ Use this to launch a worker. - :param tasks: list of tasks to process + :param tasks: list of tasks to process (or "all" for all registered + auto-discovered) :param number_of_processes: number of worker processes :return: """ @@ -60,7 +59,7 @@ def start(tasks=utils.all_tasks, number_of_processes=None): if django_compat.DJANGO: django_compat.close_db_connections() p = multiprocessing.Process( - target=_task_process, + target=_minion_process, kwargs=dict( tasks=tasks, parent_pid=os.getpid(), @@ -76,10 +75,10 @@ def start(tasks=utils.all_tasks, number_of_processes=None): deadline_dt = datetime.now() + settings.WORKER_GRACEFUL_STOP_TIMEOUT while processes and datetime.now() < deadline_dt: for p in processes: - os.kill(p.pid, signal.SIGTERM) # SIGTERM = ask nicely - time.sleep((MINION_IPC_PERIOD / 2).total_seconds()) + os.kill(p.pid, signal.SIGTERM) # SIGTERM = ask minions nicely to stop + time.sleep(IPC_PERIOD.total_seconds()) remove_ended_processes(expect_exited_processes=True) - if datetime.now() > last_reminder_dt + timedelta(seconds=5): + if datetime.now() > last_reminder_dt + timedelta(seconds=5): # log reminder every 5 seconds last_reminder_dt = datetime.now() logger.info("Stopping... Minions still running: {n}. Deadline in: {d}.".format(d=deadline_dt - datetime.now(), n=len(processes))) @@ -113,8 +112,13 @@ def start(tasks=utils.all_tasks, number_of_processes=None): logger.error("Some queues are missing: {0}".format(queues_missing)) return [t for t in tasks if t.task_queue in queues_found] + def get_all_tasks(): + if django_compat.DJANGO: + django_compat.autodiscover_tasks() + return utils.known_tasks + logger.info("Starting pikatasks worker...") - tasks = filter_tasks(tasks=tasks) + tasks = filter_tasks(tasks=get_all_tasks() if tasks == "all" else tasks) logger.info("Tasks: {0}".format(repr([t.task_name for t in tasks]))) if not tasks: raise ValueError("Empty task list.") @@ -124,7 +128,7 @@ def start(tasks=utils.all_tasks, number_of_processes=None): remove_ended_processes(expect_exited_processes=False) while len(processes) < (number_of_processes or settings.WORKER_TASK_PROCESSES): processes.append(create_minion(tasks)) - time.sleep(MASTER_IPC_PERIOD.total_seconds()) + time.sleep(IPC_PERIOD.total_seconds()) # stopping logger.info("Stopping minions...") stop_minions() @@ -136,7 +140,7 @@ def start(tasks=utils.all_tasks, number_of_processes=None): logger.info("Stopped pikatasks worker.") -def _task_process(tasks, parent_pid): +def _minion_process(tasks, parent_pid): """ This is a single process, that performs tasks. """ logger = logging.getLogger("pikatasks.worker.minion.pid{0}".format(os.getpid())) signal_handler = _SignalHandler(logger=logger, this_process_name="minion") @@ -144,10 +148,10 @@ def _task_process(tasks, parent_pid): raise RuntimeError("Got empty list of tasks") conn, channel = None, None - def control_beat(): + def process_controller(): + # this function performs IPC and decides when to finish this process # this function registers itself to be called again stop = False - # check whether the parent process is alive if os.getppid() != parent_pid: # got adopted, new owner is probably init logger.error("Master (PID={0}) has disappeared :( Stopping.".format(parent_pid)) stop = True @@ -157,7 +161,7 @@ def _task_process(tasks, parent_pid): if stop: channel.stop_consuming() logger.debug("Stopping consuming messages from queues.") - conn.add_timeout(MINION_IPC_PERIOD.total_seconds(), control_beat) # run this function again soon + conn.call_later(IPC_PERIOD.total_seconds(), process_controller) # run this function again soon try: logger.debug("Opening a connection...") @@ -168,27 +172,20 @@ def _task_process(tasks, parent_pid): callback = getattr(task, "as_callback", None) if not callback or not callable(callback): raise ValueError("Not a valid task: {0}".format(task)) - if pika.__version__ >= "1.0": - # multiple breaking changes in version 1.0.0b1 - # if this is broken again, also fix the other basic_consume in this project - channel.basic_consume(queue=task.task_queue, on_message_callback=callback) - else: - channel.basic_consume(consumer_callback=callback, queue=task.task_queue) + channel.basic_consume(queue=task.task_queue, on_message_callback=callback) logger.debug("Registered task {t} on queue {q}".format(t=task.task_name, q=task.task_queue)) except Exception as e: logger.error("Could not register task \"{t}\". {e.__class__.__qualname__}: {e}".format(t=task.task_name, e=e)) if isinstance(e, (AMQPChannelError, AMQPConnectionError,)): raise e # does not make sense to try registering other tasks - control_beat() # initial - channel.start_consuming() + process_controller() # initial + channel.start_consuming() # until stop_consuming is called except Exception as e: logger.error("{e.__class__.__qualname__}: {e}".format(e=e)) finally: if channel and channel.is_open: - logger.info("Closing the channel: {0}".format(channel)) channel.close() if conn and conn.is_open: - logger.info("Closing the connection: {0}".format(conn)) conn.close() logger.debug("Stopped.") -- GitLab