Skip to content

Commit

Permalink
connect svd pop analysis method to main program, change func name
Browse files Browse the repository at this point in the history
  • Loading branch information
kirk0830 committed Oct 9, 2024
1 parent 1bb9159 commit 26199ea
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 26 deletions.
28 changes: 17 additions & 11 deletions SIAB/spillage/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,12 +648,12 @@ def _nzeta_mean_conf(nbands, folders):
nbands = [nbands] * len(folders) if isinstance(nbands, int) else nbands

for folder, nband in zip(folders, nbands):
nzeta_ = np.array(_nzeta_infer(folder, nband))
nzeta_ = np.array(_nzeta_infer(folder, nband, 'wll'))
nzeta = np.resize(nzeta, np.maximum(nzeta.shape, nzeta_.shape)) + nzeta_

return [nz/len(folders) for nz in nzeta]

def _nzeta_infer(folder, nband, kernel = 'svd'):
def _nzeta_infer(folder, nband, pop = 'svd'):
"""infer nzeta based on one structure whose calculation result is stored
in the folder
Expand All @@ -665,6 +665,9 @@ def _nzeta_infer(folder, nband, kernel = 'svd'):
if specified as int, it is the highest band index to be considered.
if specified as list or range, it is the list of band indexes to be
considered
pop: str, optional
the population analysis method used to infer nzeta, can be 'svd' or 'wll',
default is 'svd'
Returns
-------
Expand All @@ -674,8 +677,11 @@ def _nzeta_infer(folder, nband, kernel = 'svd'):
import numpy as np
from SIAB.spillage.datparse import read_wfc_lcao_txt, read_triu, \
read_running_scf_log, read_input_script
from SIAB.spillage.lcao_wfc_analysis import _wll, _svd_on_wfc
infer_map = {"svd": _svd_on_wfc, "wll": _wll}
from SIAB.spillage.lcao_wfc_analysis import _wll, _rad_svd

def _wll_kernel(C, S, nbands, natom, nzeta, **kwargs):
return _wll_fold(_wll(C, S, natom, nzeta), nbands) / natom[0]
infer_kernel = {"svd": _rad_svd, "wll": _wll_kernel}

# read INPUT and running_*.log
params = read_input_script(os.path.join(folder, "INPUT"))
Expand All @@ -701,7 +707,7 @@ def _nzeta_infer(folder, nband, kernel = 'svd'):
# the complete return list is (wfc.T, e, occ, k)
ovlp = read_triu(os.path.join(outdir, f"data-{isk}-S"))

nz = _wll_fold(_wll(wfc, ovlp, running["natom"], running["nzeta"]), nband)/running["natom"][0]
nz = infer_kernel[pop](wfc, ovlp, nband, running["natom"], running["nzeta"])
nzeta = np.resize(nzeta, np.maximum(nzeta.shape, nz.shape)) + nz * w / nspin

# count the number of atoms
Expand Down Expand Up @@ -1004,13 +1010,13 @@ def test_nzeta_infer(self):

# gamma case is easy, multi-k case is more difficult
fpath = os.path.join(here, "testfiles/Si/jy-7au/monomer-gamma/")
nzeta = _nzeta_infer(fpath, 4)
nzeta = _nzeta_infer(fpath, 4, 'wll')
ref = [1, 1, 0]
self.assertTrue(all([abs(nz - ref[i]) < 1e-8 for i, nz in enumerate(nzeta)]))
nzeta = _nzeta_infer(fpath, 5)
nzeta = _nzeta_infer(fpath, 5, 'wll')
ref = [2, 1, 0]
self.assertTrue(all([abs(nz - ref[i]) < 1e-8 for i, nz in enumerate(nzeta)]))
nzeta = _nzeta_infer(fpath, 10)
nzeta = _nzeta_infer(fpath, 10, 'wll')
ref = [2, 1, 1]
self.assertTrue(all([abs(nz - ref[i]) < 1e-8 for i, nz in enumerate(nzeta)]))

Expand Down Expand Up @@ -1064,18 +1070,18 @@ def test_nzeta_infer(self):
degen = np.array([2*i + 1 for i in range(3)], dtype=float)

nbnd = 4
nzeta = _nzeta_infer(fpath, nbnd)
nzeta = _nzeta_infer(fpath, nbnd, 'wll')
ref = np.sum(np.array([np.sum(refdata[i, :nbnd, :], 0) / degen * wk[i] for i in range(4)]), 0)
self.assertTrue(all([abs(nz - ref[i]) < 1e-3 for i, nz in enumerate(nzeta)]))
# because we use the data only has ndigits=3, so we can only compare to 1e-3

nbnd = 5
nzeta = _nzeta_infer(fpath, nbnd)
nzeta = _nzeta_infer(fpath, nbnd, 'wll')
ref = np.sum(np.array([np.sum(refdata[i, :nbnd, :], 0) / degen * wk[i] for i in range(4)]), 0)
self.assertTrue(all([abs(nz - ref[i]) < 1e-3 for i, nz in enumerate(nzeta)]))

nbnd = 10
nzeta = _nzeta_infer(fpath, nbnd)
nzeta = _nzeta_infer(fpath, nbnd, 'wll')
ref = np.sum(np.array([np.sum(refdata[i, :nbnd, :], 0) / degen * wk[i] for i in range(4)]), 0)
self.assertTrue(all([abs(nz - ref[i]) < 1e-3 for i, nz in enumerate(nzeta)]))

Expand Down
33 changes: 18 additions & 15 deletions SIAB/spillage/lcao_wfc_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def _mean_max(a):
'''
return np.max(np.abs(a), axis=0)

def _wfc_interp(C, nbands, natom, nzeta, view = 'reduce'):
'''interpret wavefunction coefficients in different view, rearrange
def _wfc_reinterp(C, nbands, natom, nzeta, view = 'reduce'):
'''reinterpret wavefunction coefficients in different view, rearrange
and concatenate the wavefunction coefficients of all bands into a
matrix
Expand Down Expand Up @@ -104,15 +104,16 @@ def _wfc_interp(C, nbands, natom, nzeta, view = 'reduce'):
Ct[it][l][m, iz, iaib] = C[i, ib]
return Ct

def _svd_on_wfc(C,
S,
nbands,
natom,
nzeta,
fold_m = 'rotational-invariant',
svd_view = 'reduce'):
def _rad_svd(C,
S,
nbands,
natom,
nzeta,
fold_m = 'rotational-invariant',
reinterp_view = 'reduce'):
'''perform svd on the wave function coefficients, return the
singular value for each atomtype and each l, each zeta function
singular value of zeta function of each atomtype each l, which
represents the weight.
Parameters
----------
Expand All @@ -131,8 +132,8 @@ def _svd_on_wfc(C,
fold_m : str
Method to average over m for each l. Options are
'rotational-invariant' and 'max'.
svd_view : str
Method to concatenate the wave function coefficients. Options are
reinterp_view : str
Method to reinterpret the wave function coefficients. Options are
'decompose' and 'reduce'.
Returns
Expand All @@ -143,19 +144,21 @@ def _svd_on_wfc(C,
mean_map = {'rotational-invariant': _mean_rotinv,
'max': _mean_max}

# orthogonalize the wave function coefficients
C = la.sqrtm(S) @ C

lmax = [len(nz) - 1 for nz in nzeta]

ntyp = len(natom)
Ct = _wfc_interp(C, nbands, natom, nzeta, svd_view)

C = _wfc_reinterp(C, nbands, natom, nzeta, reinterp_view)

mat_tlm = [[np.zeros(shape=(2*l+1, nz)) # nz = nzeta[it][l]
for l, nz in enumerate(nzeta[it])] for it in range(ntyp)]
for it in range(ntyp):
for l in range(lmax[it]+1):
for m in range(2*l+1):
mat_tlm[it][l][m] = la.svd(Ct[it][l][m], compute_uv=False)
mat_tlm[it][l][m] = la.svd(C[it][l][m], compute_uv=False)

mean = mean_map[fold_m]
out = [[mean(mat_tlm[it][l]) for l in range(len(nzeta[it]))] for it in range(ntyp)]
Expand Down Expand Up @@ -207,7 +210,7 @@ def test_svd_on_wfc(self):
dat = read_running_scf_log(outdir + 'running_scf.log')
nbands = 25

sigma = _svd_on_wfc(wfc, S, nbands,
sigma = _rad_svd(wfc, S, nbands,
dat['natom'],
dat['nzeta'],
'rotational-invariant',
Expand Down

0 comments on commit 26199ea

Please sign in to comment.