PySpark DataFrame - Add Row Number via row_number() Function

Kontext Kontext event 2022-08-18 visibility 1,957
more_vert

Code description

In Spark SQL, row_number can be used to generate a series of sequential number starting from 1 for each record in the specified window. Examples can be found in this page:  Spark SQL - ROW_NUMBER Window Functions.

This code snippet provides the same approach to implement row_number directly using PySpark DataFrame APIs instead of Spark SQL. It created a window that partitions the data by ACCT attribute and sorts the records in each partition via TXN_DT column in descending order. The frame boundary of the window is defined as unbounded preceding and current row. 

Output:

+----+------+-------------------+
|ACCT|   AMT|             TXN_DT|
+----+------+-------------------+
| 101| 10.01|2021-01-01 00:00:00|
| 101|102.01|2021-01-01 00:00:00|
| 102|  93.0|2021-01-01 00:00:00|
| 103| 913.1|2021-01-02 00:00:00|
| 101|900.56|2021-01-03 00:00:00|
+----+------+-------------------+

+----+------+-------------------+------+
|ACCT|   AMT|             TXN_DT|rownum|
+----+------+-------------------+------+
| 101|900.56|2021-01-03 00:00:00|     1|
| 101| 10.01|2021-01-01 00:00:00|     2|
| 101|102.01|2021-01-01 00:00:00|     3|
| 102|  93.0|2021-01-01 00:00:00|     1|
| 103| 913.1|2021-01-02 00:00:00|     1|
+----+------+-------------------+------+

Code snippet

from pyspark.sql import SparkSession, Window
from datetime import datetime
from pyspark.sql.functions import row_number, desc

app_name = "PySpark row_number Window Function"
master = "local"

spark = SparkSession.builder \
    .appName(app_name) \
    .master(master) \
    .getOrCreate()

spark.sparkContext.setLogLevel("WARN")

data = [
    [101, 10.01, datetime.strptime('2021-01-01', '%Y-%m-%d')],
    [101, 102.01, datetime.strptime('2021-01-01', '%Y-%m-%d')],
    [102, 93.0, datetime.strptime('2021-01-01', '%Y-%m-%d')],
    [103, 913.1, datetime.strptime('2021-01-02', '%Y-%m-%d')],
    [101, 900.56, datetime.strptime('2021-01-03', '%Y-%m-%d')]
]

df = spark.createDataFrame(data, ['ACCT', 'AMT', 'TXN_DT'])
df.show()

window = Window.partitionBy('ACCT').orderBy(desc("TXN_DT")).rowsBetween(
    Window.unboundedPreceding, Window.currentRow)
df = df.withColumn('rownum', row_number().over(window))

df.show()
More from Kontext
comment Comments
No comments yet.

Please log in or register to comment.

account_circle Log in person_add Register

Log in with external accounts