diff --git a/README.md b/README.md index 001350bd01229be4eea60aa00b09e9df4e491388..c9c010c9346212bc0a551263cfcb7e475b380e8c 100644 --- a/README.md +++ b/README.md @@ -30,18 +30,19 @@ PIKATASKS_BROKER_HOST = "localhost" PIKATASKS_BROKER_PORT = "5671" PIKATASKS_SSL_ENABLED = False PIKATASKS_VIRTUAL_HOST = "foo" -PIKATASKS_USERNAME = "lancelot" -PIKATASKS_PASSWORD = "swalloWcoc0nut" +PIKATASKS_USERNAME = "rabbitmq_user" +PIKATASKS_PASSWORD = "password" ``` ##### Implement a task (server): ```python -@pikatasks.task(name="hello") +@pikatasks.task def hello(something): msg = "Hello, " + something + "!" print(msg) return msg ``` -Note: you will need a queue with exactly the same name as the task. See section: Queues and Permissions. + - The task name (and the queue name) will be the same as the function name. If you want to specify a custom task (and queue) name, use `@pikatasks.task("my_task_name")`. + - Note: you will need a queue with exactly the same name as the task. See section: Queues and Permissions. ##### Start a server: ```python @@ -86,13 +87,13 @@ With AMQ, messages first arrive to `exchanges`, then broker distributes them to You are done after creating queues for each of your tasks. Don't need anything else for the development. Note: exchange `amq.default` will be used. ##### Client: - * Create a new exchange for your client. Let's call it `client-out`, and its type should be `direct`. This exchange will be used for sending tasks. + * Create a new exchange for your client. Let's call it `client.out`, and its type should be `direct`. This exchange will be used for sending tasks. * Decide which tasks should the client use. Let's say these are `task1` and `task2` (you should have the corresponding queues already). - * For each of the tasks, create a new binding for the exchange `client-out`, with `routing key == queue name == task name` - * e.g. `exchange = client-out`, `routing key = task1`, `queue = task1` + * For each of the tasks, create a new binding for the exchange `client.out`, with `routing key == queue name == task name` + * e.g. `exchange = client.out`, `routing key = task1`, `queue = task1` * User permissions: * Configure: empty string (no config permissions) - * Write: `^client-out$` (replace with the name of your exchange) + * Write: `^client.out$` (replace with the name of your exchange) * Read: empty string (no read permissions, RPC results/replies will still work) ##### Worker: diff --git a/pikatasks/__init__.py b/pikatasks/__init__.py index 9484f389c2a14a62f212c90d1f322be724fcaa40..99f7153d6b83a6e0d864c046a25511dd9ad0b8a9 100644 --- a/pikatasks/__init__.py +++ b/pikatasks/__init__.py @@ -6,6 +6,7 @@ from . import django_compat from . import worker # keep it here for pikatasks.worker.start() being possible from .utils import logger import datetime +import types class RPCError(Exception): @@ -122,23 +123,32 @@ def rpc(_task_name, **kwargs): def task(name): """ - Use this to decorate your tasks. + Use this to decorate your tasks. Usage: @task or @task("task_name") It doesn't replace the function with a wrapper. Instead, it adds additional properties to the function. Property .as_callback is a callable ready to be consumed by pika's functions like Channel.basic_consume - :param name: name of the task == name of the queue """ - assert isinstance(name, str) + if isinstance(name, str): + func = None + task_name = name + elif isinstance(name, types.FunctionType): + func = name + task_name = func.__name__ + else: + raise AssertionError("Bad arguments for the @task decorator") def decorator(func): """ Creates an actual decorator. """ def as_callback(channel, method, properties, body): - """ Creates a callback to be used by pika. """ - nonlocal name, func + """ + Creates a callback to be used by pika. + More info: http://pika.readthedocs.io/en/0.10.0/modules/channel.html#pika.channel.Channel.basic_consume + """ + nonlocal task_name, func task_started_time = datetime.datetime.utcnow() func_result, func_error = None, None channel.basic_ack(delivery_tag=method.delivery_tag) - logger.debug("Received task {name}".format(**locals())) # don't log the body, private data + logger.debug("Received task {task_name}".format(**locals())) # don't log the body, private data if django_compat.DJANGO: django_compat.check_fix_db_connection() try: @@ -148,11 +158,11 @@ def task(name): except Exception as e: ec = e.__class__.__name__ logger.error(traceback.format_exc()) - logger.error("Task {name} function raised {ec}: {e}".format(**locals())) - func_error = "Task {name} raised {ec} (see worker log for details).".format(**locals()) # sort of anonymized + logger.error("Task {task_name} function raised {ec}: {e}".format(**locals())) + func_error = "Task {task_name} raised {ec} (see worker log for details).".format(**locals()) # sort of anonymized if properties.reply_to: try: - logger.debug("Sending the result of {name} to {properties.reply_to}.".format(**locals())) + logger.debug("Sending the result of {task_name} to {properties.reply_to}.".format(**locals())) reply = {"error": func_error} if func_error else {"result": func_result} channel.basic_publish( exchange="", # empty string as specified in RabbitMQ documentation, section direct reply-to @@ -162,15 +172,17 @@ def task(name): logger.error("Could not reply to the {properties.reply_to}. {e.__class__.__name__}: {e}".format(**locals())) else: if func_result: - logger.warning("Task {name} returned a result but the client doesn't want to receive it.".format(**locals())) - logger.info("Finished task {name} in {t}.".format(name=name, t=datetime.datetime.utcnow() - task_started_time)) + logger.warning("Task {task_name} returned a result but the client doesn't want to receive it.".format(**locals())) + task_time = datetime.datetime.utcnow() - task_started_time + logger.info("Finished task {task_name} in {task_time}.".format(**locals())) func.as_callback = as_callback - func.task_name = name - func.task_queue = name - global all_tasks - all_tasks.add(func) + func.task_name = task_name + func.task_queue = task_name + utils.all_tasks.add(func) return func - return decorator - + if func: + return decorator(func) + else: + return decorator diff --git a/pikatasks/django_compat.py b/pikatasks/django_compat.py index 824dd4faa538c97d1b0f63dfe20634d355f52113..46319dfb10840ac2e0f05198a931328c7849c25c 100644 --- a/pikatasks/django_compat.py +++ b/pikatasks/django_compat.py @@ -19,15 +19,12 @@ def close_db_connections(): - https://code.djangoproject.com/ticket/20562 - https://code.djangoproject.com/ticket/15802 """ - if DJANGO: - logger.debug("Closing django db connections.") - check_worker_db_settings() - try: - django_db.connections.close_all() - except Exception as e: - logger.warning("Failed to close django db connections: {e.__class__.__name__}: {e}".format(e=e)) - else: - logger.debug("No django, no db connections to close.") + logger.debug("Closing django db connections.") + check_worker_db_settings() + try: + django_db.connections.close_all() + except Exception as e: + logger.warning("Failed to close django db connections: {e.__class__.__name__}: {e}".format(e=e)) def check_worker_db_settings(): diff --git a/pikatasks/utils.py b/pikatasks/utils.py index 6bcdd0d200f6acc7c5620b4e1014a22182444c26..33cc9df8a494e4732b0cd6dd97388d70d062a6bf 100644 --- a/pikatasks/utils.py +++ b/pikatasks/utils.py @@ -2,12 +2,19 @@ import json import pika import logging import ssl +from datetime import datetime from . import settings logger = logging.getLogger("pika-tasks") +all_tasks = set() # each registered task will show up here + + +DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" # need it to de/serialize datetime stored in JSON + + def serialize(stuff): return json.dumps(stuff).encode("utf-8") @@ -16,6 +23,15 @@ def deserialize(binary): return json.loads(binary.decode("utf-8")) +def serialize_datetime(dt): + assert isinstance(dt, datetime) + return datetime.strftime(dt, DATETIME_FORMAT) + + +def deserialize_datetime(text): + return datetime.strptime(text, DATETIME_FORMAT) + + def get_ssl_options(settings): """ Create pika.SSLOptions based on pikatasks settings. """ context = ssl.SSLContext(settings.SSL_VERSION) diff --git a/pikatasks/worker.py b/pikatasks/worker.py index 1aca29eb9c017da25990b0012bbe85ffb45e4446..88a778d5c7c7887b4f534c4c9a79856e18e93f04 100644 --- a/pikatasks/worker.py +++ b/pikatasks/worker.py @@ -12,123 +12,150 @@ from . import utils from . import django_compat -MSG_STOP = "MSG_STOP" -MSG_CHECK_FREQ = 2 # seconds +SIGNAL_CHECK_FREQUENCY = 2 # seconds -logger = logging.getLogger("pikatasks.worker") +class _SignalHandler: + """ Instance of this class will intercept KILL_SIGNALS. Use instance.kill_is_requested for checks. """ + STOP_SIGNALS = [signal.SIGTERM, signal.SIGINT] -def start(tasks, number_of_processes=None): + def __init__(self, logger, this_process_name, signals=STOP_SIGNALS): + self.stop_is_requested = False # use this for checks + + def signal_callback(signal_num, *args, **kwargs): + nonlocal this_process_name + signal_name = signal.Signals(signal_num).name + pid = os.getpid() + if self.stop_is_requested: + logger.debug("Already stopping this {this_process_name} (PID={pid}) (ignoring {signal_name})".format(**locals())) + else: + logger.debug("Requested to stop {this_process_name} (PID={pid}) using {signal_name}".format(**locals())) + self.stop_is_requested = True + + for s in signals: + signal.signal(s, signal_callback) + + +def start(tasks=utils.all_tasks, number_of_processes=None): """ Use this to launch a worker. :param tasks: list of tasks to process :param number_of_processes: number of worker processes :return: """ + + logger = logging.getLogger("pikatasks.worker.master") + + def remove_ended_processes(processes, expect_exited_processes): + alive = [p for p in processes if p.is_alive()] + exited = set(processes) - set(alive) + for p in exited: + if not expect_exited_processes: + logger.error("Minion (PID={0}) disappeared unexpectedly.".format(p.pid)) + processes.remove(p) + + def create_minion(tasks): + if django_compat.DJANGO: + django_compat.close_db_connections() + p = multiprocessing.Process( + target=_task_process, + kwargs=dict( + tasks=tasks, + parent_pid=os.getpid(), + ) + ) + p.start() + logger.info("Started new minion (PID={0}).".format(p.pid)) + return p + + def stop_minion(processes): + deadline = datetime.now() + settings.WORKER_GRACEFUL_STOP_TIMEOUT + while datetime.now() < deadline: + for p in processes: + os.kill(p.pid, signal.SIGTERM) + time.sleep(SIGNAL_CHECK_FREQUENCY) + remove_ended_processes(processes, expect_exited_processes=True) + if processes: + logger.info("Stopping... Minions still running: {n}. Deadline in: {d}.".format(d=deadline - datetime.now(), n=len(processes))) + else: + break + + def force_kill_minion(processes): + for p in processes: + logger.warning("Killing minion (PID={pid})".format(pid=p.pid)) + os.kill(p.pid, signal.SIGKILL) + + def queue_exists(queue_name): + conn = None + try: + with utils.get_connection() as conn: + channel = conn.channel() + channel.queue_declare(queue=queue_name, passive=True) + exists = True + except AMQPChannelError as e: + logger.warning("Failed to {queue_name}. {e.__class__.__name__}: {e}".format(**locals())) + exists = False + finally: + if conn and conn.is_open: + conn.close() + return exists + + def filter_tasks(tasks): + queues_wanted = [t.task_queue for t in tasks] + queues_found = [q for q in queues_wanted if queue_exists(q)] + queues_missing = set(queues_wanted) - set(queues_found) + if queues_missing: + logger.error("Some queues are missing: {0}".format(queues_missing)) + return [t for t in tasks if t.task_queue in queues_found] + logger.info("Starting pikatasks worker...") + tasks = filter_tasks(tasks=tasks) + logger.info("Tasks: {0}".format(repr([t.task_name for t in tasks]))) processes = list() - control_queue = multiprocessing.Queue() - assert tasks, "Received empty task list." + if not tasks: + raise ValueError("Empty task list.") # the main loop (exits with SIGINT) watches worker processes - try: - while True: - _remove_ended_processes(processes) - while len(processes) < (number_of_processes or settings.WORKER_TASK_PROCESSES): - processes.append(_create_worker_process(tasks, control_queue)) - time.sleep(settings.WORKER_CHECK_SUBPROCESSES_PERIOD.total_seconds()) - except KeyboardInterrupt: - _start_ignoring_sigint() # in case user gets impatient and continues slamming ctrl+c - logger.info("Received SIGINT") + signal_handler = _SignalHandler(logger=logger, this_process_name="master") + while not signal_handler.stop_is_requested: + remove_ended_processes(processes, expect_exited_processes=False) + while len(processes) < (number_of_processes or settings.WORKER_TASK_PROCESSES): + processes.append(create_minion(tasks)) + time.sleep(settings.WORKER_CHECK_SUBPROCESSES_PERIOD.total_seconds()) # stopping - logger.info("Stopping worker processes...") - _stop_worker_processes(processes, control_queue) + logger.info("Stopping minions...") + stop_minion(processes) if processes: - logger.error("{n} worker processes failed to stop gracefully.".format(n=len(processes))) - _terminate_worker_processes(processes) + logger.error("{n} minions failed to stop gracefully.".format(n=len(processes))) + force_kill_minion(processes) else: - logger.info("All worker processes have stopped gracefully.") + logger.info("All minions have stopped gracefully.") logger.info("Stopped pikatasks worker.") -def _remove_ended_processes(processes, expect_exited_processes=False): - alive = [p for p in processes if p.is_alive()] - exited = set(processes) - set(alive) - for p in exited: - if not expect_exited_processes: - logger.error("Process (PID={0}) exited unexpectedly.".format(p.pid)) - processes.remove(p) - - -def _create_worker_process(tasks, control_queue): - django_compat.close_db_connections() - p = multiprocessing.Process( - target=_task_process, - kwargs=dict( - tasks=tasks, - control_queue=control_queue, - parent_pid=os.getpid(), - ) - ) - p.start() - logger.info("Started new worker process (PID={0}).".format(p.pid)) - return p - - -def _stop_worker_processes(processes, control_queue): - deadline = datetime.now() + settings.WORKER_GRACEFUL_STOP_TIMEOUT - while datetime.now() < deadline: - while control_queue.qsize() < len(processes) * 2: - control_queue.put(MSG_STOP) - time.sleep(1) - _remove_ended_processes(processes, expect_exited_processes=True) - if processes: - logger.info("Stopping worker processes. Still running: {n}. Deadline in: {d}.".format(d=deadline - datetime.now(), n=len(processes))) - else: - break - - -def _terminate_worker_processes(processes): - for p in processes: - logger.warning("Killing worker process (PID={pid})".format(pid=p.pid)) - p.terminate() - - -def _start_ignoring_sigint(): - signal.signal(signal.SIGINT, signal.SIG_IGN) - - -def _task_process(tasks, control_queue, parent_pid): +def _task_process(tasks, parent_pid): """ This is a single process, that performs tasks. """ - _start_ignoring_sigint() # no interruptions in the middle of the task, graceful exit is controlled by the main process - own_pid = os.getpid() - subprocess_logger = logging.getLogger("pikatasks.worker.subprocess_{0}".format(own_pid)) - log_prefix = "(pikatasks subprocess PID={0}) ".format(own_pid) + logger = logging.getLogger("pikatasks.worker.minion.pid{0}".format(os.getpid())) + signal_handler = _SignalHandler(logger=logger, this_process_name="minion") assert tasks conn, channel = None, None def control_beat(): # this function registers itself to be called again stop = False - # check whether parent process is alive + # check whether the parent process is alive if os.getppid() != parent_pid: # got adopted, new owner is probably init - logger.error("Parent process disappeared :( Stopping.") + logger.error("Master (PID={0}) has disappeared :( Stopping.".format(parent_pid)) + stop = True + if signal_handler.stop_is_requested: + logger.info("Minion (PID={0}) is requested to stop.".format(os.getpid())) stop = True - try: - # check whether graceful stop is requested - msg = control_queue.get_nowait() - if msg == MSG_STOP: - stop = True - else: - subprocess_logger.error(log_prefix + "Don't know what to do with the control message: {msg}".format(msg=msg)) - except Empty: - pass if stop: channel.stop_consuming() - subprocess_logger.debug(log_prefix + "Stopping consuming messages from queues.") - conn.add_timeout(MSG_CHECK_FREQ, control_beat) # run this function again soon + logger.debug("Stopping consuming messages from queues.") + conn.add_timeout(SIGNAL_CHECK_FREQUENCY, control_beat) # run this function again soon try: - subprocess_logger.debug(log_prefix + "Opening a connection...") + logger.debug("Opening a connection...") with utils.get_connection() as conn: channel = conn.channel() for task in tasks: @@ -143,15 +170,20 @@ def _task_process(tasks, control_queue, parent_pid): channel.basic_consume(consumer_callback=callback, queue=task.task_queue) logger.debug("Registered task {t} on queue {q}".format(t=task.task_name, q=task.task_queue)) except Exception as e: - logger.error("Could not register task {t}. {e.__class__.__name__}: {e}".format(t=task, e=e)) + logger.error("Could not register task \"{t}\". {e.__class__.__name__}: {e}".format(t=task.task_name, e=e)) + if isinstance(e, (AMQPChannelError, AMQPConnectionError,)): + raise e # does not make sense to try registering other tasks control_beat() # initial channel.start_consuming() - channel.close() except Exception as e: - subprocess_logger.error(log_prefix + "{e.__class__.__name__}: {e}".format(e=e)) + logger.error("{e.__class__.__name__}: {e}".format(e=e)) finally: + if channel and channel.is_open: + logger.info("Closing the channel: {0}".format(channel)) + channel.close() if conn and conn.is_open: + logger.info("Closing the connection: {0}".format(conn)) conn.close() - subprocess_logger.debug(log_prefix + "Stopped.") + logger.debug("Stopped.")