From a3f5a3cbf77d4fa4bcb336dacc54216175fe57c3 Mon Sep 17 00:00:00 2001
From: Art Lukyanchyk <artiom.lukyanchyk@hs-hannover.de>
Date: Tue, 21 Aug 2018 18:52:10 +0200
Subject: [PATCH] Improve serialization, split utils and serialization modules

---
 pikatasks/__init__.py      |  11 ++--
 pikatasks/django_compat.py |   4 +-
 pikatasks/serialization.py | 104 +++++++++++++++++++++++++++++++++++++
 pikatasks/settings.py      |   2 +
 pikatasks/utils.py         |  39 +-------------
 5 files changed, 114 insertions(+), 46 deletions(-)
 create mode 100644 pikatasks/serialization.py

diff --git a/pikatasks/__init__.py b/pikatasks/__init__.py
index b4548ce..d0b53c7 100644
--- a/pikatasks/__init__.py
+++ b/pikatasks/__init__.py
@@ -4,6 +4,7 @@ from . import utils
 from . import settings
 from . import django_compat
 from . import worker  # keep it here for pikatasks.worker.start() being possible
+from . import serialization
 from .utils import logger
 import datetime
 import types
@@ -46,7 +47,7 @@ def run(_task_name, **kwargs):
             channel.basic_publish(
                 exchange=settings.CLIENT_EXCHANGE_NAME,
                 routing_key=_task_name,
-                body=utils.serialize(kwargs),
+                body=serialization.serialize(kwargs),
                 properties=pika.BasicProperties()
             )
     except Exception as e:
@@ -70,7 +71,7 @@ def rpc(_task_name, **kwargs):
         nonlocal result, exception
         channel.stop_consuming()
         try:
-            reply = utils.deserialize(body)
+            reply = serialization.deserialize(body)
             reply_result, reply_error = reply.get("result"), reply.get("error")
             if reply_error:
                 exception = RPCRemoteError(reply_error)
