Convert PySpark Row List to Pandas Data Frame
insights Stats
Apache Spark installation guides, performance tuning tips, general tutorials, etc.
*Spark logo is a registered trademark of Apache Spark.
In Spark, it’s easy to convert Spark Dataframe to Pandas dataframe through one line of code:
df_pd = df.toPandas()
In this page, I am going to show you how to convert a list of PySpark row objects to a Pandas data frame.
Prepare the data frame
The following code snippets create a data frame with schema as:
root
|-- Category: string (nullable = false)
|-- ItemID: integer (nullable = false)
|-- Amount: decimal(10,2) (nullable = true)
from pyspark.sql import SparkSession from pyspark.sql.functions import collect_list,struct from pyspark.sql.types import ArrayType, StructField, StructType, StringType, IntegerType, DecimalType from decimal import Decimal import pandas as pd appName = "Python Example - PySpark Row List to Pandas Data Frame" master = "local" # Create Spark session spark = SparkSession.builder \ .appName(appName) \ .master(master) \ .getOrCreate() # List data = [('Category A', 1, Decimal(12.40)), ('Category B', 2, Decimal(30.10)), ('Category C', 3, Decimal(100.01)), ('Category A', 4, Decimal(110.01)), ('Category B', 5, Decimal(70.85)) ] # Create a schema for the dataframe schema = StructType([ StructField('Category', StringType(), False), StructField('ItemID', IntegerType(), False), StructField('Amount', DecimalType(scale=2), True) ]) # Convert list to RDD rdd = spark.sparkContext.parallelize(data) # Create data frame df = spark.createDataFrame(rdd, schema) df.printSchema() df.show() df_pd = df.toPandas() df_pd.info()
The above code convert a list to Spark data frame first and then convert it to a Pandas data frame.
The information of the Pandas data frame looks like the following:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5 entries, 0 to 4
Data columns (total 3 columns):
Category 5 non-null object
ItemID 5 non-null int32
Amount 5 non-null object
dtypes: int32(1), object(2)
memory usage: 172.0+ bytes
Aggregate the data frame
It’s very common to do aggregations in Spark. For example, the following code snippet groups the above Spark data frame by category attribute.
# Aggregate but still keep all the raw attributes df_agg = df.groupby("Category").agg(collect_list(struct("*")).alias('Items')) df_agg.printSchema()
The schema of the new Spark data frame have two attributes: Category and Items.
root
|-- Category: string (nullable = false)
|-- Items: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- Category: string (nullable = false)
| | |-- ItemID: integer (nullable = false)
| | |-- Amount: decimal(10,2) (nullable = true)
The Items attribute is an array or list of pyspark.sql.Row object.
Convert pyspark.sql.Row list to Pandas data frame
Now we can convert the Items attribute using foreach function.
def to_pandas(row): print('Create a pandas data frame for category: ' + row["Category"]) items = [item.asDict() for item in row["Items"]] df_pd_items = pd.DataFrame(items) print(df_pd_items) # Convert Items for each Category to a pandas dataframe df_agg.foreach(to_pandas)
In the above code snippet, Row list is converted to as dictionary list first and then the list is converted to pandas data frame using pd.DateFrame function. As the list element is dictionary object which has keys, we don’t need to specify columns argument for pd.DataFrame function.