diff --git a/pikatasks/__init__.py b/pikatasks/__init__.py index d0b53c783b6cc4bd18b42faa6bc6d45182b35c1e..1f9f1f26581537b24798bbc96887f8ec08534cd1 100644 --- a/pikatasks/__init__.py +++ b/pikatasks/__init__.py @@ -33,6 +33,16 @@ class RPCMessageQueueError(RPCError): all_tasks = set() # each registered task will show up here +def serialize_to_bytes(data): + # pika documentation is very fuzzy about it (and types in general) + return serialization.serialize(data).encode("utf-8") + + +def deserialize_from_bytes(b): + # seems like bytes are received from a channel + return serialization.deserialize(b.decode("utf-8")) + + def run(_task_name, **kwargs): """ Runs a task remotely. @@ -47,7 +57,7 @@ def run(_task_name, **kwargs): channel.basic_publish( exchange=settings.CLIENT_EXCHANGE_NAME, routing_key=_task_name, - body=serialization.serialize(kwargs), + body=serialize_to_bytes(kwargs), properties=pika.BasicProperties() ) except Exception as e: @@ -71,7 +81,7 @@ def rpc(_task_name, **kwargs): nonlocal result, exception channel.stop_consuming() try: - reply = serialization.deserialize(body) + reply = deserialize_from_bytes(body) reply_result, reply_error = reply.get("result"), reply.get("error") if reply_error: exception = RPCRemoteError(reply_error) @@ -101,7 +111,7 @@ def rpc(_task_name, **kwargs): channel.basic_publish( exchange=settings.CLIENT_EXCHANGE_NAME, routing_key=_task_name, - body=serialization.serialize(kwargs), + body=serialize_to_bytes(kwargs), properties=pika.BasicProperties( reply_to="amq.rabbitmq.reply-to", expiration=str(int(settings.RPC_TIMEOUT.total_seconds() * 1000)), # in milliseconds, int as str; expire the message so this RPC will not get remotely executed like years later @@ -155,7 +165,7 @@ def task(name): if django_compat.DJANGO: django_compat.check_fix_db_connection() try: - task_kwargs = serialization.deserialize(body) + task_kwargs = deserialize_from_bytes(body) if not isinstance(task_kwargs, dict): raise TypeError("Bad kwargs type: {0}".format(task_kwargs.__class__.__qualname__)) func_result = func(**task_kwargs) @@ -171,7 +181,7 @@ def task(name): channel.basic_publish( exchange="", # empty string as specified in RabbitMQ documentation, section direct reply-to routing_key=properties.reply_to, - body=serialization.serialize(reply)) + body=serialize_to_bytes(reply)) except Exception as e: logger.error("Could not reply to the {properties.reply_to}. {e.__class__.__name__}: {e}".format(**locals())) else: diff --git a/pikatasks/serialization.py b/pikatasks/serialization.py index c8ca7d6991871117d6e74954d99f9a5903e06f23..d73c3c0949323e720b34c349a2854b2bdbaaa0e0 100644 --- a/pikatasks/serialization.py +++ b/pikatasks/serialization.py @@ -6,12 +6,12 @@ from . import settings from .utils import logger -def serialize(stuff): - return json.dumps(stuff, default=json_serialize_tweaks).encode("utf-8") +def serialize(data): + return json.dumps(data, default=json_serialize_tweaks) -def deserialize(bytes): - return json.loads(bytes.decode("utf-8"), object_hook=json_deserialize_tweaks) +def deserialize(text): + return json.loads(text, object_hook=json_deserialize_tweaks) def datetime_to_str(dt):