Skip to content

Commit

Permalink
Refactor: substitute wll to svd
Browse files Browse the repository at this point in the history
  • Loading branch information
kirk0830 committed Oct 8, 2024
1 parent 0445e4e commit 1bb9159
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 5 deletions.
13 changes: 9 additions & 4 deletions SIAB/spillage/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def _nzeta_mean_conf(nbands, folders):

return [nz/len(folders) for nz in nzeta]

def _nzeta_infer(folder, nband):
def _nzeta_infer(folder, nband, kernel = 'svd'):
"""infer nzeta based on one structure whose calculation result is stored
in the folder
Expand All @@ -674,7 +674,8 @@ def _nzeta_infer(folder, nband):
import numpy as np
from SIAB.spillage.datparse import read_wfc_lcao_txt, read_triu, \
read_running_scf_log, read_input_script
from SIAB.spillage.lcao_wfc_analysis import _wll
from SIAB.spillage.lcao_wfc_analysis import _wll, _svd_on_wfc
infer_map = {"svd": _svd_on_wfc, "wll": _wll}

# read INPUT and running_*.log
params = read_input_script(os.path.join(folder, "INPUT"))
Expand All @@ -700,13 +701,13 @@ def _nzeta_infer(folder, nband):
# the complete return list is (wfc.T, e, occ, k)
ovlp = read_triu(os.path.join(outdir, f"data-{isk}-S"))

nz = _wll_fold(_wll(wfc, ovlp, running["natom"], running["nzeta"]), nband)
nz = _wll_fold(_wll(wfc, ovlp, running["natom"], running["nzeta"]), nband)/running["natom"][0]
nzeta = np.resize(nzeta, np.maximum(nzeta.shape, nz.shape)) + nz * w / nspin

# count the number of atoms
assert len(running["natom"]) == 1, f"multiple atom types are not supported: {running['natom']}"

return nzeta/running["natom"][0]
return nzeta

def _wll_fold(wll, nband):
"""One of strategy for inferring nzeta from wll matrix. This function
Expand All @@ -721,6 +722,10 @@ def _wll_fold(wll, nband):
if specified as int, it is the highest band index to be considered.
if specified as list or range, it is the list of band indexes to be
considered
Returns
-------
np.ndarray: the folded wll matrix in shape of 1*(lmax+1)
"""

nband = range(nband) if isinstance(nband, int) else nband
Expand Down
148 changes: 147 additions & 1 deletion SIAB/spillage/lcao_wfc_analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from SIAB.spillage.index import _lin2comp

import numpy as np
import scipy.linalg as la

def _wll(C, S, natom, nzeta):
'''
Expand Down Expand Up @@ -46,7 +47,120 @@ def _wll(C, S, natom, nzeta):

return wll

def _mean_rotinv(a):
'''average over m for each l with rotational invariance:
out = sqrt{sum_m |C_m|^2 * 4pi/(2l+1)}
'''
l = (a.shape[0] - 1) // 2
out = np.sum(np.abs(a)**2, axis=0) / (2*l + 1)
return np.sqrt(out)

def _mean_max(a):
'''average over m for each l with maximum:
out = max_m |C_m|
'''
return np.max(np.abs(a), axis=0)

def _wfc_interp(C, nbands, natom, nzeta, view = 'reduce'):
'''interpret wavefunction coefficients in different view, rearrange
and concatenate the wavefunction coefficients of all bands into a
matrix
Parameters
----------
C : 2D array of shape (nao, nbands)
Wave function coefficients in LCAO basis. The datatype is
complex for multi-k calculations and float for gamma-only.
natom : list of int
Number of atoms for each type.
nzeta : list of list of int
nzeta[i][l] specifies the number of zeta orbitals of the
angular momentum l of type i. len(nzeta) must equal len(natom).
method : str
Method to concatenate the wave function coefficients. Options are
'decompose' and 'reduce'. decompose: concatentate the C matrix of
atoms in direction of basis. reduce: concatenate the C matrix of
atoms in direction of bands.
'''
assert view in ['decompose', 'reduce']

nao, nbands_max = C.shape
assert nbands <= nbands_max, 'nbands selected is larger than the total nbands'

ntyp = len(nzeta)
if view == 'decompose':
Ct = [[np.zeros(shape=(2*l+1, natom[it]*nz, nbands)) # nz = nzeta[it][l]
for l, nz in enumerate(nzeta[it])] for it in range(ntyp)]
for i, (it, ia, l, iz, m) in enumerate(_lin2comp(natom, nzeta)):
iaiz = ia*nzeta[it][l] + iz
for ib in range(nbands):
Ct[it][l][m, iaiz, ib] = C[i, ib]
else:
Ct = [[np.zeros(shape=(2*l+1, nz, natom[it]*nbands)) # nz = nzeta[it][l]
for l, nz in enumerate(nzeta[it])] for it in range(ntyp)]
for i, (it, ia, l, iz, m) in enumerate(_lin2comp(natom, nzeta)):
for ib in range(nbands):
iaib = ia*nbands + ib
Ct[it][l][m, iz, iaib] = C[i, ib]
return Ct

def _svd_on_wfc(C,
S,
nbands,
natom,
nzeta,
fold_m = 'rotational-invariant',
svd_view = 'reduce'):
'''perform svd on the wave function coefficients, return the
singular value for each atomtype and each l, each zeta function
Parameters
----------
C : 2D array of shape (nao, nbands)
Wave function coefficients in LCAO basis. The datatype is
complex for multi-k calculations and float for gamma-only.
S : 2D array of shape (nao, nao)
Overlap matrix.
nbands : int
Number of bands selected for the analysis.
natom : list of int
Number of atoms for each type.
nzeta : list of list of int
nzeta[i][l] specifies the number of zeta orbitals of the
angular momentum l of type i. len(nzeta) must equal len(natom).
fold_m : str
Method to average over m for each l. Options are
'rotational-invariant' and 'max'.
svd_view : str
Method to concatenate the wave function coefficients. Options are
'decompose' and 'reduce'.
Returns
-------
sigma : list of list of float
Singular values for each atomtype and each l, each zeta function.
'''
mean_map = {'rotational-invariant': _mean_rotinv,
'max': _mean_max}

C = la.sqrtm(S) @ C

lmax = [len(nz) - 1 for nz in nzeta]

ntyp = len(natom)
Ct = _wfc_interp(C, nbands, natom, nzeta, svd_view)

mat_tlm = [[np.zeros(shape=(2*l+1, nz)) # nz = nzeta[it][l]
for l, nz in enumerate(nzeta[it])] for it in range(ntyp)]
for it in range(ntyp):
for l in range(lmax[it]+1):
for m in range(2*l+1):
mat_tlm[it][l][m] = la.svd(Ct[it][l][m], compute_uv=False)

mean = mean_map[fold_m]
out = [[mean(mat_tlm[it][l]) for l in range(len(nzeta[it]))] for it in range(ntyp)]

return out

############################################################
# Test
Expand All @@ -73,7 +187,7 @@ def test_wll_gamma(self):
for ib, wb in enumerate(wll):
self.assertAlmostEqual(np.sum(wb.real), 1.0, places=6)

# return # suppress output
return # suppress output

for ib, wb in enumerate(wll):
wl_row_sum = np.sum(wb.real, 0)
Expand All @@ -83,6 +197,38 @@ def test_wll_gamma(self):
print(f" sum = {np.sum(wl_row_sum):6.3f}")
print('')

def test_svd_on_wfc(self):
import os
here = os.path.dirname(os.path.abspath(__file__))
outdir = os.path.join(here, 'testfiles/Si/jy-7au/monomer-gamma/OUT.ABACUS/')

wfc = read_wfc_lcao_txt(outdir + 'WFC_NAO_GAMMA1.txt')[0]
S = read_triu(outdir + 'data-0-S')
dat = read_running_scf_log(outdir + 'running_scf.log')
nbands = 25

sigma = _svd_on_wfc(wfc, S, nbands,
dat['natom'],
dat['nzeta'],
'rotational-invariant',
'reduce')
self.assertEqual(len(sigma), len(dat['natom'])) # number of atom types
for i, (nt, nz) in enumerate(zip(dat['natom'], dat['nzeta'])):
self.assertEqual(len(sigma[i]), len(nz)) # number of l orbitals

# return # suppress output
for i, (nt, nz) in enumerate(zip(dat['natom'], dat['nzeta'])):
print(f"Atom type {i+1}")
for l, s in enumerate(sigma[i]):
print(f"l = {l}")
for ix, x in enumerate(s):
print(f"{x:6.3f} ", end='')
if ix % 5 == 4 and ix != len(s) - 1:
print('')
print('')
print(f'sum = {np.sum(s):6.3f}\n')
print('')


def test_wll_multi_k(self):
import os
Expand Down

0 comments on commit 1bb9159

Please sign in to comment.