How to do automatic differentiation on hmatrix?

2020-03-03 06:23发布

问题:

Sooooo ... as it turns out going from fake matrices to hmatrix datatypes turns out to be nontrivial :)

Preamble for reference:

{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ParallelListComp #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}

import           Numeric.LinearAlgebra.HMatrix
import           Numeric.AD

reconstruct :: (Container Vector a, Num (Vector a)) 
            => [a] -> [Matrix a] -> Matrix a
reconstruct as φs = sum [ a `scale` φ | a <- as | φ <- φs ]

preserveInfo :: (Container Vector a, Num (Vector a))
     => Matrix a -> [a] -> [Matrix a] -> a
preserveInfo img as φs = sumElements (errImg * errImg)
    where errImg = img - (reconstruct as φs)

And the call to the gradientDescent function:

gradientDescentOverAs :: forall m a. (Floating a, Ord a, Num (Vector a))
                      => Matrix a -> [Matrix a] -> [a] -> [[a]]
gradientDescentOverAs img φs as0 = gradientDescent go as0
  where go as = preserveInfo img as φs

edit: this is not the code in the original question but boiled down as much as possible. GHC requires some constraints on the go sub-function, but the answer proposed in the linked question doesn't apply here.

edit2, quoting myself from below:

I come to believe it can't be done. Matrix requires it's elements to be in the Element class. The only elements there are Double, Float and their Complex forms. All of these are not accepted by gradientDescent.

So basically this is the same question as the one linked above, but for the hmatrix datatypes instead of my handrolled ones.

edit3

Relevant, email conversation between Edward Kmett and Dominic Steinitz on the topic: https://mail.haskell.org/pipermail/haskell-cafe/2013-April/107561.html

回答1:

I found this series of blog posts to be very helpful: https://idontgetoutmuch.wordpress.com/2014/09/09/fun-with-extended-kalman-filters-4/ (both HMatrix with static size guarantees and the jacobian function from AD are demonstrated).

HTH