Data Wrangling with PySpark
Performing Data Wrangling with PySpark¶
I recently completed PySpark Essentials for Data Scientists (Big Data + Python) by Layla AI on udemy.com, course link here!, but I never had time to showcase my newly acquired PySpark skills. However, I got to work today, and herein is a sneak peek into Data wrangling with PySpark.
Our objective is to highlight some of the primary data manipulation techniques within PySpark. And as usual, we're using our favorite dataset - South Sudan 2008 Census Data.
We begin by importing pyspark and SparkSession from the pyspark.sql module and then initializing the SparkSession
. This article assumes that you have already installed PySpark with the necessary dependencies; otherwise, please see the installation instructions here - we recommend installing PySpark through anaconda, especially if you are new to Python.
Importing the Modules¶
# Import pyspark and SparkSession
import pyspark
from pyspark.sql import SparkSession
# Initialize PySpark session
spark = SparkSession.builder.appName("data wrangling").getOrCreate()
# Add the core usage information
cores = spark._jsc.sc().getExecutorMemoryStatus().keySet().size()
print("You are working with", cores, "core(s)")
# Display the spark session information
spark
Importing the Data¶
Here, we import the dataset by creating the path within the current directory. Next, we combine the path with the file name, and then we tell PySpark
to infer column data types, and we indicate that our dataset contains column names.
# Import the dataset
path = 'Datasets/'
ss_2008_data_raw = spark.read.csv(path + 'ss_2008_census_data.csv', inferSchema = True, header = True)
We display the first 10 rows of our dataset with the show() method. And we also indicate that we want PySpark
to show column contents without truncating them.
# Inspect the first 5 rows of the original dataset
ss_2008_data_raw.show(10, truncate = False)
Next, we inspect the Schema (or column data types), and we see that all the columns are of string data type; however, the last column, '2008', which is the population column, should be an integer. So we'll change it later.
# Inspect column data types
print(ss_2008_data_raw.printSchema())
Selecting the Columns¶
Below, we select the columns of interest with the select()
method. Next, we inspect the updated dataset with the show() function and toPandas() method. The toPandas() method shows the dataset in a Python
format; however, it's computationally expensive and should be avoided unless necessary.
# Select the columns to focus on
ss_2008_data_cleaned = ss_2008_data_raw.select('Region Name', 'Variable Name', 'Age Name', '2008')
# Inspect the updated dataset with the show() function
ss_2008_data_cleaned.show(5, truncate = False)
# Inspect the first 5 rows with the toPandas() method
ss_2008_data_cleaned.limit(5).toPandas()
Changing Column Data Types¶
To convert data from one data type to another, we need to import data types from pyspark.sql.types
and SQL functions
from pyspark.sql.functions
. Here, we are using the wildcard, *
.
# Import SQL data types and functions
from pyspark.sql.types import *
from pyspark.sql.functions import *
# Change the population column data type to the integer data type
# we use the backslash to break a long line of code
ss_2008_data_cleaned = ss_2008_data_cleaned.withColumn('2008', \
ss_2008_data_cleaned['2008'].cast(IntegerType()))
# View the Schema
print(ss_2008_data_cleaned.printSchema())
# Inspect the first 5 rows with the toPandas() method
ss_2008_data_cleaned.limit(5).toPandas()
Renaming Columns with the withColumnRenamed()¶
Here we rename our columns of interest with the withColumnRenamed()
method. Please note, you need to chain several withColumnRenamed() calls together to rename multiple columns. Call the withColumnRenamed(), enter the old name, and enter the new name to rename a column. For example, in the below chunk, Region Name is the old column name, and its new name is State.
# Renaming columns
ss_2008_data_cleaned = ss_2008_data_cleaned.withColumnRenamed('Region Name', 'State')\
.withColumnRenamed('2008','Population').withColumnRenamed('Variable Name', 'Gender')\
.withColumnRenamed('Age Name', 'Age Category')
# View the first 4 rows with the toPandas()
ss_2008_data_cleaned.limit(4).toPandas()
Checking Row and Column Counts¶
We display the number of rows and the number of columns with the print()
function.
# Inspect the number of rows and columns
print('Your dataset has', ss_2008_data_cleaned.count(), 'rows and', len(ss_2008_data_cleaned.columns), 'columns.')
Removing NAs¶
Here we remove nas with the na.drop()
method. However, it is imperative to be careful when dropping nas as this may negatively impact your data.
# Drop the rows with nas or missing values
ss_2008_census_df = ss_2008_data_cleaned.na.drop()
ss_2008_census_df.limit(5).toPandas()
# Verify the dimensions of the dataset
print('Your dataset has', ss_2008_census_df.count(), 'rows and', len(ss_2008_census_df.columns), 'columns.')
Splitting Column Values with the Split()¶
Earlier, we saw that the Gender column is a 3-part column; comprising Population, 'Total/Female/Male', and (Number). However, we are only interested in the male and female rows. So, we clean this column by splitting it with the split()
method and then retain the middle potion, or the second index (1).
# Transform the gender column with the split() method
ss_2008_census_df = ss_2008_census_df.withColumn('gender', split(ss_2008_census_df['Gender'], ' ').getItem(1))
# Inspect the first 10 rows
ss_2008_census_df.show(10, truncate=False)
# Filter the gender column to keep only the rows with 'female' and 'male' in them
ss_2008_census_df = ss_2008_census_df.filter(ss_2008_census_df['gender'] != "Total")
# Inspect the first 10 rows
ss_2008_census_df.show(10, truncate=False)
# Re-inspect the number of rows and the number of columns
print('Your dataset has', ss_2008_census_df.count(), 'rows and', len(ss_2008_census_df.columns),'columns.')
Selecting Row Values with the filter()¶
Below, we remove the rows with the 'Total' in them.
# Modify the Age Category column
ss_2008_census_df_1 = ss_2008_census_df.filter(ss_2008_census_df['Age Category'] != "Total")
ss_2008_census_df_1.show(10, truncate=False)
# Re-inspect the number of rows and columns
print('Your dataset has', ss_2008_census_df_1.count(), 'rows and', len(ss_2008_census_df_1.columns), 'columns.')
Replacing Column Values¶
In the below chunk, we regroup the Age Category values to reduce the number of classes. While there are multiple methods for transforming column values, we opted to use the PySpark
replace() method.
# Combine Age Category values
ss_2008_census_final_df = ss_2008_census_df_1.replace(['0 to 4',
'5 to 9',
'10 to 14',
'15 to 19',
'20 to 24',
'25 to 29',
'30 to 34',
'35 to 39',
'40 to 44',
'45 to 49',
'50 to 54',
'55 to 59',
'60 to 64',
'65+'],
['0-9', '0-9',
'10-19', '10-19',
'20-29', '20-29',
'30-39', '30-39',
'40-49', '40-49',
'50-59', '50-59',
'60+', '60+'
], 'Age Category')
# Inspect the first 5 rows
ss_2008_census_final_df.show(5, False)
Converting Column Values into a List¶
ss_2008_census_final_df.select('Age Category').distinct().collect()
# Re-inspect the number of rows and the number of columns
print('Your dataset has', ss_2008_census_final_df.count(), 'rows and', len(ss_2008_census_final_df.columns), 'columns.')
Summarizing the Dataset¶
In the next three chunks, we group the data by various columns, compute the sum of the population column, and then display the results. We will accomplish this using the groupBy()
, agg()
, and orderBy()
methods.
# Compute the state totals
state_totals = ss_2008_census_final_df.groupBy("State")\
.agg(sum("Population").alias('Total Population')).orderBy(col('Total Population').desc())
# Display the results
state_totals.show(truncate = False)
# Compute the state totals by gender
state_totals_by_gender = ss_2008_census_final_df.groupBy('State', 'Gender')\
.agg(sum("Population").alias('Population')).orderBy(col('Population').desc())
state_totals
# Display the results
state_totals_by_gender.show(truncate = False)
# Compute the state totals by gender and age category
state_totals_by_gender_n_age = ss_2008_census_final_df.groupBy('State', 'Gender', 'Age Category')\
.agg(sum("Population").alias('Population')).orderBy(col('Population').desc(), col('State'))
# Print the results
state_totals_by_gender_n_age.show(truncate = False)
Closing Remarks¶
This article has just scratched the surface of Data wrangling with PySpark, but we hope it will help you get started in PySpark and data science in general.
With that said, please follow me on Twitter @tongakuot, LinkedIn @tongakuot, and GitHub @alierwai for more data science, Python, R, statistics, mathematics, PySpark, and Shiny tutorials, and articles.
Acknowledgements¶
We are grateful to Layla AI and Mike Cohen for their phenomenal courses on PySpark and Master Math by Coding in Python, respectively.