Source code for nbed.config
"""Custom Types and Enums."""
import json
import logging
import os
from enum import Enum
from pathlib import Path
from typing import Annotated, Any
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
FilePath,
NonNegativeInt,
PositiveFloat,
PositiveInt,
TypeAdapter,
)
logger = logging.getLogger(__name__)
[docs]
class ProjectorTypes(Enum):
"""Implemented Projectors."""
MU = "mu"
HUZ = "huzinaga"
BOTH = "both"
[docs]
class OccupiedLocalizerTypes(Enum):
"""Implemented Occupied Localizers."""
SPADE = "spade"
BOYS = "boys"
IBO = "ibo"
PM = "pm"
[docs]
class VirtualLocalizerTypes(Enum):
"""Implemented Virtual Localizers."""
CONCENTRIC = "cl"
PROJECTED_AO = "pao"
DISABLE = "disable"
XYZGeometry = Annotated[
str, Field(pattern="^\\d+\n\\s?\n(?:\\w(?:\\s+\\-?\\d\\.\\d+){3}\n?)*")
]
[docs]
def validate_xyz_file(maybe_xyz: Any) -> str:
"""Validates the the filepath given leads to a valid XYZ formatted file.
Args:
maybe_xyz (Any): A path to an existing file.
Returns:
str: an XYZ geometry string.
"""
match maybe_xyz:
case str() | Path():
if os.path.exists(maybe_xyz):
with open(maybe_xyz) as file:
content = file.read()
logger.debug("File content %s", content)
TypeAdapter(XYZGeometry).validate_strings(content)
return content
else:
logger.debug("Input geometry does not match existing file")
return str(maybe_xyz)
case _:
return maybe_xyz
[docs]
class NbedConfig(BaseModel):
"""Config for Nbed.
Args:
geometry (XYZGeometry): Path to .xyz file containing molecular geometry or raw xyz string.
n_active_atoms (PositiveInt): The number of atoms to include in the active region.
basis (str): The name of an atomic orbital basis set to use for chemistry calculations.
xc_functional (str): The name of an Exchange-Correlation functional to be used for DFT.
projector (ProjectorTypes): Projector to screen out environment orbitals, One of 'mu' or 'huzinaga'.
localization (OccupiedLocalizerTypes): Orbital localization method to use. One of 'spade', 'pipek-mezey', 'boys' or 'ibo'.
convergence (Annotated[float, Gt(gt=0), Lt(lt=1)]): The convergence tolerance for energy calculations.
charge (PositiveInt): Charge of molecular species
mu_level_shift (PositiveFloat): Level shift parameter to use for mu-projector.
run_ccsd_emb (bool): Whether or not to find the CCSD energy of embbeded system for reference.
run_fci_emb (bool): Whether or not to find the FCI energy of embbeded system for reference.
run_virtual_localization (bool): Whether or not to localize virtual orbitals.
n_mo_overwrite (tuple[None| PositiveInt, None | PositiveInt]): Optional overwrite values for occupied localizers.
max_ram_memory (PositiveInt): Amount of RAM memery in MB available for PySCF calculation
unit (str): molecular geometry unit 'Angstrom' or 'Bohr'
max_hf_cycles (PositiveInt): max number of Hartree-Fock iterations allowed (for global and local HFock)
max_dft_cycles (PositiveInt): max number of DFT iterations allowed in scf calc
init_huzinaga_rhf_with_mu (bool): Hidden flag to seed huzinaga RHF with mu shift result (for developers only)
savefile (FilePath): Location of file to save output to.
"""
model_config = ConfigDict(extra="forbid")
geometry: Annotated[XYZGeometry, BeforeValidator(validate_xyz_file)]
n_active_atoms: PositiveInt
basis: str
xc_functional: str
projector: ProjectorTypes = Field(default=ProjectorTypes.MU)
localization: OccupiedLocalizerTypes = Field(default=OccupiedLocalizerTypes.SPADE)
convergence: PositiveFloat = 1e-6
charge: int = Field(default=0)
spin: int = Field(default=0)
unit: str = "angstrom"
symmetry: bool = False
restricted_global: bool = False
restricted_active: bool = False
savefile: FilePath | None = None
run_ccsd_emb: bool = False
run_fci_emb: bool = False
run_dft_in_dft: bool = False
mm_coords: list | None = None
mm_charges: list | None = None
mm_radii: list | None = None
mu_level_shift: PositiveFloat = 1e6
init_huzinaga_rhf_with_mu: bool = False
virtual_localization: VirtualLocalizerTypes = Field(
default=VirtualLocalizerTypes.CONCENTRIC
)
n_mo_overwrite: tuple[None | NonNegativeInt, None | NonNegativeInt] = (None, None)
occupied_threshold: float = Field(default=0.95, gt=0, lt=1)
virtual_threshold: float = Field(default=0.95, gt=0, lt=1)
max_shells: PositiveInt = 4
norm_cutoff: PositiveFloat = 0.05
overlap_cutoff: PositiveFloat = 1e-5
force_unrestricted: bool = False
max_ram_memory: PositiveInt = 4000
max_hf_cycles: PositiveInt = Field(default=50)
max_dft_cycles: PositiveInt = Field(default=50)
build_hamiltonian: bool = False
[docs]
def overwrite_config_kwargs(config: NbedConfig, **config_kwargs) -> NbedConfig:
"""Overwrites config values with key-words and revalidates.
Args:
config (NbedConfig): A config model.
config_kwargs (dict): Any possible key-word arguments.
Returns:
NbedConfig: A validated config model.
Raises:
ValidationError: If key-word arguments provided are not part of model.
"""
if config_kwargs != {}:
logger.info("Overwriting select field with additonal config.")
config_dict = config.model_dump()
for k, v in config_kwargs.items():
config_dict[k] = v
return NbedConfig(**config_dict)
else:
return config
[docs]
def parse_config(
config: NbedConfig | str | None = None,
**config_kwargs,
):
"""Parse the various config options and return a valid model.
Args:
config (NbedConfig): A validated config model or path to a '.json' config file.
**config_kwargs: Allows arbitrary keyword arguments for manual configuration.
Returns:
NbedConfig: A valid config model.
"""
match config:
case NbedConfig():
logger.info("Using validated config.")
config = overwrite_config_kwargs(config, **config_kwargs)
case str() | Path():
logger.info("Using config file %s", config)
logger.info("Validating config from file.")
with open(FilePath(config)) as f:
data = json.load(f)
config = NbedConfig(**data)
config = overwrite_config_kwargs(config, **config_kwargs)
case None:
logger.info("Validating config from passed arguments.")
logger.debug(f"{config_kwargs=}")
config = NbedConfig(**config_kwargs)
case _:
logger.warning("Unknown input to config argument will be ignored.")
logger.debug(f"{config=}")
logger.debug(f"{config_kwargs=}")
config = NbedConfig(**config_kwargs)
return config