Thursday, September 27, 2018

Serialising a RandomForestClassificationModel from PySpark to a SequenceFile on hdfs

Prior to Spark 2.0 the org.apache.spark.ml.classification.RandomForestClassificationModel doesn't have a save() method in it's Scala API as it doesn't implement the MLWritable interface.

If it did, then from PySpark we could easily call this like so:
lrModel._java_obj.save(model_path)
loaded_lrModel_JVM = sqlContext._jvm.org.apache.spark.ml.classification.LogisticRegressionModel.load(model_path)
loaded_lrModel = LogisticRegressionModel(loaded_lrModel_JVM)

(Note that this issue doesn't apply to the older deprecated org.apache.spark.mllib.tree.model.RandomForestModel from Spark MLlib which does have a save() method in v1.6)

This is a problem as my current client is constrained to using Spark v1.6

rdd.saveAsObjectFile() is an alternative way to serialise/deserialise a model using the Hadoop API to a SequenceFile.


Here is the relatively simple Scala approach:
// Save
sc.parallelize(Seq(model), 1).saveAsObjectFile("hdfs:///some/path/rfModel")

// Load
val rfModel = sc.objectFile[RandomForestClassificationModel]("hdfs:///some/path/rfModel").first()


Due to serialisation issues with Py4J the PySpark approach is more complex:
# Save
gateway = sc._gateway
java_list = gateway.jvm.java.util.ArrayList()
java_list.add(rfModel._java_obj)
modelRdd = sc._jsc.parallelize(java_list)
modelRdd.saveAsObjectFile("hdfs:///some/path/rfModel")

# Load
rfObjectFileLoaded = sc._jsc.objectFile("hdfs:///some/path/rfModel")
rfModelLoaded_JavaObject = rfObjectFileLoaded.first()
rfModelLoaded = RandomForestClassificationModel(rfModelLoaded_JavaObject)
predictions = rfModelLoaded.transform(test_input_df)



Reference source of RandomForestClassifier v1.6 vs. v2.2:
https://github.com/apache/spark/blob/v1.6.2/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
https://github.com/apache/spark/blob/v2.2.0/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Reference for MLWritable:
https://spark.apache.org/docs/1.6.2/api/java/org/apache/spark/ml/util/MLWritable.html
https://spark.apache.org/docs/2.0.0/api/java/org/apache/spark/ml/util/MLWritable.html

No comments:

Post a Comment