Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions src/maggma/stores/mongolike.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
various utilities.
"""

import json
import warnings
from collections.abc import Iterator
from itertools import chain, groupby
Expand Down Expand Up @@ -615,6 +616,7 @@ def __init__(
serialization_option: Optional[int] = None,
serialization_default: Optional[Callable[[Any], Any]] = None,
encoding: Optional[str] = None,
sanitize_on_write: bool = False,
**kwargs,
):
"""
Expand All @@ -640,6 +642,8 @@ def __init__(
encoding from the platform. This should work in the great majority of cases.
However, if you encounter a UnicodeDecodeError, consider setting the encoding
explicitly to 'utf8' or another encoding as appropriate.
sanitize_on_write: Whether to sanitize documents with jsanitize before writing to the
JSON file.
"""
paths = paths if isinstance(paths, (list, tuple)) else [paths]
self.paths = paths
Expand Down Expand Up @@ -669,6 +673,7 @@ def __init__(
self.default_sort = None
self.serialization_option = serialization_option
self.serialization_default = serialization_default
self.sanitize_on_write = sanitize_on_write

super().__init__(**kwargs)

Expand Down Expand Up @@ -767,12 +772,20 @@ def update_json_file(self):
data = list(self.query())
for d in data:
d.pop("_id")
bytesdata = orjson.dumps(
data,
option=self.serialization_option,
default=self.serialization_default,
)
f.write(bytesdata.decode("utf-8"))
if self.sanitize_on_write:
data = jsanitize(
data,
strict=False,
recursive_msonable=True,
)
json.dump(data, f, indent=2)
else:
bytesdata = orjson.dumps(
data,
option=self.serialization_option,
default=self.serialization_default,
)
f.write(bytesdata.decode("utf-8"))

def __hash__(self):
return hash((*self.paths, self.last_updated_field))
Expand Down
25 changes: 25 additions & 0 deletions tests/stores/test_mongolike.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,31 @@ def test_jsonstore_last_updated(test_dir):
assert jsonstore.last_updated > start_time


def test_jsonstore_sanitize_on_write(test_dir):
class SubFloat(float):
pass

with ScratchDir("."):
jsonstore = JSONStore(
"sanitize.json",
read_only=False,
sanitize_on_write=True,
)
jsonstore.connect()

# This would fail on the normal orjson path, but should succeed when
# sanitize_on_write=True.
jsonstore.update({"wrong_field": SubFloat(1.1), "task_id": 3})
jsonstore.close()

# Confirm the file was written and can be reloaded.
jsonstore = JSONStore("sanitize.json", read_only=True)
jsonstore.connect()
doc = jsonstore.query_one(criteria={"task_id": 3})
assert doc is not None
assert doc["wrong_field"] == pytest.approx(1.1)


def test_eq(mongostore, memorystore, jsonstore):
assert mongostore == mongostore
assert memorystore == memorystore
Expand Down
Loading