diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 6e53c41ec97ecb0b42d871e8d24bc8458618e1f2..cf481b167291406154f2b39de9541eeb35284c26 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,4 +1,4 @@ -image: python:3.5-buster +image: python:3.7-buster test: script: diff --git a/README.md b/README.md index e58a6d511646092806effe58df5ff5b14777ad1f..09688df4f2ac8c687cc9a6f511dffbcae753e421 100644 --- a/README.md +++ b/README.md @@ -24,17 +24,17 @@ with postgrestutils.Session() as s: By default constructing a new `postgrestutils.Session` will take the settings discussed in [setup](#setup) into account. Hence there is no need to specify `base_uri` or `token` explicitly unless you are using more than one API or database role in your project. -Additionally `postgrestutils.Session` takes `schema: Optional[str] = None`, `parse_dt: bool = True` and `count: postgrestutils.Count = postgrestutils.Count.NONE` (some of which are explained later on). +Additionally `postgrestutils.Session` takes `schema: Optional[str] = None`, `parser: postgrestutils.Parser = postgrestutils.Parser.DATETIME | postgrestutils.Parser.DATE` and `count: postgrestutils.Count = postgrestutils.Count.NONE` (some of which are explained later on). These options are session defaults and may be overridden on a per-request basis, e.g. ```python import postgrestutils -with postgrestutils.Session(parse_dt=False) as s: - print(s.get('random_datetime')) # makes request using parse_dt=False - print(s.get('random_datetime', parse_dt=True)) # makes request using parse_dt=True - print(s.get('random_datetime')) # makes request using parse_dt=False +with postgrestutils.Session(parser=postgrestutils.Parser.NONE) as s: + print(s.get('random_datetime')) # makes request using parser=postgrestutils.Parser.NONE + print(s.get('random_datetime', parser=postgrestutils.Parser.DATETIME)) # makes request using parser=postgrestutils.Parser.DATETIME + print(s.get('random_datetime')) # makes request using parser=postgrestutils.Parser.NONE ``` ### Setup diff --git a/postgrestutils/__init__.py b/postgrestutils/__init__.py index 7ee8e39de7b95475163e8b11ed867575ca2af942..56d0de8ca591725c9b925fe429745771fb3a6769 100644 --- a/postgrestutils/__init__.py +++ b/postgrestutils/__init__.py @@ -1,30 +1,20 @@ +# pyright: reportImportCycles=false import copy -import enum import re -from typing import Optional +from typing import Any, Dict, Iterator, List, Optional, Union from urllib.parse import urljoin import requests -from requests import HTTPError # re-export to allow for exception handling +from requests import HTTPError + +from postgrestutils.typing import Count, JsonDict, Parser from . import app_settings -from .utils import datetime_parser, logger +from .utils import logger, parse_with_custom_parsers default_app_config = "postgrestutils.apps.PostgrestUtilsConfig" - REPR_OUTPUT_SIZE = 20 - -Count = enum.Enum( - "Count", - ( - ("NONE", None), - ("EXACT", "exact"), - ("PLANNED", "planned"), - ("ESTIMATED", "estimated"), - ), -) - DEFAULT_SCHEMA = object() @@ -49,49 +39,49 @@ class Session: base_uri: Optional[str] = None, token: Optional[str] = None, schema: Optional[str] = None, - parse_dt: bool = True, count: Count = Count.NONE, + parser: Parser = Parser.DATETIME | Parser.DATE, ): """ :param base_uri: base uri of the PostgREST instance to use :param token: JWT for the corresponding database role and PostgREST instance :param schema: the database schema to use - :param parse_dt: whether to parse datetime strings as returned by PostgREST to python datetime objects :param count: counting strategy as explained in the README + :param parser: parsers to use on the JSON returned by PostgREST """ self.session = None self.base_uri = base_uri or app_settings.BASE_URI self.token = token or app_settings.JWT self.schema = schema - self.parse_dt = parse_dt self.count = count + self.parser = parser def __enter__(self): self.session = requests.Session() self._configure_session_defaults() return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): + assert self.session is not None self.session.close() def get( self, endpoint: str, - parse_dt: Optional[bool] = None, count: Optional[Count] = None, schema: Optional[str] = None, - **kwargs - ) -> dict: + parser: Optional[Parser] = None, + **kwargs: Dict[str, Any] + ) -> JsonDict: """ Get a single object from the specified endpoint. This will most likely require specifying params so that only a single object is found. :param endpoint: specifies which endpoint to request when multiple results are returned) - :param parse_dt: whether to parse datetime strings as returned by - PostgREST to python datetime objects :param count: counting strategy as explained in the README + :param parser: parsers to use on the JSON returned by PostgREST :param kwargs: pass kwargs directly to requests's `.get()` method :return: single element as dict :raises: `ObjectDoesNotExist`/`MultipleObjectsReturned` if no or more @@ -103,8 +93,8 @@ class Session: self, endpoint, True, - parse_dt if parse_dt is not None else self.parse_dt, count if count is not None else self.count, + parser if parser is not None else self.parser, **kwargs, ) # populate the cache @@ -117,15 +107,13 @@ class Session: def filter( self, endpoint: str, - parse_dt: Optional[bool] = None, count: Optional[Count] = None, schema: Optional[str] = None, - **kwargs + parser: Optional[Parser] = None, + **kwargs: Dict[str, Any] ) -> "JsonResultSet": """ :param endpoint: specifies which endpoint to request - :param parse_dt: whether to parse datetime strings as returned by - PostgREST to python datetime objects :param count: counting strategy as explained in the README :param kwargs: pass kwargs directly to request's `.get()` method :return: `JsonResultSet`, a lazy python object @@ -136,22 +124,23 @@ class Session: self, endpoint, False, - parse_dt if parse_dt is not None else self.parse_dt, count if count is not None else self.count, + parser if parser is not None else self.parser, **kwargs, ) def _configure_session_defaults(self): + assert self.session is not None self.session.headers["Accept"] = "application/json" if self.token: self.session.headers["Authorization"] = "Bearer {}".format(self.token) if self.schema is not None: self.session.headers["Accept-Profile"] = self.schema - def _set_schema_header(self, schema, kwargs: dict): - if schema is DEFAULT_SCHEMA: - schema = None - kwargs.setdefault("headers", dict())["Accept-Profile"] = schema + def _set_schema_header(self, schema: str, kwargs: Dict[str, Any]): + kwargs.setdefault("headers", dict())["Accept-Profile"] = ( + None if schema is DEFAULT_SCHEMA else schema + ) class JsonResultSet: @@ -168,31 +157,32 @@ class JsonResultSet: client: Session, endpoint: str, singular: bool, - parse_dt: bool, count: Count, - **kwargs + parser: Parser = Parser.DATETIME, + **kwargs: Dict[str, Any] ): - self._len_cache = None # type: Optional[int] - self._result_cache = None # type: Optional[list] + self._len_cache: Optional[int] = None + self._result_cache: Optional[List[JsonDict]] = None - self.client = client # type: Session - self.endpoint = endpoint # type: str - self.singular = singular # type: bool - self.parse_dt = parse_dt # type: bool - self.count = count # type: Count - self.request_kwargs = kwargs + self.client: Session = client + self.endpoint: str = endpoint + self.singular: bool = singular + self.count: Count = count + self.request_kwargs: Dict[str, Any] = kwargs + self.parser: Parser = parser - def __repr__(self): + def __repr__(self) -> str: data = list(self[: REPR_OUTPUT_SIZE + 1]) if len(data) > REPR_OUTPUT_SIZE: data[-1] = "...(remaining elements truncated)..." return "<{} {}>".format(self.__class__.__name__, data) - def __iter__(self): + def __iter__(self) -> Iterator[JsonDict]: self._fetch_all() + assert self._result_cache is not None return iter(self._result_cache) - def __len__(self): + def __len__(self) -> int: """ NOTE: Since singular requests (using `.get()`) return a python dict rather than a `JsonResultSet`, `self.singular` should be ignored here. @@ -201,18 +191,18 @@ class JsonResultSet: self._fetch_len() else: self._fetch_all() + assert self._len_cache is not None return self._len_cache - def __getitem__(self, key): + def __getitem__(self, key: Any) -> Union[JsonDict, List[JsonDict]]: """ NOTE: Since singular requests (using `.get()`) return a python dict rather than a `JsonResultSet`, `self.singular` should be ignored here. """ if not isinstance(key, (int, slice)): raise TypeError( - "{self.__class__.__name__} indices must be integers or slices, not {key.__class__.__name__}".format( - self=self, key=key - ) + "{self.__class__.__name__} indices must be integers or slices, not" + " {key.__class__.__name__}".format(self=self, key=key) ) if (isinstance(key, int) and key < 0) or ( isinstance(key, slice) @@ -272,6 +262,7 @@ class JsonResultSet: # Have to request something so just fetch the first item request_kwargs["headers"]["Range"] = "0-0" + assert self.client.session is not None resp = self.client.session.get( urljoin(self.client.base_uri, self.endpoint), **request_kwargs ) @@ -299,6 +290,7 @@ class JsonResultSet: "Accept" ] = "application/vnd.pgrst.object+json" + assert self.client.session is not None resp = self.client.session.get( urljoin(self.client.base_uri, self.endpoint), **request_kwargs ) @@ -307,7 +299,7 @@ class JsonResultSet: # fetched all elements anyway, caching their length is very cheap self._len_cache = len(self._result_cache) - def _fetch_range(self, range): + def _fetch_range(self, range: str) -> List[JsonDict]: """ Fetch a range of elements from the PostgREST API. NOTE: This method should ignore `self.singular`, see `__getitem__()` for @@ -317,12 +309,13 @@ class JsonResultSet: request_kwargs.setdefault("headers", dict())["Range-Unit"] = "items" request_kwargs["headers"]["Range"] = range + assert self.client.session is not None resp = self.client.session.get( urljoin(self.client.base_uri, self.endpoint), **request_kwargs ) return self._parse_response(resp) - def _parse_response(self, resp): + def _parse_response(self, resp: requests.Response) -> List[JsonDict]: """ Parse response as json and return the result if it was successful. Attempt to detect common error cases in order to raise meaningful error @@ -334,7 +327,7 @@ class JsonResultSet: """ try: resp.raise_for_status() - except requests.HTTPError as e: + except HTTPError as e: # try getting a more detailed exception if status_code = 406 if resp.status_code == 406: try: @@ -351,14 +344,14 @@ class JsonResultSet: request=e.request, ) - if self.parse_dt: - json_result = resp.json(object_hook=datetime_parser) + if self.parser: + json_result = parse_with_custom_parsers(resp, self.parser) else: json_result = resp.json() # always return a list even if it contains a single element only return [json_result] if self.singular else json_result - def _try_parse_406(self, resp): + def _try_parse_406(self, resp: requests.Response): """ Try parsing a 406 `HTTPError` to raise a more detailed error message. :param resp: the HTTP response to parse @@ -366,7 +359,8 @@ class JsonResultSet: observed row count. """ detail_regex = re.compile( - r"Results contain (?P<row_count>\d+) rows, application/vnd\.pgrst\.object\+json requires 1 row" + r"Results contain (?P<row_count>\d+) rows," + r" application/vnd\.pgrst\.object\+json requires 1 row" ) try: json = resp.json() diff --git a/postgrestutils/_django_utils.py b/postgrestutils/_django_utils.py index 8d57cb6a465bc4d95d59e9a3495e7770cb5de900..bdea8c670bf9492cac292e8ae85b5f43f8c4c309 100644 --- a/postgrestutils/_django_utils.py +++ b/postgrestutils/_django_utils.py @@ -4,14 +4,10 @@ It should never be imported directly. Instead import the .utils module which will re-export this module's items if django is available. """ - -from datetime import datetime -from typing import Union - from django.conf import settings from django.utils import dateparse, timezone as django_tz -import postgrestutils +from postgrestutils.typing import DatetimeOrStr from . import app_settings from .signals import user_account_fetched @@ -19,6 +15,8 @@ from .signals import user_account_fetched def autofetch(sender, **kwargs): """Fetch user account on login based on the AUTOFETCH configuration""" + import postgrestutils + payload = {"select": app_settings.AUTOFETCH} if settings.DEBUG: @@ -32,7 +30,7 @@ def autofetch(sender, **kwargs): user_account_fetched.send(sender=None, request=kwargs["request"], account=account) -def _try_django_parse_dt(value: str) -> Union[datetime, str]: +def try_django_parse_dt(value: str) -> DatetimeOrStr: """ Attempt to parse `value` as a `datetime` using django utilities. :param value: the string to parse diff --git a/postgrestutils/typing.py b/postgrestutils/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..ed6ab26014866ada7ff40d01166ad7f9d0ad06b4 --- /dev/null +++ b/postgrestutils/typing.py @@ -0,0 +1,23 @@ +import datetime +import enum +from typing import Any, Callable, Dict, Tuple, Union + +ParserFunction = Callable[[Any], Any] +ParserMapping = Tuple["Parser", ParserFunction] +DatetimeOrStr = Union[datetime.datetime, str] +DatetimeParser = Callable[[str], DatetimeOrStr] +DateOrStr = Union[datetime.date, str] +JsonDict = Dict[str, Any] + + +class Count(enum.Enum): + NONE = None + EXACT = "exact" + PLANNED = "planned" + ESTIMATED = "estimated" + + +class Parser(enum.Flag): + NONE = 0 + DATE = enum.auto() + DATETIME = enum.auto() diff --git a/postgrestutils/utils.py b/postgrestutils/utils.py index cf21a3fca42b2e96933464be13d075069ae16be9..f7f1d4b036ffb00c89a9602146fd07355ea9ca61 100644 --- a/postgrestutils/utils.py +++ b/postgrestutils/utils.py @@ -1,12 +1,26 @@ +import datetime +import functools import importlib.util import logging import re -from datetime import datetime, timedelta, timezone -from typing import Dict, Union + +# from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Match, Tuple, Type + +from requests.models import Response + +from postgrestutils.typing import ( + DateOrStr, + DatetimeOrStr, + DatetimeParser, + JsonDict, + Parser, + ParserMapping, +) _DJANGO = importlib.util.find_spec("django") is not None if _DJANGO: - from ._django_utils import _try_django_parse_dt, autofetch + from ._django_utils import autofetch, try_django_parse_dt # noqa: 401 logger = logging.getLogger("postgrestutils") @@ -26,18 +40,21 @@ JSON_TIMESTAMP_REGEX = re.compile( ) -def _clean_parts(parts: Dict[str, str]): - cleaned = {} # type: Dict[str, Union[int, str]] - for key, value in parts.items(): +def _split_datetime_parts( + match: Match[str], +) -> Tuple[Dict[str, int], Dict[str, str]]: + datetime: Dict[str, int] = dict() + timezone: Dict[str, str] = dict() + for key, value in match.groupdict().items(): if value: if "offset" not in key: - cleaned[key] = int(value) + datetime[key] = int(value) else: - cleaned[key] = value - return cleaned + timezone[key] = value + return datetime, timezone -def _try_python_parse_dt(value: str) -> Union[datetime, str]: +def _try_python_parse_dt(value: str) -> DatetimeOrStr: """ Attempt to parse value as a datetime using only python utilities. :param value: the string to parse @@ -45,40 +62,76 @@ def _try_python_parse_dt(value: str) -> Union[datetime, str]: """ match = JSON_TIMESTAMP_REGEX.match(value) if match: - parts = _clean_parts(match.groupdict()) - if ( - parts.get("offsetsign") - and parts.get("offsethours") - and parts.get("offsetminutes") - ): - sign = -1 if parts.pop("offsetsign", "+") == "-" else 1 - tz = timezone( + dt, tz = _split_datetime_parts(match) + if tz: + sign = -1 if tz.get("offsetsign", "+") == "-" else 1 + tz = datetime.timezone( offset=sign - * timedelta( - hours=int(parts.pop("offsethours")), - minutes=int(parts.pop("offsetminutes")), + * datetime.timedelta( + hours=int(tz.pop("offsethours")), + minutes=int(tz.pop("offsetminutes")), ) ) - parsed_dt = datetime(**parts).replace(tzinfo=tz).astimezone() + parsed_dt = datetime.datetime(**dt).replace(tzinfo=tz).astimezone() else: # naive datetime so we assume local time - local_tz = datetime.now(timezone.utc).astimezone().tzinfo - parsed_dt = datetime(**parts).replace(tzinfo=local_tz) + local_tz = datetime.datetime.now().astimezone().tzinfo + parsed_dt = datetime.datetime(**dt).replace(tzinfo=local_tz) return parsed_dt return value -_try_parse_dt = _try_django_parse_dt if _DJANGO else _try_python_parse_dt +def _lazy_try_parse_dt() -> DatetimeParser: + if _DJANGO: + return try_django_parse_dt + else: + return _try_python_parse_dt -def datetime_parser(json_dict: dict) -> dict: - """ - A function to use as `object_hook` when deserializing JSON that parses - datetime strings to timezone-aware datetime objects. - :param json_dict: the original `json_dict` to process - :return: the modified `json_dict` - """ +_try_parse_dt = _lazy_try_parse_dt() + + +def _try_parse_date(value: str) -> DateOrStr: + try: + return datetime.datetime.strptime(value, "%Y-%m-%d").date() + except ValueError: + return value + + +def _json_object_hook( + json_dict: JsonDict, + parsers_dict: Dict[Type[Any], Tuple[ParserMapping, ...]] = dict(), + parser_flags: Parser = Parser.NONE, +) -> JsonDict: + def _try_applying_parsers( + json_dict: JsonDict, + type_to_parse: Any, + parser_mappings: Tuple[ParserMapping, ...], + ): + for parser, func in parser_mappings: + if parser & parser_flags: + maybe_parsed_value = func(value) + + # check if parsing succeeded + if not isinstance(maybe_parsed_value, type_to_parse): + json_dict[key] = maybe_parsed_value + break + for key, value in json_dict.items(): - if isinstance(value, str): - json_dict[key] = _try_parse_dt(value) + type_to_parse = type(value) + parsers = parsers_dict.get(type_to_parse) + + if parsers is not None: + _try_applying_parsers(json_dict, type_to_parse, parsers) return json_dict + + +def parse_with_custom_parsers(resp: Response, parser: Parser): + parsers_dict = { + str: ((Parser.DATE, _try_parse_date), (Parser.DATETIME, _try_parse_dt)), + } + object_hook = functools.partial( + _json_object_hook, parsers_dict=parsers_dict, parser_flags=parser + ) + + return resp.json(object_hook=object_hook) diff --git a/py.typed b/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyproject.toml b/pyproject.toml index 8e790bf46af6d8c1541a7becf53614793bccaf18..2b3a6fabf4f90f75953fe246e19e0ef317636845 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,3 +9,7 @@ default_section = "THIRDPARTY" known_first_party = "postgrestutils" known_django = "django" sections = "FUTURE,STDLIB,DJANGO,THIRDPARTY,FIRSTPARTY,LOCALFOLDER" + +[tool.pyright] +reportUnusedImport = false +reportPrivateUsage = false diff --git a/setup.py b/setup.py index a6fe38fc7358f7425a838202610e7f39c651abcd..846fabfcf60f15de573501a6fc793b1460b49ccd 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ os.chdir(os.path.normpath(os.path.join(os.path.abspath(__file__), os.pardir))) setup( name="postgrestutils", - version="1.0.0", + version="2.0.0", packages=find_packages(), include_package_data=True, license="BSD", @@ -22,13 +22,14 @@ setup( zip_safe=False, install_requires=["requests>=2.19.1,<3.0.0"], extras_require={"dev": ["requests-mock"]}, + package_data={"postgrestutils": ["py.typed"]}, classifiers=[ "Environment :: Web Environment", "Intended Audience :: Developers", "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", "Programming Language :: Python", - "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.7", "Topic :: Internet :: WWW/HTTP", "Topic :: Internet :: WWW/HTTP :: Dynamic Content", ], diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_postgrestclient.py b/tests/integration/test_postgrestutils.py similarity index 87% rename from tests/test_postgrestclient.py rename to tests/integration/test_postgrestutils.py index 099f4fc0912e62b5cdc073d875ad6d10d902e44d..2768af37622a879edd8dfdf318867b233ca8fc79 100644 --- a/tests/test_postgrestclient.py +++ b/tests/integration/test_postgrestutils.py @@ -91,7 +91,7 @@ class TestPgrestClientGet(TestCase): super().setUp() self.data = SUPERHERO_TEST_DATA[0] - def test_single_object_returned(self, mock): + def test_single_object_returned(self, mock: Mocker): mock.register_uri( "GET", "http://example.com/superhero?id=eq.1000000000", @@ -110,7 +110,7 @@ class TestPgrestClientGet(TestCase): self.assertEqual(res, self.data) self.assertTrue(mock.called_once) - def test_object_does_not_exist(self, mock): + def test_object_does_not_exist(self, mock: Mocker): mock.register_uri( "GET", "http://example.com/superhero?id=eq.1337", @@ -131,7 +131,7 @@ class TestPgrestClientGet(TestCase): s.get("superhero", params=params) self.assertTrue(mock.called_once) - def test_multiple_objects_returned(self, mock): + def test_multiple_objects_returned(self, mock: Mocker): mock.register_uri( "GET", "http://example.com/superhero", @@ -151,7 +151,7 @@ class TestPgrestClientGet(TestCase): s.get("superhero") self.assertTrue(mock.called_once) - def test_datetime_parser(self, mock): + def test_datetime_parser(self, mock: Mocker): expected = { "id": 1337, "random": datetime.datetime( @@ -176,7 +176,7 @@ class TestPgrestClientGet(TestCase): self.assertEqual(res, expected) self.assertTrue(mock.called_once) - def test_without_datetime_parser(self, mock): + def test_without_datetime_parser(self, mock: Mocker): test_json = {"id": 1337, "random": "2020-05-20T08:35:06.659425+00:00"} mock.register_uri( "GET", @@ -191,7 +191,58 @@ class TestPgrestClientGet(TestCase): ) with default_session() as s: params = {"select": "id,random", "id": "eq.1337"} - res = s.get("random_datetime", params=params, parse_dt=False) + res = s.get( + "random_datetime", + params=params, + parser=postgrestutils.Parser.NONE, + ) + + self.assertEqual(res, test_json) + self.assertTrue(mock.called_once) + + def test_date_parser(self, mock: Mocker): + expected = { + "id": 1337, + "random": datetime.date(2020, 5, 20), + } + mock.register_uri( + "GET", + "http://example.com/random_date", + request_headers={ + **DEFAULT_HEADERS, + **{"Accept": "application/vnd.pgrst.object+json"}, + }, + status_code=200, + reason="OK", + json={"id": 1337, "random": "2020-05-20"}, + ) + with default_session() as s: + params = {"id": "eq.1337"} + res = s.get("random_date", params=params) + + self.assertEqual(res, expected) + self.assertTrue(mock.called_once) + + def test_without_date_parser(self, mock: Mocker): + test_json = {"id": 1337, "random": "2020-05-20"} + mock.register_uri( + "GET", + "http://example.com/random_date", + request_headers={ + **DEFAULT_HEADERS, + **{"Accept": "application/vnd.pgrst.object+json"}, + }, + status_code=200, + reason="OK", + json=test_json, + ) + with default_session() as s: + params = {"select": "id,random", "id": "eq.1337"} + res = s.get( + "random_date", + params=params, + parser=postgrestutils.Parser.NONE, + ) self.assertEqual(res, test_json) self.assertTrue(mock.called_once) @@ -203,7 +254,7 @@ class TestPgrestClientFilterStrategyNone(TestCase): super().setUp() self.data = SUPERHERO_TEST_DATA - def test_fetch_all_first(self, mock): + def test_fetch_all_first(self, mock: Mocker): mock.register_uri( "GET", "http://example.com/superhero", @@ -234,7 +285,7 @@ class TestPgrestClientFilterStrategyNone(TestCase): self.assertEqual(res[0], self.data[0]) # should utilize cache self.assertTrue(mock.called_once) # should not have been called again - def test_fetch_len_first(self, mock): + def test_fetch_len_first(self, mock: Mocker): mock.register_uri( "GET", "http://example.com/superhero", @@ -264,7 +315,7 @@ class TestPgrestClientFilterStrategyNone(TestCase): self.assertEqual(list(res), self.data) # should utilize cache self.assertTrue(mock.called_once) # should not have been called again - def test_cache_fetching_unbounded_slice(self, mock): + def test_cache_fetching_unbounded_slice(self, mock: Mocker): mock.register_uri( "GET", "http://example.com/superhero", @@ -305,7 +356,7 @@ class TestPgrestClientFilterCountingStrategies(TestCase): postgrestutils.Count.ESTIMATED, ) - def test_fetch_all_first(self, mock): + def test_fetch_all_first(self, mock: Mocker): # in order to fetch all mock.register_uri( "GET", @@ -339,7 +390,7 @@ class TestPgrestClientFilterCountingStrategies(TestCase): self.assertEqual(res[0], self.data[0]) # should utilize cache self.assertTrue(mock.called_once) # should not have been called again - def test_fetch_len_first(self, mock): + def test_fetch_len_first(self, mock: Mocker): # in order to fetch all mock.register_uri( "GET", @@ -432,8 +483,12 @@ class TestPgrestClientSessionDefaults(TestCase): super().setUp() self.data = SUPERHERO_TEST_DATA - def test_override_parse_dt_session_option(self, mock): - test_json = {"id": 1337, "random": "2020-05-20T08:35:06.659425+00:00"} + def test_override_parser_session_option(self, mock: Mocker): + test_json = { + "id": 1337, + "random_datetime": "2020-05-20T08:35:06.659425+00:00", + "random_date": "2020-05-20", + } mock.register_uri( "GET", "http://example.com/random_datetime", @@ -445,25 +500,30 @@ class TestPgrestClientSessionDefaults(TestCase): reason="OK", json=test_json, ) - with default_session(parse_dt=False) as s: - params = {"select": "id,random", "id": "eq.1337"} + with default_session(parser=postgrestutils.Parser.NONE) as s: + params = {"select": "id,random_datetime,random_date", "id": "eq.1337"} res = s.get("random_datetime", params=params) self.assertEqual(res, test_json) self.assertTrue(mock.called_once) mock.reset() - res2 = s.get("random_datetime", params=params, parse_dt=True) + res2 = s.get( + "random_datetime", + params=params, + parser=postgrestutils.Parser.DATETIME | postgrestutils.Parser.DATE, + ) expected = { "id": 1337, - "random": datetime.datetime( + "random_datetime": datetime.datetime( 2020, 5, 20, 8, 35, 6, 659425, tzinfo=datetime.timezone.utc ), + "random_date": datetime.date(2020, 5, 20), } self.assertEqual(res2, expected) self.assertTrue(mock.called_once) - def test_override_count_session_option(self, mock): + def test_override_count_session_option(self, mock: Mocker): # in order to fetch all mock.register_uri( "GET", @@ -517,7 +577,7 @@ class TestPgrestClientSessionDefaults(TestCase): # should have cached all elements self.assertEqual(res2._result_cache, self.data) - def test_override_schema_session_option(self, mock): + def test_override_schema_session_option(self, mock: Mocker): # in order to fetch all mock.register_uri( "GET",