Thursday, August 16, 2018

Calling Scala API methods from PySpark when using the spark.ml library

While implementing Logistic Regression on an older Spark 1.6 cluster I was surprised by how many Python API methods were missing so that the task of saving and loading a serialised model was unavailable.

However using the Py4j calls we can reach directly into the Spark Scala API.

Say we have a `pyspark.ml.classification.LogisticRegressionModel` object, we can call save like so:

lrModel._java_obj.save(model_path)

And load, which is different due being a static method:

loaded_lrModel_JVM = sqlContext._jvm.org.apache.spark.ml.classification.LogisticRegressionModel.load(model_path)

loaded_lrModel = LogisticRegressionModel(loaded_lrModel_JVM)

This helps future proof SparkML development since `spark.mllib` is effectively deprecated, and on any Spark 2.x upgrade there should be minimal breaking changes to the API.

No comments:

Post a Comment