Efficient Haskell equivalent to NumPy's argsor

2019-04-09 13:54发布

Is there a standard Haskell equivalent to NumPy's argsort function?

I'm using HMatrix and, so, would like a function compatible with Vector R which is an alias for Data.Vector.Storable.Vector Double. The argSort function below is the implementation I'm currently using:

{-# LANGUAGE NoImplicitPrelude #-}

module Main where

import qualified Data.List as L
import qualified Data.Vector as V
import qualified Data.Vector.Storable as VS
import           Prelude (($), Double, IO, Int, compare, print, snd)

a :: VS.Vector Double
a = VS.fromList [40.0, 20.0, 10.0, 11.0]

argSort :: VS.Vector Double -> V.Vector Int
argSort xs = V.fromList (L.map snd $ L.sortBy (\(x0, _) (x1, _) -> compare x0 x1) (L.zip (VS.toList xs) [0..]))

main :: IO ()
main = print $ argSort a -- yields [2,3,1,0]

I'm using explicit qualified imports just to make it clear where every type and function is coming from.

This implementation is not terribly efficient since it converts the input vector to a list and the result back to a vector. Does something like this (but more efficient) exist somewhere?

Update

@leftaroundabout had a good solution. This is the solution I ended up with:

module LAUtil.Sorting
  ( IndexVector
  , argSort
  )
  where

import           Control.Monad
import           Control.Monad.ST
import           Data.Ord
import qualified Data.Vector.Algorithms.Intro as VAI
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Unboxed.Mutable as VUM
import           Numeric.LinearAlgebra

type IndexVector = VU.Vector Int

argSort :: Vector R -> IndexVector
argSort xs = runST $ do
    let l = VS.length xs
    t0 <- VUM.new l
    forM_ [0..l - 1] $
        \i -> VUM.unsafeWrite t0 i (i, (VS.!) xs i)
    VAI.sortBy (comparing snd) t0
    t1 <- VUM.new l
    forM_ [0..l - 1] $
        \i -> VUM.unsafeRead t0 i >>= \(x, _) -> VUM.unsafeWrite t1 i x
    VU.freeze t1

This is more directly usable with Numeric.LinearAlgebra since the data vector is a Storable. This uses an unboxed vector for the indices.

2条回答
霸刀☆藐视天下
2楼-- · 2019-04-09 14:12

What worked better for me is using Data.map, as it is subject to list fusion, got a speed up. Here n=Length xs.

import Data.Map as M (toList, fromList, toAscList)

    out :: Int -> [Double] -> [Int]
    out n !xs = let !a=  (M.toAscList (M.fromList $! (zip xs [0..n])))
                    !res=a `seq` L.map snd a
                in res

However this is only aplicable for periodic lists, as:

out 12 [1,2,3,4,1,2,3,4,1,2,3,4] == out 12 [1,2,3,4,1,3,2,4,1,2,3,4]
查看更多
唯我独甜
3楼-- · 2019-04-09 14:25

Use vector-algorithms:

import Data.Ord (comparing)

import qualified Data.Vector.Unboxed as VU
import qualified Data.Vector.Algorithms.Intro as VAlgo

argSort :: (Ord a, VU.Unbox a) => VU.Vector a -> VU.Vector Int
argSort xs = VU.map fst $ VU.create $ do
    xsi <- VU.thaw $ VU.indexed xs
    VAlgo.sortBy (comparing snd) xsi
    return xsi

Note these are Unboxed rather than Storable vectors. The latter need to make some tradeoffs to allow impure C FFI operations and can't properly handle heterogeneous tuples. You can of course always convert to and from storable vectors.

查看更多
登录 后发表回答