No stream fusion with unsafeUpdate_ in unboxed vec

2019-06-21 23:33发布

Is it possible to maintain stream fusion when processing a vector if unsafeUpdate_ function is used to update some elements of a vector? The answer seems to be no in the test I did. For the code below, temporary vector is generated in upd function, as confirmed in the core:

module Main where
import Data.Vector.Unboxed as U

upd :: Vector Int -> Vector Int
upd v = U.unsafeUpdate_ v (U.fromList [0]) (U.fromList [2])

sum :: Vector Int -> Int
sum = U.sum . upd

main = print $ Main.sum $ U.fromList [1..3]

In the core, $wupd function is used in sum - as seen below, it generates new bytearray:

$wupd :: Vector Int -> Vector Int
$wupd =
  \ (w :: Vector Int) ->
    case w `cast` ... of _ { Vector ipv ipv1 ipv2 ->
    case main11 `cast` ... of _ { Vector ipv3 ipv4 ipv5 ->
    case main7 `cast` ... of _ { Vector ipv6 ipv7 ipv8 ->
    runSTRep
      (\ (@ s) (s :: State# s) ->
         case >=# ipv1 0 of _ {
           False -> case main6 ipv1 of wild { };
           True ->
             case newByteArray# (*# ipv1 8) (s `cast` ...)
             of _ { (# ipv9, ipv10 #) ->
             case (copyByteArray# ipv2 (*# ipv 8) ipv10 0 (*# ipv1 8) ipv9)
                  `cast` ...

There is a nice, tight loop in the core for sum function but just before that loop, there is a call to $wupd function, and so, a temporary generation.

Is there a way to avoid temporary generation in the example here? The way I think about it, updating a vector in index i is the case of parsing a stream but only acting on the stream in index i (skipping the rest), and replacing the element there with another element. So, updating a vector in an arbitrary location shouldn't break stream fusion, right?

2条回答
倾城 Initia
2楼-- · 2019-06-22 00:09

I can't be 100% sure, because with vector it's turtles all the way down (you never really reach the actual implementation, there's always another indirection), but as far as I understand it, the update variants force a new temporary through cloning:

unsafeUpdate_ :: (Vector v a, Vector v Int) => v a -> v Int -> v a -> v a
{-# INLINE unsafeUpdate_ #-}
unsafeUpdate_ v is w
  = unsafeUpdate_stream v (Stream.zipWith (,) (stream is) (stream w))

unsafeUpdate_stream :: Vector v a => v a -> Stream (Int,a) -> v a
{-# INLINE unsafeUpdate_stream #-}
unsafeUpdate_stream = modifyWithStream M.unsafeUpdate

and modifyWithStream calls clone (and new),

modifyWithStream :: Vector v a
                 => (forall s. Mutable v s a -> Stream b -> ST s ())
                 -> v a -> Stream b -> v a
{-# INLINE modifyWithStream #-}
modifyWithStream p v s = new (New.modifyWithStream p (clone v) s)

new :: Vector v a => New v a -> v a
{-# INLINE_STREAM new #-}
new m = m `seq` runST (unsafeFreeze =<< New.run m)

-- | Convert a vector to an initialiser which, when run, produces a copy of
-- the vector.
clone :: Vector v a => v a -> New v a
{-# INLINE_STREAM clone #-}
clone v = v `seq` New.create (
  do
    mv <- M.new (length v)
    unsafeCopy mv v
    return mv)

and I see no way that vector would get rid of that unsafeCopy again.

查看更多
你好瞎i
3楼-- · 2019-06-22 00:09

If you need to change one or very few elements, there are nice solutions in repa and yarr libraries. They preserve fusion (I'm not sure about repa) and Haskell-idiomatic.

Repa, using fromFunction:

upd arr = fromFunction (extent arr) ix
  where ix (Z .: 0) = 2
        ix i = index arr i

Yarr, using Delayed:

upd arr = Delayed (extent arr) (touchArray arr) (force arr) ix
  where ix 0 = return 2
        ix i = index arr i
查看更多
登录 后发表回答