Skip to content

Commit 79ca54e

Browse files
authored
Merge pull request #136 from WireCell/idft-use
Replace hard-wired DFT functions with IDFT
2 parents 97428a4 + 847fb83 commit 79ca54e

177 files changed

Lines changed: 13731 additions & 2874 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

aux/inc/WireCellAux/DftTools.h

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/**
2+
This provides std::vector and Eigen::Array typed interface to an
3+
IDFT.
4+
*/
5+
6+
#ifndef WIRECELL_AUX_DFTTOOLS
7+
#define WIRECELL_AUX_DFTTOOLS
8+
9+
#include "WireCellIface/IDFT.h"
10+
#include <vector>
11+
#include <Eigen/Core>
12+
13+
namespace WireCell::Aux {
14+
15+
using complex_t = IDFT::complex_t;
16+
17+
// std::vector based functions
18+
19+
using real_vector_t = std::vector<float>;
20+
using complex_vector_t = std::vector<complex_t>;
21+
22+
// 1D with vectors
23+
24+
// Perform forward c2c transform on vector.
25+
inline complex_vector_t fwd(const IDFT::pointer& dft, const complex_vector_t& seq)
26+
{
27+
complex_vector_t ret(seq.size());
28+
dft->fwd1d(seq.data(), ret.data(), ret.size());
29+
return ret;
30+
}
31+
32+
// Perform forward r2c transform on vector.
33+
inline complex_vector_t fwd_r2c(const IDFT::pointer& dft, const real_vector_t& vec)
34+
{
35+
complex_vector_t cvec(vec.size());
36+
std::transform(vec.begin(), vec.end(), cvec.begin(),
37+
[](float re) { return Aux::complex_t(re,0.0); } );
38+
return fwd(dft, cvec);
39+
}
40+
41+
// Perform inverse c2c transform on vector.
42+
inline complex_vector_t inv(const IDFT::pointer& dft, const complex_vector_t& spec)
43+
{
44+
complex_vector_t ret(spec.size());
45+
dft->inv1d(spec.data(), ret.data(), ret.size());
46+
return ret;
47+
}
48+
49+
// Perform inverse c2r transform on vector.
50+
inline real_vector_t inv_c2r(const IDFT::pointer& dft, const complex_vector_t& spec)
51+
{
52+
auto cvec = inv(dft, spec);
53+
real_vector_t rvec(cvec.size());
54+
std::transform(cvec.begin(), cvec.end(), rvec.begin(),
55+
[](const Aux::complex_t& c) { return std::real(c); });
56+
return rvec;
57+
}
58+
59+
// 1D high-level interface
60+
61+
/// Convovle in1 and in2. Returned vecgtor has size sum of sizes
62+
/// of in1 and in2 less one element in order to assure no periodic
63+
/// aliasing. Caller need not (should not) pad either input.
64+
/// Caller is free to truncate result as required.
65+
real_vector_t convolve(const IDFT::pointer& dft,
66+
const real_vector_t& in1,
67+
const real_vector_t& in2);
68+
69+
70+
/// Replace response res1 in meas with response res2.
71+
///
72+
/// This will compute the FFT of all three, in frequency space will form:
73+
///
74+
/// meas * resp2 / resp1
75+
///
76+
/// apply the inverse FFT and return its real part.
77+
///
78+
/// The output vector is long enough to assure no periodic
79+
/// aliasing. In general, caller should NOT pre-pad any input.
80+
/// Any subsequent truncation of result is up to caller.
81+
real_vector_t replace(const IDFT::pointer& dft,
82+
const real_vector_t& meas,
83+
const real_vector_t& res1,
84+
const real_vector_t& res2);
85+
86+
87+
// Eigen array based functions
88+
89+
/// 2D array types. Note, use Array::cast<complex_t>() if you
90+
/// need to convert rom real or arr.real() to convert to real.
91+
using real_array_t = Eigen::ArrayXXf;
92+
using complex_array_t = Eigen::ArrayXXcf;
93+
94+
// 2D with Eigen arrays. Use eg arr.cast<complex_>() to provde
95+
// from real or arr.real()() to convert result to real.
96+
97+
// Transform both dimesions.
98+
complex_array_t fwd(const IDFT::pointer& dft, const complex_array_t& arr);
99+
complex_array_t inv(const IDFT::pointer& dft, const complex_array_t& arr);
100+
101+
// As above but internally convert input or output. These are
102+
// just syntactic sugar hiding a .cast<complex_t>() or a .real()
103+
// call.
104+
complex_array_t fwd_r2c(const IDFT::pointer& dft, const real_array_t& arr);
105+
real_array_t inv_c2r(const IDFT::pointer& dft, const complex_array_t& arr);
106+
107+
// Transform a 2D array along one axis.
108+
//
109+
// The axis identifies the logical array "dimension" over which
110+
// the transform is applied. For example, axis=1 means the
111+
// transforms are applied along columns (ie, on a per-row basis).
112+
// Note: this is the same convention as held by numpy.fft.
113+
//
114+
// The axis is interpreted in the "logical" sense Eigen arrays
115+
// indexed as array(irow, icol). Ie, the dimension traversing
116+
// rows is axis 0 and the dimension traversing columns is axis 1.
117+
// Note: internal storage order of an Eigen array may differ from
118+
// the logical order and indeed that of the array template type
119+
// order. Neither is pertinent in setting the axis.
120+
complex_array_t fwd(const IDFT::pointer& dft, const complex_array_t& arr, int axis);
121+
complex_array_t inv(const IDFT::pointer& dft, const complex_array_t& arr, int axis);
122+
123+
124+
// Fixme: possible additions
125+
// - superposition of 2 reals for 2x speedup
126+
// - r2c / c2r for 1b
127+
128+
}
129+
130+
#endif

