How To Use Spark Partitions With Postgres
In the post, we will explore how we can use partitions to reduce job execution time by increasing parallelism. Specifically, we will focus on building Spark partitions from tables in a Postgres database.
Spark is an analytics engine for processing large datasets. Spark allows you to process enormous datasets across many distributed nodes. Distributing processing helps to solve two problems. First, a single node can house a finite amount of memory; for example, a commodity motherboard will only support 256GB of memory. So to process a large dataset, we need a method to distribute data across many nodes. Second, we can speed up processing by harnessing the computing power of many nodes through parallel execution.
In Spark, a partition is a chunk of data stored in a node. All of the tuples in a partition are guaranteed to reside on the same physical node. Partitions form the basic unit of parallelism in Spark.
To gain a better intuition on how partitions work with Spark and Postgres, we will build on our [Spark/Postgres/Jupyter Docker environment]({% post_url 2021-01-28-spark-postgres-jupyter %}) described in my previous post.
The large dataset we will be using is some airline flight delay data from 2008.
Environment Setup
First, let's create a local development environment for testing:
# Clone the repo
git clone https://github.com/omahoco/spark-postgres.git
# Create and start the environment
cd spark-postgres
make all
Loading Delayed Airline Data
The flight delay dataset has a couple of interesting columns we can use for our example.
Column | Description | Type |
---|---|---|
FlightDate | The date the flight departed | Timestamp |
TailNum | A flight identifier | Text |
ArrDelay | Number of minutes late | Decimal |
To load our sample dataset into Postgres:
make load_airline_data
Calculating Average Delay Using One Partition
Next, we open Jupyter http://localhost:9999 and calculate average delays for a given flight using the code below.
import pyspark
conf = pyspark.SparkConf().setAppName('DelayedFlight1Partition').setMaster('spark://spark:7077')
sc = pyspark.SparkContext(conf=conf)
session = pyspark.sql.SparkSession(sc)
jdbc_url = 'jdbc:postgresql://postgres/postgres'
connection_properties = {
'user': 'postgres',
'password': 'postgres',
'driver': 'org.postgresql.Driver',
'stringtype': 'unspecified'}
df = session.read.jdbc(jdbc_url,'public.airline_delayed_flight',properties=connection_properties)
df.groupBy("TailNum").avg("ArrDelay").show()
sc.stop()
The code above hasn't specified the number of partitions to use; Spark will default to a single partition. You can confirm the number of partitions for yourself with df.rdd.getNumPartitions())
. Spark will build the partition by issuing the following query to Postgres:
SELECT "tailnum","arrdelay" FROM public.airline_delayed_flight
The partition will reside in a node and must fit into available memory. We can prove this to ourselves by making the dataset ten times large, which should exhaust memory.
for i in {1..10}; do make load_airline_data; done
Running the code above will fail after about ~6 minutes. Eventually, we will get the following error:
Job aborted due to stage failure: Task 0 in stage 0.0 failed 4 times, most recent failure: Lost task 0.3 in stage 0.0 (TID 3, 172.18.0.6, executor 2): ExecutorLostFailure (executor 2 exited caused by one of the running tasks) Reason: Remote RPC client disassociated. Likely due to containers exceeding thresholds, or network issues. Check driver logs for WARN messages.
We can investigate further by looking at the SparkContext
Web UI by navigating to http://localhost:4040. From here, we can see that the executor is spending most of the time performing garbage collection.
Calculating Average Delay With Multiple Partitions
Now, let's try to use six partitions to process the dataset. Spark SQL supports partitioned reads using a partitionColumn
option. The partition column must be numeric, date or timestamp. The numPartitions
parameter works with upperBound
and lowerBound
to calculate the partition stride.
import pyspark
import datetime
conf = pyspark.SparkConf().setAppName('DelayedFlight6Partitions').setMaster('spark://spark:7077')
sc = pyspark.SparkContext(conf=conf)
session = pyspark.sql.SparkSession(sc)
jdbc_url = 'jdbc:postgresql://postgres/postgres'
jdbcDF = session.read \
.format("jdbc") \
.option("url", "jdbc:postgresql://postgres/postgres") \
.option("dbtable", "public.airline_delayed_flight") \
.option("user", "postgres") \
.option("password", "postgres") \
.option("partitionColumn", "FlightDate") \
.option("lowerBound", datetime.datetime(2008,1,1)) \
.option("upperBound", datetime.datetime(2009,1,1)) \
.option("numPartitions", 6) \
.load()
jdbcDF.groupBy("TailNum").avg("ArrDelay").show()
The code above will create six partitions. Each partition uses a separate database connection and query to retrieve its range. The first and last ranges ensure values below and above the bounds are included.
Range | Query |
---|---|
1 | WHERE "flightdate" < '2008-03-02 00:00:00' or "flightdate" is null |
2 | WHERE "flightdate" >= '2008-03-02 00:00:00' AND "flightdate" < '2008-05-02 00:00:00' |
3 | WHERE "flightdate" >= '2008-05-02 00:00:00' AND "flightdate" < '2008-07-02 00:00:00' |
4 | WHERE "flightdate" >= '2008-07-02 00:00:00' AND "flightdate" < '2008-09-01 00:00:00' |
5 | WHERE "flightdate" >= '2008-09-01 00:00:00' AND "flightdate" < '2008-11-01 00:00:00' |
6 | WHERE "flightdate" >= '2008-11-01 00:00:00' |