Thursday, August 16, 2018

Calling Scala API methods from PySpark when using the spark.ml library

While implementing Logistic Regression on an older Spark 1.6 cluster I was surprised by how many Python API methods were missing so that the task of saving and loading a serialised model was unavailable.

However using the Py4j calls we can reach directly into the Spark Scala API.

Say we have a `pyspark.ml.classification.LogisticRegressionModel` object, we can call save like so:

lrModel._java_obj.save(model_path)

And load, which is different due being a static method:

loaded_lrModel_JVM = sqlContext._jvm.org.apache.spark.ml.classification.LogisticRegressionModel.load(model_path)

loaded_lrModel = LogisticRegressionModel(loaded_lrModel_JVM)

This helps future proof SparkML development since `spark.mllib` is effectively deprecated, and on any Spark 2.x upgrade there should be minimal breaking changes to the API.

Monday, August 6, 2018

Find maximum row per group in Spark DataFrame

This is a great Spark resource on getting the max row, thanks zero323!
https://stackoverflow.com/questions/35218882/find-maximum-row-per-group-in-spark-dataframe

Using join (it will result in more than one row in group in case of ties):
import pyspark.sql.functions as F
from pyspark.sql.functions import count, col 

cnts = df.groupBy("id_sa", "id_sb").agg(count("*").alias("cnt")).alias("cnts")
maxs = cnts.groupBy("id_sa").agg(F.max("cnt").alias("mx")).alias("maxs")

cnts.join(maxs, 
  (col("cnt") == col("mx")) & (col("cnts.id_sa") == col("maxs.id_sa"))
).select(col("cnts.id_sa"), col("cnts.id_sb"))
Using window functions (will drop ties):
from pyspark.sql.functions import row_number
from pyspark.sql.window import Window

w = Window().partitionBy("id_sa").orderBy(col("cnt").desc())

(cnts
  .withColumn("rn", row_number().over(w))
  .where(col("rn") == 1)
  .select("id_sa", "id_sb"))
Using struct ordering:
from pyspark.sql.functions import struct

(cnts
  .groupBy("id_sa")
  .agg(F.max(struct(col("cnt"), col("id_sb"))).alias("max"))
  .select(col("id_sa"), col("max.id_sb")))