Skip to content

Commit d7c5612

Browse files
committed
Sort out precon regularization
1 parent 442380c commit d7c5612

7 files changed

Lines changed: 33 additions & 29 deletions

File tree

cxx/riesling/inputs.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ auto ReconArgs::Get() -> rl::ReconOpts { return rl::ReconOpts{.decant = decant.G
5050

5151
PreconArgs::PreconArgs(args::Subparser &parser)
5252
: type(parser, "P", "Pre-conditioner (none/single/multi/filename)", {"precon", 'p'}, "single")
53-
, λ(parser, "BIAS", "Pre-conditioner regularization (1)", {"precon-lambda"}, 1.e-3f)
53+
, max(parser, "M", "Maximum value, threshold above (1)", {"precon-max"}, 1.f)
5454
{
5555
}
5656

57-
auto PreconArgs::Get() -> rl::PreconOpts { return rl::PreconOpts{.type = type.Get(), .λ = λ.Get()}; }
57+
auto PreconArgs::Get() -> rl::PreconOpts { return rl::PreconOpts{.type = type.Get(), .max = max.Get()}; }
5858

5959
LSMRArgs::LSMRArgs(args::Subparser &parser)
6060
: its(parser, "N", "Max iterations (4)", {"max-its", 'i'}, 4)

cxx/riesling/inputs.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct ReconArgs
4747
struct PreconArgs
4848
{
4949
args::ValueFlag<std::string> type;
50-
args::ValueFlag<float> λ;
50+
args::ValueFlag<float> max;
5151

5252
PreconArgs(args::Subparser &parser);
5353
auto Get() -> rl::PreconOpts;

cxx/riesling/util/precon.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ void main_precon(args::Subparser &parser)
1212
GridArgs<3> gridArgs(parser);
1313
args::Positional<std::string> trajFile(parser, "INPUT", "File to read trajectory from");
1414
args::Positional<std::string> preFile(parser, "OUTPUT", "File to save pre-conditioner to");
15-
args::ValueFlag<float> preλ(parser, "BIAS", "Pre-conditioner regularization (1)", {"precon-lambda"}, 1.e-3f);
15+
args::ValueFlag<float> preMax(parser, "M", "Maximum value, threshold above (1)", {"precon-max"}, 1.f);
1616
args::ValueFlag<std::string> sfile(parser, "S", "Load SENSE kernels from file", {"sense"});
1717
args::ValueFlag<std::string> basisFile(parser, "B", "Read basis from file", {"basis", 'b'});
1818
ParseCommand(parser, trajFile);
@@ -25,10 +25,10 @@ void main_precon(args::Subparser &parser)
2525
HD5::Reader senseReader(sfile.Get());
2626
Cx5 const skern = senseReader.readTensor<Cx5>(HD5::Keys::Data);
2727
Cx5 const smaps = SENSE::KernelsToMaps(skern, traj.matrixForFOV(gridArgs.fov.Get()), gridArgs.osamp.Get());
28-
auto const M = KSpaceMulti(smaps, gridArgs.Get(), traj, preλ.Get(), basis.get());
28+
auto const M = KSpaceMulti(smaps, gridArgs.Get(), traj, preMax.Get(), basis.get());
2929
writer.writeTensor(HD5::Keys::Weights, M.dimensions(), M.data(), {"channel", "sample", "trace"});
3030
} else {
31-
auto const M = KSpaceSingle(gridArgs.Get(), traj, preλ.Get(), basis.get());
31+
auto const M = KSpaceSingle(gridArgs.Get(), traj, preMax.Get(), basis.get());
3232
writer.writeTensor(HD5::Keys::Weights, M.dimensions(), M.data(), {"sample", "trace"});
3333
}
3434

cxx/rl/precon-opts.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace rl {
66
struct PreconOpts
77
{
88
std::string type = "single";
9-
float λ = 1.e-3f;
9+
float max = 1.f;
1010
};
1111

1212
} // namespace rl

cxx/rl/precon.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ namespace rl {
1818
* of the cropping during the NUFFT. I also tested simply grid adj * grid, which gave reasonable results but would do a double
1919
* convolution with the gridding kernel.
2020
*/
21-
template <int ND> auto KSpaceSingle(GridOpts<ND> const &gridOpts, TrajectoryN<ND> const &traj, float const λ, Basis::CPtr basis)
22-
-> Re2
21+
template <int ND>
22+
auto KSpaceSingle(GridOpts<ND> const &gridOpts, TrajectoryN<ND> const &traj, float const max, Basis::CPtr basis) -> Re2
2323
{
24-
Log::Print("Precon", "Starting preconditioner calculation λ {}", λ);
24+
Log::Print("Precon", "Starting preconditioner calculation");
2525
TrajectoryN<ND> newTraj(traj.points() * 2.f, MulToEven(traj.matrix(), 2), traj.voxelSize() / 2.f);
2626
auto nufft = TOps::MakeNUFFT<ND>(gridOpts, newTraj, 1, basis);
2727
Cx3 W(nufft->oshape);
@@ -30,7 +30,7 @@ template <int ND> auto KSpaceSingle(GridOpts<ND> const &gridOpts, TrajectoryN<ND
3030
CxN<ND + 2> ones(AddBack(traj.matrix(), psf.dimension(ND), psf.dimension(ND + 1)));
3131
ones.setConstant(1.f);
3232
TOps::Pad<ND + 2> padX(ones.dimensions(), psf.dimensions());
33-
CxN<ND + 2> xcor(padX.oshape);
33+
CxN<ND + 2> xcor(padX.oshape);
3434
xcor.device(Threads::TensorDevice()) = padX.forward(ones);
3535
FFT::Forward(xcor, FirstN<ND>(Sz3{0, 1, 2}));
3636
xcor.device(Threads::TensorDevice()) = xcor * xcor.conjugate();
@@ -40,9 +40,8 @@ template <int ND> auto KSpaceSingle(GridOpts<ND> const &gridOpts, TrajectoryN<ND
4040
float scale =
4141
std::pow(Product(FirstN<ND>(psf.dimensions())), 1.5f) / Product(traj.matrix()) / Product(FirstN<ND>(ones.dimensions()));
4242
Re3 weights = nufft->forward(xcor).abs() * scale;
43-
44-
// weights.device(Threads::TensorDevice()) = (weights == 0.f).select(weights.constant(1.f), (1.f + λ) / (weights + λ));
45-
weights.device(Threads::TensorDevice()) = (weights < 1.f).select(weights.constant(1.f), 1.f / weights);
43+
Log::Print("Precon", "Thresholding to {}", max);
44+
weights.device(Threads::TensorDevice()) = (weights < 1.f / max).select(weights.constant(max), 1.f / weights);
4645

4746
float const norm = Norm<true>(weights);
4847
if (!std::isfinite(norm)) {
@@ -70,9 +69,9 @@ auto KSpaceMulti(Cx5 const &smaps, GridOpts<3> const &gridOpts, Trajectory const
7069
Index const nTrace = traj.nTraces();
7170
Re3 weights(nC, nSamp, nTrace);
7271

73-
auto nufft = TOps::NUFFT<3>(gridOpts, newTraj, 1, basis);
74-
Sz5 const psfShape = nufft.ishape;
75-
Sz5 const smapShape = smaps.dimensions();
72+
auto nufft = TOps::NUFFT<3>(gridOpts, newTraj, 1, basis);
73+
Sz5 const psfShape = nufft.ishape;
74+
Sz5 const smapShape = smaps.dimensions();
7675
Index const nB = smapShape[3];
7776
if (nB > 1 && nB != psfShape[3]) {
7877
throw Log::Failure("Precon", "SENSE maps had basis dimension {}, expected {}", nB, psfShape[3]);
@@ -159,7 +158,7 @@ template <int ND, int NB> auto MakeKSpacePrecon(PreconOpts const &opts,
159158
Log::Print("Precon", "Using no preconditioning");
160159
return nullptr;
161160
} else if (opts.type == "single") {
162-
Re2 const w = KSpaceSingle(gridOpts, traj, opts.λ);
161+
Re2 const w = KSpaceSingle(gridOpts, traj, opts.max);
163162
return std::make_shared<TOps::TensorScale<3 + NB, 1, NB>>(shape, w.cast<Cx>());
164163
} else if (opts.type == "multi") {
165164
throw Log::Failure("Precon", "Multichannel preconditioner requested without SENSE maps");
@@ -193,10 +192,10 @@ template <int ND, int NB> auto MakeKSpacePrecon(PreconOpts const &opts,
193192
Log::Print("Precon", "Using no preconditioning");
194193
return nullptr;
195194
} else if (opts.type == "single") {
196-
Re2 const w = KSpaceSingle(gridOpts, traj, opts.λ);
195+
Re2 const w = KSpaceSingle(gridOpts, traj, opts.max);
197196
return std::make_shared<TOps::TensorScale<3 + NB, 1, NB>>(shape, w.cast<Cx>());
198197
} else if (opts.type == "multi") {
199-
Re3 const w = KSpaceMulti(smaps, gridOpts, traj, opts.λ);
198+
Re3 const w = KSpaceMulti(smaps, gridOpts, traj, opts.max);
200199
return std::make_shared<TOps::TensorScale<3 + NB, 0, NB>>(shape, w.cast<Cx>());
201200
} else {
202201
return LoadKSpacePrecon<NB>(opts.type, traj.nSamples(), traj.nTraces(), shape);

cxx/rl/precon.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@
99
namespace rl {
1010

1111
template <int ND>
12-
auto KSpaceSingle(GridOpts<ND> const &gridOpts, TrajectoryN<ND> const &traj, float const λ, Basis::CPtr basis = nullptr) -> Re2;
13-
auto KSpaceMulti(
14-
Cx5 const &smaps, GridOpts<3> const &gridOpts, Trajectory const &traj, float const λ, Basis::CPtr basis = nullptr) -> Re3;
12+
auto KSpaceSingle(GridOpts<ND> const &gridOpts, TrajectoryN<ND> const &traj, float const max = 1.f, Basis::CPtr basis = nullptr)
13+
-> Re2;
14+
auto KSpaceMulti(Cx5 const &smaps,
15+
GridOpts<3> const &gridOpts,
16+
Trajectory const &traj,
17+
float const max = 1.f,
18+
Basis::CPtr basis = nullptr) -> Re3;
1519

1620
template <int ND, int NB> auto MakeKSpacePrecon(PreconOpts const &opts,
1721
GridOpts<ND> const &gridOpts,

cxx/test/precon.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ TEST_CASE("Preconditioner", "[precon]")
1313
{
1414
Index const M = GENERATE(16);
1515
Sz3 const matrix{M, M, M};
16-
Re3 points(3, 3, 1);
16+
Re3 points(3, 4, 1);
1717
points.setZero();
1818
points(0, 0, 0) = -0.25f * M;
19-
points(0, 2, 0) = 0.25f * M;
19+
points(0, 3, 0) = 0.25f * M;
2020
Trajectory const traj(points, matrix);
21-
auto const sc = KSpaceSingle(GridOpts<3>(), traj, 0.);
21+
auto const sc = KSpaceSingle(GridOpts<3>(), traj);
2222
INFO("Weights\n" << sc);
23-
CHECK(sc(0, 0) == Approx(1.f).margin(1.e-1f));
24-
CHECK(sc(1, 0) == Approx(1.f).margin(1.e-1f));
25-
CHECK(sc(2, 0) == Approx(1.f).margin(1.e-1f));
23+
CHECK(sc(0, 0) == Approx(1.f).margin(1.e-3f));
24+
CHECK(sc(1, 0) == Approx(0.5f).margin(1.e-3f));
25+
CHECK(sc(2, 0) == Approx(0.5f).margin(1.e-3f));
26+
CHECK(sc(3, 0) == Approx(1.f).margin(1.e-3f));
2627
}

0 commit comments

Comments
 (0)