PySpark - Flatten (Explode) Nested StructType Column
In Spark, we can create user defined functions to convert a column to a StructType
. This article shows you how to flatten or explode a StructType
column to multiple columns using Spark SQL.
Create a DataFrame with complex data type
Let's first create a DataFrame using the following script:
from pyspark.sql import SparkSession from pyspark.sql.functions import udf from pyspark.sql.types import StructType, StructField, StringType, IntegerType appName = "PySpark Example - Flatten Struct Type" master = "local" # Create Spark session spark = SparkSession.builder \ .appName(appName) \ .master(master) \ .getOrCreate() spark.sparkContext.setLogLevel("WARN") data = ['a|100', 'b|200', 'c|300'] df_raw = spark.createDataFrame(data, StringType()) print(df_raw.schema) df_raw.show() # Define UDF to convert raw string to object. def parse_text(text): """ Function to parse str """ parts = text.split('|') return (parts[0], int(parts[1])) # Schema schema = StructType([ StructField("category", StringType(), False), StructField("count", IntegerType(), False) ]) # Define a UDF my_udf = udf(parse_text, schema) # Now use UDF to transform Spark DataFrame df = df_raw.withColumn('log', my_udf(df_raw['value'])) print(df.schema) df.show()
StructType([StructField('value', StringType(), True)]) +-----+ |value| +-----+ |a|100| |b|200| |c|300| +-----+ StructType([StructField('value', StringType(), True), StructField('cat', StructType([StructField('category', StringType(), False), StructField('count', IntegerType(), False)]), True)]) +-----+--------+ |value| cat| +-----+--------+ |a|100|{a, 100}| |b|200|{b, 200}| |c|300|{c, 300}| +-----+--------+
As we can tell, the Spark DataFrame is created with the following schema:
StructType([StructField('value', StringType(), True), StructField('cat', StructType([StructField('category', StringType(), False), StructField('count', IntegerType(), False)]), True)])
For column/field cat
, the type is StructType
.
Flatten or explode StructType
Now we can simply add the following code to explode or flatten column log
.
# Flatten df = df.select("value", 'cat.*') print(df.schema) df.show()
The approach is to use [column name].* in select
function.
The output looks like the following:
StructType([StructField('value', StringType(), True), StructField('category', StringType(), True), StructField('count', IntegerType(), True)]) +-----+--------+-----+ |value|category|count| +-----+--------+-----+ |a|100| a| 100| |b|200| b| 200| |c|300| c| 300| +-----+--------+-----+
Now we've successfully flattened column cat from complex StructType
to columns of simple types.
If you want to drop the original column, refer to Delete or Remove Columns from PySpark DataFrame.