Skip to content
Snippets Groups Projects
Commit a3f5a3cb authored by Art's avatar Art :lizard:
Browse files

Improve serialization, split utils and serialization modules

parent f527af0d
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,7 @@ from . import utils ...@@ -4,6 +4,7 @@ from . import utils
from . import settings from . import settings
from . import django_compat from . import django_compat
from . import worker # keep it here for pikatasks.worker.start() being possible from . import worker # keep it here for pikatasks.worker.start() being possible
from . import serialization
from .utils import logger from .utils import logger
import datetime import datetime
import types import types
...@@ -46,7 +47,7 @@ def run(_task_name, **kwargs): ...@@ -46,7 +47,7 @@ def run(_task_name, **kwargs):
channel.basic_publish( channel.basic_publish(
exchange=settings.CLIENT_EXCHANGE_NAME, exchange=settings.CLIENT_EXCHANGE_NAME,
routing_key=_task_name, routing_key=_task_name,
body=utils.serialize(kwargs), body=serialization.serialize(kwargs),
properties=pika.BasicProperties() properties=pika.BasicProperties()
) )
except Exception as e: except Exception as e:
...@@ -70,7 +71,7 @@ def rpc(_task_name, **kwargs): ...@@ -70,7 +71,7 @@ def rpc(_task_name, **kwargs):
nonlocal result, exception nonlocal result, exception
channel.stop_consuming() channel.stop_consuming()
try: try:
reply = utils.deserialize(body) reply = serialization.deserialize(body)
reply_result, reply_error = reply.get("result"), reply.get("error") reply_result, reply_error = reply.get("result"), reply.get("error")
if reply_error: if reply_error:
exception = RPCRemoteError(reply_error) exception = RPCRemoteError(reply_error)
...@@ -100,7 +101,7 @@ def rpc(_task_name, **kwargs): ...@@ -100,7 +101,7 @@ def rpc(_task_name, **kwargs):
channel.basic_publish( channel.basic_publish(
exchange=settings.CLIENT_EXCHANGE_NAME, exchange=settings.CLIENT_EXCHANGE_NAME,
routing_key=_task_name, routing_key=_task_name,
body=utils.serialize(kwargs), body=serialization.serialize(kwargs),
properties=pika.BasicProperties( properties=pika.BasicProperties(
reply_to="amq.rabbitmq.reply-to", reply_to="amq.rabbitmq.reply-to",
expiration=str(int(settings.RPC_TIMEOUT.total_seconds() * 1000)), # in milliseconds, int as str; expire the message so this RPC will not get remotely executed like years later expiration=str(int(settings.RPC_TIMEOUT.total_seconds() * 1000)), # in milliseconds, int as str; expire the message so this RPC will not get remotely executed like years later
...@@ -154,7 +155,7 @@ def task(name): ...@@ -154,7 +155,7 @@ def task(name):
if django_compat.DJANGO: if django_compat.DJANGO:
django_compat.check_fix_db_connection() django_compat.check_fix_db_connection()
try: try:
task_kwargs = utils.deserialize(body) task_kwargs = serialization.deserialize(body)
if not isinstance(task_kwargs, dict): if not isinstance(task_kwargs, dict):
raise TypeError("Bad kwargs type: {0}".format(task_kwargs.__class__.__qualname__)) raise TypeError("Bad kwargs type: {0}".format(task_kwargs.__class__.__qualname__))
func_result = func(**task_kwargs) func_result = func(**task_kwargs)
...@@ -170,7 +171,7 @@ def task(name): ...@@ -170,7 +171,7 @@ def task(name):
channel.basic_publish( channel.basic_publish(
exchange="", # empty string as specified in RabbitMQ documentation, section direct reply-to exchange="", # empty string as specified in RabbitMQ documentation, section direct reply-to
routing_key=properties.reply_to, routing_key=properties.reply_to,
body=utils.serialize(reply)) body=serialization.serialize(reply))
except Exception as e: except Exception as e:
logger.error("Could not reply to the {properties.reply_to}. {e.__class__.__name__}: {e}".format(**locals())) logger.error("Could not reply to the {properties.reply_to}. {e.__class__.__name__}: {e}".format(**locals()))
else: else:
......
import logging import logging
import importlib import importlib
from . import utils from . import utils
from .utils import logger
try: try:
import django import django
...@@ -11,9 +12,6 @@ except ImportError: ...@@ -11,9 +12,6 @@ except ImportError:
DJANGO = None DJANGO = None
logger = logging.getLogger("pika-tasks")
def close_db_connections(): def close_db_connections():
""" """
Closes all Django db connections. Closes all Django db connections.
......
import json
import datetime
import collections
import itertools
from . import settings
from .utils import logger
def serialize(stuff):
return json.dumps(stuff, default=json_serialize_tweaks).encode("utf-8")
def deserialize(bytes):
return json.loads(bytes.decode("utf-8"), object_hook=json_deserialize_tweaks)
def datetime_to_str(dt):
"""
:param dt: datetime (timezone-aware)
:return: serializable str that can be parsed by str_to_datetime
"""
if not dt.tzinfo:
logger.warning("Naive datetime received by serialize_datetime() and will be treated as local: {dt}. Avoid using naive datetime objects.".format(dt=dt))
utc_dt = dt.astimezone(datetime.timezone.utc)
return datetime.datetime.strftime(utc_dt, settings.DATETIME_FORMAT)
def str_to_datetime(text, return_utc=False):
"""
:param text: str created by datetime_to_str
:param utc: set to True if you want this function to return UTC datetime
:return: datetime (timezone-aware)
"""
dt = datetime.datetime.strptime(text, settings.DATETIME_FORMAT)
assert dt.tzinfo, "ok, now that's weird, no tzinfo, but there must have been %z in the DATETIME_FORMAT"
if return_utc:
return dt.astimezone(datetime.timezone.utc)
else:
return dt.astimezone() # not pytz, just old c++ stuff, but still a timezone with correct utc offset
def date_to_str(d):
return d.strftime(settings.DATE_FORMAT)
def str_to_date(text):
return datetime.datetime.strptime(text, settings.DATE_FORMAT).date()
JSON_PYTHON_DATA_TYPE = "__pythonic_type"
JSON_PYTHON_DATA_VALUE = "__pythonic_value"
JSON_ITER_MAX_YIELD = 1000
def json_serialize_tweaks(obj):
""" use this as default argument of json.dump """
if isinstance(obj, set):
return {
JSON_PYTHON_DATA_TYPE: "set",
JSON_PYTHON_DATA_VALUE: list(obj),
}
elif isinstance(obj, collections.Iterable):
# iterators and other iterables will become lists
elements = list(itertools.islice(obj, JSON_ITER_MAX_YIELD)) # protect from trolls with itertools.repeat()
if len(elements is JSON_ITER_MAX_YIELD):
logger.warning("Will not automatically yield more than {n} elements from {obj}.".format(n=JSON_ITER_MAX_YIELD, obj=obj))
return elements
elif isinstance(obj, datetime.datetime):
return {
JSON_PYTHON_DATA_TYPE: "datetime",
JSON_PYTHON_DATA_VALUE: datetime_to_str(obj),
}
elif isinstance(obj, datetime.date):
return {
JSON_PYTHON_DATA_TYPE: "date",
JSON_PYTHON_DATA_VALUE: date_to_str(obj),
}
elif isinstance(obj, datetime.timedelta):
return {
JSON_PYTHON_DATA_TYPE: "timedelta",
JSON_PYTHON_DATA_VALUE: obj.total_seconds(),
}
else:
raise TypeError
def json_deserialize_tweaks(obj):
""" use this as object_hook argument of json.load """
if set(obj.keys()) != {JSON_PYTHON_DATA_TYPE, JSON_PYTHON_DATA_VALUE}:
return obj # have nothing to do with it
assert len(obj) is 2
type_name = obj[JSON_PYTHON_DATA_TYPE]
value = obj[JSON_PYTHON_DATA_VALUE]
if type_name == "set":
return set(value)
elif type_name == "datetime":
return str_to_datetime(value)
elif type_name == "date":
return str_to_date(value)
elif type_name == "timedelta":
return datetime.timedelta(seconds=value)
else:
raise TypeError("Don't know how to deserialize \"{0}\".".format(type_name))
...@@ -35,6 +35,8 @@ WORKER_GRACEFUL_STOP_TIMEOUT = timedelta(seconds=60) ...@@ -35,6 +35,8 @@ WORKER_GRACEFUL_STOP_TIMEOUT = timedelta(seconds=60)
# stuff you probably don't want to touch: # stuff you probably don't want to touch:
BLOCKED_CONNECTION_TIMEOUT = timedelta(seconds=20) # weird stuff to avoid deadlocks, see pika documentation BLOCKED_CONNECTION_TIMEOUT = timedelta(seconds=20) # weird stuff to avoid deadlocks, see pika documentation
DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f%Z%z" # serialization, both %Z%z for compatibility in case we ever want to improve timezone-related stuff
DATE_FORMAT = "%Y-%m-%d" # serialization
# merge these settings with django.conf.settings # merge these settings with django.conf.settings
......
import json
import pika import pika
import logging import logging
import ssl import ssl
from datetime import datetime, timezone
from . import settings from . import settings
logger = logging.getLogger("pika-tasks") logger = logging.getLogger("pikatasks")
all_tasks = set() # each registered task will show up here all_tasks = set() # each registered task will show up here
DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f%Z%z" # used for serialization, I leave here both %Z%z for compatibility in case we ever want to improve timezone-related stuff
def serialize(stuff):
return json.dumps(stuff).encode("utf-8")
def deserialize(bytes):
return json.loads(bytes.decode("utf-8"))
def serialize_datetime(dt):
"""
:param dt: datetime (timezone-aware)
:return: str that can be deserialized by deserialize_datetime()
"""
if not dt.tzinfo:
logger.warning("Naive datetime received by serialize_datetime() and will be treated as local: {dt}. Avoid using naive datetime objects.".format(dt=dt))
utc_dt = dt.astimezone(timezone.utc)
return datetime.strftime(utc_dt, DATETIME_FORMAT)
def deserialize_datetime(text, utc=False):
"""
:param text: str created by serialize_datetime()
:param utc: set to True if you want this function to return UTC datetime
:return: datetime (timezone-aware)
"""
dt = datetime.strptime(text, DATETIME_FORMAT)
assert dt.tzinfo, "ok, now that's weird, no tzinfo, but there must have been %z in the DATETIME_FORMAT"
if utc:
return dt.astimezone(timezone.utc)
else:
return dt.astimezone() # not pytz, just old c++ stuff, but still a timezone with correct utc offset
def get_ssl_options(settings): def get_ssl_options(settings):
""" Create pika.SSLOptions based on pikatasks settings. """ """ Create pika.SSLOptions based on pikatasks settings. """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment