diff --git a/src/maggma/stores/mongolike.py b/src/maggma/stores/mongolike.py index b6073d77a..b56941826 100644 --- a/src/maggma/stores/mongolike.py +++ b/src/maggma/stores/mongolike.py @@ -4,6 +4,7 @@ various utilities. """ +import json import warnings from collections.abc import Iterator from itertools import chain, groupby @@ -615,6 +616,7 @@ def __init__( serialization_option: Optional[int] = None, serialization_default: Optional[Callable[[Any], Any]] = None, encoding: Optional[str] = None, + sanitize_on_write: bool = False, **kwargs, ): """ @@ -640,6 +642,8 @@ def __init__( encoding from the platform. This should work in the great majority of cases. However, if you encounter a UnicodeDecodeError, consider setting the encoding explicitly to 'utf8' or another encoding as appropriate. + sanitize_on_write: Whether to sanitize documents with jsanitize before writing to the + JSON file. """ paths = paths if isinstance(paths, (list, tuple)) else [paths] self.paths = paths @@ -669,6 +673,7 @@ def __init__( self.default_sort = None self.serialization_option = serialization_option self.serialization_default = serialization_default + self.sanitize_on_write = sanitize_on_write super().__init__(**kwargs) @@ -767,12 +772,20 @@ def update_json_file(self): data = list(self.query()) for d in data: d.pop("_id") - bytesdata = orjson.dumps( - data, - option=self.serialization_option, - default=self.serialization_default, - ) - f.write(bytesdata.decode("utf-8")) + if self.sanitize_on_write: + data = jsanitize( + data, + strict=False, + recursive_msonable=True, + ) + json.dump(data, f, indent=2) + else: + bytesdata = orjson.dumps( + data, + option=self.serialization_option, + default=self.serialization_default, + ) + f.write(bytesdata.decode("utf-8")) def __hash__(self): return hash((*self.paths, self.last_updated_field)) diff --git a/tests/stores/test_mongolike.py b/tests/stores/test_mongolike.py index 997b7edde..259a59060 100644 --- a/tests/stores/test_mongolike.py +++ b/tests/stores/test_mongolike.py @@ -551,6 +551,31 @@ def test_jsonstore_last_updated(test_dir): assert jsonstore.last_updated > start_time +def test_jsonstore_sanitize_on_write(test_dir): + class SubFloat(float): + pass + + with ScratchDir("."): + jsonstore = JSONStore( + "sanitize.json", + read_only=False, + sanitize_on_write=True, + ) + jsonstore.connect() + + # This would fail on the normal orjson path, but should succeed when + # sanitize_on_write=True. + jsonstore.update({"wrong_field": SubFloat(1.1), "task_id": 3}) + jsonstore.close() + + # Confirm the file was written and can be reloaded. + jsonstore = JSONStore("sanitize.json", read_only=True) + jsonstore.connect() + doc = jsonstore.query_one(criteria={"task_id": 3}) + assert doc is not None + assert doc["wrong_field"] == pytest.approx(1.1) + + def test_eq(mongostore, memorystore, jsonstore): assert mongostore == mongostore assert memorystore == memorystore