C++ code for testing the Collatz conjecture faster

2018-12-31 02:47发布

I wrote these two solutions for Project Euler Q14, in assembly and in C++. They are the same identical brute force approach for testing the Collatz conjecture. The assembly solution was assembled with

nasm -felf64 p14.asm && gcc p14.o -o p14

The C++ was compiled with

g++ p14.cpp -o p14

Assembly, p14.asm

section .data
    fmt db "%d", 10, 0

global main
extern printf

section .text

main:
    mov rcx, 1000000
    xor rdi, rdi        ; max i
    xor rsi, rsi        ; i

l1:
    dec rcx
    xor r10, r10        ; count
    mov rax, rcx

l2:
    test rax, 1
    jpe even

    mov rbx, 3
    mul rbx
    inc rax
    jmp c1

even:
    mov rbx, 2
    xor rdx, rdx
    div rbx

c1:
    inc r10
    cmp rax, 1
    jne l2

    cmp rdi, r10
    cmovl rdi, r10
    cmovl rsi, rcx

    cmp rcx, 2
    jne l1

    mov rdi, fmt
    xor rax, rax
    call printf
    ret

C++, p14.cpp

#include <iostream>

using namespace std;

int sequence(long n) {
    int count = 1;
    while (n != 1) {
        if (n % 2 == 0)
            n /= 2;
        else
            n = n*3 + 1;

        ++count;
    }

    return count;
}

int main() {
    int max = 0, maxi;
    for (int i = 999999; i > 0; --i) {
        int s = sequence(i);
        if (s > max) {
            max = s;
            maxi = i;
        }
    }

    cout << maxi << endl;
}

I know about the compiler optimizations to improve speed and everything, but I don't see many ways to optimize my assembly solution further (speaking programmatically not mathematically).

The C++ code has modulus every term and division every even term, where assembly is only one division per even term.

But the assembly is taking on average 1 second longer than the C++ solution. Why is this? I am asking out of mainly curiosity.

Execution times

My system: 64 bit Linux on ‎1.4 GHz Intel Celeron 2955U (Haswell microarchitecture).

11条回答
查无此人
2楼-- · 2018-12-31 03:01

The simple answer:

  • doing a MOV RBX, 3 and MUL RBX is expensive; just ADD RBX, RBX twice

  • ADD 1 is probably faster than INC here

  • MOV 2 and DIV is very expensive; just shift right

  • 64-bit code is usually noticeably slower than 32-bit code and the alignment issues are more complicated; with small programs like this you have to pack them so you are doing parallel computation to have any chance of being faster than 32-bit code

If you generate the assembly listing for your C++ program, you can see how it differs from your assembly.

查看更多
孤独寂梦人
3楼-- · 2018-12-31 03:02

Even without looking at assembly, the most obvious reason is that /= 2 is probably optimized as >>=1 and many processors have a very quick shift operation. But even if a processor doesn't have a shift operation, the integer division is faster than floating point division.

Edit: your milage may vary on the "integer division is faster than floating point division" statement above. The comments below reveal that the modern processors have prioritized optimizing fp division over integer division. So if someone were looking for the most likely reason for the speedup which this thread's question asks about, then compiler optimizing /=2 as >>=1 would be the best 1st place to look.


On an unrelated note, if n is odd, the expression n*3+1 will always be even. So there is no need to check. You can change that branch to

{
   n = (n*3+1) >> 1;
   count += 2;
}

So the whole statement would then be

if (n & 1)
{
    n = (n*3 + 1) >> 1;
    count += 2;
}
else
{
    n >>= 1;
    ++count;
}
查看更多
冷夜・残月
4楼-- · 2018-12-31 03:04

