Recursive lambda callbacks without Y Combinator

2019-01-24 16:37发布

问题:

I wish to create a callback that recursively returns itself as a callback.

The suggested method to recurse is for the function to have a reference to itself:

std::function<void (int)> recursive_function = [&] (int recurse) {
    std::cout << recurse << std::endl;

    if (recurse > 0) {
        recursive_function(recurse - 1);
    }
};

This fails as soon as you return it from a function:

#include <functional>
#include <iostream>

volatile bool no_optimize = true;

std::function<void (int)> get_recursive_function() {
    std::function<void (int)> recursive_function = [&] (int recurse) {
        std::cout << recurse << std::endl;

        if (recurse > 0) {
            recursive_function(recurse - 1);
        }
    };

    if (no_optimize) {
        return recursive_function;
    }

    return [] (int) {};
}

int main(int, char **) {
    get_recursive_function()(10);
}

which gives a segmentation fault after outputting 10 because the reference becomes invalid.

How do I do this? I have successfully used what I think is a Y Combinator (which I'll post as an answer), but it is hugely confusing. Is there a better way?


Other attempts

I have tried the boring approach of wrapping it in another layer of callbacks:

#include <functional>
#include <iostream>
#include <memory>

volatile bool no_optimize = true;

std::function<void (int)> get_recursive_function() {
    // Closure to allow self-reference
    auto recursive_function = [] (int recurse) {
        // Actual function that does the work.
        std::function<void (int)> function = [&] (int recurse) {
            std::cout << recurse << std::endl;

            if (recurse > 0) {
                function(recurse - 1);
            }
        };

        function(recurse);
    };

    if (no_optimize) {
        return recursive_function;
    }

    return [] (int) {};
}

int main(int, char **) {
    get_recursive_function()(10);
}

but this fails in the actual scenario, where the function is being delayed and called by an outer loop:

#include <functional>
#include <iostream>
#include <memory>
#include <queue>

volatile bool no_optimize = true;

std::queue<std::function<void (void)>> callbacks;

std::function<void (int)> get_recursive_function() {
    // Closure to allow self-reference
    auto recursive_function = [] (int recurse) {
        // Actual function that does the work.
        std::function<void (int)> function = [&] (int recurse) {
            std::cout << recurse << std::endl;

            if (recurse > 0) {
                callbacks.push(std::bind(function, recurse - 1));
            }
        };

        function(recurse);
    };

    if (no_optimize) {
        return recursive_function;
    }

    return [] (int) {};
}

int main(int, char **) {
    callbacks.push(std::bind(get_recursive_function(), 10));

    while (!callbacks.empty()) {
        callbacks.front()();
        callbacks.pop();
    }
}

which gives 10, then 9 and then segmentation faults.

回答1:

As you rightly pointed out, there is an invalid reference from the lambda capture [&].

Your returns are functors of various sorts, so I assume that the exact type of the return is not important, just that it behaves as a function would i.e. be callable.

If the recursive_function is wrapped in a struct or class you can map the call operator to the recursive_function member. A problem arises with the capture of the this variable. It'll be captured with the this on creation, so if the object is copied around a bit, the original this may no longer be valid. So an appropriate this can be passed in to the function at execution time (this this issue may not be a problem, but it depends heavily on when and how you call the function).

#include <functional>
#include <iostream>

volatile bool no_optimize = true;

struct recursive {
    std::function<void (recursive*, int)> recursive_function = [] (recursive* me, int recurse) {
        std::cout << recurse << std::endl;

        if (recurse > 0) {
            me->recursive_function(me, recurse - 1);
        }
    };

    void operator()(int n)
    {
        if (no_optimize) {
            recursive_function(this, n);
        }
    }
};

recursive get_recursive_function() {
    return recursive();
}

int main(int, char **) {
    get_recursive_function()(10);
}

Alternatively, if the recursive_function can be static then declaring it as such in the original code sample may also do the trick for you.

I wanted to add some generality to the answer above, i.e. make it a template;

#include <functional>
#include <iostream>

volatile bool no_optimize = true;

template <typename Signature>
struct recursive;

template <typename R, typename... Args>
struct recursive<R (Args...)> {
    std::function<R (recursive const&, Args... args)> recursive_function;

    recursive() = default;

    recursive(decltype(recursive_function) const& func) : recursive_function(func)
    {
    }

    template <typename... T>
    R operator()(T&&... args) const
    {
        return recursive_function(*this, std::forward<Args>(args)...);
    }
};

recursive<void (int)> get_recursive_function()
{
    using result_type = recursive<void (int)>;

    if (!no_optimize) {
        return result_type();
    }

    result_type result ([](result_type const& me, int a) {
        std::cout << a << std::endl;

        if (a > 0) {
            me(a - 1);
        }
    });

    return result;
}

int main(int, char **) {
    get_recursive_function()(10);
}

How does this work? Basically it moves the recursion from inside the function (i.e. calling itself) to an object (i.e. the function operator on the object itself) to implement the recursion. In the get_recursive_function the result type recursive<void (int)> is used as the first argument to the recursive function. It is const& because I've implemented the operator() as const in line with most standard algorithms and the default for a lambda function. It does require some "co-operation" from the implementer of the function (i.e. the use of the me parameter; itself being *this) to get the recursion working, but for that price you get a recursive lambda that is not dependent on a stack reference.



回答2:

All problems in programming can be solved with another layer of indirection, except too many layers of indirection.

My goal is to create a type recursive<void(int)> that lets you easily create a recursive lambda. To do this, you pass in a lambda with signature void(recursive<void(int)>, int) -- the first argument is what you invoke in order to do your recursive call.

I then tie it up in knots and make it a fully recursive function with signature void(int).

Here is my implementation of recursive<Signature>:

template<class Sig>
struct recursive;
template<class R, class... As>
struct recursive< R(As...) > {
  using base_type = std::function<R(recursive, As...)>;
private:
  std::shared_ptr< base_type > base;
public:

  template<typename...Ts>
  auto operator()(Ts&&... ts) const
  -> typename std::result_of< base_type( recursive, Ts... ) >::type
  {
    return (*base)(*this, std::forward<Ts>(ts)...);
  }

  recursive(recursive const&)=default;
  recursive(recursive&&)=default;
  recursive& operator=(recursive const&)=default;
  recursive& operator=(recursive &&)=default;
  recursive() = default;
  template<typename L, typename=typename std::result_of< L(recursive, As...) >::type>
  explicit recursive( L&& f ):
    base( std::make_shared<base_type>(std::forward<L>(f)))
  {}
  explicit operator bool() const { return base && *base; }
};

That is, admittedly, pretty complicated. I did a bunch of things to make it more efficient, such a perfect forwarding. It also, unlike std::function, double checks that the type of the lambda you pass to it matches the signature it wants.

I believe, but have not confirmed, that I made it friendly to make the lambdas signature void(auto&&,int). Anyone know of a fully compliant C++1y online compiler?

The above is just boilerplate. What matters is how it looks at point of use:

std::function<void (int)> get_recursive_function() {
  auto f =
    [] (recursive<void(int)> self, int recurse) {
      std::cout << recurse << std::endl;

      if (recurse > 0) {
        self(recurse - 1);
      }
    };
  return recursive< void(int) >( f );
};

Here we do the popular auto f = lambda syntax. No need to directly store it in a std::function.

Then, we explicitly cast it to a recursive<void(int)>, which ties it up in knots and removes the recursive<void(int)> argument to f from the front, and exposes the signature void(int).

This does require that your lambda take recursive<void(int)> self as its first parameter, and do recursion through it, but that doesn't seem to harsh. If I wrote it just right it might work with auto&& self as the first parameter, but I am unsure.

recursive<?> works for any signature, naturally.

live example

And, with delayed calls in outer loop it still works. Note that I got rid of that global variable (it would work with it as a global variable, it just felt dirty to leave it in).

In C++1y, we can eliminate the type erasure & shared_ptr overhead you see above (with the recursive object holding a shared_ptr<function<?>>). You do have to provide the return value, as I cannot get result_of to untangle my mess:

struct wrap {};

template<class R, class F>
struct recursive {
  using base_type = F;
private:
  F base;
public:

  template<class... Ts>
  R operator()(Ts&&... ts) const
  {
    return (*base)(*this, std::forward<Ts>(ts)...);
  }

  recursive(recursive const&)=default;
  recursive(recursive&&)=default;
  recursive& operator=(recursive const&)=default;
  recursive& operator=(recursive &&)=default;
  recursive() = delete;
  template<typename L>
  recursive( wrap, L&& f ):
    base( std::forward<L>(f) )
  {}
};

template<class T>using decay_t = typename std::decay<T>::type;

template<class R, class F>
recursive<R, decay_t<F>> recurse( F&& f ) { return recursive<R, decay_t<F>>(wrap{}, std::forward<F>(f)); }

Then a slightly different implementation of get_recursive_function (where I added some state for fun):

std::function<void (int)> get_recursive_function(int amt) {
  auto f =
    [amt] (auto&& self, int count) {
      std::cout << count << std::endl;

      if (count > 0) {
        self(count - amt);
      }
    };
  return recurse<void>( std::move(f) );
};

int main() {
  auto f = get_recursive_function(2);
  f(10);
}

the use of std::function in the return value of get_recursive_function is optional -- you could use auto in C++1y. There is still some overhead compared to a perfect version (where the lambda could access its own operator()), because the operator() probably doesn't know that it is being recursively invoked on the same object when it invokes self.

It would be tempting to allow operator()( blah ) within the body of a lambda to allow a recursive invocation of the lambda. It would probably break very little code.



回答3:

My current workaround, which is unfortunately complicated, is:

#include <functional>
#include <iostream>
#include <queue>

volatile bool no_optimize = true;

std::queue<std::function<void (void)>> callbacks;

(I think this is a Y Combinator, but I am unsure.)

std::function<void (int)> y_combinator(
        std::function<void (std::function<void (int)>, int)> almost_recursive_function
) {
    auto bound_almost_recursive_function = [almost_recursive_function] (int input) {
        y_combinator(almost_recursive_function)(input);
    };

    return [almost_recursive_function, bound_almost_recursive_function] (int input) {
        almost_recursive_function(bound_almost_recursive_function, input);
    };
}

This is the base function; it does not call itself but an argument it is passed. This argument should be the recursive function itself.

std::function<void (std::function<void (int)>, int)> get_almost_recursive_function() {
    auto almost_recursive_function = (
        [] (std::function<void (int)> bound_self, int recurse) {
            std::cout << recurse << std::endl;

            if (recurse > 0) {
                callbacks.push(std::bind(bound_self, recurse - 1));
            }
        }
    );

    if (no_optimize) {
        return almost_recursive_function;
    }

    return [] (std::function<void (int)>, int) {};
}

So the wanted function can be made by applying the combinator to the almost-recursive function.

std::function<void (int)> get_recursive_function() {
    return y_combinator(get_almost_recursive_function());
}

When main is run, this outputs 10, 9, ..., 0, as wanted.

int main(int, char **) {
    callbacks.push(std::bind(get_recursive_function(), 10));

    while (!callbacks.empty()) {
        callbacks.front()();
        callbacks.pop();
    }
}


回答4:

Since you already deal with std::function that add a bit of overhead all over the place, you can add memory persistence that will only add an indirection at the calling site with a unique_ptr:

std::unique_ptr<std::function<void (int)>> CreateRecursiveFunction() {
    auto result = std::make_unique<std::function<void (int)>>();

    auto ptr = result.get();
    *result = [ptr] (int recurse) { // c++1y can also capture a reference to the std::function directly [&func = *result]
        std::cout << recurse << std::endl;
        if (recurse > 0) {
            (*ptr)(recurse - 1); // with c++1y func( recurse - 1 )
        }
    };
    return result;
}