Suppose I have a base class and two classes derived from it:
class Base
{
protected:
double value;
public:
virtual ~Base();
Base(double value) : value(value) {}
Base(const Base& B) { value=B.value; }
Base operator+ (const Base& B) const {
return Base(value+B.value);
}
};
class final Derived1 : public Base {
public:
Derived1(double value) : Base(value) {}
};
class final Derived2 : public Base {
public:
Derived2(double value) : Base(value) {}
};
I want to accomplish the following:
int main(int argc, char *argv[])
{
Derived1 a = Derived1(4.0);
Derived2 b = Derived2(3.0);
a+a; // this should return a Derived1 object
b+b; // this should return a Derived2 object
a+b; // this should FAIL AT COMPILE TIME
return 0;
}
In other words, I want to guarantee that the inherited operator+
only operates on objects of the same type as the calling instance.
How do I do this cleanly? I found myself re-defining the operator for each class:
class final Derived1 : public Base {
...
Derived1 operator+ (const Derived1& D1) const {
return Derived1(value+D1.value);
}
...
};
class final Derived2 : public Base {
...
Derived2 operator+ (const Derived2& D1) const {
return Derived2(value+D1.value);
}
...
};
But that's just a pain. Moreover, it doesn't seem like proper code re-use to me.
What is the proper technique to use here?
If you can make sure Derived1
and Derived2
are leaf classes (i.e. no other class can derive from them) you can do this with the curiously recurring template pattern:
template <typename T>
class BaseWithAddition : public Base {
T operator+(T const& rhs) const {
return T(value + rhs.value);
}
};
class final Derived1 : public BaseWithAddition<Derived1> {
// blah blah
};
class final Derived2 : public BaseWithAddition<Derived2> {
// blah blah
};
(final
is a C++11 feature that prevents further derivation.)
If you allow derivation from Derived1
and Derived2
then you get trouble:
class Derived3 : public Derived1 {};
Derived3 d3;
Derived1 d1;
Derived1& d3_disguised = d3;
d1 + d3_disguised; // oooops, this is allowed
There's no way to prevent this at compile-time. And even if you want to allow it, it's not easy to get decent semantics for this operation without multiple dispatch.
You can use specialized template function to add values. Unfortunately this trick does not work with operators:
It fails if types are not the same, and returns proper type:
#include <type_traits>
class Base;
template <class Derived>
Derived add(const Derived& l, const Derived& r,
typename std::enable_if<std::is_base_of<Base,Derived>::value>::type* = NULL);
class Base
{
...
template <class Derived>
friend Derived add(const Derived& l, const Derived& r,
typename std::enable_if<std::is_base_of<Base,Derived>::value>::type* = NULL);
};
template <class Derived>
Derived add(const Derived& l, const Derived& r,
typename std::enable_if<std::is_base_of<Base,Derived>::value>::type* = NULL)
{
return l.value + r.value;
}
And the proof it works:
int main() {
int a = 0;
a = a + a;
Derived1 d11(0), d12(0);
Derived2 d21(0), d22(0);
add(d11, d12);
add(d21, d22);
add(d12, d22); // here it fails to compile...
}
As long as value
is defined only in the base class, and the operation doesn't need to access any derived members, you might be able to get away with only defining the base operator and letting implicit type casting handle the rest. As for errors with different types, it might be worth a small sacrifice to use an enum-based system to track the types, and then do a simple comparison to check for invalid conditions.
enum eTypeEnum {BASE, DER1, DER2};
class Base {
public:
virtual ~Base(){}
Base(double value) : eType(BASE),value(value) {}
Base(const Base& B) { value=B.value; }
Base operator+ (const Base& B) const {
if (eType != B.eType) return -1; //error condition
return Base(value+B.value);
}
double getVal(){return value;}
protected:
eTypeEnum eType;
double value;
};
class Derived1 : public Base {
public:
Derived1(double value) : Base(value) {eType = DER1;}
};
class Derived2 : public Base {
public:
Derived2(double value) : Base(value) {eType = DER2;}
};
int main() {
int tmp;
Derived1 a(4.0);
Derived2 b(3.0);
Base c(2.0);
cout << "aa:" << (a+a).getVal(); // 8
cout << "\nbb:" << (b+b).getVal(); // 6
cout << "\nba:" << (b+a).getVal(); // 7
cout << "\nab:"<< (a+b).getVal(); // 7
cout << "\ncc:"<< (c+c).getVal(); // 4
cout << "\nac:"<< (a+c).getVal(); // 6
cout << "\nbc:" << (b+c).getVal(); // 5
cout << "\nabc:" << (a+b+c).getVal();// 9
cout << endl;
cin >> tmp;
return 0;
}
Outputs:
aa:8
bb:6
ba:-1
ab:-1
cc:4
ac:-1
bc:-1
abc:1
The only issue I see is that when chaining multiple operations together, the casting screws up the handling. Here, a+b+c
432 evaluates as (a+b)+c
so the a+b
bit experiences the error condition (returning -1), but gets cast as a Base
which lets (-1)+c
return '1'.