Numeric.AD and typing problem

2019-02-21 23:10发布

问题:

I'm trying to work with Numeric.AD and a custom Expr type. I wish to calculate the symbolic gradient of user inputted expression. The first trial with a constant expression works nicely:

calcGrad0 :: [Expr Double]
calcGrad0 = grad df vars
  where
   df [x,y] = eval (env [x,y]) (EVar "x"*EVar "y")
   env vs = zip varNames vs
   varNames = ["x","y"]
   vars = map EVar varNames

This works:

>calcGrad0
[Const 0.0 :+ (Const 0.0 :+ (EVar "y" :* Const 1.0)),Const 0.0 :+ (Const 0.0 :+ (EVar "x" :* Const 1.0))]

However, if I pull the expression out as a parameter:

calcGrad1 :: [Expr Double]
calcGrad1 = calcGrad1' (EVar "x"*EVar "y")
calcGrad1' e = grad df vars
  where
   df [x,y] = eval (env [x,y]) e
   env vs = zip varNames vs
   varNames = ["x","y"]
   vars = map EVar varNames

I get

Could not deduce (a ~ AD s (Expr a1))
from the context (Num a1, Floating a)
  bound by the inferred type of
           calcGrad1' :: (Num a1, Floating a) => Expr a -> [Expr a1]
  at Symbolics.hs:(60,1)-(65,29)
or from (Mode s)
  bound by a type expected by the context:
             Mode s => [AD s (Expr a1)] -> AD s (Expr a1)
  at Symbolics.hs:60:16-27
  `a' is a rigid type variable bound by
      the inferred type of
      calcGrad1' :: (Num a1, Floating a) => Expr a -> [Expr a1]
      at Symbolics.hs:60:1
Expected type: [AD s (Expr a1)] -> AD s (Expr a1)
  Actual type: [a] -> a
In the first argument of `grad', namely `df'
In the expression: grad df vars

How do I get ghc to accept this?

回答1:

My guess is you are forgetting to apply lift to convert an Expr to an AD s Expr.

If you are interested in using the ad package for symbolic differentiation. Lennart Augustsson's traced package works well.



回答2:

When GHC cannot deduce a type equality signature for a valid function, as in your case, the solution is to give the function a type signature. I do not know the interface to this library. However, my guess is that the correct signature is calcGrad1 :: (Num a, Floating a) => Expr a -> [Expr a].