I'm trying to compute the sum of node values in a spark graphx graph. In short the graph is a tree and the top node (root) should sum all children and their children. My graph is actually a tree that looks like this and the expected summed value should be 1850:
+----+
+---------------> | VertexID 14
| | | Value: 1000
+---+--+ +----+
+------------> | VertexId 11
| | | Value: +----+
| +------+ Sum of 14 & 24 | VertexId 24
+---++ +--------------> | Value: 550
| | VertexId 20 +----+
| | Value:
+----++Sum of 11 & 911
|
| +-----+
+-----------> | VertexId 911
| | Value: 300
+-----+
The first stab at this looks like this:
val vertices: RDD[(VertexId, Int)] =
sc.parallelize(Array((20L, 0)
, (11L, 0)
, (14L, 1000)
, (24L, 550)
, (911L, 300)
))
//note that the last value in the edge is for factor (positive or negative)
val edges: RDD[Edge[Int]] =
sc.parallelize(Array(
Edge(14L, 11L, 1),
Edge(24L, 11L, 1),
Edge(11L, 20L, 1),
Edge(911L, 20L, 1)
))
val dataItemGraph = Graph(vertices, edges)
val sum: VertexRDD[(Int, BigDecimal, Int)] = dataItemGraph.aggregateMessages[(Int, BigDecimal, Int)](
sendMsg = { triplet => triplet.sendToDst(1, triplet.srcAttr, 1) },
mergeMsg = { (a, b) => (a._1, a._2 * a._3 + b._2 * b._3, 1) }
)
sum.collect.foreach(println)
This returns the following:
(20,(1,300,1))
(11,(1,1550,1))
It's doing the sum for vertex 11 but it's not rolling up to the root node (vertex 20). What am I missing or is there a better way of doing this? Of course the tree can be of arbitrary size and each vertex can have an arbitrary number of children edges.