Extract Value from XML Column in PySpark DataFrame
Spark doesn't provide a built-in function to extract value from XML string column in a DataFrame object. However we can use user defined function to extract value in PySpark. This article shows you how to implement that.
XML operations with Python
There are different Python packages can be used to read XML data. Refer to Read and Write XML Files with Python for more details. This article uses xml.etree.ElementTree
to extract values.
Sample DataFrame
Let's first create a sample DataFrame with a XML column type is StringType
).
from pyspark.sql import SparkSession appName = "Python Example - " master = "local" # Create Spark session spark = SparkSession.builder \ .appName(appName) \ .master(master) \ .getOrCreate() spark.sparkContext.setLogLevel('WARN') data = [ {'id': 1, 'data': """<test a="100" b="200"> <records> <record id="101" /> <record id="201" /> </records> </test> """}, {'id': 2, 'data': """<test a="200" b="400"> <records> <record id="202" /> <record id="402" /> </records> </test> """}] df = spark.createDataFrame(data) print(df.schema) df.show()
StructType(List(StructField(data,StringType,true),StructField(id,LongType,true)))
+--------------------+---+ | data| id| +--------------------+---+ |<test a="100" b="...| 1| |<test a="200" b="...| 2| +--------------------+---+
Create UDF
Now let's create an Python UDF. WE can directly use udf
decorator to mark the Python function as a UDF.
We need to first import ElementTree
:
import xml.etree.ElementTree as ET
Then we can use it to define a UDF:
# UDF to extract value @udf def extract_ab(xml): doc = ET.fromstring(xml) return [doc.attrib['a'], doc.attrib['b']] df = df.withColumn('ab', extract_ab(df['data'])) df.show()
The results looks like the following:
+--------------------+---+----------+ | data| id| ab| +--------------------+---+----------+ |<test a="100" b="...| 1|[100, 200]| |<test a="200" b="...| 2|[200, 400]| +--------------------+---+----------+
The added column ab
is StringType
(not ArrayType
).
Extract as array
If we want to extract all the id attribute of record element, we need to change our UDF definition.
<record id="101" />
Let's create another UDF function explicitly using the following code:
def extract_rid(xml): doc = ET.fromstring(xml) records = doc.findall('records/record') ids = [] for r in records: ids.append(int(r.attrib["id"])) return ids schema = ArrayType(IntegerType()) udf_extract_rid = udf(extract_rid, schema) df = df.withColumn('rids', udf_extract_rid(df["data"])) print(df.schema) df.show()
The returned type is defined as array of integers. Make sure ArrayType
and IntegerType
are imported.
from pyspark.sql.types import ArrayType, IntegerType
Run the script, it prints out the following:
StructType(List(StructField(data,StringType,true),StructField(id,LongType,true),StructField(ab,StringType,true),StructField(rids,ArrayType(IntegerType,true),true)))
+--------------------+---+----------+----------+
| data| id| ab| rids|
+--------------------+---+----------+----------+
|<test a="100" b="...| 1|[100, 200]|[101, 201]|
|<test a="200" b="...| 2|[200, 400]|[202, 402]|
+--------------------+---+----------+----------+
Explode array column
We can explode the array column using explode
built-in function:
df.withColumn('rid', explode(df['rids'])).show()
Now the script prints out the following content with a new column rid added:
+--------------------+---+----------+----------+---+ | data| id| ab| rids|rid| +--------------------+---+----------+----------+---+ |<test a="100" b="...| 1|[100, 200]|[101, 201]|101| |<test a="100" b="...| 1|[100, 200]|[101, 201]|201| |<test a="200" b="...| 2|[200, 400]|[202, 402]|202| |<test a="200" b="...| 2|[200, 400]|[202, 402]|402| +--------------------+---+----------+----------+---+
Complete script
The following is the complete script content:
from pyspark.sql.functions import udf, explode from pyspark.sql import SparkSession from pyspark.sql.types import ArrayType, IntegerType import xml.etree.ElementTree as ET appName = "Python Example - " master = "local" # Create Spark session spark = SparkSession.builder \ .appName(appName) \ .master(master) \ .getOrCreate() spark.sparkContext.setLogLevel('WARN') data = [ {'id': 1, 'data': """<test a="100" b="200"> <records> <record id="101" /> <record id="201" /> </records> </test> """}, {'id': 2, 'data': """<test a="200" b="400"> <records> <record id="202" /> <record id="402" /> </records> </test> """}] df = spark.createDataFrame(data) print(df.schema) df.show() # UDF to extract value @udf def extract_ab(xml): doc = ET.fromstring(xml) return [doc.attrib['a'], doc.attrib['b']] df = df.withColumn('ab', extract_ab(df['data'])) df.show() def extract_rid(xml): doc = ET.fromstring(xml) records = doc.findall('records/record') ids = [] for r in records: ids.append(int(r.attrib["id"])) return ids schema = ArrayType(IntegerType()) udf_extract_rid = udf(extract_rid, schema) df = df.withColumn('rids', udf_extract_rid(df["data"])) print(df.schema) df.show() df.withColumn('rid', explode(df['rids'])).show()