-
Notifications
You must be signed in to change notification settings - Fork 81
Disorder builder #1410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ColinBundschu
wants to merge
10
commits into
materialsproject:new-builders
Choose a base branch
from
ColinBundschu:new-builders
base: new-builders
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Disorder builder #1410
Changes from 6 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
9a4e661
disorder builder first pass
ColinBundschu ced3f09
disorder integration
ColinBundschu 2965cbd
removed debug print statements
ColinBundschu 594b30a
DisorderDoc now inherits from PropertyDoc
ColinBundschu 0fe9c3c
improvements to post processing
ColinBundschu 5a8126a
code cleanup
ColinBundschu 39df587
Addressed PR comments from Aaron
ColinBundschu fec59c1
Update emmet-builders/emmet/builders/disorder/disorder.py
ColinBundschu 8f794f8
Update emmet-builders/emmet/builders/disorder/disorder.py
ColinBundschu 27ac50b
code cleanup
ColinBundschu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
117 changes: 117 additions & 0 deletions
117
emmet-builders/emmet/builders/disorder/design_metrics.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| """Design-matrix diagnostics for CE training. | ||
|
|
||
| Vendored from phaseedge.science.design_metrics with imports adjusted. | ||
| """ | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Any, cast | ||
|
|
||
| import numpy as np | ||
| from numpy.typing import NDArray | ||
|
|
||
| from emmet.core.disorder import CEDesignMetrics | ||
|
|
||
|
|
||
| @dataclass(slots=True) | ||
| class MetricOptions: | ||
| """Options controlling how we compute design metrics.""" | ||
|
|
||
| standardize: bool = True | ||
| eps: float = 1e-12 | ||
|
|
||
|
|
||
| def _standardize_columns( | ||
| X: NDArray[np.float64], eps: float | ||
| ) -> tuple[NDArray[np.float64], int]: | ||
| """Column z-score standardization.""" | ||
| Xc = X - X.mean(axis=0, keepdims=True) | ||
| std = Xc.std(axis=0, ddof=0, keepdims=True) | ||
|
|
||
| zero_var_mask = std <= eps | ||
| zero_var_count = int(zero_var_mask.sum()) | ||
|
|
||
| std_safe = std.copy() | ||
| std_safe[std_safe <= eps] = 1.0 | ||
|
|
||
| Xz = Xc / std_safe | ||
| return cast(NDArray[np.float64], Xz), zero_var_count | ||
|
|
||
|
|
||
| def compute_design_metrics( | ||
| *, | ||
| X: NDArray[np.float64], | ||
| w: NDArray[np.float64] | None = None, | ||
| options: MetricOptions | None = None, | ||
| ) -> CEDesignMetrics: | ||
| """Compute design-matrix diagnostics for CE training.""" | ||
| if X.ndim != 2: | ||
| raise ValueError(f"X must be 2D, got shape {X.shape!r}") | ||
| n, p = map(int, X.shape) | ||
| if n == 0 or p == 0: | ||
| raise ValueError("X must have non-zero shape.") | ||
|
|
||
| opts = options or MetricOptions() | ||
| eps = float(opts.eps) | ||
|
|
||
| # Apply weights as in training: Xw = diag(sqrt(w)) @ X | ||
| if w is not None: | ||
| if w.ndim != 1 or int(w.size) != n: | ||
| raise ValueError(f"w must be length-{n} vector; got shape {w.shape!r}") | ||
| sqrt_w = np.sqrt(w, dtype=np.float64).reshape(-1, 1) | ||
| Xw = X * sqrt_w | ||
| else: | ||
| Xw = X | ||
| weighting_applied = w is not None | ||
|
|
||
| if opts.standardize: | ||
| Xm, zero_var_count = _standardize_columns(Xw, eps) | ||
| std_mode = "column_zscore" | ||
| else: | ||
| Xm = Xw | ||
| zero_var_count = 0 | ||
| std_mode = "none" | ||
|
|
||
| # SVD-based metrics (economy SVD) | ||
| U, s, _ = np.linalg.svd(Xm, full_matrices=False) | ||
| keep = s > eps | ||
| rank = int(keep.sum()) | ||
|
|
||
| sigma_max = float(s[0]) if s.size > 0 else 0.0 | ||
| sigma_min = float(s[rank - 1]) if rank > 0 else 0.0 | ||
|
|
||
| if rank == 0 or sigma_min <= eps: | ||
| condition_number = float("inf") | ||
| else: | ||
| condition_number = float(sigma_max / sigma_min) | ||
|
|
||
| positive = s[keep] | ||
| if positive.size == 0: | ||
| logdet_xtx = float("-inf") | ||
| else: | ||
| logdet_xtx = float(2.0 * np.log(positive).sum()) | ||
|
|
||
| # Leverage diagnostics via U_r | ||
| if rank > 0: | ||
| Ur = U[:, :rank] | ||
| lev = np.einsum("ij,ij->i", Ur, Ur, optimize=True) | ||
| leverage_mean = float(lev.mean()) | ||
| leverage_max = float(lev.max(initial=0.0)) | ||
| leverage_p95 = float(np.percentile(lev, 95.0)) | ||
| else: | ||
| leverage_mean = leverage_max = leverage_p95 = float("nan") | ||
|
|
||
| return CEDesignMetrics( | ||
| n_samples=n, | ||
| n_features=p, | ||
| rank=rank, | ||
| sigma_max=sigma_max, | ||
| sigma_min=sigma_min, | ||
| condition_number=condition_number, | ||
| logdet_xtx=logdet_xtx, | ||
| leverage_mean=leverage_mean, | ||
| leverage_max=leverage_max, | ||
| leverage_p95=leverage_p95, | ||
| weighting_applied=weighting_applied, | ||
| standardization=cast(Any, std_mode), | ||
| zero_variance_feature_count=zero_var_count, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,264 @@ | ||
| """Builder function for creating DisorderDoc from DisorderedTaskDoc instances. | ||
|
|
||
| Follows the functional builder pattern used in emmet-builders (see vasp/materials.py). | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
| from smol.cofe import ClusterExpansion | ||
| from smol.moca.ensemble import Ensemble | ||
|
|
||
| from emmet.core.disorder import CationBinCount, DisorderDoc, DisorderedTaskDoc, WLDensityOfStates, WLSpecParams | ||
| from emmet.core.tasks import CoreTaskDoc | ||
|
|
||
| from .mixture import sublattices_from_composition_maps | ||
| from .prototype_spec import PrototypeSpec | ||
| from .train_ce import run_train_ce | ||
| from .wl_sampling import run_wl_block | ||
|
|
||
| # Default CE training hyper-parameters | ||
| _DEFAULT_BASIS_SPEC: dict[str, Any] = {"basis": "sinusoid", "cutoffs": {2: 10, 3: 8, 4: 5}} | ||
| _DEFAULT_REGULARIZATION: dict[str, Any] = {"type": "ridge", "alpha": 1e-3, "l1_ratio": 0.5} | ||
| _DEFAULT_WEIGHTING: dict[str, Any] = {"scheme": "balance_by_comp", "alpha": 1.0} | ||
| _DEFAULT_CV_SEED: int = 42 | ||
|
|
||
| # Default WL sampling hyper-parameters | ||
| _DEFAULT_WL_STEPS: int = 1_000_000 | ||
| _DEFAULT_WL_CHECK_PERIOD: int = 5000 | ||
| _DEFAULT_WL_UPDATE_PERIOD: int = 1 | ||
| _DEFAULT_WL_SEED: int = 0 | ||
| _DEFAULT_WL_CONVERGENCE_THRESHOLD: float = 1e-7 | ||
| _DEFAULT_BIN_WIDTH: float = 0.1 | ||
| _DEFAULT_MIN_BINS: int = 50 | ||
| _DEFAULT_MAX_BINS: int = 200 | ||
|
|
||
|
|
||
| def build_disorder_doc( | ||
| disordered_documents: list[DisorderedTaskDoc], | ||
| ordered_task_doc: CoreTaskDoc, | ||
| *, | ||
| basis_spec: dict[str, Any] | None = None, | ||
| regularization: dict[str, Any] | None = None, | ||
| weighting: dict[str, Any] | None = None, | ||
| cv_seed: int | None = _DEFAULT_CV_SEED, | ||
| wl_steps: int = _DEFAULT_WL_STEPS, | ||
| wl_check_period: int = _DEFAULT_WL_CHECK_PERIOD, | ||
| wl_update_period: int = _DEFAULT_WL_UPDATE_PERIOD, | ||
| wl_seed: int = _DEFAULT_WL_SEED, | ||
| wl_convergence_threshold: float = _DEFAULT_WL_CONVERGENCE_THRESHOLD, | ||
| initial_bin_width: float = _DEFAULT_BIN_WIDTH, | ||
| min_bins: int = _DEFAULT_MIN_BINS, | ||
| max_bins: int = _DEFAULT_MAX_BINS, | ||
| ) -> DisorderDoc: | ||
| """Train a Cluster Expansion on disordered task documents from one ordered | ||
| material and run Wang-Landau sampling to convergence. | ||
|
|
||
| Args: | ||
| disordered_documents: All DisorderedTaskDoc instances sharing the same | ||
| ordered_task_id, prototype, supercell_diag, and versions. | ||
| ordered_task_doc: The CoreTaskDoc for the parent ordered material. | ||
| Its structure is used to populate search metadata (chemsys, | ||
| elements, composition, symmetry, etc.). | ||
| basis_spec: CE basis specification (cutoffs, basis type). | ||
| regularization: Regularization settings for the CE fit. | ||
| weighting: Optional weighting scheme for the CE fit. | ||
| cv_seed: Random seed for cross-validation folds. | ||
| wl_steps: Number of MC steps per WL block. | ||
| wl_check_period: How often (in steps) to check WL flatness. | ||
| wl_update_period: Update period for the WL modification factor. | ||
| wl_seed: Random seed for WL sampling. | ||
| wl_convergence_threshold: Stop when mod_factor drops below this. | ||
| initial_bin_width: Starting energy bin width for WL sampling. | ||
| min_bins: Minimum acceptable number of WL bins (halve bin_width if fewer). | ||
| max_bins: Maximum acceptable number of WL bins (double bin_width if more). | ||
|
|
||
| Returns: | ||
| A fully populated DisorderDoc. | ||
| """ | ||
| if not disordered_documents: | ||
| raise ValueError("disordered_documents must be non-empty.") | ||
|
|
||
| if basis_spec is None: | ||
| basis_spec = dict(_DEFAULT_BASIS_SPEC) | ||
| if regularization is None: | ||
| regularization = dict(_DEFAULT_REGULARIZATION) | ||
| if weighting is None: | ||
| weighting = dict(_DEFAULT_WEIGHTING) | ||
|
|
||
| # --- validate consistency across documents --- | ||
| first = disordered_documents[0] | ||
| for doc in disordered_documents[1:]: | ||
| if doc.ordered_task_id != first.ordered_task_id: | ||
| raise ValueError("Ordered task IDs do not match across documents.") | ||
| if doc.supercell_diag != first.supercell_diag: | ||
| raise ValueError("Supercell diagonals do not match across documents.") | ||
| if doc.prototype != first.prototype: | ||
| raise ValueError("Prototypes do not match across documents.") | ||
| if doc.prototype_params != first.prototype_params: | ||
| raise ValueError("Prototype parameters do not match across documents.") | ||
| if doc.versions != first.versions: | ||
| raise ValueError("Versions do not match across documents.") | ||
|
ColinBundschu marked this conversation as resolved.
Outdated
|
||
|
|
||
| # --- extract training data --- | ||
| structures_pm = [doc.reference_structure for doc in disordered_documents] | ||
| n_prims = int(np.prod(first.supercell_diag)) | ||
| y_cell = [doc.output.energy / float(n_prims) for doc in disordered_documents] | ||
|
ColinBundschu marked this conversation as resolved.
Outdated
|
||
|
|
||
| prototype_spec = PrototypeSpec( | ||
| prototype=first.prototype, params=first.prototype_params | ||
| ) | ||
|
|
||
| # The primitive cell uses placeholder element symbols (e.g. "Es", "Fm") | ||
| # for active sublattices, while DisorderedTaskDoc.composition_map uses | ||
| # sublattice labels (e.g. "A", "B"). Build the mapping to translate. | ||
| prim = prototype_spec.primitive_cell | ||
| sublattice_labels = prim.get_array("sublattice") | ||
| chem_symbols = prim.get_chemical_symbols() | ||
| active_subs = prototype_spec.active_sublattices | ||
| # element_symbol -> sublattice_label (e.g. "Es" -> "A") | ||
| elem_to_label: dict[str, str] = {} | ||
| for sym, lab in zip(chem_symbols, sublattice_labels): | ||
| if sym in active_subs and sym not in elem_to_label: | ||
| elem_to_label[sym] = str(lab) | ||
| # sublattice_label -> element_symbol (e.g. "A" -> "Es") | ||
| label_to_elem = {v: k for k, v in elem_to_label.items()} | ||
|
|
||
| # Remap composition maps from sublattice labels to element symbols | ||
| composition_maps = [ | ||
| {label_to_elem.get(site, site): comp for site, comp in doc.composition_map.items()} | ||
| for doc in disordered_documents | ||
| ] | ||
| sublattices = sublattices_from_composition_maps(composition_maps) | ||
|
|
||
| # --- train cluster expansion --- | ||
| ce_train_output = run_train_ce( | ||
| structures_pm=structures_pm, | ||
| y_cell=y_cell, | ||
| prototype_spec=prototype_spec, | ||
| supercell_diag=first.supercell_diag, | ||
| sublattices=sublattices, | ||
| basis_spec=basis_spec, | ||
| regularization=regularization, | ||
| weighting=weighting, | ||
| cv_seed=cv_seed, | ||
| ) | ||
|
|
||
| # --- build ensemble from trained CE --- | ||
| ce = ClusterExpansion.from_dict(ce_train_output.payload) | ||
| ensemble = Ensemble.from_cluster_expansion( | ||
| ce, supercell_matrix=np.diag(first.supercell_diag) | ||
| ) | ||
|
|
||
| # --- auto-tune bin width --- | ||
| bin_width = initial_bin_width | ||
| wl_spec = WLSpecParams( | ||
| bin_width=bin_width, | ||
| steps=wl_steps, | ||
| initial_comp_map=composition_maps[0], | ||
| step_type="swap", | ||
| check_period=wl_check_period, | ||
| update_period=wl_update_period, | ||
| seed=wl_seed, | ||
| samples_per_bin=0, | ||
| collect_cation_stats=False, | ||
| production_mode=False, | ||
| reject_cross_sublattice_swaps=False, | ||
| ) | ||
| wl_block = run_wl_block( | ||
| spec=wl_spec, | ||
| ensemble=ensemble, | ||
| tip=None, | ||
| prototype_spec=prototype_spec, | ||
| supercell_diag=first.supercell_diag, | ||
| ) | ||
|
|
||
| num_bins = len(wl_block["state"].bin_indices) | ||
| while num_bins < min_bins or num_bins > max_bins: | ||
|
ColinBundschu marked this conversation as resolved.
Outdated
|
||
| if num_bins < min_bins: | ||
| bin_width /= 2.0 | ||
| else: | ||
| bin_width *= 2.0 | ||
| wl_spec = WLSpecParams( | ||
| bin_width=bin_width, | ||
| steps=wl_steps, | ||
| initial_comp_map=composition_maps[0], | ||
| step_type="swap", | ||
| check_period=wl_check_period, | ||
| update_period=wl_update_period, | ||
| seed=wl_seed, | ||
| samples_per_bin=0, | ||
| collect_cation_stats=False, | ||
| production_mode=False, | ||
| reject_cross_sublattice_swaps=False, | ||
| ) | ||
| wl_block = run_wl_block( | ||
| spec=wl_spec, | ||
| ensemble=ensemble, | ||
| tip=None, | ||
| prototype_spec=prototype_spec, | ||
| supercell_diag=first.supercell_diag, | ||
| ) | ||
| num_bins = len(wl_block["state"].bin_indices) | ||
|
ColinBundschu marked this conversation as resolved.
|
||
|
|
||
| # --- WL convergence loop --- | ||
| while wl_block["state"].mod_factor > wl_convergence_threshold: | ||
|
ColinBundschu marked this conversation as resolved.
Outdated
|
||
| wl_block = run_wl_block( | ||
| spec=wl_spec, | ||
| ensemble=ensemble, | ||
| tip=wl_block, | ||
| prototype_spec=prototype_spec, | ||
| supercell_diag=first.supercell_diag, | ||
| ) | ||
|
|
||
| # --- Production-mode block to collect cation stats --- | ||
| prod_spec = WLSpecParams( | ||
| bin_width=bin_width, | ||
| steps=wl_steps, | ||
| initial_comp_map=composition_maps[0], | ||
| step_type="swap", | ||
| check_period=wl_check_period, | ||
| update_period=wl_update_period, | ||
| seed=wl_seed, | ||
| samples_per_bin=0, | ||
| collect_cation_stats=True, | ||
| production_mode=True, | ||
| reject_cross_sublattice_swaps=False, | ||
| ) | ||
| prod_block = run_wl_block( | ||
| spec=prod_spec, | ||
| ensemble=ensemble, | ||
| tip=wl_block, | ||
| prototype_spec=prototype_spec, | ||
| supercell_diag=first.supercell_diag, | ||
| ) | ||
|
|
||
| # --- assemble DisorderDoc --- | ||
| wl_final = prod_block["state"] | ||
| return DisorderDoc.from_structure( | ||
| meta_structure=ordered_task_doc.structure, | ||
| ordered_task_id=first.ordered_task_id, | ||
| prototype=first.prototype, | ||
| prototype_params=first.prototype_params, | ||
| supercell_diag=first.supercell_diag, | ||
| sublattices=sublattices, | ||
|
ColinBundschu marked this conversation as resolved.
|
||
| composition_maps=composition_maps, | ||
| training_stats=ce_train_output.stats, | ||
| design_metrics=ce_train_output.design_metrics, | ||
| wl_dos=WLDensityOfStates( | ||
| bin_indices=wl_final.bin_indices, | ||
| entropy=wl_final.entropy, | ||
| bin_size=wl_final.bin_size, | ||
| mod_factor=wl_final.mod_factor, | ||
| steps_counter=wl_final.steps_counter, | ||
| ), | ||
| wl_occupancy=list(prod_block["occupancy"]), | ||
| wl_spec_params=wl_spec, | ||
|
ColinBundschu marked this conversation as resolved.
Outdated
|
||
| cation_counts=[ | ||
| CationBinCount(**row) for row in prod_block["cation_counts"] | ||
| ], | ||
| disordered_task_ids=[doc.task_id for doc in disordered_documents], | ||
| versions=first.versions, | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.