Spark's StringIndexer is quite useful, but it's common to need to retrieve the correspondences between the generated index values and the original strings, and it seems like there should be a built-in way to accomplish this. I'll illustrate using this simple example from the Spark documentation:
from pyspark.ml.feature import StringIndexer
df = sqlContext.createDataFrame(
[(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")],
["id", "category"])
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
indexed_df = indexer.fit(df).transform(df)
This simplified case gives us:
+---+--------+-------------+
| id|category|categoryIndex|
+---+--------+-------------+
| 0| a| 0.0|
| 1| b| 2.0|
| 2| c| 1.0|
| 3| a| 0.0|
| 4| a| 0.0|
| 5| c| 1.0|
+---+--------+-------------+
All fine and dandy, but for many use cases I want to know the mapping between my original strings and the index labels. The simplest way I can think to do this off hand is something like this:
In [8]: indexed.select('category','categoryIndex').distinct().show()
+--------+-------------+
|category|categoryIndex|
+--------+-------------+
| b| 2.0|
| c| 1.0|
| a| 0.0|
+--------+-------------+
The result of which I could store as a dictionary or similar if I wanted:
In [12]: mapping = {row.categoryIndex:row.category for row in
indexed.select('category','categoryIndex').distinct().collect()}
In [13]: mapping
Out[13]: {0.0: u'a', 1.0: u'c', 2.0: u'b'}
My question is this: Since this is such a common task, and I'm guessing (but could of course be wrong) that the string indexer is somehow storing this mapping anyway, is there a way to accomplish the above task more simply?
My solution is more or less straightforward, but for large data structures this involves a bunch of extra computation that (perhaps) I can avoid. Ideas?