Spark doesn't provide a built-in function to extract value from XML string column in a DataFrame object. However we can use user defined function to extract value in PySpark. This article shows you how to implement that.
XML operations with Python
There are different Python packages can be used to read XML data. Refer to Read and Write XML Files with Python for more details. This article uses xml.etree.ElementTree
to extract values.
Sample DataFrame
Let's first create a sample DataFrame with a XML column type is StringType
).
from pyspark.sql import SparkSession
appName = "Python Example - "
master = "local"
# Create Spark session
spark = SparkSession.builder \
.appName(appName) \
.master(master) \
.getOrCreate()
spark.sparkContext.setLogLevel('WARN')
data = [
{'id': 1, 'data': """<test a="100" b="200">
<records>
<record id="101" />
<record id="201" />
</records>
</test>
"""},
{'id': 2, 'data': """<test a="200" b="400">
<records>
<record id="202" />
<record id="402" />
</records>
</test>
"""}]
df = spark.createDataFrame(data)
print(df.schema)
df.show()
The above script will create a DataFrame with the following schema:
StructType(List(StructField(data,StringType,true),StructField(id,LongType,true)))
It has two columns and one is a XML string:
+--------------------+---+
| data| id|
+--------------------+---+
|<test a="100" b="...| 1|
|<test a="200" b="...| 2|
+--------------------+---+
Create UDF
Now let's create an Python UDF. WE can directly use udf
decorator to mark the Python function as a UDF.
We need to first import ElementTree
:
import xml.etree.ElementTree as ET
Then we can use it to define a UDF:
# UDF to extract value
@udf
def extract_ab(xml):
doc = ET.fromstring(xml)
return [doc.attrib['a'], doc.attrib['b']]
df = df.withColumn('ab', extract_ab(df['data']))
df.show()
The results looks like the following:
+--------------------+---+----------+
| data| id| ab|
+--------------------+---+----------+
|<test a="100" b="...| 1|[100, 200]|
|<test a="200" b="...| 2|[200, 400]|
+--------------------+---+----------+
The added column ab
is StringType
(not ArrayType
).
Extract as array
If we want to extract all the id attribute of record element, we need to change our UDF definition.
<record id="101" />
Let's create another UDF function explicitly using the following code:
def extract_rid(xml):
doc = ET.fromstring(xml)
records = doc.findall('records/record')
ids = []
for r in records:
ids.append(int(r.attrib["id"]))
return ids
schema = ArrayType(IntegerType())
udf_extract_rid = udf(extract_rid, schema)
df = df.withColumn('rids', udf_extract_rid(df["data"]))
print(df.schema)
df.show()
The returned type is defined as array of integers. Make sure ArrayType
and IntegerType
are imported.
from pyspark.sql.types import ArrayType, IntegerType
Run the script, it prints out the following:
StructType(List(StructField(data,StringType,true),StructField(id,LongType,true),StructField(ab,StringType,true),StructField(rids,ArrayType(IntegerType,true),true)))
+--------------------+---+----------+----------+| data| id| ab| rids|+--------------------+---+----------+----------+|<test a="100" b="...| 1|[100, 200]|[101, 201]||<test a="200" b="...| 2|[200, 400]|[202, 402]|+--------------------+---+----------+----------+
Explode array column
We can explode the array column using explode
built-in function:
df.withColumn('rid', explode(df['rids'])).show()
Now the script prints out the following content with a new column rid added:
+--------------------+---+----------+----------+---+
| data| id| ab| rids|rid|
+--------------------+---+----------+----------+---+
|<test a="100" b="...| 1|[100, 200]|[101, 201]|101|
|<test a="100" b="...| 1|[100, 200]|[101, 201]|201|
|<test a="200" b="...| 2|[200, 400]|[202, 402]|202|
|<test a="200" b="...| 2|[200, 400]|[202, 402]|402|
+--------------------+---+----------+----------+---+
Complete script
The following is the complete script content:
from pyspark.sql.functions import udf, explode
from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, IntegerType
import xml.etree.ElementTree as ET
appName = "Python Example - "
master = "local"
# Create Spark session
spark = SparkSession.builder \
.appName(appName) \
.master(master) \
.getOrCreate()
spark.sparkContext.setLogLevel('WARN')
data = [
{'id': 1, 'data': """<test a="100" b="200">
<records>
<record id="101" />
<record id="201" />
</records>
</test>
"""},
{'id': 2, 'data': """<test a="200" b="400">
<records>
<record id="202" />
<record id="402" />
</records>
</test>
"""}]
df = spark.createDataFrame(data)
print(df.schema)
df.show()
# UDF to extract value
@udf
def extract_ab(xml):
doc = ET.fromstring(xml)
return [doc.attrib['a'], doc.attrib['b']]
df = df.withColumn('ab', extract_ab(df['data']))
df.show()
def extract_rid(xml):
doc = ET.fromstring(xml)
records = doc.findall('records/record')
ids = []
for r in records:
ids.append(int(r.attrib["id"]))
return ids
schema = ArrayType(IntegerType())
udf_extract_rid = udf(extract_rid, schema)
df = df.withColumn('rids', udf_extract_rid(df["data"]))
print(df.schema)
df.show()
df.withColumn('rid', explode(df['rids'])).show()