How could I pass std::function as function pointer

2020-03-30 04:01发布

问题:

I am trying to write a class template and internally it use a C function (implementation of BFGS optimization, provided by the R environment) with the following interface:

void vmmin(int n, double *x, double *Fmin, 
           optimfn fn, optimgr gr, ... ,
           void *ex, ... );

where fn and gr are function pointers of type

typedef double optimfn(int n, double *par, void *ex);

and

typedef void optimgr(int n, double *par, double *gr, void *ex);

respectively. My C++ class template looks like this:

template <typename T>
class optim {
 public:
  // ...
  void minimize(T& func, arma::vec &dpar, void *ex) {
    std::function<optimfn> fn = 
      std::bind(&T::fr, func, std::placeholders::_1, 
                std::placeholders::_2, std::placeholders::_3);
    std::function<optimgr> gr = 
      std::bind(&T::grr, func, std::placeholders::_1,
                std::placeholders::_2, std::placeholders::_3,
                std::placeholders::_4);
    // ERROR: cannot convert std::function to function pointer
    vmmin(... , fn, gr, ...);
    // ...
  }  
};

so that it can be instantiated by any class with two specified member functions, e.g.:

class Rosen {
 public:
  // ...
  double fr(int n, double *par, void *ex);
  void grr(int n, double *par, double *gr, void *ex);
 private:
  // ...
};

// main.cc
Rosen func;
optim<Rosen> obj;
obj.minimize(func, dpar, ex);

Is this possible? Or maybe there is a better way of doing this -- pass the two member functions separately as function pointer? (If the objective function and the corresponding gradient are simple, it is absolutely okay to write two functions. However, most of the time, the problem I got is far more complicated and I have to implement the problem as a class).

回答1:

Let me say up front:

I do not endorse the usage of the following library

#include<tuple>
#include<type_traits>
#include<utility>

// func_traits
template <typename T>
struct func_traits : public func_traits<decltype(&std::remove_reference_t<T>::operator())> {};

template <typename Callable, typename Ret, typename... Args>
struct func_traits<Ret(Callable::*)(Args...) const> {
    using ptr_type = Ret (*) (Args...);
    using return_type =  Ret;

    template<std::size_t i>
    struct arg
    {
        using type = typename std::tuple_element<i, std::tuple<Args...>>::type;
    };

    template<typename Ret2>
    using cast_return_type = Ret2 (*) (Args...);
};

template<typename Ret, typename... Args>
struct func_traits<Ret (&) (Args...)> : public func_traits<Ret (*) (Args...)> {};

template <typename Ret, typename... Args>
struct func_traits<Ret (*) (Args...)>
{
    using ptr_type = Ret (*) (Args...);
    using return_type =  Ret;

    template<std::size_t i>
    struct arg
    {
        using type = typename std::tuple_element<i, std::tuple<Args...>>::type;
    };

    template<typename Ret2>
    using cast_return_type = Ret2 (*) (Args...);
};



// constexpr counter
template <int N>
struct flag
{
    friend constexpr int adl_flag(flag<N>);
    constexpr operator int() { return N; }
};

template <int N>
struct write
{
    friend constexpr int adl_flag(flag<N>) { return N; }
    static constexpr int value = N;
};

template <int N, int = adl_flag(flag<N>{})>
constexpr int read(int, flag<N>, int R = read(0, flag<N + 1>{}))
{
    return R;
}

template <int N>
constexpr int read(float, flag<N>)
{
    return N;
}

template <int N = 0>
constexpr int counter(int R = write<read(0, flag<N>{})>::value)
{
    return R;
}


// fnptr
template<int nonce = counter()>
class fnptr
{
    //these are to make sure fnptr is never constructed
    //technically the first one should be enough, but compilers are not entirely standard conformant
    explicit fnptr() = delete;
    fnptr(const fnptr&) {}
    ~fnptr() = delete;

