My goal is to group objects based on time overlap.
Each object in my rdd
contains a start_time
and end_time
.
I'm probably going about this inefficiently but what I'm planning on doing is assigning an overlap id to each object based on if it has any time overlap with any of the other objects. I have the logic for time overlap down. Then, I hope to group by that overlap_id
.
So first,
mapped_rdd = rdd.map(assign_overlap_id)
final_rdd = mapped_rdd.reduceByKey(combine_objects)
Now this comes to my question. How can I go about writing the assign_overlap_id function?
def assign_overlap_id(x):
...
...
return (overlap_id, x)
Naive solution using Spark SQL and Data Frames:
Scala:
import org.apache.spark.sql.functions.udf
case class Interval(start_time: Long, end_time: Long)
val rdd = sc.parallelize(
Interval(0, 3) :: Interval(1, 4) ::
Interval(2, 5) :: Interval(3, 4) ::
Interval(5, 8) :: Interval(7, 10) :: Nil
)
val df = sqlContext.createDataFrame(rdd)
// Simple check if a given intervals overlap
def overlaps(start_first: Long, end_first: Long,
start_second: Long, end_second: Long):Boolean = {
(start_second > start_first & start_second < end_first) |
(end_second > start_first & end_second < end_first)
}
// Register udf and data frame aliases
// It look like Spark SQL doesn't support
// aliases in FROM clause [1] so we have to
// register df twice
sqlContext.udf.register("overlaps", overlaps)
df.registerTempTable("df1")
df.registerTempTable("df2")
// Join and filter
sqlContext.sql("""
SELECT * FROM df1 JOIN df2
WHERE overlaps(df1.start_time, df1.end_time, df2.start_time, df2.end_time)
""").show
And the same thing using PySpark
from pyspark.sql.functions import udf
from pyspark.sql.types import BooleanType
rdd = sc.parallelize([
(0, 3), (1, 4),
(2, 5), (3, 4),
(5, 8), (7, 10)
])
df = sqlContext.createDataFrame(rdd, ('start_time', 'end_time'))
def overlaps(start_first, end_first, start_second, end_second):
return ((start_first < start_second < end_first) or
(start_first < end_second < end_first))
sqlContext.registerFunction('overlaps', overlaps, BooleanType())
df.registerTempTable("df1")
df.registerTempTable("df2")
sqlContext.sql("""
SELECT * FROM df1 JOIN df2
WHERE overlaps(df1.start_time, df1.end_time, df2.start_time, df2.end_time)
""").show()
Low level transformations with grouping by window
A little bit smarter approach is to generate candidate pairs using a window of some specified width. Here is a rather simplified solution:
Scala:
// Generates list of "buckets" for a given interval
def genRange(interval: Interval) = interval match {
case Interval(start_time, end_time) => {
(start_time / 10L * 10L) to (((end_time / 10) + 1) * 10) by 1
}
}
// For each interval generate pairs (bucket, interval)
val pairs = rdd.flatMap( (i: Interval) => genRange(i).map((r) => (r, i)))
// Join (in the worst case scenario it is still O(n^2)
// But in practice should be better than a naive
// Cartesian product
val candidates = pairs.
join(pairs).
map({
case (k, (Interval(s1, e1), Interval(s2, e2))) => (s1, e1, s2, e2)
}).distinct
// For each candidate pair check if there is overlap
candidates.filter { case (s1, e1, s2, e2) => overlaps(s1, e1, s2, e2) }
Python:
def genRange(start_time, end_time):
return xrange(start_time / 10L * 10L, ((end_time / 10) + 1) * 10)
pairs = rdd.flatMap(lambda (s, e): ((r, (s, e)) for r in genRange(s, e)))
candidates = (pairs
.join(pairs)
.map(lambda (k, ((s1, e1), (s2, e2))): (s1, e1, s2, e2))
.distinct())
candidates.filter(lambda (s1, e1, s2, e2): overlaps(s1, e1, s2, e2))
While it can be sufficient on some datasest for a production ready solution you should consider implementing some state-of-the-art algorithm like NCList.
- http://docs.datastax.com/en/datastax_enterprise/4.6/datastax_enterprise/spark/sparkSqlSupportedSyntax.html