Julia macro for transforming `f(dim1,dim2,..) = va

2019-07-25 16:29发布

问题:

I am trying to write a Julia macro that transforms this:

[par1!( par2(d1,d2)+par3(d1,d2)   ,d1,d2,dfix3) for d1 in DIM1, d2 in DIM2]

(not very inspiring) into something much more readable, like this:

@meq par1!(d1 in DIM1, d2 in DIM2, dfix3) =  par2(d1,d2)+par3(d1,d2)

where par1!() is a function to set some multi-dimensional data and par2() is a getData()-type of function.

I am trying to implement it using a macro, but as I am on my first experience with julia marcro, I'm not sure how to "assemble" the final expression from the various pieces.. Here is what I done so far:

macro meq(eq)
   # dump(eq)
    lhs_par               = eq.args[1].args[1]
    rhs                   = eq.args[2]
    lhs_dims              = eq.args[1].args[2:end]
    loop_counters         = [d.args[2] for d in lhs_dims if typeof(d) == Expr]
    loop_sets             = [d.args[3] for d in lhs_dims if typeof(d) == Expr]
    loop_wholeElements    = [d for d in lhs_dims if typeof(d) == Expr]
    lhs_dims_placeholders = []
    for d in lhs_dims
        if typeof(d) == Expr
            push!(lhs_dims_placeholders,d.args[2])
        else
            push!(lhs_dims_placeholders,d)
        end
    end
    outExp =  quote
      [$(lhs_par)($(rhs),$(lhs_dims_placeholders ...)) for  $(loop_wholeElements ...) ]
    end
    #show(outExp)
    return outExp
end

However the above macro doesn't compile and returns a syntax error (“invalid iteration specification”) due to the for $(loop_wholeElements) part… indeed I don’t know how to treat the expressions in lhs_dims_placeholders and loop_wholeElements in order to “assemble” the expanded expression…

EDIT:

The example posted, with d1, d2 and dfix3, is only a specific case, but the macro should be able to handle whichever dimensions are looped for.. I think the macro up there does that, but I don't know how to build the final expression.. :-(

回答1:

Instead of manually doing those hard-coded args matching stuff, we could use MacroTools.jl as a handy tool for template matching:

julia> using MacroTools

julia> macro meq(ex)
           @capture(ex, f_(d1_ in dim1_, d2_ in dim2_, dfix3_) = body__)
           ret = :([$f($(body[]), $d1, $d2, $dfix3) for $d1 in $dim1, $d2 in $dim2])
       end
@meq (macro with 1 method)

julia> prettify(@macroexpand @meq par1!(d1 in DIM1, d2 in DIM2, dfix3) =  par2(d1,d2)+par3(d1,d2))
:([(Main.par1!)((Main.par2)(lobster, redpanda) + (Main.par3)(lobster, redpanda), lobster, redpanda, Main.dfix3) for lobster = Main.DIM1, redpanda = Main.DIM2])

UPDATE:

The desired final expression is a comprehension, it seems that for some reason Julia couldn't figure out for expr(where $expr #=> XXX in XXX) is a comprehension. The workaround is directly using its lowered form:

julia> using MacroTools

julia> par1(a, b, c, d) = a + b + c + d
par1 (generic function with 1 method)

julia> par2(a, b) = a + b
par2 (generic function with 1 method)

julia> macro meq(ex)
           @capture(ex, par_(dims__) = rhs_)
           loopElements = []
           dimsPlaceholders = []
           for d in dims
               @capture(d, di_ in DIMi_) || (push!(dimsPlaceholders, d); continue)
               # push!(loopElements, x) 
               push!(loopElements, :($di = $DIMi))
               push!(dimsPlaceholders, di)
           end
           ret = Expr(:comprehension, :($par($(rhs),$(dimsPlaceholders...))), loopElements...)
       end
@meq (macro with 1 method)

julia> prettify(@macroexpand @meq par1!(d1 in DIM1, d2 in DIM2, dfix3) =  par2(d1,d2)+par3(d1,d2))
:($(Expr(:comprehension, :((Main.par1!)(begin 
            (Main.par2)(bee, wildebeest) + (Main.par3)(bee, wildebeest)
        end, bee, wildebeest, Main.dfix3)), :(bee = Main.DIM1), :(wildebeest = Main.DIM2))))

julia> @meq par1(m in 1:2, n in 4:5, 3) =  par2(m,n) + par2(m,n)
2×2 Array{Int64,2}:
 18  21
 21  24

Note that, the variable scope of d1,d2 in generated expression will be wrong if we use push!(loopElements, x) rather than push!(loopElements, :($di = $DIMi)). Let's wait for someone knowledgeable to give a thorough explanation.



回答2:

If you do not want to rely on an external package for this, the solution I provided on the Julia discourse should also work

return :([$(Expr(:generator,:($(Expr(:call,lhs_par,rhs,lhs_dims_placeholders...))),loop_wholeElements...))])

The key is to use the :generator constructor to make the loop expression

Also, rhs can be replaced with rhs.args[n] in order to eliminate the quote block and insert the expression directly.

This then produces the exact expression:

:([(par1!(par2(d1, d2) + par3(d1, d2), d1, d2, dfix3) for d1 in DIM1, d2 in DIM2)])

EDIT:

Alright, so I went ahead and tested this:

return Expr(:comprehension,Expr(:generator,Expr(:call,lhs_par,rhs.args[2],lhs_dims_placeholders...),loop_wholeElements...))

end

Then computing the result like this

meq(:(par1!(d1 = 1:2, d2 = 1:2, 3) =  par2(d1,d2)+par3(d1,d2))) |> eval