C++: select argmax over vector of classes w.r.t. a

2019-07-08 09:07发布

问题:

I have trouble describing my problem so I'll give an example:

I have a class description that has a couple of variables in it, for example:

class A{
  float a, b, c, d;
}

Now, I maintain a vector<A> that contains many of these classes. What I need to do very very often is to find the object inside this vector that satisfies that one of it's parameters is maximal w.r.t to the others. i.e code looks something like:

int maxi=-1;
float maxa=-1000;
for(int i=0;i<vec.size();i++){
  res= vec[i].a;
  if(res > maxa) {
    maxa= res;
    maxi=i;
  }
}
return vec[maxi];

However, sometimes I need to find class with maximal a, sometimes with maximal b, sometimes the class with maximal 0.8*a + 0.2*b, sometimes I want a maximal a*VAR + b, where VAR is some variable that is assigned in front, etc. In other words, I need to evaluate an expression for every class, and take the max. I find myself copy-pasting this everywhere, and only changing the single line that defines res.

Is there some nice way to avoid this insanity in C++? What's the neatest way to handle this?

Thank you!

回答1:

template <typename F>
struct CompareBy
{
    bool operator()(const typename F::argument_type& x,
                    const typename F::argument_type& y)
    { return f(x) < f(y); }

    CompareBy(const F& f) : f(f) {}

 private:
    F f;
};


template <typename T, typename U>
struct Member : std::unary_function<U, T>
{
    Member(T U::*ptr) : ptr(ptr) {}
    const T& operator()(const U& x) { return x.*ptr; }

private:
    T U::*ptr;
};

template <typename F>
CompareBy<F> by(const F& f) { return CompareBy<F>(f); }

template <typename T, typename U>
Member<T, U> mem_ptr(T U::*ptr) { return Member<T, U>(ptr); }

You need to include <functional> for this to work. Now use, from header <algorithm>

std::max_element(v.begin(), v.end(), by(mem_ptr(&A::a)));

or

double combination(A x) { return 0.2 * x.a + 0.8 * x.b; }

and

std::max_element(v.begin(), v.end(), by(std::fun_ptr(combination)));

or even

struct combination : std::unary_function<A, double>
{
    combination(double x, double y) : x(x), y(y) {}
    double operator()(const A& u) { return x * u.a + y * u.b; }

private:
    double x, y;
};

with

std::max_element(v.begin(), v.end(), by(combination(0.2, 0.8)));

to compare by a member or by linear combinations of a and b members. I split the comparer in two because the mem_ptr thing is damn useful and worth being reused. The return value of std::max_element is an iterator to the maximum value. You can dereference it to get the max element, or you can use std::distance(v.begin(), i) to find the corresponding index (include <iterator> first).

See http://codepad.org/XQTx0vql for the complete code.



回答2:

I know this thread is old, but i find it quite useful to implement a powerful argmax function in C++.

However, as far as i can see, all the given examples above rely on std::max_element, which does comparison between the elements (either using a functor or by calling the operator<). this can be slow, if the calculation for each element is expensive. It works well for sorting numbers and handling simple classes, but what if the functor is much more complex? Maybe calculating a heuristic value of a chess position or something else that generate a huge tree etc.

A real argmax, as the thread starter mentioned, would only calculate its arg once, then save it to be compared with the others.

EDIT: Ok i got annoyed and had too much free time, so i created one < C++11 and one C++11 version with r-value references, first the C++11 version:

#include <iostream>
#include <algorithm>
#include <iterator>
#include <vector>

template<typename IteratorT, typename HeuristicFunctorT>
IteratorT argmax(IteratorT && it, const IteratorT & end, const HeuristicFunctorT & functor) {
    IteratorT best(it++);
    typename HeuristicFunctorT::result_type best_value(functor(*best));

    for(; it != end; ++it) {
        typename HeuristicFunctorT::result_type value(functor(*it));

        if (value > best_value) {
            best_value = value;
            best = it;
        }
    }

    return best;
}

template<typename IteratorT, typename HeuristicFunctorT>
inline IteratorT argmax(const IteratorT & begin, const IteratorT & end, const HeuristicFunctorT & functor) {
    return argmax(IteratorT(begin), end, functor);
}

