Skip to content

Commit 66aa4e8

Browse files
committed
Lobpcg with unittests
1 parent 0ed1894 commit 66aa4e8

8 files changed

Lines changed: 2115 additions & 5 deletions

File tree

source/source_estate/elecstate_print.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ void print_scf_iterinfo(const std::string& ks_solver,
5555
{"elpa", "EL"},
5656
{"dav", "DA"},
5757
{"dav_subspace", "DS"},
58+
{"lobpcg", "LB"},
5859
{"scalapack_gvx", "GV"},
5960
{"cusolver", "CU"},
6061
{"bpcg", "BP"},

source/source_hsolver/diago_lobpcg.cpp

Lines changed: 1089 additions & 0 deletions
Large diffs are not rendered by default.

source/source_hsolver/diago_lobpcg.h

Lines changed: 490 additions & 0 deletions
Large diffs are not rendered by default.

source/source_hsolver/hsolver_pw.cpp

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66
#include "source_estate/elecstate_pw.h"
77
#include "source_hamilt/hamilt.h"
88
#include "source_hsolver/diag_comm_info.h"
9+
910
#include "source_hsolver/diago_bpcg.h"
1011
#include "source_hsolver/diago_cg.h"
1112
#include "source_hsolver/diago_dav_subspace.h"
1213
#include "source_hsolver/diago_david.h"
14+
#include "source_hsolver/diago_lobpcg.h"
1315
#include "source_hsolver/diago_iter_assist.h"
16+
1417
#include "source_io/module_parameter/parameter.h"
1518
#include "source_psi/psi.h"
1619
#include "source_estate/elecstate_tools.h"
@@ -82,7 +85,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
8285
this->nproc_in_pool = nproc_in_pool_in;
8386

8487
// report if the specified diagonalization method is not supported
85-
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg"};
88+
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg", "lobpcg"};
8689
if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods))
8790
{
8891
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This type of eigensolver is not supported!");
@@ -393,6 +396,11 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
393396
hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
394397
};
395398

399+
double tolerance = this->diag_thr;
400+
int max_iter = this->diag_iter_max;
401+
std::cout << "DS default tolerance: " << tolerance << ", max_iter: " << max_iter << std::endl;
402+
403+
396404
Diago_DavSubspace<T, Device> dav_subspace(pre_condition,
397405
psi.get_nbands(),
398406
psi.get_k_first() ? psi.get_current_ngk()
@@ -467,6 +475,52 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
467475
ntry_max,
468476
notconv_max));
469477
}
478+
else if (this->method == "lobpcg"){
479+
// hpsi_func (X, HX, ld, nvec) -> HX = H(X), X and HX blockvectors of size ld x nvec
480+
auto hpsi_func = [hm, cur_nbasis](T* psi_in, T* hpsi_out, const int ld_psi, const int nvec) {
481+
482+
// Convert "pointer data stucture" to a psi::Psi object
483+
auto psi_iter_wrapper = psi::Psi<T, Device>(psi_in, 1, nvec, ld_psi, cur_nbasis);
484+
485+
psi::Range bands_range(true, 0, 0, nvec - 1);
486+
487+
using hpsi_info = typename hamilt::Operator<T, Device>::hpsi_info;
488+
hpsi_info info(&psi_iter_wrapper, bands_range, hpsi_out);
489+
hm->ops->hPsi(info);
490+
};
491+
492+
auto spsi_func = [hm](T* psi_in, T* spsi_out, const int ld_psi, const int nvec) {
493+
hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
494+
};
495+
const int ndim = psi.get_current_ngk(); /// dimension of matrix
496+
const int nband = psi.get_nbands(); /// number of eigenpairs sought
497+
const int nmax = nband + 20;
498+
const int ld_psi = psi.get_nbasis(); /// leading dimension of psi
499+
500+
bool gen_eig = false;
501+
502+
double tolerance = this->diag_thr;
503+
int max_iter = this->diag_iter_max;
504+
// print default tolerance and max_iter for LOBPCG
505+
std::cout << "LOBPCG default tolerance: " << tolerance << ", max_iter: " << max_iter << std::endl;
506+
max_iter = 1000; // LOBPCG is not stable enough, set max_iter to 200 to avoid divergence. TODO: further test and optimize LOBPCG in the future.
507+
if (tolerance > 1e-6)tolerance = 1e-6;
508+
std::cout << "LOBPCG current tolerance: " << tolerance << ", max_iter: " << max_iter << std::endl;
509+
510+
DiagoLOBPCG<T, Device> lobpcg(pre_condition.data(), nband, ndim, nmax);
511+
bool ok = lobpcg.diag(hpsi_func, spsi_func, gen_eig,
512+
eigenvalue, psi.get_pointer(), ld_psi, tolerance, max_iter);
513+
}
514+
// now print lowest 5 eigenvalues for debugging
515+
if (this->rank_in_pool == 0)
516+
{
517+
std::cout << "Lowest 5 eigenvalues for current k-point: ";
518+
for (int i = 0; i < std::min(5, psi.get_nbands()); i++)
519+
{
520+
std::cout << eigenvalue[i] << " ";
521+
}
522+
std::cout << std::endl;
523+
}
470524
ModuleBase::timer::tick("HSolverPW", "solve_psik");
471525
return;
472526
}

