iterative code with long lineage RDD causes stacko

2019-04-13 01:29发布

I am a beginner of Apache Spark. I am currently working on a Machine Learning program, which requires to iteratively update a RDD and then collect nearly 10KB data to driver from executors. Unfortunately, I get a StackOverFlow error when it runs over 600 iterations! The following is my code. The stackoverflow error happened at collectAsMap function when iteration number is over 400! where indexedDevF and indexedData are indexedRDD (developed by AMPLab as an library provided https://github.com/amplab/spark-indexedrdd)

breakable{
  while(bLow > bHigh + 2*tolerance){
    indexedDevF = indexedDevF.innerJoin(indexedData){(id, a, b) => (b, a)}.mapValues( x => ( x._2 + alphaHighDiff * broad_y.value(iHigh) * kernel(x._1, dataiHigh) + alphaLowDiff * broad_y.value(iLow) * kernel(x._1, dataiLow) ) )
    if (iteration % 50 == 0 ) {
          indexedDevF.checkpoint()
    }
    indexedDevF.persist()  // essential to get correct answer

    val devFMap = indexedDevF.collectAsMap() //0.5s every time according to local:4040! here will stackoverflow

    var min_value = Double.PositiveInfinity
    var max_value = -min_value
    var min_i = -1
    var max_i = -1

    i = 0
    while( i < m ){

      if(((y(i) > 0) && (alpha(i) < cEpsilon)) || ((y(i) < 0) && (alpha(i) > epsilon))){
          if( devFMap(i) <= min_value){
              min_value = devFMap(i)
              min_i = i
          }
      }

      if(((y(i) > 0) && (alpha(i) > epsilon)) || ((y(i) < 0) && (alpha(i) < cEpsilon))){
          if( devFMap(i) >= max_value ){
              max_value = devFMap(i)
              max_i = i
          }
      }
      i = i+1
    }

    iHigh = min_i
    iLow = max_i
    bHigh = devFMap(iHigh)
    bLow = devFMap(iLow) 

    dataiHigh = indexedData.get(iHigh.toLong).get
    dataiLow = indexedData.get(iLow.toLong).get 

    eta = 2 - 2 * kernel(dataiHigh, dataiLow)

    alphaHighOld = alpha(iHigh)
    alphaLowOld = alpha(iLow)
    var alphaDiff = alphaLowOld - alphaHighOld
    var lowLabel = y(iLow)
    var sign = y(iHigh) * lowLabel

    var alphaLowLowerBound = 0D
    var alphaLowUpperBound = 0D

    if (sign < 0){
        if (alphaDiff < 0){
            alphaLowLowerBound = 0;
            alphaLowUpperBound = cost + alphaDiff;
        }
        else{
            alphaLowLowerBound = alphaDiff;
            alphaLowUpperBound = cost;
        }
    }
    else{
        var alphaSum = alphaLowOld + alphaHighOld;
        if (alphaSum < cost){
            alphaLowUpperBound = alphaSum;
            alphaLowLowerBound = 0;
        }
        else{
            alphaLowLowerBound = alphaSum - cost;
            alphaLowUpperBound = cost;
        }
    }

    if (eta > 0){
        alphaLowNew = alphaLowOld + lowLabel*(bHigh - bLow)/eta;
        if (alphaLowNew < alphaLowLowerBound)
            alphaLowNew = alphaLowLowerBound;
        else if (alphaLowNew > alphaLowUpperBound) 
            alphaLowNew = alphaLowUpperBound;
    }
    else{
        var slope = lowLabel * (bHigh - bLow);
        var delta = slope * (alphaLowUpperBound - alphaLowLowerBound);
        if (delta > 0){
            if (slope > 0)  
                alphaLowNew = alphaLowUpperBound;
            else
                alphaLowNew = alphaLowLowerBound;
        }
        else
            alphaLowNew = alphaLowOld;
    }

    alphaLowDiff = alphaLowNew - alphaLowOld;
    alphaHighDiff = -sign*(alphaLowDiff);
    alpha(iLow) = alphaLowNew;
    alpha(iHigh) = (alphaHighOld + alphaHighDiff);


    if(iteration % 50 == 0)
      print(".")

    iteration = iteration + 1;


}

===================

The original question is as following, I find the checkpoint is useless and the program will conclude with stackoverflow errer!! I write a test simple code to describe my problem. Fortunately, a nice guy help me solve the problem, you can find the answer below! However, even the checkpoint really works, I still get stackoverflow error with my program :(

for(i <- 1 to 1000){
  a = a.map(x => x+1).persist
  var b = a.collect()
  if(i%100 == 0){
    a.checkpoint()
  }
  print(".")
}

1条回答
你好瞎i
2楼-- · 2019-04-13 02:10

Looking at RDD.checkpoint documentation, it says:

This function must be called before any job has been executed on this RDD

And indeed, if you change your code slightly, to have the checkpoint done before collecting a - it works with no StackOverflowError:

for(i <- 1 to 1000){
  a = a.map(x => x+1).persist

  if(i%100 == 0){
    a.checkpoint()
  }

  var b = a.collect()

  print(".")
}
查看更多
登录 后发表回答