Thrust filter by key value

2019-07-16 01:33发布

In my application I have a class like this:

class sample{
    thrust::device_vector<int>   edge_ID;
    thrust::device_vector<float> weight;
    thrust::device_vector<int>   layer_ID;

/*functions, zip_iterators etc. */

};

At a given index every vector stores the corresponding data of the same edge.

I want to write a function that filters out all the edges of a given layer, something like this:

void filter(const sample& src, sample& dest, const int& target_layer){
      for(...){
        if( src.layer_ID[x] == target_layer)/*copy values to dest*/;
      }
}

The best way I've found to do this is by using thrust::copy_if(...) (details)

It would look like this:

void filter(const sample& src, sample& dest, const int& target_layer){
     thrust::copy_if(src.begin(),
                     src.end(),
                     dest.begin(),
                     comparing_functor() );
}

And this is where we reach my problem:

The comparing_functor() is an unary function, which means I cant pass my target_layer value to it.

Anyone knows how to get around this, or has an idea for implementing this while keeping the data structure of the class intact?

1条回答
We Are One
2楼-- · 2019-07-16 02:17

You can pass specific values to functors for use in the predicate test in addition to the data that is ordinarily passed to them. Here's a worked example:

#include <iostream>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/sequence.h>
#include <thrust/copy.h>

#define DSIZE 10
#define FVAL 5

struct test_functor
{
  const int a;

  test_functor(int _a) : a(_a) {}

  __device__
  bool operator()(const int& x ) {
    return (x==a);
    }
};

int main(){
  int target_layer = FVAL;
  thrust::host_vector<int> h_vals(DSIZE);
  thrust::sequence(h_vals.begin(), h_vals.end());
  thrust::device_vector<int> d_vals = h_vals;
  thrust::device_vector<int> d_result(DSIZE);
  thrust::copy_if(d_vals.begin(), d_vals.end(), d_result.begin(),  test_functor(target_layer));
  thrust::host_vector<int> h_result = d_result;
  std::cout << "Data :" << std::endl;
  thrust::copy(h_vals.begin(), h_vals.end(), std::ostream_iterator<int>( std::cout, " "));
  std::cout << std::endl;
  std::cout << "Filter Value: " << target_layer << std::endl;
  std::cout << "Results :" << std::endl;
  thrust::copy(h_result.begin(), h_result.end(), std::ostream_iterator<int>( std::cout, " "));
  std::cout << std::endl;
  return 0;
}
查看更多
登录 后发表回答