diff --git a/pikatasks/__init__.py b/pikatasks/__init__.py index b4548cee0e719631dfd24aeec9c5b05d2d1855ae..d0b53c783b6cc4bd18b42faa6bc6d45182b35c1e 100644 --- a/pikatasks/__init__.py +++ b/pikatasks/__init__.py @@ -4,6 +4,7 @@ from . import utils from . import settings from . import django_compat from . import worker # keep it here for pikatasks.worker.start() being possible +from . import serialization from .utils import logger import datetime import types @@ -46,7 +47,7 @@ def run(_task_name, **kwargs): channel.basic_publish( exchange=settings.CLIENT_EXCHANGE_NAME, routing_key=_task_name, - body=utils.serialize(kwargs), + body=serialization.serialize(kwargs), properties=pika.BasicProperties() ) except Exception as e: @@ -70,7 +71,7 @@ def rpc(_task_name, **kwargs): nonlocal result, exception channel.stop_consuming() try: - reply = utils.deserialize(body) + reply = serialization.deserialize(body) reply_result, reply_error = reply.get("result"), reply.get("error") if reply_error: exception = RPCRemoteError(reply_error) @@ -100,7 +101,7 @@ def rpc(_task_name, **kwargs): channel.basic_publish( exchange=settings.CLIENT_EXCHANGE_NAME, routing_key=_task_name, - body=utils.serialize(kwargs), + body=serialization.serialize(kwargs), properties=pika.BasicProperties( reply_to="amq.rabbitmq.reply-to", expiration=str(int(settings.RPC_TIMEOUT.total_seconds() * 1000)), # in milliseconds, int as str; expire the message so this RPC will not get remotely executed like years later @@ -154,7 +155,7 @@ def task(name): if django_compat.DJANGO: django_compat.check_fix_db_connection() try: - task_kwargs = utils.deserialize(body) + task_kwargs = serialization.deserialize(body) if not isinstance(task_kwargs, dict): raise TypeError("Bad kwargs type: {0}".format(task_kwargs.__class__.__qualname__)) func_result = func(**task_kwargs) @@ -170,7 +171,7 @@ def task(name): channel.basic_publish( exchange="", # empty string as specified in RabbitMQ documentation, section direct reply-to routing_key=properties.reply_to, - body=utils.serialize(reply)) + body=serialization.serialize(reply)) except Exception as e: logger.error("Could not reply to the {properties.reply_to}. {e.__class__.__name__}: {e}".format(**locals())) else: diff --git a/pikatasks/django_compat.py b/pikatasks/django_compat.py index 521a657862396fcb6853b113e2b002f6d11041c3..ba0fad615e67a542a2a4203bd7edbb7df34f4c6a 100644 --- a/pikatasks/django_compat.py +++ b/pikatasks/django_compat.py @@ -1,6 +1,7 @@ import logging import importlib from . import utils +from .utils import logger try: import django @@ -11,9 +12,6 @@ except ImportError: DJANGO = None -logger = logging.getLogger("pika-tasks") - - def close_db_connections(): """ Closes all Django db connections. diff --git a/pikatasks/serialization.py b/pikatasks/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ca7d6991871117d6e74954d99f9a5903e06f23 --- /dev/null +++ b/pikatasks/serialization.py @@ -0,0 +1,104 @@ +import json +import datetime +import collections +import itertools +from . import settings +from .utils import logger + + +def serialize(stuff): + return json.dumps(stuff, default=json_serialize_tweaks).encode("utf-8") + + +def deserialize(bytes): + return json.loads(bytes.decode("utf-8"), object_hook=json_deserialize_tweaks) + + +def datetime_to_str(dt): + """ + :param dt: datetime (timezone-aware) + :return: serializable str that can be parsed by str_to_datetime + """ + if not dt.tzinfo: + logger.warning("Naive datetime received by serialize_datetime() and will be treated as local: {dt}. Avoid using naive datetime objects.".format(dt=dt)) + utc_dt = dt.astimezone(datetime.timezone.utc) + return datetime.datetime.strftime(utc_dt, settings.DATETIME_FORMAT) + + +def str_to_datetime(text, return_utc=False): + """ + :param text: str created by datetime_to_str + :param utc: set to True if you want this function to return UTC datetime + :return: datetime (timezone-aware) + """ + dt = datetime.datetime.strptime(text, settings.DATETIME_FORMAT) + assert dt.tzinfo, "ok, now that's weird, no tzinfo, but there must have been %z in the DATETIME_FORMAT" + if return_utc: + return dt.astimezone(datetime.timezone.utc) + else: + return dt.astimezone() # not pytz, just old c++ stuff, but still a timezone with correct utc offset + + +def date_to_str(d): + return d.strftime(settings.DATE_FORMAT) + + +def str_to_date(text): + return datetime.datetime.strptime(text, settings.DATE_FORMAT).date() + + +JSON_PYTHON_DATA_TYPE = "__pythonic_type" +JSON_PYTHON_DATA_VALUE = "__pythonic_value" +JSON_ITER_MAX_YIELD = 1000 + + +def json_serialize_tweaks(obj): + """ use this as default argument of json.dump """ + if isinstance(obj, set): + return { + JSON_PYTHON_DATA_TYPE: "set", + JSON_PYTHON_DATA_VALUE: list(obj), + } + elif isinstance(obj, collections.Iterable): + # iterators and other iterables will become lists + elements = list(itertools.islice(obj, JSON_ITER_MAX_YIELD)) # protect from trolls with itertools.repeat() + if len(elements is JSON_ITER_MAX_YIELD): + logger.warning("Will not automatically yield more than {n} elements from {obj}.".format(n=JSON_ITER_MAX_YIELD, obj=obj)) + return elements + elif isinstance(obj, datetime.datetime): + return { + JSON_PYTHON_DATA_TYPE: "datetime", + JSON_PYTHON_DATA_VALUE: datetime_to_str(obj), + } + elif isinstance(obj, datetime.date): + return { + JSON_PYTHON_DATA_TYPE: "date", + JSON_PYTHON_DATA_VALUE: date_to_str(obj), + } + elif isinstance(obj, datetime.timedelta): + return { + JSON_PYTHON_DATA_TYPE: "timedelta", + JSON_PYTHON_DATA_VALUE: obj.total_seconds(), + } + else: + raise TypeError + + +def json_deserialize_tweaks(obj): + """ use this as object_hook argument of json.load """ + if set(obj.keys()) != {JSON_PYTHON_DATA_TYPE, JSON_PYTHON_DATA_VALUE}: + return obj # have nothing to do with it + assert len(obj) is 2 + type_name = obj[JSON_PYTHON_DATA_TYPE] + value = obj[JSON_PYTHON_DATA_VALUE] + if type_name == "set": + return set(value) + elif type_name == "datetime": + return str_to_datetime(value) + elif type_name == "date": + return str_to_date(value) + elif type_name == "timedelta": + return datetime.timedelta(seconds=value) + else: + raise TypeError("Don't know how to deserialize \"{0}\".".format(type_name)) + diff --git a/pikatasks/settings.py b/pikatasks/settings.py index d1a04cc8af2be81f85b753c33093a2a0e465d075..68de782996141b22947fe1b614770e8ff7a9393e 100644 --- a/pikatasks/settings.py +++ b/pikatasks/settings.py @@ -35,6 +35,8 @@ WORKER_GRACEFUL_STOP_TIMEOUT = timedelta(seconds=60) # stuff you probably don't want to touch: BLOCKED_CONNECTION_TIMEOUT = timedelta(seconds=20) # weird stuff to avoid deadlocks, see pika documentation +DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f%Z%z" # serialization, both %Z%z for compatibility in case we ever want to improve timezone-related stuff +DATE_FORMAT = "%Y-%m-%d" # serialization # merge these settings with django.conf.settings diff --git a/pikatasks/utils.py b/pikatasks/utils.py index 9f59caa8f2e29f1b9ebd96932cb77f6ad9d66d44..b5565ba5cb6c215b5b8742cc296031f1fc2ef52f 100644 --- a/pikatasks/utils.py +++ b/pikatasks/utils.py @@ -1,50 +1,13 @@ -import json import pika import logging import ssl -from datetime import datetime, timezone from . import settings -logger = logging.getLogger("pika-tasks") +logger = logging.getLogger("pikatasks") all_tasks = set() # each registered task will show up here -DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f%Z%z" # used for serialization, I leave here both %Z%z for compatibility in case we ever want to improve timezone-related stuff - - -def serialize(stuff): - return json.dumps(stuff).encode("utf-8") - - -def deserialize(bytes): - return json.loads(bytes.decode("utf-8")) - - -def serialize_datetime(dt): - """ - :param dt: datetime (timezone-aware) - :return: str that can be deserialized by deserialize_datetime() - """ - if not dt.tzinfo: - logger.warning("Naive datetime received by serialize_datetime() and will be treated as local: {dt}. Avoid using naive datetime objects.".format(dt=dt)) - utc_dt = dt.astimezone(timezone.utc) - return datetime.strftime(utc_dt, DATETIME_FORMAT) - - -def deserialize_datetime(text, utc=False): - """ - :param text: str created by serialize_datetime() - :param utc: set to True if you want this function to return UTC datetime - :return: datetime (timezone-aware) - """ - dt = datetime.strptime(text, DATETIME_FORMAT) - assert dt.tzinfo, "ok, now that's weird, no tzinfo, but there must have been %z in the DATETIME_FORMAT" - if utc: - return dt.astimezone(timezone.utc) - else: - return dt.astimezone() # not pytz, just old c++ stuff, but still a timezone with correct utc offset - def get_ssl_options(settings): """ Create pika.SSLOptions based on pikatasks settings. """