Extract Value from XML Column in PySpark DataFrame

Kontext Kontext event 2022-07-15 visibility 6,527
more_vert

Spark doesn't provide a built-in function to extract value from XML string column in a DataFrame object. However we can use user defined function to extract value in PySpark. This article shows you how to implement that.

XML operations with Python

There are different Python packages can be used to read XML data. Refer to Read and Write XML Files with Python for more details. This article uses xml.etree.ElementTree to extract values.

Sample DataFrame

Let's first create a sample DataFrame with a XML column type is StringType).

from pyspark.sql import SparkSession

appName = "Python Example - "
master = "local"

# Create Spark session
spark = SparkSession.builder \
    .appName(appName) \
    .master(master) \
    .getOrCreate()

spark.sparkContext.setLogLevel('WARN')

data = [
    {'id': 1, 'data': """<test a="100" b="200">
        <records>
            <record id="101" />
            <record id="201" />
        </records>
    </test>
"""},
    {'id': 2, 'data': """<test a="200" b="400">
        <records>
            <record id="202" />
            <record id="402" />
        </records>
    </test>
"""}]

df = spark.createDataFrame(data)
print(df.schema)
df.show()
The above script will create a DataFrame with the following schema:
StructType(List(StructField(data,StringType,true),StructField(id,LongType,true)))
It has two columns and one is a XML string:
+--------------------+---+
|                data| id|
+--------------------+---+
|<test a="100" b="...|  1|
|<test a="200" b="...|  2|
+--------------------+---+

Create UDF

Now let's create an Python UDF. WE can directly use udf decorator to mark the Python function as a UDF.

We need to first import ElementTree:

import xml.etree.ElementTree as ET

Then we can use it to define a UDF:

# UDF to extract value
@udf
def extract_ab(xml):
    doc = ET.fromstring(xml)
    return [doc.attrib['a'], doc.attrib['b']]

df = df.withColumn('ab', extract_ab(df['data']))
df.show()

The results looks like the following:

+--------------------+---+----------+
|                data| id|        ab|
+--------------------+---+----------+
|<test a="100" b="...|  1|[100, 200]|
|<test a="200" b="...|  2|[200, 400]|
+--------------------+---+----------+

The added column ab is StringType (not ArrayType).

Extract as array

If we want to extract all the id attribute of record element, we need to change our UDF definition. 

<record id="101" />

Let's create another UDF function explicitly using the following code:

def extract_rid(xml):
    doc = ET.fromstring(xml)
    records = doc.findall('records/record')
    ids = []
    for r in records:
        ids.append(int(r.attrib["id"]))
    return ids


schema = ArrayType(IntegerType())
udf_extract_rid = udf(extract_rid, schema)

df = df.withColumn('rids', udf_extract_rid(df["data"]))
print(df.schema)
df.show()

The returned type is defined as array of integers. Make sure ArrayType and IntegerType are imported.

from pyspark.sql.types import ArrayType, IntegerType

Run the script, it prints out the following:

StructType(List(StructField(data,StringType,true),StructField(id,LongType,true),StructField(ab,StringType,true),StructField(rids,ArrayType(IntegerType,true),true)))
+--------------------+---+----------+----------+
|                data| id|        ab|      rids|
+--------------------+---+----------+----------+
|<test a="100" b="...|  1|[100, 200]|[101, 201]|
|<test a="200" b="...|  2|[200, 400]|[202, 402]|
+--------------------+---+----------+----------+

Explode array column

We can explode the array column using explode built-in function:

df.withColumn('rid', explode(df['rids'])).show()

Now the script prints out the following content with a new column rid added:

+--------------------+---+----------+----------+---+
|                data| id|        ab|      rids|rid|
+--------------------+---+----------+----------+---+
|<test a="100" b="...|  1|[100, 200]|[101, 201]|101|
|<test a="100" b="...|  1|[100, 200]|[101, 201]|201|
|<test a="200" b="...|  2|[200, 400]|[202, 402]|202|
|<test a="200" b="...|  2|[200, 400]|[202, 402]|402|
+--------------------+---+----------+----------+---+

Complete script

The following is the complete script content:

from pyspark.sql.functions import udf, explode
from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, IntegerType
import xml.etree.ElementTree as ET


appName = "Python Example - "
master = "local"

# Create Spark session
spark = SparkSession.builder \
    .appName(appName) \
    .master(master) \
    .getOrCreate()

spark.sparkContext.setLogLevel('WARN')

data = [
    {'id': 1, 'data': """<test a="100" b="200">
        <records>
            <record id="101" />
            <record id="201" />
        </records>
    </test>
"""},
    {'id': 2, 'data': """<test a="200" b="400">
        <records>
            <record id="202" />
            <record id="402" />
        </records>
    </test>
"""}]

df = spark.createDataFrame(data)
print(df.schema)
df.show()

# UDF to extract value


@udf
def extract_ab(xml):
    doc = ET.fromstring(xml)
    return [doc.attrib['a'], doc.attrib['b']]


df = df.withColumn('ab', extract_ab(df['data']))
df.show()


def extract_rid(xml):
    doc = ET.fromstring(xml)
    records = doc.findall('records/record')
    ids = []
    for r in records:
        ids.append(int(r.attrib["id"]))
    return ids


schema = ArrayType(IntegerType())
udf_extract_rid = udf(extract_rid, schema)

df = df.withColumn('rids', udf_extract_rid(df["data"]))
print(df.schema)
df.show()

df.withColumn('rid', explode(df['rids'])).show()
More from Kontext
comment Comments
No comments yet.

Please log in or register to comment.

account_circle Log in person_add Register

Log in with external accounts