How does a compiler optimise this factorial functi

2019-03-09 03:43发布

问题:

So I have been having a look at some of the magic that is O3 in GCC (well actually I'm compiling using Clang but it's the same with GCC and I'm guessing a large part of the optimiser was pulled over from GCC to Clang).

Consider this C program:

int foo(int n) {
    if (n == 0) return 1;
    return n * foo(n-1);
}

int main() {
    return foo(10);
}

The first thing I was pretty WOW-ed at (which was also WOW-ed at in this question - https://stackoverflow.com/a/414774/1068248) was how int foo(int) (a basic factorial function) compiles into a tight loop. This is the ARM assembly for it:

    .globl  _foo
    .align  2
    .code   16
    .thumb_func _foo
_foo:
    mov r1, r0
    movs    r0, #1
    cbz r1, LBB0_2
LBB0_1:
    muls    r0, r1, r0
    subs    r1, #1
    bne LBB0_1
LBB0_2:
    bx  lr

Blimey I thought. That's pretty interesting! Completely tight looping to do the factorial. WOW. It's not a tail call optimisation since, well, it's not a tail call. But it appears to have done a much similar optimisation.

Now look at main:

    .globl  _main
    .align  2
    .code   16
    .thumb_func _main
_main:
    movw    r0, #24320
    movt    r0, #55
    bx  lr

That just blew my mind to be honest. It's just totally bypassing foo and returning 3628800 which is 10!.

This makes me really realise how your compiler can often do a much better job than you can at optimising your code. But it raises the question, how does it manage to do such a good job? So, can anyone explain (possibly by linking to relevant code) how the following optimisations work:

  1. The initial foo optimisation to be a tight loop.

  2. The optimisation where main just goes and returns the result directly rather than actually executing foo.

Also another interesting side effect of this question would be to show some more interesting optimisations which GCC/Clang can do.

回答1:

If you compile with gcc -O3 -fdump-tree-all, you can see that the first dump in which the recursion has been turned into a loop is foo.c.035t.tailr1. This means the same optimisation that handles other tail calls also handles this slightly extended case. Recursion in the form of n * foo(...) or n + foo(...) is not that hard to handle manually (see below), and since it's possible to describe exactly how, the compiler can perform that optimisation automatically.

The optimisation of main is much simpler: inlining can turn this into 10 * 9 * 8 * 7 * 6 * 5 * 4 * 3 * 2 * 1 * 1, and if all the operands of a multiplication are constants, then the multiplication can be performed at compile time.

Update: Here's how you can manually remove the recursion from foo, which can be done automatically. I'm not saying this is the method used by GCC, but it's one realistic possibility.

First, create a helper function. It behaves exactly as foo(n), except that its results are multiplied by an extra parameter f.

int foo(int n)
{
    return foo_helper(n, 1);
}

int foo_helper(int n, int f)
{
    if (n == 0) return f * 1;
    return f * n * foo(n-1);
}

Then, turn recursive calls of foo into recursive calls of foo_helper, and rely on the factor parameter to get rid of the multiplication.

int foo(int n)
{
    return foo_helper(n, 1);
}

int foo_helper(int n, int f)
{
    if (n == 0) return f;
    return foo_helper(n-1, f * n);
}

Turn this into a loop:

int foo(int n)
{
    return foo_helper(n, 1);
}

int foo_helper(int n, int f)
{
restart:
    if (n == 0) return f;
    {
        int newn = n-1;
        int newf = f * n;
        n = newn;
        f = newf;
        goto restart;
    }
}

Finally, inline foo_helper:

int foo(int n)
{
    int f = 1;
restart:
    if (n == 0) return f;
    {
        int newn = n-1;
        int newf = f * n;
        n = newn;
        f = newf;
        goto restart;
    }
}

(Naturally, this is not the most sensible way to manually write the function.)