aux/inc/WireCellAux/FftwDFT.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#ifndef WIRECELLAUX_FFTWDFT
2+
#define WIRECELLAUX_FFTWDFT
3+
4+
#include "WireCellIface/IDFT.h"
5+
6+
namespace WireCell::Aux {
7+
8+
/**
9+
The FftwDFT component provides IDFT based on FFTW3.
10+
11+
All instances share a common thread-safe plan cache. There is
12+
no benefit to using more than one instance in a process.
13+
14+
See IDFT.h for important comments.
15+
*/
16+
class FftwDFT : public IDFT {
17+
public:
18+
19+
FftwDFT();
20+
virtual ~FftwDFT();
21+
22+
// 1d
23+
24+
virtual
25+
void fwd1d(const complex_t* in, complex_t* out,
26+
int size) const;
27+
28+
virtual
29+
void inv1d(const complex_t* in, complex_t* out,
30+
int size) const;
31+
32+
virtual
33+
void fwd1b(const complex_t* in, complex_t* out,
34+
int nrows, int ncols, int axis) const;
35+
36+
virtual
37+
void inv1b(const complex_t* in, complex_t* out,
38+
int nrows, int ncols, int axis) const;
39+
40+
virtual
41+
void fwd2d(const complex_t* in, complex_t* out,
42+
int nrows, int ncols) const;
43+
virtual
44+
void inv2d(const complex_t* in, complex_t* out,
45+
int nrows, int ncols) const;
46+
47+
virtual
48+
void transpose(const scalar_t* in, scalar_t* out,
49+
int nrows, int ncols) const;
50+
virtual
51+
void transpose(const complex_t* in, complex_t* out,
52+
int nrows, int ncols) const;
53+
54+
};
55+
}
56+
57+
#endif

aux/inc/WireCellAux/Semaphore.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/** Implement a semaphore component interace. */
2+
3+
#ifndef WIRECELLAUX_SEMAPHORE
4+
#define WIRECELLAUX_SEMAPHORE
5+
6+
#include "WireCellIface/IConfigurable.h"
7+
#include "WireCellIface/ISemaphore.h"
8+
#include "WireCellUtil/Semaphore.h"
9+
10+
11+
namespace WireCell::Aux {
12+
class Semaphore : public ISemaphore,
13+
public IConfigurable
14+
{
15+
public:
16+
Semaphore();
17+
virtual ~Semaphore();
18+
19+
// IConfigurable interface
20+
virtual void configure(const WireCell::Configuration& config);
21+
virtual WireCell::Configuration default_configuration() const;
22+
23+
// ISemaphore
24+
virtual void acquire() const;
25+
virtual void release() const;
26+
27+
private:
28+
29+
mutable FastSemaphore m_sem;
30+
31+
};
32+
} // namespace WireCell::Pytorch
33+
34+
#endif // WIRECELLPYTORCH_TORCHSERVICE