    template<typename Callable, typename Ret, typename... Args>
    static auto cast(Callable&& c, Ret(*fp)(Args...)) -> decltype(fp)
    {
        using callable_type = std::remove_reference_t<Callable>;
        static callable_type clb{std::forward<Callable>(c)};
        static bool full = false;
        if(full)
        {
            clb.~callable_type();
            new (&clb) decltype(clb){std::forward<Callable>(c)};
        }
        else
            full = true;
        return [](Args... args) noexcept(noexcept(clb(std::forward<Args>(args)...))) -> Ret
        {
            return Ret(clb(std::forward<Args>(args)...));
        };
    }

public:
    template<typename Signature, typename Callable>
    static Signature* cast(Callable&& c)
    {
        return cast(std::forward<Callable>(c), static_cast<Signature*>(nullptr));
    }

    template<typename Signature, typename Ret, typename... Args>
    static auto cast(Ret (*fp)(Args...))
    {
        static decltype(fp) fnptr;
        fnptr = fp;
        using return_type = typename func_traits<Signature*>::return_type;
        return [](Args... args) noexcept(noexcept(fp(std::forward<Args>(args)...)) -> return_type
        {
            return return_type(fnptr(std::forward<Args>(args)...));
        };
    }

    template<typename Callable>
    static auto get(Callable&& c)
    {
        return cast(std::forward<Callable>(c), typename func_traits<Callable>::ptr_type{nullptr});
    }

    template<typename Ret, typename... Args>
    static auto get(Ret (*fp)(Args...))
    {
        return fp;
    }
};

And use it as

#include<functional>
#include<iostream>

using optimfn = double (int, double*, void*);
using optimgr = void (int, double*, double*, void*);

void test(optimfn* fn, optimgr* gr)
{
    double d;
    fn(42, &d, &d);
    gr(42, &d, &d, &d);
}

int main()
{
    std::function<optimfn> fn = [](int, double*, void*){
        std::cout << "I'm fn" << std::endl;
        return 0.;
    };
    std::function<optimgr> gr = [](int, double*, double*, void*){
        std::cout << "I'm gr" << std::endl;
    };

    test(fnptr<>::get(fn), fnptr<>::get(gr));
}

Live example

func_traits

Is just a helper traits type that will fetch the type of any callable in an easily accessible form

constexpr counter

This is half the evilness of what's going on. For details visit is stateful metaprogramming ill formed yet?

fnptr

The actual meat of the code. It takes any callable with appropriate signatures and implicitly declares an anonymous C function at every point it is called and coerces the callable into the C function.

It has the funky syntax fnptr<>::get and fnptr<>::cast<Ret(Args...)>. This is intentional.

get will declare the anonymous C function with the same signature as the callable object.

cast works on any compatible callable type, that is, if the return type and arguments are implicitly convertible, it can be casted.

Caveats

fnptr implicitly declares an anonymous C function at each point in the code it is called. It is not the same as std::function that is actually a variable.

If you call the same fnptr in the code again, all hell breaks lose.

std::vector<int(*)()> v;
for(int i = 0; i < 10; i++)
    v.push_back(fnptr<>::get([i]{return i;}));  // This will implode

You have been warned.



回答2:

Basically, you need a free function that has the correct signature, takes the void * parameter with the "user data" (without which it won't work), somehow extracts a pointer/reference to the std::function out of that, and calls it with the other arguments. Simple example to illustrate what I mean:

void call_it(int value, void * user) {
  std::function<void(int)> * f = static_cast<std::function<void(int)>*>(user);
  (*f)(value);
}
// pass it as callback:
registerCallback(call_it, static_cast<void *>(&my_std_function));

Of course you need to make sure that the pointer remains valid!

With the code below you don't need to write such call_it functions for every possible signature. Above example would read:

registerCallback(trampoline<1, Single::Extract<void,int>, void, int, void *>,
                 Single::wrap(my_std_function));

And your case would be:

// obj and ex passed as parameters
std::function<double(int, double *)> fn =
  [ex, &obj] (int a, double * b) { return obj.fr(a, b, ex); };
std::function<void(int, double *, double *)> gr =
  [ex, &obj] (int a, double * b, double * c) { obj.grr(a, b, c, ex); };
void * fns = Multi<2>::wrap(fn, gr);
vmmin(... ,
      trampoline<2, Multi<2>::Extract<0, double, int, double *>, double, int, double *, void *>,
      trampoline<3, Multi<2>::Extract<1, void, int, double *, double *>, void, int, double *, double *, void *>,
      ..., fns, ...); // fns passed as ex
Multi<2>::free_wrap_result(fns);

My "scratch area" on ideone for forking and testing. Now, Templates to the rescue:

template<
    std::size_t N, ///> index of parameter with the user data
    typename Extractor,
    typename R,
    typename... Args>
R trampoline (Args... args) {
  auto all = std::make_tuple(std::ref(args)...);
  auto arguments = tuple_remove<N>(all);
  return std::apply(Extractor{}.get_function(std::get<N>(all)),
                    arguments);
}

std::apply is a C++17 thing, though you should be able to easily find a C++11 compatible version on this site. The N specifies the (zero based) index of the parameter which contains the "user data" (i.e. the pointer to the actual function). The Extractor is a type that has a static get_function member function, which given a void * returns something "callable" for std::apply to work with. The use case is inspired by your actual issue at hand: If you have only one pointer with "user data" which will be passed to two (or more) different callbacks, then you need a way to "extract" these different functions in the different callbacks.

An "extractor" for a single function:

struct Single {
  template<typename R, typename... Args>
  struct Extract {
    std::function<R(Args...)> & get_function(void * ptr) {
        return *(static_cast<std::function<R(Args...)>*>(ptr));
    }
  };
  template<typename R, typename... Args>
  static void * wrap(std::function<R(Args...)> & fn) {
    return &fn;
  }
};

And one for multiple functions:

template<std::size_t Num>
struct Multi {
  template<std::size_t I, typename R, typename... Args>
  struct Extract {
    std::function<R(Args...)> & get_function(void * ptr) {
      auto arr = static_cast<std::array<void *, Num> *>(ptr);
      return *(static_cast<std::function<R(Args...)>*>((*arr)[I]));
    }
  };
  template<typename... Fns>
  static void * wrap(Fns &... fns) {
    static_assert(sizeof...(fns) == Num, "Don't lie!");
    std::array<void *, Num> arr = { static_cast<void *>(&fns)... };
    return static_cast<void*>(new std::array<void *, Num>(std::move(arr)));
  }
  static void free_wrap_result(void * ptr) {
    delete (static_cast<std::array<void *, Num>*>(ptr));
  }
};

Note that here wrap does an allocation, thus must be met with a corresponding de-allocation in free_wrap_result. This is still very unidiomatic ... should probably be converted to RAII.


tuple_remove still needs to be written:

template<
    std::size_t N,
    typename... Args,
    std::size_t... Is>
auto tuple_remove_impl(
    std::tuple<Args...> const & t,
    std::index_sequence<Is...>) {
  return std::tuple_cat(if_t<N == Is, Ignore, Use<Is>>::from(t)...);
}
template<
    std::size_t N,
    typename... Args>
auto tuple_remove (std::tuple<Args...> const & t) {
  return tuple_remove_impl<N>(t, std::index_sequence_for<Args...>{});
}

if_t (see further down) is just my shorthand for std:: conditional, Use and Ignore need to be implemented:

struct Ignore {
  template<typename Tuple>
  static std::tuple<> from(Tuple) {
    return {};
  }
};
template<std::size_t N>
struct Use {
  template<typename Tuple>
  static auto from(Tuple t) {
    return std:: make_tuple(std::get<N>(t));
  }
};

tuple_remove exploits that std::tuple_cat accepts empty std::tuple<> arguments, and because it cannot get something out of them, basically ignores them.


Shorthand for std::conditional:

template<bool Condition,
         typename Then,
         typename Else>
using if_t = typename std::conditional<
    Condition, Then, Else>::type;


回答3:

An alternative solution could be to have the optim class do its magic with two (possibly pure) virtual functions, and then inherit to define a new class Rosen which implements them. This could look like

class optim {
    public:
        // ...

        virtual double fn(int n, double *par, void *ex) = 0;
        virtual void gr(int n, double *par, double *gr, void *ex) = 0;

        void minimize(arma::vec &dpar, void *ex) {
            vmmin(... , &fn, &gr, ...);
            // ...
        }
};

class Rosen : public optim {
    public:
        // ...
        double fn(int n, double *par, void *ex);
        void gr(int n, double *par, double *gr, void *ex);

    private:
        // ...
};

// main.cc    
Rosen obj;
obj.minimize(dpar, ex);