PySpark - Flatten (Explode) Nested StructType Column

2022-07-09 pysparkspark-sql

In Spark, we can create user defined functions to convert a column to a StructType. This article shows you how to flatten or explode a *StructType *column to multiple columns using Spark SQL.

Create a DataFrame with complex data type

Let's first create a DataFrame using the following script:

from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

appName = "PySpark Example - Flatten Struct Type"
master = "local"

# Create Spark session
spark = SparkSession.builder \
    .appName(appName) \
    .master(master) \
    .getOrCreate()

spark.sparkContext.setLogLevel("WARN")

data = ['a|100', 'b|200', 'c|300']
df_raw = spark.createDataFrame(data, StringType())
print(df_raw.schema)
df_raw.show()

# Define UDF to convert raw string to object.

def parse_text(text):
    """
    Function to parse str
    """
    parts = text.split('|')
    return (parts[0], int(parts[1]))

# Schema
schema = StructType([
    StructField("category", StringType(), False),
    StructField("count", IntegerType(), False)
])

# Define a UDF
my_udf = udf(parse_text, schema)

# Now use UDF to transform Spark DataFrame

df = df_raw.withColumn('log', my_udf(df_raw['value']))
print(df.schema)
df.show()

The following are the output from running the above script:

StructType([StructField('value', StringType(), True)])
+-----+
|value|
+-----+
|a|100|
|b|200|
|c|300|
+-----+

StructType([StructField('value', StringType(), True), StructField('cat', StructType([StructField('category', StringType(), False), StructField('count', IntegerType(), False)]), True)])
+-----+--------+
|value|     cat|
+-----+--------+
|a|100|{a, 100}|
|b|200|{b, 200}|
|c|300|{c, 300}|
+-----+--------+

As we can tell, the Spark DataFrame is created with the following schema:

StructType([StructField('value', StringType(), True), StructField('cat', StructType([StructField('category', StringType(), False), StructField('count', IntegerType(), False)]), True)])

For column/field cat, the type is StructType.

Flatten or explode StructType

Now we can simply add the following code to explode or flatten column log.

# Flatten
df = df.select("value", 'cat.*')
print(df.schema)
df.show()

The approach is to use [column name].* in select function.

The output looks like the following:

StructType([StructField('value', StringType(), True), StructField('category', StringType(), True), StructField('count', IntegerType(), True)])
+-----+--------+-----+
|value|category|count|
+-----+--------+-----+
|a|100|       a|  100|
|b|200|       b|  200|
|c|300|       c|  300|
+-----+--------+-----+

Now we've successfully flattened column cat from complex *StructType *to columns of simple types.

If you want to drop the original column, refer to Delete or Remove Columns from PySpark DataFrame.