aux/inc/WireCellAux/SimpleTensor.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
#define WIRECELL_AUX_SIMPLETENSOR
33

44
#include "WireCellIface/ITensor.h"
5+
56
#include <boost/multi_array.hpp>
7+
#include <cstring>
68

79
namespace WireCell {
810

@@ -13,14 +15,26 @@ namespace WireCell {
1315
public:
1416
typedef ElementType element_t;
1517

16-
SimpleTensor(const shape_t& shape)
18+
// Create simple tensor, allocating space for data. If
19+
// data given it must have at least as many elements as
20+
// implied by shape and that span will be copied into
21+
// allocated memory.
22+
SimpleTensor(const shape_t& shape,
23+
const element_t* data=nullptr,
24+
const Configuration& md = Configuration())
1725
{
1826
size_t nbytes = element_size();
19-
for (const auto& s : shape) {
27+
m_shape = shape;
28+
for (const auto& s : m_shape) {
2029
nbytes *= s;
2130
}
22-
m_store.resize(nbytes);
23-
m_shape = shape;
31+
if (data) {
32+
const std::byte* bytes = reinterpret_cast<const std::byte*>(data);
33+
m_store.assign(bytes, bytes+nbytes);
34+
}
35+
else {
36+
m_store.resize(nbytes);
37+
}
2438
}
2539
virtual ~SimpleTensor() {}
2640

aux/inc/WireCellAux/TensorTools.h

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#ifndef WIRECELL_AUX_TENSORTOOLS
2+
#define WIRECELL_AUX_TENSORTOOLS
3+
4+
#include "WireCellIface/ITensor.h"
5+
#include "WireCellIface/IDFT.h"
6+
#include "WireCellUtil/Exceptions.h"
7+
8+
#include <Eigen/Core>
9+
#include <complex>
10+
11+
namespace WireCell::Aux {
12+
13+
bool is_row_major(const ITensor::pointer& ten) {
14+
if (ten->order().empty() or ten->order()[0] == 1) {
15+
return true;
16+
}
17+
return false;
18+
}
19+
20+
template<typename scalar_t>
21+
bool is_type(const ITensor::pointer& ten) {
22+
return (ten->element_type() == typeid(scalar_t));
23+
}
24+
25+
26+
// Extract the underlying data array from the tensor as a vector.
27+
// Caution: this ignores storage order hints and 1D or 2D will be
28+
// flattened assuming C-ordering, aka row-major (if 2D). It
29+
// throws ValueError on type mismatch.
30+
template<typename element_type>
31+
std::vector<element_type> asvec(const ITensor::pointer& ten)
32+
{
33+
if (ten->element_type() != typeid(element_type)) {
34+
THROW(ValueError() << errmsg{"element type mismatch"});
35+
}
36+
const element_type* data = (const element_type*)ten->data();
37+
const size_t nelems = ten->size()/sizeof(element_type);
38+
return std::vector<element_type>(data, data+nelems);
39+
}
40+
41+
// Extract the tensor data as an Eigen array.
42+
template<typename element_type>
43+
Eigen::Array<element_type, Eigen::Dynamic, Eigen::Dynamic> // this default is column-wise
44+
asarray(const ITensor::pointer& tens)
45+
{
46+
if (tens->element_type() != typeid(element_type)) {
47+
THROW(ValueError() << errmsg{"element type mismatch"});
48+
}
49+
using ROWM = Eigen::Array<element_type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
50+
using COLM = Eigen::Array<element_type, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>;
51+
52+
auto shape = tens->shape();
53+
int nrows, ncols;
54+
if (shape.size() == 1) {
55+
nrows = 1;
56+
ncols = shape[0];
57+
}
58+
else {
59+
nrows = shape[0];
60+
ncols = shape[1];
61+
}
62+
63+
// Eigen::Map is a non-const view of data but a copy happens
64+
// on return. We need to temporarily break const correctness.
65+
const element_type* cdata = reinterpret_cast<const element_type*>(tens->data());
66+
element_type* mdata = const_cast<element_type*>(cdata);
67+
68+
if (is_row_major(tens)) {
69+
return Eigen::Map<ROWM>(mdata, nrows, ncols);
70+
}
71+
// column-major
72+
return Eigen::Map<COLM>(mdata, nrows, ncols);
73+
}
74+
75+
}
76+
77+
#endif

0 commit comments

Comments
 (0)