PySpark - Time Overlap for Object in RDD

2019-02-15 18:29发布

My goal is to group objects based on time overlap.

Each object in my rdd contains a start_time and end_time.

I'm probably going about this inefficiently but what I'm planning on doing is assigning an overlap id to each object based on if it has any time overlap with any of the other objects. I have the logic for time overlap down. Then, I hope to group by that overlap_id.

So first,

mapped_rdd = rdd.map(assign_overlap_id)
final_rdd = mapped_rdd.reduceByKey(combine_objects)

Now this comes to my question. How can I go about writing the assign_overlap_id function?

def assign_overlap_id(x):
  ...
  ...
  return (overlap_id, x)

1条回答
神经病院院长
2楼-- · 2019-02-15 19:22

Naive solution using Spark SQL and Data Frames:

Scala:

import org.apache.spark.sql.functions.udf

case class Interval(start_time: Long, end_time: Long)

val rdd = sc.parallelize(
    Interval(0, 3) :: Interval(1, 4) ::
    Interval(2, 5) :: Interval(3, 4) ::
    Interval(5, 8) :: Interval(7, 10) :: Nil
)

val df = sqlContext.createDataFrame(rdd)

// Simple check if a given intervals overlap
def overlaps(start_first: Long, end_first: Long,
        start_second: Long, end_second: Long):Boolean = {
    (start_second > start_first & start_second < end_first) |
    (end_second > start_first & end_second < end_first) 
}

// Register udf and data frame aliases
// It look like Spark SQL doesn't support
// aliases in FROM clause [1] so we have to
// register df twice
sqlContext.udf.register("overlaps", overlaps)
df.registerTempTable("df1")
df.registerTempTable("df2")

// Join and filter
sqlContext.sql("""
     SELECT * FROM df1 JOIN df2
     WHERE overlaps(df1.start_time, df1.end_time, df2.start_time, df2.end_time)
""").show

And the same thing using PySpark

from pyspark.sql.functions import udf
from pyspark.sql.types import BooleanType

rdd = sc.parallelize([
    (0, 3), (1, 4), 
    (2, 5), (3, 4),
    (5, 8), (7, 10)
])

df = sqlContext.createDataFrame(rdd, ('start_time', 'end_time'))

def overlaps(start_first, end_first, start_second, end_second):
    return ((start_first < start_second < end_first) or
        (start_first < end_second < end_first))

sqlContext.registerFunction('overlaps', overlaps, BooleanType())
df.registerTempTable("df1")
df.registerTempTable("df2")

sqlContext.sql("""
     SELECT * FROM df1 JOIN df2
     WHERE overlaps(df1.start_time, df1.end_time, df2.start_time, df2.end_time)
""").show()

Low level transformations with grouping by window

A little bit smarter approach is to generate candidate pairs using a window of some specified width. Here is a rather simplified solution:

Scala:

// Generates list of "buckets" for a given interval
def genRange(interval: Interval) = interval match {
    case Interval(start_time, end_time) => {
      (start_time / 10L * 10L) to (((end_time / 10) + 1) * 10) by 1
    }
}


// For each interval generate pairs (bucket, interval)
val pairs = rdd.flatMap( (i: Interval) => genRange(i).map((r) => (r, i)))

// Join (in the worst case scenario it is still O(n^2)
// But in practice should be better than a naive
// Cartesian product
val candidates = pairs.
    join(pairs).
    map({
        case (k, (Interval(s1, e1), Interval(s2, e2))) => (s1, e1, s2, e2)
   }).distinct


// For each candidate pair check if there is overlap
candidates.filter { case (s1, e1, s2, e2) => overlaps(s1, e1, s2, e2) }

Python:

def genRange(start_time, end_time):
    return xrange(start_time / 10L * 10L, ((end_time / 10) + 1) * 10)

pairs = rdd.flatMap(lambda (s, e): ((r, (s, e)) for r in genRange(s, e)))
candidates = (pairs
    .join(pairs)
    .map(lambda (k, ((s1, e1), (s2, e2))): (s1, e1, s2, e2))
    .distinct())

candidates.filter(lambda (s1, e1, s2, e2): overlaps(s1, e1, s2, e2))

While it can be sufficient on some datasest for a production ready solution you should consider implementing some state-of-the-art algorithm like NCList.

  1. http://docs.datastax.com/en/datastax_enterprise/4.6/datastax_enterprise/spark/sparkSqlSupportedSyntax.html
查看更多
登录 后发表回答