source/source_hsolver/test/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ if (ENABLE_MPI)
9292
AddTest(
9393
TARGET MODULE_HSOLVER_pw
9494
LIBS parameter ${math_libs} psi device base container
95-
SOURCES test_hsolver_pw.cpp ../hsolver_pw.cpp ../hsolver_lcaopw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp
95+
SOURCES test_hsolver_pw.cpp ../hsolver_pw.cpp ../hsolver_lcaopw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp ../diago_lobpcg.cpp
9696
../../source_estate/elecstate_tools.cpp ../../source_estate/occupy.cpp ../../source_base/module_fft/fft_bundle.cpp ../../source_base/module_fft/fft_cpu.cpp
9797
)
9898

9999
AddTest(
100100
TARGET MODULE_HSOLVER_sdft
101101
LIBS parameter ${math_libs} psi device base container
102-
SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp
102+
SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp ../diago_lobpcg.cpp
103103
../../source_estate/elecstate_tools.cpp ../../source_estate/occupy.cpp ../../source_base/module_fft/fft_bundle.cpp ../../source_base/module_fft/fft_cpu.cpp
104104
)
105105

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
#include <gtest/gtest.h>
2+
#include <complex>
3+
#include <vector>
4+
#include <cmath>
5+
#include <iostream>
6+
#include <random>
7+
#include "source_hsolver/diago_lobpcg.h"
8+
9+
// Define complex double type
10+
using Complex = std::complex<double>;
11+
12+
// Declare LAPACK zheev helper
13+
extern "C" {
14+
void zheev_(const char* jobz, const char* uplo, const int* n, Complex* a, const int* lda, double* w, Complex* work, const int* lwork, double* rwork, int* info);
15+
void zgemm_(const char* transa, const char* transb, const int* m, const int* n, const int* k,
16+
const Complex* alpha, const Complex* a, const int* lda,
17+
const Complex* b, const int* ldb,
18+
const Complex* beta, Complex* c, const int* ldc);
19+
}
20+
21+
class DiagoLobpcgTest : public testing::Test {
22+
protected:
23+
std::vector<Complex> matrix;
24+
25+
// Generate matrix
26+
// type 0: Deterministic (Original)
27+
// type 1: Random Diagonally Dominant Complex Hermitian
28+
void GenerateMatrix(int n, int type) {
29+
matrix.resize(n * n);
30+
if (type == 0) {
31+
for (int j = 0; j < n; ++j) {
32+
for (int i = 0; i < n; ++i) {
33+
if (i == j) {
34+
matrix[j * n + i] = static_cast<double>(i + 1); // Diagonal 1..n
35+
} else {
36+
// Off-diagonal
37+
double val = 1.0 / (std::abs(i - j) + 1.0);
38+
matrix[j * n + i] = val * 0.1;
39+
}
40+
}
41+
}
42+
} else if (type == 1) {
43+
// Random Hermitian matrix
44+
// Use specific seed for reproducibility
45+
std::mt19937 gen(42);
46+
// Diagonal elements: spaced out to ensure good conditioning for basic LOBPCG
47+
// Off-diagonal: small random values
48+
std::uniform_real_distribution<> val_dist(-0.5, 0.5);
49+
50+
for (int j = 0; j < n; ++j) {
51+
// Diagonal (real)
52+
// j+1 plus small noise. Keeps eigenvalues well separated approx 1.0 apart.
53+
matrix[j * n + j] = static_cast<double>(j + 1) + val_dist(gen) * 0.5;
54+
55+
// Off-diagonal (complex)
56+
for (int i = j + 1; i < n; ++i) {
57+
Complex val(val_dist(gen) * 0.05, val_dist(gen) * 0.05);
58+
// A(i, j) at matrix[j*n + i]
59+
matrix[j * n + i] = val;
60+
// A(j, i) at matrix[i*n + j]
61+
matrix[i * n + j] = std::conj(val);
62+
}
63+
}
64+
}
65+
}
66+
67+
void VerifyLobpcg(int n, int nband, double check_tol = 1e-3, double cg_tol = 1e-5) {
68+
// ---------------------------------------------------------
69+
// 1. Solve with LAPACK (Gold Standard)
70+
// ---------------------------------------------------------
71+
std::vector<Complex> mat_lapack = matrix; // Deep copy
72+
std::vector<double> ev_lapack(n);
73+
74+
Complex work_query;
75+
std::vector<double> rwork(3 * n - 2);
76+
int lwork_query = -1;
77+
int info = 0;
78+
int n_val = n;
79+
80+
char jobz = 'N';
81+
char uplo = 'U';
82+
83+
// Query workspace
84+
zheev_(&jobz, &uplo, &n_val, mat_lapack.data(), &n_val, ev_lapack.data(), &work_query, &lwork_query, rwork.data(), &info);
85+
86+
int lwork = static_cast<int>(work_query.real()) + 1;
87+
std::vector<Complex> work(lwork);
88+
89+
// Compute
90+
zheev_(&jobz, &uplo, &n_val, mat_lapack.data(), &n_val, ev_lapack.data(), work.data(), &lwork, rwork.data(), &info);
91+
92+
ASSERT_EQ(info, 0) << "LAPACK zheev computation failed with info=" << info;
93+
94+
// Output LAPACK eigenvalues for debug
95+
// std::cout << "LAPACK computed eigenvalues (first 5): ";
96+
// for(int i=0; i<5 && i<n; ++i) std::cout << ev_lapack[i] << " ";
97+
// std::cout << std::endl;
98+
99+
// ---------------------------------------------------------
100+
// 2. Solve with LOBPCG
101+
// ---------------------------------------------------------
102+
std::vector<double> precondition(n, 1.0); // Identity Preconditioner
103+
104+
int n_max = nband + 5;
105+
hsolver::DiagoLOBPCG<Complex> lobpcg(precondition.data(), nband, n, n_max);
106+
107+
std::vector<double> ev_lobpcg(nband);
108+
std::vector<Complex> psi(n * nband);
109+
110+
// Initialize psi with values
111+
for(auto &val : psi) val = static_cast<double>(rand()) / RAND_MAX;
112+
113+
auto hpsi_func = [&](Complex* in, Complex* out, const int ld, const int nvec) {
114+
char transa = 'N';
115+
char transb = 'N';
116+
int m_ = n;
117+
int n_ = nvec;
118+
int k_ = n;
119+
Complex alpha = 1.0;
120+
Complex beta = 0.0;
121+
int lda = n;
122+
123+
zgemm_(&transa, &transb, &m_, &n_, &k_,
124+
&alpha, matrix.data(), &lda,
125+
in, &ld,
126+
&beta, out, &ld);
127+
};
128+
129+
int max_iter = 2000;
130+
bool converged = lobpcg.diag(
131+
hpsi_func,
132+
nullptr,
133+
false,
134+
ev_lobpcg.data(),
135+
psi.data(),
136+
n,
137+
cg_tol,
138+
max_iter
139+
);
140+
141+
EXPECT_TRUE(converged) << "LOBPCG did not converge in " << max_iter << " iterations";
142+
143+
// Output LOBPCG eigenvalues for debug
144+
// std::cout << "LOBPCG computed eigenvalues (first 5): ";
145+
// for(int i=0; i<5 && i<nband; ++i) std::cout << ev_lobpcg[i] << " ";
146+
// std::cout << std::endl;
147+
148+
// ---------------------------------------------------------
149+
// 3. Compare Results
150+
// ---------------------------------------------------------
151+
for(int i = 0; i < nband; ++i) {
152+
EXPECT_NEAR(ev_lobpcg[i], ev_lapack[i], check_tol)
153+
<< "Mismatch at eigenvalue index " << i
154+
<< " LAPACK: " << ev_lapack[i] << " LOBPCG: " << ev_lobpcg[i];
155+
}
156+
// output anyway even if test passed, for debug
157+
std::cout << "Eigenvalues comparison (LOBPCG vs LAPACK):" << std::endl;
158+
std::cout << "LAPACK eigenvalues: ";
159+
for(int i=0; i<nband; ++i) {
160+
std::cout << "Index " << i << ": " << ev_lapack[i] << " vs " << ev_lapack[i] << std::endl;
161+
}
162+
std::cout << std::endl;
163+
std::cout << "LOBPCG eigenvalues: ";
164+
for(int i=0; i<nband; ++i) {
165+
std::cout << "Index " << i << ": " << ev_lobpcg[i] << " vs " << ev_lapack[i] << std::endl;
166+
}
167+
}
168+
};
169+
170+
TEST_F(DiagoLobpcgTest, CompareWithLapack) {
171+
int n = 100;
172+
int nband = 10;
173+
GenerateMatrix(n, 0);
174+
VerifyLobpcg(n, nband);
175+
}
176+
177+
TEST_F(DiagoLobpcgTest, LargeScale) {
178+
int n = 200;
179+
int nband = 20;
180+
GenerateMatrix(n, 0);
181+
VerifyLobpcg(n, nband);
182+
}
183+
184+
TEST_F(DiagoLobpcgTest, RandomMatrix) {
185+
int n = 50;
186+
int nband = 10;
187+
GenerateMatrix(n, 1);
188+
VerifyLobpcg(n, nband, 0.1, 1e-2);
189+
}

0 commit comments

Comments
 (0)