From 43b8ec91b2c24a516543bf39eeec801bdb58e408 Mon Sep 17 00:00:00 2001 From: kirk0830 Date: Sun, 29 Sep 2024 19:08:10 +0800 Subject: [PATCH] Refactor: move example files into examples/ dir and add validation --- SIAB/io/read_input.py | 48 ++++++++++++++++++++++++- SIAB/spillage/api.py | 3 ++ examples/jy.json | 56 +++++++++++++++++++++++++++++ SIAB_INPUT.json => examples/pw.json | 2 +- SIAB_INPUT_new => examples/pw_new | 0 SIAB_INPUT_old => examples/pw_old | 0 6 files changed, 107 insertions(+), 2 deletions(-) create mode 100644 examples/jy.json rename SIAB_INPUT.json => examples/pw.json (98%) rename SIAB_INPUT_new => examples/pw_new (100%) rename SIAB_INPUT_old => examples/pw_old (100%) diff --git a/SIAB/io/read_input.py b/SIAB/io/read_input.py index 6538b491..8fa03c4b 100644 --- a/SIAB/io/read_input.py +++ b/SIAB/io/read_input.py @@ -13,7 +13,7 @@ def read_siab_plaintext(fname: str = ""): float_pattern = r"^\d+\.\d*$" int_pattern = r"^\d+$" scalar_keywords = ["Ecut", "sigma", "element"] - result = {} + result = {"fit_basis": "pw"} if fname == "": raise ValueError("No filename provided") with open(fname, "r") as f: @@ -471,6 +471,11 @@ def skip_ppread(user_settings: dict): """ skip = True + # case 0 + # if element is not set, it is not possible to skip the pseudopotential read-in + if "element" not in user_settings: + print("AUTOSET: `element` is not specified => AUTOSET", flush=True) + return False # case 1 # if nbands is specified as auto, occ, all, it must requires the number of valence # electrons, therefore it is not possible to skip the pseudopotential read-in @@ -536,10 +541,51 @@ def skip_ppread(user_settings: dict): # in a list in input, therefore the index is known). return skip +def _validate_param(user_settings: dict): + """validate the input parameters + + Parameters + ---------- + user_settings: dict + the user settings + + Returns + ------- + None + """ + # check if the shape assigned to orbitals is valid + shape2index = {rs["shape"]: i for i, rs in enumerate(user_settings["reference_systems"])} + for iorb, orb in enumerate(user_settings["orbitals"]): + shape = orb["shape"] + shape = [shape] if not isinstance(shape, list) else shape + for s in shape: + assert isinstance(s, (str, int)), f"shape {s} is not a valid shape" + if isinstance(s, str): + assert s in shape2index, f"shape {s} is not found in reference systems" + + # check if the nbands set for reference system is smaller than + # bands needed for fitting orbitals + shape2index = {rs["shape"]: i for i, rs in enumerate(user_settings["reference_systems"])} + for iorb, orb in enumerate(user_settings["orbitals"]): + shape = orb["shape"] + shape = [shape] if not isinstance(shape, list) else shape + for s in shape: + if isinstance(s, str): + assert orb["nbands_ref"] <= user_settings["reference_systems"][shape2index[s]]["nbands"], \ + f"ERROR: `nbands_ref` for orbital {iorb} is larger than the number of bands set for\ + reference system `{s}`" + elif isinstance(s, int): + assert orb["nbands_ref"] <= user_settings["reference_systems"][s]["nbands"], \ + f"ERROR: `nbands_ref` for orbital {iorb} is larger than the number of bands set for\ + reference system `{s}`" + + def parse(user_settings: dict): """unpack the SIAB input to structure (shape as key and bond lengths are list as value), input setting of abacus, orbital generation settings, environmental settings and general description """ + _validate_param(user_settings) + # move the information fetch from pseudopotential from front.py here... # get value from the dict returned by function from_pseudopotential diff --git a/SIAB/spillage/api.py b/SIAB/spillage/api.py index e2a2d10f..addf6bb9 100644 --- a/SIAB/spillage/api.py +++ b/SIAB/spillage/api.py @@ -686,6 +686,9 @@ def _nzeta_infer(folder, nband): for isk in range(nspin*len(wk)): # loop over (ispin, ik) w = wk[isk % len(wk)] # spin-up and spin-down share the wk wfc, _, _, _ = read_wfc_lcao_txt(os.path.join(outdir, f"{fwfc}{isk+1}.txt")) + assert wfc.shape[1] >= nband, \ + f"ERROR: number of bands for orbgen is larger than calculated: {nband} > {wfc.shape[1]}" + # the complete return list is (wfc.T, e, occ, k) ovlp = read_triu(os.path.join(outdir, f"data-{isk}-S")) diff --git a/examples/jy.json b/examples/jy.json new file mode 100644 index 00000000..0485ab04 --- /dev/null +++ b/examples/jy.json @@ -0,0 +1,56 @@ +{ + "environment": "", + "mpi_command": "mpirun -np 8", + "abacus_command": "abacus", + + "pseudo_dir": "/root/abacus-develop/pseudopotentials/sg15_oncv_upf_2020-02-06/", + "pseudo_name": "Si_ONCV_PBE-1.0.upf", + "ecutwfc": 100, + "bessel_nao_smooth": 0, + "bessel_nao_rcut": [6, 7, 8, 9, 10], + "smearing_sigma": 0.01, + + "fit_basis": "jy", + "optimizer": "bfgs", + "max_steps": 3000, + "spill_guess": "atomic", + "nthreads_rcut": 4, + "jY_type": "reduced", + + "reference_systems": [ + { + "shape": "dimer", + "nbands": 8, + "nspin": 1, + "lmaxmax": 2, + "bond_lengths": [1.62, 1.82, 2.22, 2.72, 3.22] + }, + { + "shape": "trimer", + "nbands": 10, + "nspin": 1, + "lmaxmax": 2, + "bond_lengths": [1.9, 2.1, 2.6] + } + ], + "orbitals": [ + { + "zeta_notation": "auto", + "shape": "dimer", + "nbands_ref": 4, + "orb_ref": "none" + }, + { + "zeta_notation": "auto", + "shape": "dimer", + "nbands_ref": 10, + "orb_ref": 0 + }, + { + "zeta_notation": "auto", + "shape": "trimer", + "nbands_ref": 20, + "orb_ref": 1 + } + ] +} \ No newline at end of file diff --git a/SIAB_INPUT.json b/examples/pw.json similarity index 98% rename from SIAB_INPUT.json rename to examples/pw.json index 89d452cc..4c80198f 100644 --- a/SIAB_INPUT.json +++ b/examples/pw.json @@ -12,7 +12,7 @@ "pseudo_rcut": 10, "pseudo_mesh": 1, - "basis_type": "jy", + "fit_basis": "pw", "optimizer": "bfgs", "max_steps": 1000, diff --git a/SIAB_INPUT_new b/examples/pw_new similarity index 100% rename from SIAB_INPUT_new rename to examples/pw_new diff --git a/SIAB_INPUT_old b/examples/pw_old similarity index 100% rename from SIAB_INPUT_old rename to examples/pw_old