Skip to content

Commit dd396c2

Browse files
committed
Update to work with shared pointers
1 parent f0bb5eb commit dd396c2

6 files changed

Lines changed: 182 additions & 54 deletions

File tree

include/gauxc/util/c_load_balancer.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ namespace GauXC::detail {
1818
static inline LoadBalancer* get_load_balancer_ptr(C::GauXCLoadBalancer lb) noexcept {
1919
return static_cast<LoadBalancer*>(lb.ptr);
2020
}
21+
static inline std::shared_ptr<LoadBalancer>* get_load_balancer_shared(C::GauXCLoadBalancer lb) noexcept {
22+
return static_cast<std::shared_ptr<LoadBalancer>*>(lb.ptr);
23+
}
2124
static inline LoadBalancerFactory* get_load_balancer_factory_ptr(C::GauXCLoadBalancerFactory lbf) noexcept {
2225
return static_cast<LoadBalancerFactory*>(lbf.ptr);
2326
}

include/gauxc/util/c_molecular_weights.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ namespace GauXC::detail {
1818
static inline MolecularWeights* get_molecular_weights_ptr(C::GauXCMolecularWeights mw) noexcept {
1919
return static_cast<MolecularWeights*>(mw.ptr);
2020
}
21+
static inline std::shared_ptr<MolecularWeights>* get_molecular_weights_shared(C::GauXCMolecularWeights mw) noexcept {
22+
return static_cast<std::shared_ptr<MolecularWeights>*>(mw.ptr);
23+
}
2124
static inline MolecularWeightsFactory* get_molecular_weights_factory_ptr(C::GauXCMolecularWeightsFactory mwf) noexcept {
2225
return static_cast<MolecularWeightsFactory*>(mwf.ptr);
2326
}

include/gauxc/util/c_xc_integrator.hpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,51 @@
1515
#include <gauxc/xc_integrator.hpp>
1616
#include <gauxc/xc_integrator/integrator_factory.hpp>
1717
#include <gauxc/util/c_matrix.hpp>
18+
#include <gauxc/util/c_load_balancer.hpp>
19+
#include <gauxc/util/c_functional.hpp>
1820

1921
namespace GauXC::detail {
2022

2123
static inline XCIntegrator<CMatrix>* get_xc_integrator_ptr(C::GauXCIntegrator integrator) noexcept {
2224
return static_cast<XCIntegrator<CMatrix>*>(integrator.ptr);
2325
}
24-
26+
static inline std::shared_ptr<XCIntegrator<CMatrix>>* get_xc_integrator_shared(C::GauXCIntegrator integrator) noexcept {
27+
return static_cast<std::shared_ptr<XCIntegrator<CMatrix>>*>(integrator.ptr);
28+
}
2529
static inline XCIntegratorFactory<CMatrix>* get_xc_integrator_factory_ptr(C::GauXCIntegratorFactory factory) noexcept {
2630
return static_cast<XCIntegratorFactory<CMatrix>*>(factory.ptr);
2731
}
32+
static inline XCIntegrator<CMatrix> get_integrator_instance(
33+
C::GauXCIntegratorFactory factory,
34+
C::GauXCFunctional functional,
35+
C::GauXCLoadBalancer lb
36+
) {
37+
if (lb.owned)
38+
return get_xc_integrator_factory_ptr(factory)->get_instance(
39+
*get_functional_ptr(functional),
40+
*get_load_balancer_ptr(lb)
41+
);
42+
else
43+
return get_xc_integrator_factory_ptr(factory)->get_instance(
44+
*get_functional_ptr(functional),
45+
**get_load_balancer_shared(lb)
46+
);
47+
}
48+
static inline std::shared_ptr<XCIntegrator<CMatrix>> get_shared_integrator_instance(
49+
C::GauXCIntegratorFactory factory,
50+
C::GauXCFunctional functional,
51+
C::GauXCLoadBalancer lb
52+
) {
53+
if (lb.owned)
54+
return get_xc_integrator_factory_ptr(factory)->get_shared_instance(
55+
*get_functional_ptr(functional),
56+
*get_load_balancer_ptr(lb)
57+
);
58+
else
59+
return get_xc_integrator_factory_ptr(factory)->get_shared_instance(
60+
*get_functional_ptr(functional),
61+
**get_load_balancer_shared(lb)
62+
);
63+
}
2864

2965
} // namespace GauXC::detail

src/c_load_balancer.cxx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@ void gauxc_load_balancer_delete(
2525
GauXCLoadBalancer lb
2626
) {
2727
status->code = 0;
28-
if(lb.ptr != nullptr && lb.owned)
29-
delete detail::get_load_balancer_ptr(lb);
28+
if(lb.ptr != nullptr) {
29+
if (lb.owned)
30+
delete detail::get_load_balancer_ptr(lb);
31+
else
32+
delete detail::get_load_balancer_shared(lb);
33+
}
3034
lb.ptr = nullptr;
3135
}
3236
GauXCLoadBalancerFactory gauxc_load_balancer_factory_new(

src/c_molecular_weights.cxx

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,43 @@ GauXCMolecularWeights gauxc_molecular_weights_factory_get_instance(
4747
return mw;
4848
}
4949

50+
GauXCMolecularWeights gauxc_molecular_weights_factory_get_shared_instance(
51+
GauXCStatus* status,
52+
GauXCMolecularWeightsFactory mwf
53+
) {
54+
status->code = 0;
55+
auto mw_instance_ptr = detail::get_molecular_weights_factory_ptr(mwf)->get_shared_instance();
56+
GauXCMolecularWeights mw;
57+
mw.ptr = new std::shared_ptr<MolecularWeights>( std::move(mw_instance_ptr) );
58+
mw.owned = false;
59+
return mw;
60+
}
5061

5162
void gauxc_molecular_weights_modify_weights(
5263
GauXCStatus* status,
5364
GauXCMolecularWeights mw,
5465
GauXCLoadBalancer lb
5566
) {
5667
status->code = 0;
57-
detail::get_molecular_weights_ptr(mw)->modify_weights(
58-
*detail::get_load_balancer_ptr(lb)
59-
);
68+
if (mw.owned) {
69+
if (lb.owned)
70+
detail::get_molecular_weights_ptr(mw)->modify_weights(
71+
*detail::get_load_balancer_ptr(lb)
72+
);
73+
else
74+
detail::get_molecular_weights_ptr(mw)->modify_weights(
75+
**detail::get_load_balancer_shared(lb)
76+
);
77+
} else {
78+
if (lb.owned)
79+
detail::get_molecular_weights_shared(mw)->get()->modify_weights(
80+
*detail::get_load_balancer_ptr(lb)
81+
);
82+
else
83+
detail::get_molecular_weights_shared(mw)->get()->modify_weights(
84+
**detail::get_load_balancer_shared(lb)
85+
);
86+
}
6087
}
6188

6289

@@ -65,8 +92,12 @@ void gauxc_molecular_weights_delete(
6592
GauXCMolecularWeights mw
6693
) {
6794
status->code = 0;
68-
if(mw.ptr != nullptr && mw.owned)
69-
delete detail::get_molecular_weights_ptr(mw);
95+
if(mw.ptr != nullptr) {
96+
if (mw.owned)
97+
delete detail::get_molecular_weights_ptr(mw);
98+
else
99+
delete detail::get_molecular_weights_shared(mw);
100+
}
70101
mw.ptr = nullptr;
71102
}
72103

src/c_xc_integrator.cxx

Lines changed: 97 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void gauxc_integrator_delete(
2828
if(integrator.owned)
2929
delete detail::get_xc_integrator_ptr(integrator);
3030
else
31-
delete static_cast<std::shared_ptr< XCIntegrator<detail::CMatrix> >*>(integrator.ptr);
31+
delete detail::get_xc_integrator_shared(integrator);
3232
}
3333
integrator.ptr = nullptr;
3434
}
@@ -71,10 +71,7 @@ GauXCIntegrator gauxc_integrator_factory_get_instance(
7171
GauXCLoadBalancer lb
7272
) {
7373
status->code = 0;
74-
auto integrator_instance = detail::get_xc_integrator_factory_ptr(factory)->get_instance(
75-
*detail::get_functional_ptr(functional),
76-
*detail::get_load_balancer_ptr(lb)
77-
);
74+
auto integrator_instance = detail::get_integrator_instance(factory, functional, lb);
7875
GauXCIntegrator integrator;
7976
integrator.ptr = new XCIntegrator<detail::CMatrix>( std::move(integrator_instance) );
8077
integrator.owned = true;
@@ -88,12 +85,9 @@ GauXCIntegrator gauxc_integrator_factory_get_shared_instance(
8885
GauXCLoadBalancer lb
8986
) {
9087
status->code = 0;
91-
auto integrator_instance_ptr = detail::get_xc_integrator_factory_ptr(factory)->get_shared_instance(
92-
*detail::get_functional_ptr(functional),
93-
*detail::get_load_balancer_ptr(lb)
94-
);
88+
auto integrator_instance = detail::get_shared_integrator_instance(factory, functional, lb);
9589
GauXCIntegrator integrator;
96-
integrator.ptr = new std::shared_ptr< XCIntegrator<detail::CMatrix> >( std::move(integrator_instance_ptr) );
90+
integrator.ptr = new std::shared_ptr< XCIntegrator<detail::CMatrix> >( std::move(integrator_instance) );
9791
integrator.owned = false;
9892
return integrator;
9993
}
@@ -105,9 +99,15 @@ void gauxc_integrator_integrate_den(
10599
double* den_out
106100
) {
107101
status->code = 0;
108-
auto& xc_integrator = *detail::get_xc_integrator_ptr(integrator);
109-
auto& dm = *detail::get_matrix_ptr(density_matrix);
110-
*den_out = xc_integrator.integrate_den( dm );
102+
if (integrator.owned) {
103+
auto& xc_integrator = *detail::get_xc_integrator_ptr(integrator);
104+
auto& dm = *detail::get_matrix_ptr(density_matrix);
105+
*den_out = xc_integrator.integrate_den( dm );
106+
} else {
107+
auto& xc_integrator = *detail::get_xc_integrator_shared(integrator)->get();
108+
auto& dm = *detail::get_matrix_ptr(density_matrix);
109+
*den_out = xc_integrator.integrate_den( dm );
110+
}
111111
}
112112

113113
void gauxc_integrator_eval_exc_rks(
@@ -117,9 +117,15 @@ void gauxc_integrator_eval_exc_rks(
117117
double* exc_out
118118
) {
119119
status->code = 0;
120-
auto& xc_integrator = *detail::get_xc_integrator_ptr(integrator);
121-
auto& dm = *detail::get_matrix_ptr(density_matrix);
122-
*exc_out = xc_integrator.eval_exc( dm );
120+
if (integrator.owned) {
121+
auto& xc_integrator = *detail::get_xc_integrator_ptr(integrator);
122+
auto& dm = *detail::get_matrix_ptr(density_matrix);
123+
*exc_out = xc_integrator.eval_exc( dm );
124+
} else {
125+
auto& xc_integrator = *detail::get_xc_integrator_shared(integrator)->get();
126+
auto& dm = *detail::get_matrix_ptr(density_matrix);
127+
*exc_out = xc_integrator.eval_exc( dm );
128+
}
123129
}
124130

125131
void gauxc_integrator_eval_exc_uks(
@@ -130,10 +136,17 @@ void gauxc_integrator_eval_exc_uks(
130136
double* exc_out
131137
) {
132138
status->code = 0;
133-
auto& xc_integrator = *detail::get_xc_integrator_ptr(integrator);
134-
auto& dm_s = *detail::get_matrix_ptr(density_matrix_s);
135-
auto& dm_z = *detail::get_matrix_ptr(density_matrix_z);
136-
*exc_out = xc_integrator.eval_exc( dm_s, dm_z );
139+
if (integrator.owned) {
140+
auto& xc_integrator = *detail::get_xc_integrator_ptr(integrator);
141+
auto& dm_s = *detail::get_matrix_ptr(density_matrix_s);
142+
auto& dm_z = *detail::get_matrix_ptr(density_matrix_z);
143+
*exc_out = xc_integrator.eval_exc( dm_s, dm_z );
144+
} else {
145+
auto& xc_integrator = *detail::get_xc_integrator_shared(integrator)->get();
146+
auto& dm_s = *detail::get_matrix_ptr(density_matrix_s);
147+
auto& dm_z = *detail::get_matrix_ptr(density_matrix_z);
148+
*exc_out = xc_integrator.eval_exc( dm_s, dm_z );
149+
}
137150
}
138151

139152
void gauxc_integrator_eval_exc_gks(
@@ -146,12 +159,21 @@ void gauxc_integrator_eval_exc_gks(
146159
double* exc_out
147160
) {
148161
status->code = 0;
149-
auto& xc_integrator = *detail::get_xc_integrator_ptr(integrator);
150-
auto& dm_s = *detail::get_matrix_ptr(density_matrix_s);
151-
auto& dm_z = *detail::get_matrix_ptr(density_matrix_z);
152-
auto& dm_tx = *detail::get_matrix_ptr(density_matrix_tx);
153-
auto& dm_ty = *detail::get_matrix_ptr(density_matrix_ty);
154-
*exc_out = xc_integrator.eval_exc( dm_s, dm_z, dm_tx, dm_ty );
162+
if (integrator.owned) {
163+
auto& xc_integrator = *detail::get_xc_integrator_ptr(integrator);
164+
auto& dm_s = *detail::get_matrix_ptr(density_matrix_s);
165+
auto& dm_z = *detail::get_matrix_ptr(density_matrix_z);
166+
auto& dm_tx = *detail::get_matrix_ptr(density_matrix_tx);
167+
auto& dm_ty = *detail::get_matrix_ptr(density_matrix_ty);
168+
*exc_out = xc_integrator.eval_exc( dm_s, dm_z, dm_tx, dm_ty );
169+
} else {
170+
auto& xc_integrator = *detail::get_xc_integrator_shared(integrator)->get();
171+
auto& dm_s = *detail::get_matrix_ptr(density_matrix_s);
172+
auto& dm_z = *detail::get_matrix_ptr(density_matrix_z);
173+
auto& dm_tx = *detail::get_matrix_ptr(density_matrix_tx);
174+
auto& dm_ty = *detail::get_matrix_ptr(density_matrix_ty);
175+
*exc_out = xc_integrator.eval_exc( dm_s, dm_z, dm_tx, dm_ty );
176+
}
155177
}
156178

157179
void gauxc_integrator_eval_exc_vxc_rks(
@@ -162,10 +184,17 @@ void gauxc_integrator_eval_exc_vxc_rks(
162184
GauXCMatrix vxc_matrix
163185
) {
164186
status->code = 0;
165-
auto& xc_integrator = *detail::get_xc_integrator_ptr(integrator);
166-
auto& dm = *detail::get_matrix_ptr(density_matrix);
167-
auto& vxc = *detail::get_matrix_ptr(vxc_matrix);
168-
std::tie(*exc_out, vxc) = xc_integrator.eval_exc_vxc( dm );
187+
if (integrator.owned) {
188+
auto& xc_integrator = *detail::get_xc_integrator_ptr(integrator);
189+
auto& dm = *detail::get_matrix_ptr(density_matrix);
190+
auto& vxc = *detail::get_matrix_ptr(vxc_matrix);
191+
std::tie(*exc_out, vxc) = xc_integrator.eval_exc_vxc( dm );
192+
} else {
193+
auto& xc_integrator = *detail::get_xc_integrator_shared(integrator)->get();
194+
auto& dm = *detail::get_matrix_ptr(density_matrix);
195+
auto& vxc = *detail::get_matrix_ptr(vxc_matrix);
196+
std::tie(*exc_out, vxc) = xc_integrator.eval_exc_vxc( dm );
197+
}
169198
}
170199

171200
void gauxc_integrator_eval_exc_vxc_uks(
@@ -178,12 +207,21 @@ void gauxc_integrator_eval_exc_vxc_uks(
178207
GauXCMatrix vxc_matrix_z
179208
) {
180209
status->code = 0;
181-
auto& xc_integrator = *detail::get_xc_integrator_ptr(integrator);
182-
auto& dm_s = *detail::get_matrix_ptr(density_matrix_s);
183-
auto& dm_z = *detail::get_matrix_ptr(density_matrix_z);
184-
auto& vxc_s = *detail::get_matrix_ptr(vxc_matrix_s);
185-
auto& vxc_z = *detail::get_matrix_ptr(vxc_matrix_z);
186-
std::tie(*exc_out, vxc_s, vxc_z) = xc_integrator.eval_exc_vxc( dm_s, dm_z );
210+
if (integrator.owned) {
211+
auto& xc_integrator = *detail::get_xc_integrator_ptr(integrator);
212+
auto& dm_s = *detail::get_matrix_ptr(density_matrix_s);
213+
auto& dm_z = *detail::get_matrix_ptr(density_matrix_z);
214+
auto& vxc_s = *detail::get_matrix_ptr(vxc_matrix_s);
215+
auto& vxc_z = *detail::get_matrix_ptr(vxc_matrix_z);
216+
std::tie(*exc_out, vxc_s, vxc_z) = xc_integrator.eval_exc_vxc( dm_s, dm_z );
217+
} else {
218+
auto& xc_integrator = *detail::get_xc_integrator_shared(integrator)->get();
219+
auto& dm_s = *detail::get_matrix_ptr(density_matrix_s);
220+
auto& dm_z = *detail::get_matrix_ptr(density_matrix_z);
221+
auto& vxc_s = *detail::get_matrix_ptr(vxc_matrix_s);
222+
auto& vxc_z = *detail::get_matrix_ptr(vxc_matrix_z);
223+
std::tie(*exc_out, vxc_s, vxc_z) = xc_integrator.eval_exc_vxc( dm_s, dm_z );
224+
}
187225
}
188226

189227
void gauxc_integrator_eval_exc_vxc_gks(
@@ -200,16 +238,29 @@ void gauxc_integrator_eval_exc_vxc_gks(
200238
GauXCMatrix vxc_matrix_y
201239
) {
202240
status->code = 0;
203-
auto& xc_integrator = *detail::get_xc_integrator_ptr(integrator);
204-
auto& dm_s = *detail::get_matrix_ptr(density_matrix_s);
205-
auto& dm_z = *detail::get_matrix_ptr(density_matrix_z);
206-
auto& dm_x = *detail::get_matrix_ptr(density_matrix_x);
207-
auto& dm_y = *detail::get_matrix_ptr(density_matrix_y);
208-
auto& vxc_s = *detail::get_matrix_ptr(vxc_matrix_s);
209-
auto& vxc_z = *detail::get_matrix_ptr(vxc_matrix_z);
210-
auto& vxc_x = *detail::get_matrix_ptr(vxc_matrix_x);
211-
auto& vxc_y = *detail::get_matrix_ptr(vxc_matrix_y);
212-
std::tie(*exc_out, vxc_s, vxc_z, vxc_x, vxc_y) = xc_integrator.eval_exc_vxc( dm_s, dm_z, dm_x, dm_y );
241+
if (integrator.owned) {
242+
auto& xc_integrator = *detail::get_xc_integrator_ptr(integrator);
243+
auto& dm_s = *detail::get_matrix_ptr(density_matrix_s);
244+
auto& dm_z = *detail::get_matrix_ptr(density_matrix_z);
245+
auto& dm_x = *detail::get_matrix_ptr(density_matrix_x);
246+
auto& dm_y = *detail::get_matrix_ptr(density_matrix_y);
247+
auto& vxc_s = *detail::get_matrix_ptr(vxc_matrix_s);
248+
auto& vxc_z = *detail::get_matrix_ptr(vxc_matrix_z);
249+
auto& vxc_x = *detail::get_matrix_ptr(vxc_matrix_x);
250+
auto& vxc_y = *detail::get_matrix_ptr(vxc_matrix_y);
251+
std::tie(*exc_out, vxc_s, vxc_z, vxc_x, vxc_y) = xc_integrator.eval_exc_vxc( dm_s, dm_z, dm_x, dm_y );
252+
} else {
253+
auto& xc_integrator = *detail::get_xc_integrator_shared(integrator)->get();
254+
auto& dm_s = *detail::get_matrix_ptr(density_matrix_s);
255+
auto& dm_z = *detail::get_matrix_ptr(density_matrix_z);
256+
auto& dm_x = *detail::get_matrix_ptr(density_matrix_x);
257+
auto& dm_y = *detail::get_matrix_ptr(density_matrix_y);
258+
auto& vxc_s = *detail::get_matrix_ptr(vxc_matrix_s);
259+
auto& vxc_z = *detail::get_matrix_ptr(vxc_matrix_z);
260+
auto& vxc_x = *detail::get_matrix_ptr(vxc_matrix_x);
261+
auto& vxc_y = *detail::get_matrix_ptr(vxc_matrix_y);
262+
std::tie(*exc_out, vxc_s, vxc_z, vxc_x, vxc_y) = xc_integrator.eval_exc_vxc( dm_s, dm_z, dm_x, dm_y );
263+
}
213264
}
214265

215266
} // extern "C"

0 commit comments

Comments
 (0)