How to do memoization or memoisation in Julia 1.0

2020-06-03 09:13发布

I have been trying to do memorisation in Julia for the Fibonacci function. This is what I came up with.

The original unmodified code (for control purposes)

function fib(x)
    if x < 3
        return 1
    else
        return fib(x-2) + fib(x-1)
    end
end

This is my first attempt

memory=Dict()

function memfib(x)
    global memory
    if haskey(memory,x)
        return memory[x]
    else
        if x < 3
            return memory[x] = 1
        else
            return memory[x] = memfib(x-2) + memfib(x-1)
        end
    end
end

My second attempt

memory=Dict()

function membetafib(x)
    global memory
    return haskey(memory,x) ? memory[x] : x < 3 ? memory[x]=1 : memory[x] = membetafib(x-2) + membetafib(x-1)
end

My third attempt

memory=Dict()

function memgammafib!(memory,x)
    return haskey(memory,x) ? memory[x] : x < 3 ? memory[x]=1 : memory[x] = memgammafib!(memory,x-2) + memgammafib!(memory,x-1)
end

Are there other ways of doing so that I am not aware of?

标签: julia
2条回答
地球回转人心会变
2楼-- · 2020-06-03 09:53

The simplest way to do it is to use get!

const fibmem = Dict{Int,Int}()
function fib(n)
    get!(fibmem, n) do
        n < 3 ? 1 : fib(n-1) + fib(n-2)
    end
end

Note the const specifier outside fibmem. This avoids the need for global, and will make the code faster as it allows the compiler to use type inference within fib.

查看更多
你好瞎i
3楼-- · 2020-06-03 09:58

As pointed out in the comments, the Memoize.jl package is certainly the easiest option. This requires you to mark the method at definition time.

By far the most powerful approach, however, is to use Cassette.jl, which lets you add memoization to pre-existing functions, e.g.

fib(x) = x < 3 ? 1 : fib(x-2) + fib(x-1)

using Cassette
Cassette.@context MemoizeCtx
function Cassette.overdub(ctx::MemoizeCtx, ::typeof(fib), x)
       get(ctx.metadata, x) do
           result = recurse(ctx, fib, x)
           ctx.metadata[x] = result
           return result
       end
   end

A little bit of a description of what is going on:

  • MemoizeCtx is the Cassette "context" which we are defining
  • overdub is run instead of the original function definition
    • We use this to check if the arg exists in the metadata dictionary.
    • recurse(...) tells Cassette to call the function, but ignore the top level overload.

Now we can run the function with memoization:

Cassette.overdub(MemoizeCtx(metadata=Dict{Int,Int}()), fib, 80)

Now what's even cooler is that we can take an existing function which calls fib, and memoize the call to fib inside that function:

function foo()
    println("calling fib")
    @show fib(80)
    println("done.")
end
Cassette.overdub(MemoizeCtx(metadata=Dict{Int,Int}()), foo)

(Cassette is still pretty hard on the compiler, so this may take a while to run the first time, but will be fast after that).

查看更多
登录 后发表回答