I am trying to write a Julia macro that transforms this:
[par1!( par2(d1,d2)+par3(d1,d2) ,d1,d2,dfix3) for d1 in DIM1, d2 in DIM2]
(not very inspiring) into something much more readable, like this:
@meq par1!(d1 in DIM1, d2 in DIM2, dfix3) = par2(d1,d2)+par3(d1,d2)
where par1!()
is a function to set some multi-dimensional data and par2()
is a getData()-type of function.
I am trying to implement it using a macro, but as I am on my first experience with julia marcro, I'm not sure how to "assemble" the final expression from the various pieces..
Here is what I done so far:
macro meq(eq)
# dump(eq)
lhs_par = eq.args[1].args[1]
rhs = eq.args[2]
lhs_dims = eq.args[1].args[2:end]
loop_counters = [d.args[2] for d in lhs_dims if typeof(d) == Expr]
loop_sets = [d.args[3] for d in lhs_dims if typeof(d) == Expr]
loop_wholeElements = [d for d in lhs_dims if typeof(d) == Expr]
lhs_dims_placeholders = []
for d in lhs_dims
if typeof(d) == Expr
outExp = quote
[$(lhs_par)($(rhs),$(lhs_dims_placeholders ...)) for $(loop_wholeElements ...) ]
return outExp
However the above macro doesn't compile and returns a syntax error (“invalid iteration specification”) due to the for $(loop_wholeElements)
part… indeed I don’t know how to treat the expressions in lhs_dims_placeholders and loop_wholeElements in order to “assemble” the expanded expression…
The example posted, with d1
, d2
and dfix3
, is only a specific case, but the macro should be able to handle whichever dimensions are looped for..
I think the macro up there does that, but I don't know how to build the final expression.. :-(
Instead of manually doing those hard-coded args
matching stuff, we could use MacroTools.jl as a handy tool for template matching:
julia> using MacroTools
julia> macro meq(ex)
@capture(ex, f_(d1_ in dim1_, d2_ in dim2_, dfix3_) = body__)
ret = :([$f($(body[]), $d1, $d2, $dfix3) for $d1 in $dim1, $d2 in $dim2])
@meq (macro with 1 method)
julia> prettify(@macroexpand @meq par1!(d1 in DIM1, d2 in DIM2, dfix3) = par2(d1,d2)+par3(d1,d2))
:([(Main.par1!)((Main.par2)(lobster, redpanda) + (Main.par3)(lobster, redpanda), lobster, redpanda, Main.dfix3) for lobster = Main.DIM1, redpanda = Main.DIM2])
The desired final expression is a comprehension
, it seems that for some reason Julia couldn't figure out for expr
(where $expr #=> XXX in XXX
) is a comprehension. The workaround is directly using its lowered form:
julia> using MacroTools
julia> par1(a, b, c, d) = a + b + c + d
par1 (generic function with 1 method)
julia> par2(a, b) = a + b
par2 (generic function with 1 method)
julia> macro meq(ex)
@capture(ex, par_(dims__) = rhs_)
loopElements = []
dimsPlaceholders = []
for d in dims
@capture(d, di_ in DIMi_) || (push!(dimsPlaceholders, d); continue)
# push!(loopElements, x)
push!(loopElements, :($di = $DIMi))
push!(dimsPlaceholders, di)
ret = Expr(:comprehension, :($par($(rhs),$(dimsPlaceholders...))), loopElements...)
@meq (macro with 1 method)
julia> prettify(@macroexpand @meq par1!(d1 in DIM1, d2 in DIM2, dfix3) = par2(d1,d2)+par3(d1,d2))
:($(Expr(:comprehension, :((Main.par1!)(begin
(Main.par2)(bee, wildebeest) + (Main.par3)(bee, wildebeest)
end, bee, wildebeest, Main.dfix3)), :(bee = Main.DIM1), :(wildebeest = Main.DIM2))))
julia> @meq par1(m in 1:2, n in 4:5, 3) = par2(m,n) + par2(m,n)
2×2 Array{Int64,2}:
18 21
21 24
Note that, the variable scope of d1,d2
in generated expression will be wrong if we use push!(loopElements, x)
rather than push!(loopElements, :($di = $DIMi))
. Let's wait for someone knowledgeable to give a thorough explanation.
If you do not want to rely on an external package for this, the solution I provided on the Julia discourse should also work
return :([$(Expr(:generator,:($(Expr(:call,lhs_par,rhs,lhs_dims_placeholders...))),loop_wholeElements...))])
The key is to use the :generator constructor to make the loop expression
Also, rhs can be replaced with rhs.args[n] in order to eliminate the quote block and insert the expression directly.
This then produces the exact expression:
:([(par1!(par2(d1, d2) + par3(d1, d2), d1, d2, dfix3) for d1 in DIM1, d2 in DIM2)])
Alright, so I went ahead and tested this:
return Expr(:comprehension,Expr(:generator,Expr(:call,lhs_par,rhs.args[2],lhs_dims_placeholders...),loop_wholeElements...))
Then computing the result like this
meq(:(par1!(d1 = 1:2, d2 = 1:2, 3) = par2(d1,d2)+par3(d1,d2))) |> eval