In Spark, we can create user defined functions to convert a column to a StructType
. This article shows you how to flatten or explode a *StructType
*column to multiple columns using Spark SQL.
Create a DataFrame with complex data type
Let's first create a DataFrame using the following script:
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
appName = "PySpark Example - Flatten Struct Type"
master = "local"
# Create Spark session
spark = SparkSession.builder \
.appName(appName) \
.master(master) \
.getOrCreate()
spark.sparkContext.setLogLevel("WARN")
data = ['a|100', 'b|200', 'c|300']
df_raw = spark.createDataFrame(data, StringType())
print(df_raw.schema)
df_raw.show()
# Define UDF to convert raw string to object.
def parse_text(text):
"""
Function to parse str
"""
parts = text.split('|')
return (parts[0], int(parts[1]))
# Schema
schema = StructType([
StructField("category", StringType(), False),
StructField("count", IntegerType(), False)
])
# Define a UDF
my_udf = udf(parse_text, schema)
# Now use UDF to transform Spark DataFrame
df = df_raw.withColumn('log', my_udf(df_raw['value']))
print(df.schema)
df.show()
The following are the output from running the above script:
StructType([StructField('value', StringType(), True)])
+-----+
|value|
+-----+
|a|100|
|b|200|
|c|300|
+-----+
StructType([StructField('value', StringType(), True), StructField('cat', StructType([StructField('category', StringType(), False), StructField('count', IntegerType(), False)]), True)])
+-----+--------+
|value| cat|
+-----+--------+
|a|100|{a, 100}|
|b|200|{b, 200}|
|c|300|{c, 300}|
+-----+--------+
As we can tell, the Spark DataFrame is created with the following schema:
StructType([StructField('value', StringType(), True), StructField('cat', StructType([StructField('category', StringType(), False), StructField('count', IntegerType(), False)]), True)])
For column/field cat
, the type is StructType
.
Flatten or explode StructType
Now we can simply add the following code to explode or flatten column log
.
# Flatten
df = df.select("value", 'cat.*')
print(df.schema)
df.show()
The approach is to use [column name].* in select
function.
The output looks like the following:
StructType([StructField('value', StringType(), True), StructField('category', StringType(), True), StructField('count', IntegerType(), True)])
+-----+--------+-----+
|value|category|count|
+-----+--------+-----+
|a|100| a| 100|
|b|200| b| 200|
|c|300| c| 300|
+-----+--------+-----+
Now we've successfully flattened column cat from complex *StructType
*to columns of simple types.
If you want to drop the original column, refer to Delete or Remove Columns from PySpark DataFrame.