Skip to content
Snippets Groups Projects
Commit c5becf88 authored by Art's avatar Art :lizard:
Browse files

Pika 1.0 support. Plus better task autodiscover and other tweaks.

parent 2d658fe2
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,9 @@ import datetime
import types
assert pika.__version__ >= "1.0.0"
class RPCError(Exception):
""" Something RPC-related went wrong. """
pass
......@@ -30,9 +33,6 @@ class RPCMessageQueueError(RPCError):
pass
all_tasks = set() # each registered task will show up here
def serialize_to_bytes(data):
# pika documentation is very fuzzy about it (and types in general)
return serialization.serialize(data).encode("utf-8")
......@@ -101,12 +101,7 @@ def rpc(_task_name, **kwargs):
channel = conn.channel()
channel.confirm_delivery()
# start waiting for RPC response (required before sending a request with reply-to)
if pika.__version__ >= "1.0":
# multiple breaking changes in version 1.0.0b1
# if this is broken again, also fix the other basic_consume in this project
channel.basic_consume(queue="amq.rabbitmq.reply-to", on_message_callback=callback_result, auto_ack=True)
else:
channel.basic_consume(consumer_callback=callback_result, queue="amq.rabbitmq.reply-to", no_ack=True)
# send a request
channel.basic_publish(
exchange=settings.CLIENT_EXCHANGE_NAME,
......@@ -138,14 +133,12 @@ def task(name):
It doesn't replace the function with a wrapper. Instead, it adds additional properties to the function.
Property .as_callback is a callable ready to be consumed by pika's functions like Channel.basic_consume
"""
if isinstance(name, str):
# used with an argument: @task("task_name")
func = None
if isinstance(name, str): # used with an argument as @task(...)
func = None # function is not yet known, need to return the decorator
task_name = name
elif isinstance(name, types.FunctionType):
# used without an argument: @task
func = name
task_name = func.__name__
elif isinstance(name, types.FunctionType): # used as @task
func = name # function is passed as positional argument
task_name = func.__name__ # will use function name as task name
else:
raise TypeError("Cannot decorate this: {0}".format(repr(name)))
......@@ -163,7 +156,7 @@ def task(name):
channel.basic_ack(delivery_tag=method.delivery_tag)
logger.debug("Received task {task_name}".format(**locals())) # don't log the body, private data
if django_compat.DJANGO:
django_compat.check_fix_db_connection()
django_compat.check_fix_db_connections()
try:
task_kwargs = deserialize_from_bytes(body)
if not isinstance(task_kwargs, dict):
......@@ -193,7 +186,7 @@ def task(name):
func.as_callback = as_callback
func.task_name = task_name
func.task_queue = task_name
utils.all_tasks.add(func)
utils.known_tasks.add(func)
return func
if func:
......
import logging
import importlib
import itertools
from . import utils
from .utils import logger
......@@ -12,6 +12,15 @@ except ImportError:
DJANGO = None
def requires_django(callable):
def wrapper(*args, **kwargs):
if not DJANGO:
raise ModuleNotFoundError("Cannot use django compat features without django itself.")
callable(*args, **kwargs)
return wrapper
@requires_django
def close_db_connections():
"""
Closes all Django db connections.
......@@ -28,24 +37,22 @@ def close_db_connections():
logger.warning("Failed to close django db connections: {e.__class__.__qualname__}: {e}".format(e=e))
@requires_django
def check_worker_db_settings():
assert DJANGO
t = int(django_conf.settings.CONN_MAX_AGE)
if not t or t > 20 * 60:
raise ValueError("When using django, CONN_MAX_AGE must be set to a sane value. The current value: {t} seconds.".format(t=t))
def check_fix_db_connection():
@requires_django
def check_fix_db_connections():
"""
I leave multiple options here to help solving possible future issues.
This should fix OperationalError when nothing helps (starting with CONN_MAX_AGE).
Theis function has to be run *before* executing *each* task.
This function has to be run *before* executing *each* task.
"""
assert DJANGO
# Option 1:
django_db.close_old_connections()
# # Option 2 (If Option 1 does not help):
# for name in django_db.connections:
# conn = django_db.connections[name]
......@@ -53,36 +60,32 @@ def check_fix_db_connection():
# cursor = conn.cursor() # test
# except django_db.OperationalError: # probably closed
# conn.close() # let django reopen it if needed
# # Option 3 (If Option 2 does not help):
# django_db.connections.close_all()
pass
@requires_django
def autodiscover_tasks(apps=None, modules=("tasks",)):
"""
Imports modules with tasks from django apps.
This function utilizes the fact that each task registers itself in utils.all_tasks
:param apps: tuple of app names, leave None for everything in INSTALLED_APPS
:param modules: tuple of module names, if apps have their tasks in places other than "tasks.py"
:return: utils.all_tasks
:param modules: tuple of module names, override if apps have their tasks in places other than "tasks.py"
:return: known tasks ()
"""
assert DJANGO
if apps is None:
apps = django_conf.settings.INSTALLED_APPS
for app_name in apps:
for module_name in modules:
for app_name, module_name in itertools.product(apps, modules):
full_module_name = "{0}.{1}".format(app_name, module_name)
try:
importlib.import_module(full_module_name)
# just importing the module is perfectly enough, each task will register itself on import
logger.info("Autodiscover: imported \"{0}\"".format(full_module_name))
except ImportError as ie:
msg = "Autodiscover: module \"{0}\" does not exist: {1}".format(full_module_name, str(ie))
if 'tasks' in app_name:
logger.warning(msg)
else:
logger.debug(msg)
return utils.all_tasks
logger.info("Discovered \"{0}\"".format(full_module_name))
except ImportError as e:
if e.name == full_module_name: # tasks module does not exist, it's okay
logger.debug("App {0} does not have module {1}".format(app_name, module_name))
else: # failed to import something nested, it's bad
raise e
return utils.known_tasks
......@@ -3,10 +3,9 @@ import logging
import ssl
from . import settings
logger = logging.getLogger("pikatasks")
all_tasks = set() # each registered task will show up here
known_tasks = set() # each declared (and imported) task will add itself into this set
def get_ssl_options(settings):
......
......@@ -11,8 +11,7 @@ from . import utils
from . import django_compat
MASTER_IPC_PERIOD = timedelta(seconds=0.2) # how often master checks own signals and minion processes
MINION_IPC_PERIOD = timedelta(seconds=0.2) # how often minions check their signals
IPC_PERIOD = timedelta(seconds=0.2) # how often processes check their signals and other processes
class _SignalHandler:
......@@ -36,10 +35,10 @@ class _SignalHandler:
signal.signal(s, signal_callback)
def start(tasks=utils.all_tasks, number_of_processes=None):
def start(tasks="all", number_of_processes=None):
"""
Use this to launch a worker.
:param tasks: list of tasks to process
:param tasks: list of tasks to process (or "all" for all registered + auto-discovered)
:param number_of_processes: number of worker processes
:return:
"""
......@@ -60,7 +59,7 @@ def start(tasks=utils.all_tasks, number_of_processes=None):
if django_compat.DJANGO:
django_compat.close_db_connections()
p = multiprocessing.Process(
target=_task_process,
target=_minion_process,
kwargs=dict(
tasks=tasks,
parent_pid=os.getpid(),
......@@ -76,10 +75,10 @@ def start(tasks=utils.all_tasks, number_of_processes=None):
deadline_dt = datetime.now() + settings.WORKER_GRACEFUL_STOP_TIMEOUT
while processes and datetime.now() < deadline_dt:
for p in processes:
os.kill(p.pid, signal.SIGTERM) # SIGTERM = ask nicely
time.sleep((MINION_IPC_PERIOD / 2).total_seconds())
os.kill(p.pid, signal.SIGTERM) # SIGTERM = ask minions nicely to stop
time.sleep(IPC_PERIOD.total_seconds())
remove_ended_processes(expect_exited_processes=True)
if datetime.now() > last_reminder_dt + timedelta(seconds=5):
if datetime.now() > last_reminder_dt + timedelta(seconds=5): # log reminder every 5 seconds
last_reminder_dt = datetime.now()
logger.info("Stopping... Minions still running: {n}. Deadline in: {d}.".format(d=deadline_dt - datetime.now(), n=len(processes)))
......@@ -113,8 +112,13 @@ def start(tasks=utils.all_tasks, number_of_processes=None):
logger.error("Some queues are missing: {0}".format(queues_missing))
return [t for t in tasks if t.task_queue in queues_found]
def get_all_tasks():
if django_compat.DJANGO:
django_compat.autodiscover_tasks()
return utils.known_tasks
logger.info("Starting pikatasks worker...")
tasks = filter_tasks(tasks=tasks)
tasks = filter_tasks(tasks=get_all_tasks() if tasks == "all" else tasks)
logger.info("Tasks: {0}".format(repr([t.task_name for t in tasks])))
if not tasks:
raise ValueError("Empty task list.")
......@@ -124,7 +128,7 @@ def start(tasks=utils.all_tasks, number_of_processes=None):
remove_ended_processes(expect_exited_processes=False)
while len(processes) < (number_of_processes or settings.WORKER_TASK_PROCESSES):
processes.append(create_minion(tasks))
time.sleep(MASTER_IPC_PERIOD.total_seconds())
time.sleep(IPC_PERIOD.total_seconds())
# stopping
logger.info("Stopping minions...")
stop_minions()
......@@ -136,7 +140,7 @@ def start(tasks=utils.all_tasks, number_of_processes=None):
logger.info("Stopped pikatasks worker.")
def _task_process(tasks, parent_pid):
def _minion_process(tasks, parent_pid):
""" This is a single process, that performs tasks. """
logger = logging.getLogger("pikatasks.worker.minion.pid{0}".format(os.getpid()))
signal_handler = _SignalHandler(logger=logger, this_process_name="minion")
......@@ -144,10 +148,10 @@ def _task_process(tasks, parent_pid):
raise RuntimeError("Got empty list of tasks")
conn, channel = None, None
def control_beat():
def process_controller():
# this function performs IPC and decides when to finish this process
# this function registers itself to be called again
stop = False
# check whether the parent process is alive
if os.getppid() != parent_pid: # got adopted, new owner is probably init
logger.error("Master (PID={0}) has disappeared :( Stopping.".format(parent_pid))
stop = True
......@@ -157,7 +161,7 @@ def _task_process(tasks, parent_pid):
if stop:
channel.stop_consuming()
logger.debug("Stopping consuming messages from queues.")
conn.add_timeout(MINION_IPC_PERIOD.total_seconds(), control_beat) # run this function again soon
conn.call_later(IPC_PERIOD.total_seconds(), process_controller) # run this function again soon
try:
logger.debug("Opening a connection...")
......@@ -168,27 +172,20 @@ def _task_process(tasks, parent_pid):
callback = getattr(task, "as_callback", None)
if not callback or not callable(callback):
raise ValueError("Not a valid task: {0}".format(task))
if pika.__version__ >= "1.0":
# multiple breaking changes in version 1.0.0b1
# if this is broken again, also fix the other basic_consume in this project
channel.basic_consume(queue=task.task_queue, on_message_callback=callback)
else:
channel.basic_consume(consumer_callback=callback, queue=task.task_queue)
logger.debug("Registered task {t} on queue {q}".format(t=task.task_name, q=task.task_queue))
except Exception as e:
logger.error("Could not register task \"{t}\". {e.__class__.__qualname__}: {e}".format(t=task.task_name, e=e))
if isinstance(e, (AMQPChannelError, AMQPConnectionError,)):
raise e # does not make sense to try registering other tasks
control_beat() # initial
channel.start_consuming()
process_controller() # initial
channel.start_consuming() # until stop_consuming is called
except Exception as e:
logger.error("{e.__class__.__qualname__}: {e}".format(e=e))
finally:
if channel and channel.is_open:
logger.info("Closing the channel: {0}".format(channel))
channel.close()
if conn and conn.is_open:
logger.info("Closing the connection: {0}".format(conn))
conn.close()
logger.debug("Stopped.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment