Lead and Lag using Spark Scala

Sometimes while processing data we will come across some situations where we need to find the difference of price a an item with yesterday's price or similar interview questions. For these cases we can use Lead and Lag functions.

Lead() - This function can be used to get the values of the rows that precede the current row.

Lag() - this function can be used to get the values of the rows that succeed the current row.

 

These functions are termed as non-aggregation functions because we can't perform any aggregation except to to form a new columns that will move above or below.

 

Let's how we can use these with a practical example..

 

Below is the data from a dataframe called df containing the prices on different dates of a particular product:

+-----+-------+------+----------+
|IT_ID|IT_Name| Price| PriceDate|
+-----+-------+------+----------+
|    1| KitKat|1000.0|2021-01-01|
|    1| KitKat|2000.0|2021-01-02|
|    1| KitKat|1000.0|2021-01-03|
|    1| KitKat|2000.0|2021-01-04|
|    1| KitKat|3000.0|2021-01-05|
|    1| KitKat|1000.0|2021-01-06|
+-----+-------+------+----------+

 

Let's perform lead and lag functions on the price column with date:

 

import org.apache.spark.sql.expressions.Window
val window = Window.orderBy("PriceDate") 
val lagCol = lag(col("Price"), 1).over(window)
val leadCol = lead(col("Price"), 1).over(window)
mdf.withColumn("LagCol", lagCol) 
   .withColumn("LeadCol", leadCol) 
   .show()


+-----+-------+------+----------+------+-------+
|IT_ID|IT_Name|Price | PriceDate|LagCol|LeadCol|
+-----+-------+------+----------+------+-------+
|    1| KitKat|1000.0|2021-01-01|  null| 2000.0|
|    1| KitKat|2000.0|2021-01-02|1000.0| 1000.0|
|    1| KitKat|1000.0|2021-01-03|2000.0| 2000.0|
|    1| KitKat|2000.0|2021-01-04|1000.0| 3000.0|
|    1| KitKat|3000.0|2021-01-05|2000.0| 1000.0|
|    1| KitKat|1000.0|2021-01-06|3000.0|   null|
+-----+-------+------+----------+------+-------+

In the above output, first value in the LagCol column is null because there won't be any previous row before first row and last value of LeadCol column is also null because there won'r be any last row after final record.

 

Now we want to find the difference between the price on each day with it's previous day, we can write a code as shown below:

val window = Window.orderBy("PriceDate")
val laggingCol = lag(col("Price"), 1).over(window)
val priceDifference = col("Price") - col("LastPrice")
val change = when(col("PriceDiff").isNull || col("PriceDiff")===(0), "SAME")
    .when(col("PriceDiff").>(0), "UP")
    .otherwise("DOWN")
df.withColumn("LastPrice", laggingCol)
    .withColumn("PriceDiff", priceDifference)
    .withColumn("Change", change).show()

+-----+-------+------+----------+---------+---------+------+
|IT_ID|IT_Name|Price | PriceDate|LastPrice|PriceDiff|Change|
+-----+-------+------+----------+---------+---------+------+
|    1| KitKat|1000.0|2021-01-01|     null|     null|  SAME|
|    1| KitKat|2000.0|2021-01-02|   1000.0|   1000.0|    UP|
|    1| KitKat|1000.0|2021-01-03|   2000.0|  -1000.0|  DOWN|
|    1| KitKat|2000.0|2021-01-04|   1000.0|   1000.0|    UP|
|    1| KitKat|3000.0|2021-01-05|   2000.0|   1000.0|    UP|
|    1| KitKat|1000.0|2021-01-06|   3000.0|  -2000.0|  DOWN|
+-----+-------+------+----------+------+------------+------+

 

If we observe above output we have calculated the difference between the price between current day and the previous day. And we have shown that trend as UP or DOWN. This is how we can use non-aggregation functions lead and lag.

Leave a Reply

Your email address will not be published. Required fields are marked *