Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions src/pymatgen/core/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3076,6 +3076,29 @@ def from_id(cls, id_, source: Literal["Materials Project", "COD"] = "Materials P

raise ValueError(f"Invalid source: {source}")

@staticmethod
def _filter_kwargs(func: Callable, kwargs: dict) -> dict:
"""Filter kwargs to only those accepted by func, warning about any removed.

Args:
func: The callable to inspect.
kwargs: The kwargs dict to filter.

Returns:
dict of kwargs supported by func.
"""
params = inspect.signature(func).parameters
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()):
return kwargs
supported = {k: v for k, v in kwargs.items() if k in params}
unsupported = kwargs.keys() - supported.keys()
if unsupported:
warnings.warn(
f"The following kwargs are not supported by {func.__qualname__} and will be ignored: {unsupported}",
stacklevel=3,
)
return supported

@classmethod
def from_str( # type:ignore[override]
cls,
Expand Down Expand Up @@ -3108,16 +3131,18 @@ def from_str( # type:ignore[override]
if fmt_low == "cif":
from pymatgen.io.cif import CifParser

parser = CifParser.from_str(input_string, **kwargs)
parser = CifParser.from_str(input_string, **cls._filter_kwargs(CifParser.from_str, kwargs))
struct = parser.parse_structures(primitive=primitive)[0]
elif fmt_low == "poscar":
from pymatgen.io.vasp import Poscar

struct = Poscar.from_str(input_string, default_names=None, read_velocities=False, **kwargs).structure
struct = Poscar.from_str(
input_string, default_names=None, read_velocities=False, **cls._filter_kwargs(Poscar.from_str, kwargs)
).structure
elif fmt_low == "cssr":
from pymatgen.io.cssr import Cssr

cssr = Cssr.from_str(input_string, **kwargs)
cssr = Cssr.from_str(input_string, **cls._filter_kwargs(Cssr.from_str, kwargs))
struct = cssr.structure # type:ignore[assignment]
elif fmt_low == "json":
dct = orjson.loads(input_string)
Expand All @@ -3129,11 +3154,11 @@ def from_str( # type:ignore[override]
elif fmt_low == "xsf":
from pymatgen.io.xcrysden import XSF

struct = XSF.from_str(input_string, **kwargs).structure # type:ignore[assignment]
struct = XSF.from_str(input_string, **cls._filter_kwargs(XSF.from_str, kwargs)).structure # type:ignore[assignment]
elif fmt_low == "mcsqs":
from pymatgen.io.atat import Mcsqs

struct = Mcsqs.structure_from_str(input_string, **kwargs)
struct = Mcsqs.structure_from_str(input_string, **cls._filter_kwargs(Mcsqs.structure_from_str, kwargs))
elif fmt == "aims":
from pymatgen.io.aims.inputs import AimsGeometryIn

Expand All @@ -3142,6 +3167,11 @@ def from_str( # type:ignore[override]
elif fmt == "fleur-inpgen":
from pymatgen.io.fleur import FleurInput

if kwargs:
warnings.warn(
f"kwargs {set(kwargs)} cannot be validated for fleur-inpgen (external package) and will be passed through as-is.",
stacklevel=2,
)
struct = FleurInput.from_string(input_string, inpgen_input=True, **kwargs).structure
elif fmt == "fleur":
from pymatgen.io.fleur import FleurInput
Expand All @@ -3150,11 +3180,11 @@ def from_str( # type:ignore[override]
elif fmt == "res":
from pymatgen.io.res import ResIO

struct = ResIO.structure_from_str(input_string, **kwargs)
struct = ResIO.structure_from_str(input_string, **cls._filter_kwargs(ResIO.structure_from_str, kwargs))
elif fmt == "pwmat":
from pymatgen.io.pwmat import AtomConfig

struct = AtomConfig.from_str(input_string, **kwargs).structure
struct = AtomConfig.from_str(input_string, **cls._filter_kwargs(AtomConfig.from_str, kwargs)).structure
else:
raise ValueError(f"Invalid {fmt=}, valid options are {get_args(FileFormats)}")

Expand Down
58 changes: 58 additions & 0 deletions tests/core/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import math
import os
import warnings
from fractions import Fraction
from pathlib import Path
from shutil import which
Expand Down Expand Up @@ -1678,6 +1679,63 @@ def test_to_from_file_str(self):
with pytest.raises(ValueError, match="Unrecognized extension in filename="):
self.struct.from_file(filename=filename)

def test_filter_kwargs_passthrough_when_var_keyword(self):
# func that accepts **kwargs: all kwargs returned unchanged, no warning
def accepts_var(**kwargs):
pass

result = Structure._filter_kwargs(accepts_var, {"foo": 1, "bar": 2})
assert result == {"foo": 1, "bar": 2}

def test_filter_kwargs_filters_unsupported_and_warns(self):
def strict(a, b=0):
pass

with pytest.warns(UserWarning, match="unsupported_key"):
result = Structure._filter_kwargs(strict, {"a": 1, "unsupported_key": 99})
assert result == {"a": 1}
assert "unsupported_key" not in result

def test_filter_kwargs_all_supported_no_warning(self):
def strict(a, b=0):
pass

with warnings.catch_warnings():
warnings.simplefilter("error")
result = Structure._filter_kwargs(strict, {"a": 1, "b": 2})
assert result == {"a": 1, "b": 2}

def test_filter_kwargs_empty_no_warning(self):
def strict(a):
pass

with warnings.catch_warnings():
warnings.simplefilter("error")
result = Structure._filter_kwargs(strict, {})
assert result == {}

@pytest.mark.parametrize("fmt", ["cif", "poscar", "cssr", "xsf", "res", "pwmat"])
def test_from_str_unsupported_kwarg_warns(self, fmt):
struct_str = self.struct.to(fmt=fmt)
with pytest.warns(UserWarning, match="unsupported_kwarg"):
result = Structure.from_str(struct_str, fmt=fmt, unsupported_kwarg="bad")
assert result.formula == self.struct.formula

@pytest.mark.parametrize("fmt", ["cif", "poscar", "cssr", "xsf", "res", "pwmat"])
def test_from_str_no_warning_without_extra_kwargs(self, fmt):
struct_str = self.struct.to(fmt=fmt)
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
Structure.from_str(struct_str, fmt=fmt)

def test_from_str_cif_supported_kwarg_no_warning(self):
# frac_tolerance is a real CifParser.from_str kwarg — should not warn
cif_str = self.struct.to(fmt="cif")
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
result = Structure.from_str(cif_str, fmt="cif", frac_tolerance=0.01)
assert result.formula == self.struct.formula

def test_from_spacegroup(self):
s1 = Structure.from_spacegroup("Fm-3m", Lattice.cubic(3), ["Li", "O"], [[0.25, 0.25, 0.25], [0, 0, 0]])
assert s1.formula == "Li8 O4"
Expand Down
Loading