diff --git a/src/pymatgen/io/vasp/inputs.py b/src/pymatgen/io/vasp/inputs.py index 4d572bcd17..ca25b04a9e 100644 --- a/src/pymatgen/io/vasp/inputs.py +++ b/src/pymatgen/io/vasp/inputs.py @@ -3101,7 +3101,7 @@ def __str__(self) -> str: def as_dict(self) -> dict: """MSONable dict.""" - dct = {key: val.as_dict() for key, val in self.items()} + dct = {key: val.as_dict() if hasattr(val, "as_dict") else val for key, val in self.items()} dct["@module"] = type(self).__module__ dct["@class"] = type(self).__name__ return dct @@ -3115,10 +3115,13 @@ def from_dict(cls, dct: dict) -> Self: Returns: VaspInput """ - sub_dct: dict[str, dict] = {"optional_files": {}} + sub_dct: dict[str, Any] = {"optional_files": {}} for key, val in dct.items(): if key in ("INCAR", "POSCAR", "POTCAR", "KPOINTS"): sub_dct[key.lower()] = MontyDecoder().process_decoded(val) + elif key == "POTCAR.spec": + sub_dct["potcar"] = val + sub_dct["potcar_spec"] = True elif key not in ["@module", "@class"]: sub_dct["optional_files"][key] = MontyDecoder().process_decoded(val) return cls(**sub_dct) # type: ignore[arg-type] diff --git a/tests/io/vasp/test_inputs.py b/tests/io/vasp/test_inputs.py index 7ed9b093d4..7c8006ad35 100644 --- a/tests/io/vasp/test_inputs.py +++ b/tests/io/vasp/test_inputs.py @@ -2002,6 +2002,24 @@ def test_input_attr(self): vis_potcar_spec.incar["NSW"] = 100 assert vis_potcar_spec.incar["NSW"] == 100 + def test_as_from_dict_potcar_spec(self): + vis_potcar_spec = VaspInput( + self.vasp_input.incar, + self.vasp_input.kpoints, + self.vasp_input.poscar, + "\n".join(self.vasp_input.potcar.symbols), + potcar_spec=True, + ) + # as_dict should not raise when potcar_spec=True + dct = vis_potcar_spec.as_dict() + + # round-trip should preserve potcar_spec structure + roundtripped = VaspInput.from_dict(dct) + assert {*roundtripped} == {"INCAR", "KPOINTS", "POSCAR", "POTCAR.spec"} + assert isinstance(roundtripped.potcar, str) + assert roundtripped["POTCAR.spec"] == vis_potcar_spec["POTCAR.spec"] + assert roundtripped["INCAR"] == vis_potcar_spec["INCAR"] + def test_potcar_summary_stats() -> None: potcar_summary_stats = loadfn(POTCAR_STATS_PATH)