I have two files in which I store:
- an IP range - country lookup
- a list of requests coming from different IPs
The IPs are stored as integers (using inet_aton()).
I tried using Spark SQL to join these pieces of data by loading both files into dataframes and registering them as temp tables.
GeoLocTable - ipstart, ipend, ...additional Geo location data
Recordstable - INET_ATON, ...3 more fields
I tried using Spark SQL to join these pieces of data using a SQL statement like so -
"select a.*, b.* from Recordstable a left join GeoLocTable b on a.INET_ATON between b.ipstart and b.ipend"
There are about 850K records in RecordsTable and about 2.5M records in GeoLocTable. The join as it exists runs for about 2 hours with about 20 executors.
I have tried caching and broadcasting the GeoLocTable but it does not really seem to help. I have bumped up spark.sql.autoBroadcastJoinThreshold=300000000 and spark.sql.shuffle.partitions=600.
Spark UI shows a BroadcastNestedLoopJoin being performed. Is this the best I should be expecting? I tried searching for conditions where this type of join would be performed but the documentation seems sparse.
PS - I am using PySpark to work with Spark.
The source of the problem is pretty simple. When you execute join and join condition is not equality based the only thing that Spark can do right now is expand it to Cartesian product followed by filter what is pretty much what happens inside BroadcastNestedLoopJoin
. So logically you have this huge nested loop which tests all 850K * 2.5M records.
This approach is obviously extremely inefficient. Since it looks like lookup table fits into memory the simplest improvement is to use local, sorted data structure instead of Spark DataFrame
. Assuming your data looks like this:
geo_loc_table = sc.parallelize([
(1, 10, "foo"), (11, 36, "bar"), (37, 59, "baz"),
]).toDF(["ipstart", "ipend", "loc"])
records_table = sc.parallelize([
(1, 11), (2, 38), (3, 50)
]).toDF(["id", "inet"])
We can project and sort reference data by ipstart
and create broadcast variable:
geo_start_bd = sc.broadcast(geo_loc_table
.select("ipstart")
.orderBy("ipstart")
.flatMap(lambda x: x)
.collect())
Next we'll use an UDF and bisect module to augment records_table
from bisect import bisect_right
from pyspark.sql.functions import udf
from pyspark.sql.types import LongType
# https://docs.python.org/3/library/bisect.html#searching-sorted-lists
def find_le(x):
'Find rightmost value less than or equal to x'
i = bisect_right(geo_start_bd.value, x)
if i:
return geo_start_bd.value[i-1]
return None
records_table_with_ipstart = records_table.withColumn(
"ipstart", udf(find_le, LongType())("inet")
)
and finally join both datasets:
records_table_with_ipstart.join(geo_loc_table, ["ipstart"], "left")