class IntPairFunctor : public std::unary_function< std::pair<int, int>, int > {
public:
    int operator() (const std::pair<int, int> & v) const {
        return v.first + v.second;
    }
};

std::pair<int, int> rand_pair() {
    return std::make_pair(rand(), rand());
}

int main(int argc, const char **argv) {
    srand(time(NULL));

    std::vector< std::pair<int, int> > ints;

    std::generate_n(std::back_insert_iterator< std::vector< std::pair<int, int> > >(ints), 1000, rand_pair);

    std::vector< std::pair<int, int> >::iterator m (argmax(ints.begin(), ints.end(), IntPairFunctor()));

    std::cout << std::endl << "argmax: " << *m << std::endl;
}

The non C++11 version is much simpler, only the template:

template<typename IteratorT, typename HeuristicFunctorT>
IteratorT argmax(IteratorT it, const IteratorT & end, const HeuristicFunctorT & functor) {
IteratorT best(it++);
typename HeuristicFunctorT::result_type best_value(functor(*best));

for(; it != end; ++it) {
    typename HeuristicFunctorT::result_type value(functor(*it));

    if (value > best_value) {
        best_value = value;
        best = it;
    }
}

return best;
}

Note that neither version requires any template arguments, the only requirement is that the heuristic implements the unary_function class



回答3:

This is what functors and STL are made for:

// A class whose objects perform custom comparisons
class my_comparator
{
public:
    explicit my_comparator(float c1, float c2) : c1(c1), c2(c2) {}
    // std::max_element calls this on pairs of elements
    bool operator() (const A &x, const A &y) const
    {
        return (x.a*c1 + x.b*c2) < (y.a*c1 + y.b*c2);
    }
private:
    const float c1, c2;
};


// Returns the "max" element in vec
*std::max_element(vec.begin(), vec.end(), my_comparator(0.8,0.2));


回答4:

Is the expression always linear? You could pass in an array of four coefficients. If you need to support arbitrary expressions, you'll need a functor, but if it's just an affine combination of the four fields then there's no need for all that complexity.



回答5:

You can use the std::max_element algorithm with a custom comparator.

It's easy to write the comparator if your compiler supports lambda expressions.

If it doesn't, you can write a custom comparator functor. For the simple case of just comparing a single member, you can write a generic "member comparator" function object, which would look something like this:

template <typename MemberPointer>
struct member_comparator
{
    MemberPointer p_;

    member_comparator(MemberPointer p) : p_(p) { }

    template <typename T>
    bool operator()(const T& lhs, const T& rhs) const
    {
        return lhs.*p_ < rhs.*p_;
    }
};

template <typename MemberPointer>
member_comparator<MemberPointer> make_member_comparator(MemberPointer p)
{
    return member_comparator<MemberPointer>(p);
}

used as:

// returns an iterator to the element that has the maximum 'd' member:
std::max_element(v.begin(), v.end(), make_member_comparator(&A::d));


回答6:

You could use the std::max_element STL algorithm providing a custom comparison predicate each time.

With C++0x you can even use a lambda function for it for maximum conciseness:

auto maxElement=*std::max_element(vector.begin(), vector.end(), [](const A& Left, const A& Right) {
    return (0.8*Left.a + 0.2*Left.b)<(0.8*Right.a + 0.2*Right.b);
});


回答7:

Sample of using max_element/min_element with custom functor

#include <algorithm>
#include <iostream>
#include <vector>

using namespace std;

struct A{
  float a, b, c, d;
};


struct CompareA {
  bool operator()(A const & Left, A const & Right) const {
    return Left.a < Right.a;
  }
};


int main() {
  vector<A> vec;
  vec.resize(3);
  vec[0].a = 1;
  vec[1].a = 2;
  vec[2].a = 1.5;

  vector<A>::iterator it = std::max_element(vec.begin(), vec.end(), CompareA());
  cout << "Largest A: " << it->a << endl;
  it = std::min_element(vec.begin(), vec.end(), CompareA());
  cout << "Smallest A: " << it->a << endl;
}


标签: c++ class vector