How to enforce template parameter class to derive

2019-03-03 13:37发布

问题:

I have a couple of template classes

template < class Cost >
class Transition {
  public:
    virtual Cost getCost() = 0;
};

template < class TransitionCl, class Cost >
class State {
    protected:
        State(){
            static_assert(
                std::is_base_of< Transition< Cost >, TransitionCl >::value,
                "TransitionCl class in State must be derived from Transition< Cost >"
            );
        }
    public:
        virtual void apply( const TransitionCl& ) = 0;
};

and I would rather not have to pass Cost into State, because State is completely independent on Cost, but I want to ensure that TransitionCl implements the Transition interface.

Is there a way to make Cost anonymous in the second template so as not to have to pass it when declaring a new State class?

For reference, I'm using g++ -std=c++14 [source file]

EDIT: I posted a rephrased (and hopefully clearer) version of the question and received the best answer here

回答1:

From what I understood from your question, here is a quick solution that I have:

template < class Cost >
class Transition {
  public:
    virtual Cost getCost() = 0;
};

template <typename T>
class Der : public Transition<T>
{
public:
  T getCost() override {
  }
};

template < class Transition >
class State;

template <template <typename> class TransitionCl, typename Cost>
class State <TransitionCl<Cost>> {
public:
        State(){
            static_assert(
                std::is_base_of< Transition< Cost >, TransitionCl<Cost> >::value,
                "TransitionCl class in State must be derived from Transition< Cost >"
            );
        }
};

int main()
{
  Der<int> d;
  State<decltype(d)> s;

  return 0;
}

In the above example, you dont have to pass the 'Cost' type while creating State object.

===== UPDATE ======

template <typename Cost>
class Transition
{
public:
    virtual Cost getCost() = 0;
    virtual ~Transition() {}
};

class TD: public Transition<int>
{
public:
    int getCost() override {
        std::cout << "getCost override" << std::endl;
        return 42;
    }
};

namespace detail {
    template <typename T>
    struct is_base_of_cust {
        // This is a bit hacky as it is based upon the internal functions
        // (though public) of the Transition class.
        using CostType = decltype(std::declval<T>().getCost());
        static const bool value = std::is_base_of<Transition<CostType>, T>::value;
    };
};

template <class TransitionCl>
class State
{
protected:
    State() {
        static_assert(
            detail::is_base_of_cust<TransitionCl>::value,
            "TransitionCl class in State must be derived from Transition<Cost>"
        );
    }
public:
    virtual void apply(const TransitionCl&) = 0;
    virtual ~State() {}
};


class StateImpl: public State<TD>
{
public:
    void apply(const TD&) override {
        std::cout << "StateImpl::apply" << std::endl;
    }
};


int main() {
    StateImpl impl;
    return 0;
}


回答2:

One way is to use the return type of getCost() (but it may give you an uglier error messgae if TransactionCl() doesn't have such a public member function).

std::is_base_of< Transition< decltype(TransitionCl().getCost()) >, TransitionCl >::value,

Another option is adding a typedef to the base class:

template < class Cost >
class Transition {
  public:
    typedef Cost Cost_Type;    // <-------
    virtual Cost getCost() = 0;
};

Then you can remove State's typename Cost parameter and use the typedef instead in your static assert...

std::is_base_of< Transition< typename TransitionCl::Cost_Type >, TransitionCl >::value,