nVidia Thrust: device_ptr Const-Correctness

2019-06-27 16:59发布

In my project which makes extensive use of nVidia CUDA, I sometimes use Thrust for things that it does very, very well. Reduce is one algorithm that is particularly well implemented in that library and one use of reduce is to normalise a vector of non-negative elements by dividing each element by the sum of all elements.

template <typename T>
void normalise(T const* const d_input, const unsigned int size, T* d_output)
{
    const thrust::device_ptr<T> X = thrust::device_pointer_cast(const_cast<T*>(d_input));
    T sum = thrust::reduce(X, X + size);

    thrust::constant_iterator<T> denominator(sum);
    thrust::device_ptr<T> Y = thrust::device_pointer_cast(d_output);
    thrust::transform(X, X + size, denominator, Y, thrust::divides<T>());
}

(T is typically float or double)

In general, I don't want to depend on Thrust throughout my entire code base so I try to make sure that functions like the above example accept only raw CUDA device pointers. This means that once they are compiled by NVCC, I can link them statically into other code without NVCC.

This code worries me, however. I want the function to be const-correct but I can't seem to find a const version of thrust::device_pointer_cast(...) - Does such a thing exist? In this version of the code, I have resorted to a const_cast so that I use const in the function signature and that makes me sad.

On a side note, it feels odd to copy the result of reduce to the host only to send it back to the device for the next step. Is there a better way to do this?

标签: c++ cuda thrust
1条回答
forever°为你锁心
2楼-- · 2019-06-27 17:36

If you want const-correctness, you need to be const-correct everywhere. input is a pointer to const T, therefore so should be X:

const thrust::device_ptr<const T> X = thrust::device_pointer_cast(d_input);
查看更多
登录 后发表回答