I am trying to write a Scala function that can infer Spark DataTypes based on a provided input string:
/**
* Example:
* ========
* toSparkType("string") => StringType
* toSparkType("boolean") => BooleanType
* toSparkType("date") => DateType
* etc.
*/
def toSparkType(inputType : String) : DataType = {
var dt : DataType = null
if(matchesStringRegex(inputType)) {
dt = StringType
} else if(matchesBooleanRegex(inputType)) {
dt = BooleanType
} else if(matchesDateRegex(inputType)) {
dt = DateType
} else if(...) {
...
}
dt
}
My goal is to support a large subset, if not all, of the available DataTypes
. As I started implementing this function, I got to thinking: "Spark/Scala probably already have a helper/util method that will do this for me." After all, I know I can do something like:
var structType = new StructType()
structType.add("some_new_string_col", "string", true, Metadata.empty)
structType.add("some_new_boolean_col", "boolean", true, Metadata.empty)
structType.add("some_new_date_col", "date", true, Metadata.empty)
And either Scala and/or Spark will implicitly convert my "string"
argument to StringType
, etc. So I ask: what magic can I do with either Spark or Scala to help me implement my converter method?
Spark/Scala probably already have a helper/util method that will do this for me.
You're right. Spark already has its own schema and data type inference code that it uses to infer the schema from underlying data sources (csv, json etc.) So you can look at that to implement your own (the actual implementation is marked private to Spark and is tied to RDD and internal classes, so it can not be used directly from code outside of Spark but should give you a good idea on how to go about it.)
Given that csv is flat type (and json can have nested structure), csv schema inference is relative more straight forward and should help you with the task you're trying to achieve above. So I will explain how csv inference works (json inference just needs to take possibly nested structure into account but data type inference is pretty analogous).
With that prologue, the thing you want to have a look at is CSVInferSchema object. Particularly, look at the infer
method which takes an RDD[Array[String]]
and infer the data type for each element of the array across the whole of RDD. The way it does is -- it marks each field as NullType
to begin with and then as it iterates over next row of values (Array[String]
) in the RDD
it updates the already inferred DataType
to a new DataType
if the new DataType
is more specific. This is happening here:
val rootTypes: Array[DataType] =
tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)
Now inferRowType
calls inferField
for each of the field in the row. inferField
implementation is what you're probably looking for -- it takes type inferred so far for a particular field and the string value of the field for current row as parameter. It then returns either the existing inferred type or if the new type inferred is more specific then the new type.
Relevant section of the code is as follows:
typeSoFar match {
case NullType => tryParseInteger(field, options)
case IntegerType => tryParseInteger(field, options)
case LongType => tryParseLong(field, options)
case _: DecimalType => tryParseDecimal(field, options)
case DoubleType => tryParseDouble(field, options)
case TimestampType => tryParseTimestamp(field, options)
case BooleanType => tryParseBoolean(field, options)
case StringType => StringType
case other: DataType =>
throw new UnsupportedOperationException(s"Unexpected data type $other")
}
Please note that if the typeSoFar
is NullType then it first tries to parse it as Integer
but tryParseInteger
call is a chain of call to lower type parsing. So if it is not able to parse the value as Integer then it will invoke tryParseLong
which on failure will invoke tryParseDecimal
which on failure will invoke tryParseDouble
w.o.f.w.i. tryParseTimestamp
w.o.f.w.i tryParseBoolean
w.o.f.w.i. finally stringType
.
So you can use pretty much the similar logic to implement whatever your use case is. (If you do not need to merge across rows then you simply implement all the tryParse*
methods verbatim and simply invoke tryParseInteger
. No need to write your own regex.)
Hope this helps.
Yes, of course Spark has magic you need.
In Spark 2.x it's CatalystSqlParser
object, defined here.
For example:
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
CatalystSqlParser.parseDataType("string") // StringType
CatalystSqlParser.parseDataType("int") // IntegerType
And so on.
But as I understand, it's not a part of public API and so may change in next versions without any warnings.
So you may just implement your method as:
def toSparkType(inputType: String): DataType = CatalystSqlParser.parseDataType(inputType)
From scala, it doesn't seem you can do what you wish magically, check for instance this example:
import com.scalakata._
@instrument class Playground {
val x = 5
def f[T](v: T) = v
f(x)
val y = "boolean"
f(y)
def manOf[T: Manifest](t: T): Manifest[T] = manifest[T]
println(manOf(y))
}
which I composed after reading I want to get the type of a variable at runtime.
Now from spark, since I do not have an installation in place right now, I couldn't compose an example, but there is nothing obvious to use, so I would suggest you to continue writing toSparkType()
as you have started, but take a look at the Source code for pyspark.sql.types first.
You see the problem is that you are always passing a string.
If you have String Literals written as DataType name, ie. "StringType" , "IntegerType" -
use this function -
def StrtoDatatype(str: String): org.apache.spark.sql.types.DataType = {
val m = ru.runtimeMirror(getClass.getClassLoader)
val module = m.staticModule(s"org.apache.spark.sql.types.$str")
m.reflectModule(module).instance.asInstanceOf[org.apache.spark.sql.types.DataType]
}
If you have string literals as - string, int etc..
def sqlStrtoDatatype(str: String): org.apache.spark.sql.types.DataType = {
CatalystSqlParser.parseDataType(str)
}