Skip to content

Commit 2801b0e

Browse files
committed
Update multi-channel precon
1 parent b323d06 commit 2801b0e

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

cxx/rl/precon.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ template <int ND> auto KSpaceSingle(GridOpts<ND> const &gridOpts, TrajectoryN<ND
2525
{
2626
Log::Print("Precon", "Starting preconditioner calculation, λ {}", λ);
2727
TrajectoryN<ND> newTraj(traj.points() * 2.f, MulToEven(traj.matrix(), 2), traj.voxelSize() / 2.f);
28-
auto nufft = TOps::MakeNUFFT<ND>(gridOpts, newTraj, 1, basis);
29-
CxN<ND + 2> psf(nufft->ishape);
30-
Cx3 W(nufft->oshape);
28+
TOps::NUFFT<ND> nufft(gridOpts, newTraj, 1, basis);
29+
CxN<ND + 2> psf(nufft.ishape);
30+
Cx3 W(nufft.oshape);
3131

3232
W.setConstant(1.f);
33-
nufft->adjoint(W, psf);
33+
nufft.adjoint(W, psf);
3434
CxN<ND + 2> ones(AddBack(traj.matrix(), psf.dimension(ND), psf.dimension(ND + 1)));
3535
ones.setConstant(1.f);
3636
TOps::Pad<ND + 2> padX(ones.dimensions(), psf.dimensions());
@@ -48,8 +48,10 @@ template <int ND> auto KSpaceSingle(GridOpts<ND> const &gridOpts, TrajectoryN<ND
4848
// I do not understand this scaling factor but it's in Frank's code and works
4949
float scale =
5050
std::pow(Product(FirstN<ND>(psf.dimensions())), 1.5f) / Product(traj.matrix()) / Product(FirstN<ND>(ones.dimensions()));
51-
Re3 weights = nufft->forward(xcor).abs() * scale;
51+
Re3 weights(nufft.oshape);
52+
weights.device(Threads::TensorDevice()) = nufft.forward(xcor).abs() * scale;
5253
if constexpr (ND == 3) { Log::Tensor("w", LastN<2>(weights.dimensions()), weights.data(), {"s", "t"}); }
54+
Log::Print("Precon", "Before inversion min {} max {}", Minimum(weights), Maximum(weights));
5355
weights.device(Threads::TensorDevice()) = (weights + weights.constant(λ)) / weights.constant(1.f + λ);
5456
weights.device(Threads::TensorDevice()) = (weights == 0.f).select(weights.constant(1.f), weights);
5557
weights.device(Threads::TensorDevice()) = weights.constant(1.f) / weights;
@@ -98,8 +100,9 @@ auto KSpaceMulti(Cx5 const &smaps, GridOpts<3> const &gridOpts, Trajectory const
98100
Sz5 const smap1Shape = AddBack(FirstN<3>(smapShape), nB, 1);
99101
Sz5 const xcor1Shape = AddBack(FirstN<3>(psfShape), nB, 1);
100102

101-
auto padXC = TOps::Pad<5>(smap1Shape, xcor1Shape);
102-
Cx5 smap1(smap1Shape), xcorTemp(xcor1Shape), xcor1(xcor1Shape), xcor(psfShape);
103+
auto padXC = TOps::Pad<5>(smap1Shape, xcor1Shape);
104+
Cx5 smap1(smap1Shape), xcorTemp(xcor1Shape), xcor1(xcor1Shape), xcor(psfShape);
105+
Sz3 const slSz{1, nSamp, nTrace};
103106
for (Index si = 0; si < nC; si++) {
104107
float const ni = Norm2<true>(smaps.chip<1>(si));
105108
xcor1.setZero();
@@ -118,9 +121,14 @@ auto KSpaceMulti(Cx5 const &smaps, GridOpts<3> const &gridOpts, Trajectory const
118121
} else {
119122
xcor.device(Threads::TensorDevice()) = xcor1 * psf;
120123
}
121-
weights.slice(Sz3{si, 0, 0}, Sz3{1, nSamp, nTrace}).device(Threads::TensorDevice()) =
122-
(nufft.forward(xcor).abs() * scale / ni + λ) / (1.f + λ);
124+
Sz3 const slSt{si, 0, 0};
125+
weights.slice(slSt, slSz).device(Threads::TensorDevice()) = nufft.forward(xcor).abs() * scale;
123126
}
127+
Log::Print("Precon", "Before inversion min {} max {}", Minimum(weights), Maximum(weights));
128+
weights.device(Threads::TensorDevice()) = (weights + weights.constant(λ)) / weights.constant(1.f + λ);
129+
weights.device(Threads::TensorDevice()) = (weights == 0.f).select(weights.constant(1.f), weights);
130+
weights.device(Threads::TensorDevice()) = weights.constant(1.f) / weights;
131+
124132
float const norm = Norm<true>(weights);
125133
if (!std::isfinite(norm)) {
126134
throw Log::Failure("Precon", "Pre-conditioner norm was not finite ({})", norm);

0 commit comments

Comments
 (0)