On a rather unrelated note: more performance hacks!

  • [the first «conjecture» has been finally debunked by @ShreevatsaR; removed]

  • When traversing the sequence, we can only get 3 possible cases in the 2-neighborhood of the current element N (shown first):

    1. [even] [odd]
    2. [odd] [even]
    3. [even] [even]

    To leap past these 2 elements means to compute (N >> 1) + N + 1, ((N << 1) + N + 1) >> 1 and N >> 2, respectively.

    Let`s prove that for both cases (1) and (2) it is possible to use the first formula, (N >> 1) + N + 1.

    Case (1) is obvious. Case (2) implies (N & 1) == 1, so if we assume (without loss of generality) that N is 2-bit long and its bits are ba from most- to least-significant, then a = 1, and the following holds:

    (N << 1) + N + 1:     (N >> 1) + N + 1:
    
            b10                    b1
             b1                     b
           +  1                   + 1
           ----                   ---
           bBb0                   bBb
    

    where B = !b. Right-shifting the first result gives us exactly what we want.

    Q.E.D.: (N & 1) == 1 ⇒ (N >> 1) + N + 1 == ((N << 1) + N + 1) >> 1.

    As proven, we can traverse the sequence 2 elements at a time, using a single ternary operation. Another 2× time reduction.

The resulting algorithm looks like this:

uint64_t sequence(uint64_t size, uint64_t *path) {
    uint64_t n, i, c, maxi = 0, maxc = 0;

    for (n = i = (size - 1) | 1; i > 2; n = i -= 2) {
        c = 2;
        while ((n = ((n & 3)? (n >> 1) + n + 1 : (n >> 2))) > 2)
            c += 2;
        if (n == 2)
            c++;
        if (c > maxc) {
            maxi = i;
            maxc = c;
        }
    }
    *path = maxc;
    return maxi;
}

int main() {
    uint64_t maxi, maxc;

    maxi = sequence(1000000, &maxc);
    printf("%llu, %llu\n", maxi, maxc);
    return 0;
}

Here we compare n > 2 because the process may stop at 2 instead of 1 if the total length of the sequence is odd.

[EDIT:]

Let`s translate this into assembly!

MOV RCX, 1000000;



DEC RCX;
AND RCX, -2;
XOR RAX, RAX;
MOV RBX, RAX;

@main:
  XOR RSI, RSI;
  LEA RDI, [RCX + 1];

  @loop:
    ADD RSI, 2;
    LEA RDX, [RDI + RDI*2 + 2];
    SHR RDX, 1;
    SHRD RDI, RDI, 2;    ror rdi,2   would do the same thing
    CMOVL RDI, RDX;      Note that SHRD leaves OF = undefined with count>1, and this doesn't work on all CPUs.
    CMOVS RDI, RDX;
    CMP RDI, 2;
  JA @loop;

  LEA RDX, [RSI + 1];
  CMOVE RSI, RDX;

  CMP RAX, RSI;
  CMOVB RAX, RSI;
  CMOVB RBX, RCX;

  SUB RCX, 2;
JA @main;



MOV RDI, RCX;
ADD RCX, 10;
PUSH RDI;
PUSH RCX;

@itoa:
  XOR RDX, RDX;
  DIV RCX;
  ADD RDX, '0';
  PUSH RDX;
  TEST RAX, RAX;
JNE @itoa;

  PUSH RCX;
  LEA RAX, [RBX + 1];
  TEST RBX, RBX;
  MOV RBX, RDI;
JNE @itoa;

POP RCX;
INC RDI;
MOV RDX, RDI;

@outp:
  MOV RSI, RSP;
  MOV RAX, RDI;
  SYSCALL;
  POP RAX;
  TEST RAX, RAX;
JNE @outp;

LEA RAX, [RDI + 59];
DEC RDI;
SYSCALL;

Use these commands to compile:

nasm -f elf64 file.asm
ld -o file file.o

