可以将文章内容翻译成中文,广告屏蔽插件可能会导致该功能失效(如失效,请关闭广告屏蔽插件后再试):
问题:
I'm spending some time in learning how to use templates in C++. I never used them
before and I'm not always sure what can be or what cannot be achieved in different situation.
As an exercise I'm wrapping some of the Blas and Lapack functions that I use for my activities,
and I'm currently working on the wrapping of ?GELS
(that evaluates the solution of a linear set of equations).
A x + b = 0
?GELS
function (for real values only) exists with two names: SGELS
, for single precision vectors and
DGELS
for double precision.
My idea of interface is a function solve
in this way:
const std::size_t rows = /* number of rows for A */;
const std::size_t cols = /* number of cols for A */;
std::array< double, rows * cols > A = { /* values */ };
std::array< double, ??? > b = { /* values */ }; // ??? it can be either
// rows or cols. It depends on user
// problem, in general
// max( dim(x), dim(b) ) =
// max( cols, rows )
solve< double, rows, cols >(A, b);
// the solution x is stored in b, thus b
// must be "large" enough to accomodate x
Depending on user requirements, the problem may be overdetermined or undetermined, that means:
- if it is overdetermined
dim(b) > dim(x)
(the solution is a pseudo-inverse)
- if it is undetermined
dim(b) < dim(x)
(the solution is a LSQ minimization)
- or the normal case in which
dim(b) = dim(x)
(the solution is the inverse of A
)
(without considering singular cases).
Since ?GELS
stores the result in the input vector b
, the std::array
shouold
have enough space to accomodate the solution, as described in code comments (max(rows, cols)
).
I want to (compile time) determine wich kind of solution to adopt (it is a paramenter change
in ?GELS
call). I have two functions (I'm simplifying for the sake of the question),
that handle the precision and already know which is the dimension of b
and the number of rows
/cols
:
namespace wrap {
template <std::size_t rows, std::size_t cols, std::size_t dimb>
void solve(std::array<float, rows * cols> & A, std::array<float, dimb> & b) {
SGELS(/* Called in the right way */);
}
template <std::size_t rows, std::size_t cols, std::size_t dimb>
void solve(std::array<double, rows * cols> & A, std::array<double, dimb> & b) {
DGELS(/* Called in the right way */);
}
}; /* namespace wrap */
that are part of an internal wrapper. The user function, detemine the size required
in the b
vector through templates:
#include <type_traits>
/** This struct makes the max between rows and cols */
template < std::size_t rows, std::size_t cols >
struct biggest_dim {
static std::size_t const value = std::conditional< rows >= cols, std::integral_constant< std::size_t, rows >,
std::integral_constant< std::size_t, cols > >::type::value;
};
/** A type for the b array is selected using "biggest_dim" */
template < typename REAL_T, std::size_t rows, std::size_t cols >
using b_array_t = std::array< REAL_T, biggest_dim< rows, cols >::value >;
/** Here we have the function that allows only the call with b of
* the correct size to continue */
template < typename REAL_T, std::size_t rows, std::size_t cols >
void solve(std::array< REAL_T, cols * rows > & A, b_array_t< REAL_T, cols, rows > & b) {
static_assert(std::is_floating_point< REAL_T >::value, "Only float/double accepted");
wrap::solve< rows, cols, biggest_dim< rows, cols >::value >(A, b);
}
In this way it actually works. But I want to go one step further, and I really don't have a clue on how to do it.
If the user tries to call solve
with b
of a size that is too small an extremely difficult-to-read error is raised by the compiler.
I'm trying to insert
a static_assert
that helps the user to understand his error. But any direction that comes in my mind
requires the use of two function with the same signature (it is like a template overloading?) for which
I cannot find a SFINAE strategy (and they actually do not compile at all).
Do you think it is possible to raise a static assertion for the case of wrong b
dimension without changing the user interface at compile time?
I hope the question is clear enough.
@Caninonos: For me the user interface is how the user calls the solver, that is:
solve< type, number of rows, number of cols > (matrix A, vector b)
This is a constraint that I put on my exercise, in order to improve my skills. That means, I don't know if it is actually possible to achieve the solution. The type of b
must match the function call, and it is easy if I add another template parameter and I change the user interface, violating my constraint.
Minimal complete and working example
This is a minimal complete and working example. As requested I removed any reference to linear algebra concepts. It is a problem of number. The cases are:
N1 = 2, N2 =2
. Since N3 = max(N1, N2) = 2
everything works
N1 = 2, N2 =1
. Since N3 = max(N1, N2) = N1 = 2
everything works
N1 = 1, N2 =2
. Since N3 = max(N1, N2) = N2 = 2
everything works
N1 = 1, N2 =2
. Since N3 = N1 = 1 < N2
it correctly raises a compilation error. Iwant to intercept the compilation error with a static assertion that explains the fact that the dimension of N3
is wrong. As for now the error is difficult to read and understand.
You can view and test it online here
回答1:
First some improvements that simplify the design a bit and help redability:
there is no need for biggest_dim
. std::max
is constexpr sice C++14. You should use it instead.
there is no need for b_array_t
. You can just write std::array< REAL_T, std::max(N1, N2)>
And now to your problem. One nice way in C++17 is:
template < typename REAL_T, std::size_t N1, std::size_t N2, std::size_t N3>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) {
if constexpr (N3 == std::max(N1, N2))
wrap::internal< N1, N2, N3 >(A, b);
else
static_assert(N3 == std::max(N1, N2), "invalid 3rd dimmension");
// don't write static_assert(false)
// this would make the program ill-formed (*)
}
Or, as pointed by @max66
template < typename REAL_T, std::size_t N1, std::size_t N2, std::size_t N3>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) {
static_assert(N3 == std::max(N1, N2), "invalid 3rd dimmension");
if constexpr (N3 == std::max(N1, N2))
wrap::internal< N1, N2, N3 >(A, b);
}
Tadaa!! Simple, elegant, nice error message.
The difference between the constexpr if version and just a static_assert
I.e.:
void solve(...)
{
static_assert(...);
wrap::internal(...);
}
is that with just the static_assert
the compiler will try to instantiate wrap::internal
even on static_assert
fail, polluting the error output. With the constexpr if the call to wrap::internal
is not part of the body on condition fail so the error output is clean.
(*) The reason I didn't just write static_asert(false, "error msg)
is because that would make the program ill-formed, no diagnostics required. See constexpr if and static_assert
You can also make the float
/ double
deductible if you want by moving the template argument after the non-deductible ones:
template < std::size_t N1, std::size_t N2, std::size_t N3, typename REAL_T>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) {
So the call becomes:
solve< n1_3, n2_3>(A_3, b_3);
回答2:
Why don't you try to combine tag dispatch together with some static_assert
s? Below is one way of achieving what you want to solve, I hope. I mean, all the three correct cases are properly piped to the correct blas
calls, different types and dimension mismatches are handled, and the violation about float
and double
s is also handled, all in a user-friendly way, thanks to static_assert
.
EDIT. I am not sure about your C++
version requirement, but below is C++11
friendly.
#include <algorithm>
#include <iostream>
#include <type_traits>
template <class value_t, int nrows, int ncols> struct Matrix {};
template <class value_t, int rows> struct Vector {};
template <class value_t> struct blas;
template <> struct blas<float> {
static void overdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
static void underdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
static void normal(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
};
template <> struct blas<double> {
static void overdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
static void underdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
static void normal(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
};
class overdet {};
class underdet {};
class normal {};
template <class T1, class T2, int nrows, int ncols, int dim>
void solve(const Matrix<T1, nrows, ncols> &lhs, Vector<T2, dim> &rhs) {
static_assert(std::is_same<T1, T2>::value,
"lhs and rhs must have the same value types");
static_assert(dim >= nrows && dim >= ncols,
"rhs does not have enough space");
static_assert(std::is_same<T1, float>::value ||
std::is_same<T1, double>::value,
"Only float or double are accepted");
solve_impl(lhs, rhs,
typename std::conditional<(nrows < ncols), underdet,
typename std::conditional<(nrows > ncols), overdet,
normal>::type>::type{});
}
template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
Vector<value_t, dim> &rhs, underdet) {
/* get the pointers and dimension information from lhs and rhs */
blas<value_t>::underdet(
/* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);
}
template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
Vector<value_t, dim> &rhs, overdet) {
/* get the pointers and dimension information from lhs and rhs */
blas<value_t>::overdet(
/* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);
}
template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
Vector<value_t, dim> &rhs, normal) {
/* get the pointers and dimension information from lhs and rhs */
blas<value_t>::normal(
/* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);
}
int main() {
/* valid types */
Matrix<float, 2, 4> A1;
Matrix<float, 4, 4> A2;
Matrix<float, 5, 4> A3;
Vector<float, 4> b1;
Vector<float, 5> b2;
solve(A1, b1);
solve(A2, b1);
solve(A3, b2);
Matrix<int, 4, 4> A4;
Vector<int, 4> b3;
// solve(A4, b3); // static_assert for float & double
Matrix<float, 4, 4> A5;
Vector<int, 4> b4;
// solve(A5, b4); // static_assert for different types
// solve(A3, b1); // static_assert for dimension problem
return 0;
}
回答3:
You have to consider why the interface offers this (convoluted) mess of parameters. The author had several things in mind. First of all, you can solve problems of the form A x + b == 0
and A^T x + b == 0
in one function. Secondly, the given A
and b
can actually point to memory in matrices larger than the ones needed by alg. This can be seen by the LDA
and LDB
parameters.
It is the subaddressing that makes things complicated. If you want a simple but maybe useful enough API, you could chose to ignore that part:
using ::std::size_t;
using ::std::array;
template<typename T, size_t rows, size_t cols>
using matrix = array<T, rows * cols>;
enum class TransposeMode : bool {
None = false, Transposed = true
};
// See https://stackoverflow.com/questions/14637356/
template<typename T> struct always_false_t : std::false_type {};
template<typename T> constexpr bool always_false_v = always_false_t<T>::value;
template < typename T, size_t rowsA, size_t colsA, size_t rowsB, size_t colsB
, TransposeMode mode = TransposeMode::None >
void solve(matrix<T, rowsA, colsA>& A, matrix<T, rowsB, colsB>& B)
{
// Since the algorithm works in place, b needs to be able to store
// both input and output
static_assert(rowsB >= rowsA && rowsB >= colsA, "b is too small");
// LDA = rowsA, LDB = rowsB
if constexpr (::std::is_same_v<T, float>) {
// SGELS(mode == TransposeMode::None ? 'N' : 'T', ....);
} else if constexpr (::std::is_same_v<T, double>) {
// DGELS(mode == TransposeMode::None ? 'N' : 'T', ....);
} else {
static_assert(always_false_v<T>, "Unknown type");
}
}
Now, addressing the subaddressing possible with LDA
and LDB
. I propose that you make that part of your data type, not directly part of the template signature. You want to have your own matrix type that can reference storage in a matrix. Perhaps something like this:
// Since we store elements in a column-major order, we can always
// pretend that our matrix has less columns than it actually has
// less rows than allocated. We can not equally pretend less rows
// otherwise the addressing into the array is off.
// Thus, we'd only four total parameters:
// offset = columnSkipped * actualRows + rowSkipped), actualRows, rows, cols
// We store the offset implicitly by adjusting our begin pointer
template<typename T, size_t rows, size_t cols, size_t actualRows>
class matrix_view { // Name derived from string_view :)
static_assert(actualRows >= rows);
T* start;
matrix_view(T* start) : start(start) {}
template<typename U, size_t r, size_t c, size_t ac>
friend class matrix_view;
public:
template<typename U>
matrix_view(matrix<U, rows, cols>& ref)
: start(ref.data()) { }
template<size_t rowSkipped, size_t colSkipped, size_t newRows, size_t newCols>
auto submat() {
static_assert(colSkipped + newCols <= cols, "can only shrink");
static_assert(rowSkipped + newRows <= rows, "can only shrink");
auto newStart = start + colSkipped * actualRows + rowSkipped;
using newType = matrix_view<T, newRows, newCols, actualRows>
return newType{ newStart };
}
T* data() {
return start;
}
};
Now, you'd need to adapt your interface to this new datatype, that's basically just introducing a few new parameters. The checks stay basically the same.
// Using this instead of just type-defing allows us to use deducation guides
// Replaces: using matrix = std::array from above
template<typename T, size_t rows, size_t cols>
class matrix {
public:
std::array<T, rows * cols> storage;
auto data() { return storage.data(); }
auto data() const { return storage.data(); }
};
extern void dgels(char TRANS
, integer M, integer N , integer NRHS
, double* A, integer LDA
, double* B, integer LDB); // Mock, missing a few parameters at the end
// Replaces the solve method from above
template < typename T, size_t rowsA, size_t colsA, size_t actualRowsA
, size_t rowsB, size_t colsB, size_t actualRowsB
, TransposeMode mode = TransposeMode::None >
void solve(matrix_view<T, rowsA, colsA, actualRowsA> A, matrix_view<T, rowsB, colsB, actualRowsB> B)
{
static_assert(rowsB >= rowsA && rowsB >= colsA, "b is too small");
char transMode = mode == TransposeMode::None ? 'N' : 'T';
// LDA = rowsA, LDB = rowsB
if constexpr (::std::is_same_v<T, float>) {
fgels(transMode, rowsA, colsA, colsB, A.data(), actualRowsA, B.data(), actualRowsB);
} else if constexpr (::std::is_same_v<T, double>) {
dgels(transMode, rowsA, colsA, colsB, A.data(), actualRowsA, B.data(), actualRowsB);
// DGELS(, ....);
} else {
static_assert(always_false_v<T>, "Unknown type");
}
}
Example usage:
int main() {
matrix<float, 5, 5> A;
matrix<float, 4, 1> b;
auto viewA = matrix_view{A}.submat<1, 1, 4, 4>();
auto viewb = matrix_view{b};
solve(viewA, viewb);
// solve(viewA, viewb.submat<1, 0, 2, 1>()); // Error: b is too small
// solve(matrix_view{A}, viewb.submat<0, 0, 5, 1>()); // Error: can only shrink (b is 4x1 and can not be viewed as 5x1)
}