C++ Dynamically detect class of parameter and cast

2019-07-24 06:36发布

问题:

I have two classes, one which inherits from the other. The relevant part of the base class is as follows (obviously this class has ctors, a dtor, etc., and particularly an operator[], but I thought those irrelevant to the matter at hand):

#include <array>

template < class T, unsigned int N >
class Vector
{
public:
    template < class U, unsigned int M > friend Vector< U, M > operator+ ( const Vector< U, M >&, const Vector< U, M >& );

    template < class U, unsigned int M > friend std::ostream& operator<< ( std::ostream&, const Vector< U, M >& );
};

The derived class (again, obviously I've taken out those parts which I thought irrelevant):

#include "Vector.h"

template < class T, unsigned int N >
class Polynomial
    : public Vector< T, N >
{
public:
    template < class U, unsigned int M > friend std::ostream& operator<< ( std::ostream&, const Polynomial< U, M >& );
};

(Note: The friend functions use different letters for the templates than the classes do, because otherwise gcc complains about "shadowing". The logic is the same, though.)

Vectors print out one way (e.g. < 3, 5, 1 >); Polynomials print out another (e.g. 3 x^2 + 5 x + 1).

This causes a problem, though. When I go to add two Polynomials together, the compiler uses template < class U, unsigned int M > Vector< U, M > operator+ ( const Vector< U, M >&, const Vector< U, M >& ), which of course returns a Vector. Therefore, if I try to do something like std::cout << poly1 + poly2;, the resultant display is in the wrong format.

I would like to modify template < class U, unsigned int M > Vector< U, M > operator+ ( const Vector< U, M >&, const Vector< U, M >& ) such that it will detect the actual data types of its parameters, and cast the return value accordingly (e.g. return a Polynomial if two Polynomials are passed to it). I would like to do this, if possible, without operator+ knowing about each and every possible subclass of Vector (I think this is probably a legitimate desire?), and without making a new operator+ function for each subclass (since I also have several other overloaded operators, and would like to avoid copying almost exactly the same code ten times for each derived class).

I know that this is possible (and, in fact, relatively easy) in Python. Does C++ support such a thing?

回答1:

If you calculate the result as Vector<T,N>, you can not simply (legally) cast it to Polynomial<T,N>. To achieve the desired effect, you need some deeper modifications. You need a free operator+, an implementation that can deliver the desired result type and a way to detect everything derived from Vector<T,N>. Let's build it.

a) Detect all Vector<T,N>

For that, you could derive from an empty base class that will be optimized away by the empty base optimization (EBO) and that is detectable by std::enable_if:

struct VectorBase {};

template< class T, unsigned int N >
class Vector
{
  // ...
};

now you can check any class U if it's derived from Vector< T, N > with std::is_base_of< VectorBase, U >::value. To be absolutly correct, you need to exclude VectorBase itself (!std::is_same< U, VectorBase >::value), but that is probably not needed for your use case.

b) An implementation, that deliveres the desired return type. Before we do that:

template< class T, unsigned int N >
class Vector
{
    template < class U, unsigned int M >
    friend Vector< U, M > operator+ ( const Vector< U, M >&, const Vector< U, M >& );
};

should be replaced by:

template< class T, unsigned int N >
class Vector
{
    friend Vector< T, N > operator+ ( const Vector< T, N >&, const Vector< T, N >& );
};

for the general case. But you need a special return type which can later become Polynomial<T,N>, so:

template< class T, unsigned int N >
class Vector
{
public:
    template< typename R >
    static R add( const Vector< T, N >& lhs, const Vector< T, N >& rhs )
    {
        static_assert( std::is_base_of<VectorBase,R>::value,
                       "R needs to be derived from Vector<T,N>" );
        R result;
        // implement it here...
        return result;
    }
};

c) Provide an operator+ that calls add and that is protected by SFINAE:

// as a free function:
template< typename V >
typename std::enable_if< std::is_base_of< VectorBase, V >::value, V >::type
operator+( const V& lhs, const V& rhs )
{
  return V::template add<V>( lhs, rhs );
}

Minus some small typos (I haven't tested it), this strategy should work for you.