See the C and an improved/bugfixed version of the asm by Peter Cordes on Godbolt. (editor's note: Sorry for putting my stuff in your answer, but my answer hit the 30k char limit from Godbolt links + text!)

查看更多
只若初见
5楼-- · 2018-12-31 03:07

Claiming that the C++ compiler can produce more optimal code than a competent assembly language programmer is a very bad mistake. And especially in this case. The human always can make the code better that the compiler can, and this particular situation is good illustration of this claim.

The timing difference you're seeing is because the assembly code in the question is very far from optimal in the inner loops.

(The below code is 32-bit, but can be easily converted to 64-bit)

For example, the sequence function can be optimized to only 5 instructions:

    .seq:
        inc     esi                 ; counter
        lea     edx, [3*eax+1]      ; edx = 3*n+1
        shr     eax, 1              ; eax = n/2
        cmovc   eax, edx            ; if CF eax = edx
        jnz     .seq                ; jmp if n<>1

The whole code looks like:

include "%lib%/freshlib.inc"
@BinaryType console, compact
options.DebugMode = 1
include "%lib%/freshlib.asm"

start:
        InitializeAll
        mov ecx, 999999
        xor edi, edi        ; max
        xor ebx, ebx        ; max i

    .main_loop:

        xor     esi, esi
        mov     eax, ecx

    .seq:
        inc     esi                 ; counter
        lea     edx, [3*eax+1]      ; edx = 3*n+1
        shr     eax, 1              ; eax = n/2
        cmovc   eax, edx            ; if CF eax = edx
        jnz     .seq                ; jmp if n<>1

        cmp     edi, esi
        cmovb   edi, esi
        cmovb   ebx, ecx

        dec     ecx
        jnz     .main_loop

        OutputValue "Max sequence: ", edi, 10, -1
        OutputValue "Max index: ", ebx, 10, -1

        FinalizeAll
        stdcall TerminateAll, 0

In order to compile this code, FreshLib is needed.

In my tests, (1 GHz AMD A4-1200 processor), the above code is approximately four times faster than the C++ code from the question (when compiled with -O0: 430 ms vs. 1900 ms), and more than two times faster (430 ms vs. 830 ms) when the C++ code is compiled with -O3.

The output of both programs is the same: max sequence = 525 on i = 837799.

查看更多
忆尘夕之涩
6楼-- · 2018-12-31 03:08

You did not post the code generated by the compiler, so there' some guesswork here, but even without having seen it, one can say that this:

test rax, 1
jpe even

... has a 50% chance of mispredicting the branch, and that will come expensive.

The compiler almost certainly does both computations (which costs neglegibly more since the div/mod is quite long latency, so the multiply-add is "free") and follows up with a CMOV. Which, of course, has a zero percent chance of being mispredicted.

查看更多
梦醉为红颜
7楼-- · 2018-12-31 03:08

For the Collatz problem, you can get a significant boost in performance by caching the "tails". This is a time/memory trade-off. See: memoization (https://en.wikipedia.org/wiki/Memoization). You could also look into dynamic programming solutions for other time/memory trade-offs.

Example python implementation:

import sys

inner_loop = 0

def collatz_sequence(N, cache):
    global inner_loop

    l = [ ]
    stop = False
    n = N

    tails = [ ]

    while not stop:
        inner_loop += 1
        tmp = n
        l.append(n)
        if n <= 1:
            stop = True  
        elif n in cache:
            stop = True
        elif n % 2:
            n = 3*n + 1
        else:
            n = n // 2
        tails.append((tmp, len(l)))

    for key, offset in tails:
        if not key in cache:
            cache[key] = l[offset:]

    return l

def gen_sequence(l, cache):
    for elem in l:
        yield elem
        if elem in cache:
            yield from gen_sequence(cache[elem], cache)
            raise StopIteration

if __name__ == "__main__":
    le_cache = {}

    for n in range(1, 4711, 5):
        l = collatz_sequence(n, le_cache)
        print("{}: {}".format(n, len(list(gen_sequence(l, le_cache)))))

    print("inner_loop = {}".format(inner_loop))
查看更多
登录 后发表回答