@@ -100,7 +101,7 @@ def rpc(_task_name, **kwargs):
             channel.basic_publish(
                 exchange=settings.CLIENT_EXCHANGE_NAME,
                 routing_key=_task_name,
-                body=utils.serialize(kwargs),
+                body=serialization.serialize(kwargs),
                 properties=pika.BasicProperties(
                     reply_to="amq.rabbitmq.reply-to",
                     expiration=str(int(settings.RPC_TIMEOUT.total_seconds() * 1000)),  # in milliseconds, int as str; expire the message so this RPC will not get remotely executed like years later
@@ -154,7 +155,7 @@ def task(name):
             if django_compat.DJANGO:
                 django_compat.check_fix_db_connection()
             try:
-                task_kwargs = utils.deserialize(body)
+                task_kwargs = serialization.deserialize(body)
                 if not isinstance(task_kwargs, dict):
                     raise TypeError("Bad kwargs type: {0}".format(task_kwargs.__class__.__qualname__))
                 func_result = func(**task_kwargs)
@@ -170,7 +171,7 @@ def task(name):
                     channel.basic_publish(
                         exchange="",  # empty string as specified in RabbitMQ documentation, section direct reply-to
                         routing_key=properties.reply_to,
-                        body=utils.serialize(reply))
+                        body=serialization.serialize(reply))
                 except Exception as e:
                     logger.error("Could not reply to the {properties.reply_to}. {e.__class__.__name__}: {e}".format(**locals()))
             else:
diff --git a/pikatasks/django_compat.py b/pikatasks/django_compat.py
index 521a657..ba0fad6 100644
--- a/pikatasks/django_compat.py
+++ b/pikatasks/django_compat.py
@@ -1,6 +1,7 @@
 import logging
 import importlib
 from . import utils
+from .utils import logger
 
 try:
     import django
@@ -11,9 +12,6 @@ except ImportError:
     DJANGO = None
 
 
-logger = logging.getLogger("pika-tasks")
-
-
 def close_db_connections():
     """
     Closes all Django db connections.
diff --git a/pikatasks/serialization.py b/pikatasks/serialization.py
new file mode 100644
index 0000000..c8ca7d6
--- /dev/null
+++ b/pikatasks/serialization.py
@@ -0,0 +1,104 @@
+import json
+import datetime
+import collections
+import itertools
+from . import settings
+from .utils import logger
+
+
+def serialize(stuff):
+    return json.dumps(stuff, default=json_serialize_tweaks).encode("utf-8")
+
+
+def deserialize(bytes):
+    return json.loads(bytes.decode("utf-8"), object_hook=json_deserialize_tweaks)
+
+
+def datetime_to_str(dt):
+    """
+    :param dt: datetime (timezone-aware)
+    :return: serializable str that can be parsed by str_to_datetime
+    """
+    if not dt.tzinfo:
+        logger.warning("Naive datetime received by serialize_datetime() and will be treated as local: {dt}. Avoid using naive datetime objects.".format(dt=dt))
+    utc_dt = dt.astimezone(datetime.timezone.utc)
+    return datetime.datetime.strftime(utc_dt, settings.DATETIME_FORMAT)
+
+
+def str_to_datetime(text, return_utc=False):
+    """
+    :param text: str created by datetime_to_str
+    :param utc: set to True if you want this function to return UTC datetime
+    :return: datetime (timezone-aware)
+    """
+    dt = datetime.datetime.strptime(text, settings.DATETIME_FORMAT)
+    assert dt.tzinfo, "ok, now that's weird, no tzinfo, but there must have been %z in the DATETIME_FORMAT"
+    if return_utc:
+        return dt.astimezone(datetime.timezone.utc)
+    else:
+        return dt.astimezone()  # not pytz, just old c++ stuff, but still a timezone with correct utc offset
+
+
+def date_to_str(d):
+    return d.strftime(settings.DATE_FORMAT)
+
+
+def str_to_date(text):
+    return datetime.datetime.strptime(text, settings.DATE_FORMAT).date()
+
+
+JSON_PYTHON_DATA_TYPE = "__pythonic_type"
+JSON_PYTHON_DATA_VALUE = "__pythonic_value"
+JSON_ITER_MAX_YIELD = 1000
+
+
+def json_serialize_tweaks(obj):
+    """ use this as default argument of json.dump """
+    if isinstance(obj, set):
+        return {
+            JSON_PYTHON_DATA_TYPE: "set",
+            JSON_PYTHON_DATA_VALUE: list(obj),
+        }
+    elif isinstance(obj, collections.Iterable):
+        # iterators and other iterables will become lists
+        elements = list(itertools.islice(obj, JSON_ITER_MAX_YIELD))  # protect from trolls with itertools.repeat()
+        if len(elements is JSON_ITER_MAX_YIELD):
+            logger.warning("Will not automatically yield more than {n} elements from {obj}.".format(n=JSON_ITER_MAX_YIELD, obj=obj))
+        return elements
+    elif isinstance(obj, datetime.datetime):
+        return {
+            JSON_PYTHON_DATA_TYPE: "datetime",
+            JSON_PYTHON_DATA_VALUE: datetime_to_str(obj),
+        }
+    elif isinstance(obj, datetime.date):
+        return {
+            JSON_PYTHON_DATA_TYPE: "date",
+            JSON_PYTHON_DATA_VALUE: date_to_str(obj),
+        }
+    elif isinstance(obj, datetime.timedelta):
+        return {
+            JSON_PYTHON_DATA_TYPE: "timedelta",
+            JSON_PYTHON_DATA_VALUE: obj.total_seconds(),
+        }
+    else:
+        raise TypeError
+
+
+def json_deserialize_tweaks(obj):
+    """ use this as object_hook argument of json.load """
+    if set(obj.keys()) != {JSON_PYTHON_DATA_TYPE, JSON_PYTHON_DATA_VALUE}:
+        return obj  # have nothing to do with it
+    assert len(obj) is 2
+    type_name = obj[JSON_PYTHON_DATA_TYPE]
+    value = obj[JSON_PYTHON_DATA_VALUE]
+    if type_name == "set":
+        return set(value)
+    elif type_name == "datetime":
+        return str_to_datetime(value)
+    elif type_name == "date":
+        return str_to_date(value)
+    elif type_name == "timedelta":
+        return datetime.timedelta(seconds=value)
+    else:
+        raise TypeError("Don't know how to deserialize \"{0}\".".format(type_name))
+
diff --git a/pikatasks/settings.py b/pikatasks/settings.py
index d1a04cc..68de782 100644
--- a/pikatasks/settings.py
+++ b/pikatasks/settings.py
@@ -35,6 +35,8 @@ WORKER_GRACEFUL_STOP_TIMEOUT = timedelta(seconds=60)
 
 # stuff you probably don't want to touch:
 BLOCKED_CONNECTION_TIMEOUT = timedelta(seconds=20)  # weird stuff to avoid deadlocks, see pika documentation
+DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f%Z%z"  # serialization, both %Z%z for compatibility in case we ever want to improve timezone-related stuff
+DATE_FORMAT = "%Y-%m-%d"  # serialization
 
 
 # merge these settings with django.conf.settings
diff --git a/pikatasks/utils.py b/pikatasks/utils.py
index 9f59caa..b5565ba 100644
--- a/pikatasks/utils.py
+++ b/pikatasks/utils.py
@@ -1,50 +1,13 @@
-import json
 import pika
 import logging
 import ssl
-from datetime import datetime, timezone
 from . import settings
 
 
-logger = logging.getLogger("pika-tasks")
+logger = logging.getLogger("pikatasks")
 
 all_tasks = set()  # each registered task will show up here
 
-DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f%Z%z"  # used for serialization, I leave here both %Z%z for compatibility in case we ever want to improve timezone-related stuff
-
-
-def serialize(stuff):
-    return json.dumps(stuff).encode("utf-8")
-
-
-def deserialize(bytes):
-    return json.loads(bytes.decode("utf-8"))
-
-
-def serialize_datetime(dt):
-    """
-    :param dt: datetime (timezone-aware)
-    :return: str that can be deserialized by deserialize_datetime()
-    """
-    if not dt.tzinfo:
-        logger.warning("Naive datetime received by serialize_datetime() and will be treated as local: {dt}. Avoid using naive datetime objects.".format(dt=dt))
-    utc_dt = dt.astimezone(timezone.utc)
-    return datetime.strftime(utc_dt, DATETIME_FORMAT)
-
-
-def deserialize_datetime(text, utc=False):
-    """
-    :param text: str created by serialize_datetime()
-    :param utc: set to True if you want this function to return UTC datetime
-    :return: datetime (timezone-aware)
-    """
-    dt = datetime.strptime(text, DATETIME_FORMAT)
-    assert dt.tzinfo, "ok, now that's weird, no tzinfo, but there must have been %z in the DATETIME_FORMAT"
-    if utc:
-        return dt.astimezone(timezone.utc)
-    else:
-        return dt.astimezone()  # not pytz, just old c++ stuff, but still a timezone with correct utc offset
-
 
 def get_ssl_options(settings):
     """ Create pika.SSLOptions based on pikatasks settings. """
-- 
GitLab