Intro Link to heading
Discrete convolutions are common in signal processing and machine learning applications. Most approaches to running them are limited to the resources of a single machine, or require manual parallelization.
If we instead approach it as a data science problem, and utilize the cluster scale compute capacity of Spark, we can achieve significant performance gains with fairly minimal complexity.
This guide will give you a single-node proof-of-concept, which you can integrate with your own use-case.
Credit to atctwo for giving me this problem and checking my work.
This implementation was compared against the NumPy convolve function for correctness.
The following import statement will cover everything we need:
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, DoubleType, LongType
import pyspark.sql.functions as sf
Spark Setup Link to heading
We’re going to start by creating a local Spark instance, with one worker per core.
When we use local[*]
as the ‘master’, Spark will use your local machine as the executor, and create 1 worker per core (or whatever number you enter in the brackets.)
If you would like to connect to an external cluster instead, you replace it with the required connection details.
spark = SparkSession \
.builder \
.master("local[*]") \
.getOrCreate()
Now, if you try to use this with it’s default settings, you are likely to run into heap space or ‘max result size’ errors. Therefore, we can add some config options to allocate more memory for both.
spark = SparkSession \
.builder \
.master("local[*]") \
.config("spark.driver.memory", "16g") \
.config("spark.driver.maxResultSize", "0") \
.config("spark.sql.shuffle.partitions", "1000") \
.getOrCreate()
‘Memory’ refers to the maximum RAM used by all threads on the driver / your local machine. I.E if you have 10 workers, and a memory value of 6gb, the maximum total memory usage will be 6gb + JVM-overhead.
Note: When connecting to a cluster, each node will have it’s own memory limit, separate from the config value set here.
By default, Spark will abort execution if a result chunk is greater than 1GiB. Therefore we set the ‘maxResultSize’ to ‘0’ to allow unlimited result sizes.
The ‘shuffle partitions’ value is the default number of chunks that Spark will use to store data before processing. Increasing the value will allow spark to load smaller chunks into memory at a time, but too high a value could cause Spark to waste time on overhead.
Loading & Batching Our Data Link to heading
Spark doesn’t have native support for complex numbers, therefore we’re going to store the real and imaginary portions in separate columns.
The preferred file format is typically Apache Parquet, as Spark has the ability to read the data in chunks, reducing memory usage. See the Apache Docs for more data sources.
If needed you could also parse the data into a pandas DataFrame separately, and then load that into a Spark DataFrame instead (Guide), but you may need to be able to fit the full dataset in your RAM.
Input Schema Link to heading
Field | Type | Purpose |
---|---|---|
id | Long | Sequential sample index |
real | Double/Float | Real portion |
imag | Double/Float | Imaginary portion |
You can load from a file:
input_df = spark.read.parquet("data.parquet")
Or start from a list, Pandas DF etc.:
data = [
(0, 1.0, 2.0), (1, 3.0, 4.0), (2, 5.0, 6.0), (3, 7.0, 8.0),
(4, 9.0, 10.0), (5, 11.0, 12.0), (6, 13.0, 14.0), (7, 15.0, 16.0),
(8, 17.0, 18.0), (9, 19.0, 20.0), (10, 21.0, 22.0), (11, 23.0, 24.0)
]
schema = StructType([
StructField("id", LongType(), False),
StructField("real", DoubleType(), False),
StructField("imag", DoubleType(), False)
])
input_df = spark.createDataFrame(data, schema)
Batching Link to heading
For our use-case, we are going to split the data into batches, as it can significantly speed up the summarization process.
Without batching, when it was time to sum up each output value, Spark would need to retrieve the dot products of every multiplication in the dataset. Since these will be distributed across different workers, this will result in a large amount of data transfer between nodes. This, in-turn, can significantly slow down our program.
By splitting the data into chunks small enough to fit into a single worker, we eliminate this communication entirely. In my testing it delivered ~4x speed-ups, but your mileage may vary.
We’ll define the batch size with a reasonable starting value:
BATCH_SIZE = 2048
Now we will:
- Split the events into batches.
- Calculate the index relative to the batch.
- Tell Spark to treat the batch as the unit of work.
base = input_df \
.withColumn(
"batch_id",
sf.floor((sf.col("id")) / BATCH_SIZE)
) \
.withColumn(
"id_in_batch",
(sf.col("id")) % BATCH_SIZE
) \
.repartition("batch_id")
We must also keep in mind that a batch may be smaller than BATCH_SIZE, therefore we will count the actual lengths for use at the end.
input_counts = base.groupBy("batch_id").agg(sf.count("id").alias("input_len"))
Loading our Kernel Link to heading
A Kernel or Convolution Matrix is the, typically small, array which is multiplied with each element in the input to produce a weighted output.
You can either load it from a file, or as below, just define it in-line. Since our data is made of complex numbers, our kernel should be the same:
KERN = [
(0.229397 + 0.229397j),
(-0.080817 - 0.080817j),
(0.132327 + 0.132327j),
(-0.212555 - 0.212555j),
(0.313648 + 0.313648j),
]
Since we are splitting the Real and Imaginary portions of the Complex numbers, lets iterate through the list to grab the index and also the components:
rows = [(i, v.real, v.imag) for i, v in enumerate(KERN)]
Now we can tell Spark what to call the parts of our data, and then load it into a DataFrame:
kern_schema = StructType([
StructField("k_id", LongType(), True),
StructField("k_real", DoubleType(), True),
StructField("k_imag", DoubleType(), True)
])
kernel_df = spark.createDataFrame(rows, kern_schema)
Note: prefixing the kernel fields with
k_
will differentiate them from the input data, after the join. You could also use an alias.
In Spark we can ‘broadcast’ such small datasets to all nodes, allowing them to quickly reference local copies of the data, rather than having to communicate with a central source.
Since we will be computing the kernel against every single batch, we are going to take advantage of that:
kdf = sf.broadcast(kernel_df)
We will exclusively refer to this broadcasted ‘kdf’ frame from now on, to make sure we get the performance benefits.
Cross Joining Our Data With The Kernel Link to heading
We can now cross join the input DataFrame with our broadcasted kernel. This will produce a new row for every combination of an input row with a kernel value.
cross_joined = base.crossJoin(kdf)
This will give us a new combined schema of:
Field | Type |
---|---|
id | Long |
id_in_batch | Long |
batch_id | Long |
real | Double/Float |
imag | Double/Float |
k_id | Long |
k_real | Double/Float |
k_imag | Double/Float |
Calculating Row Values Link to heading
Now, largely due to rows being partially duplicated, our dataset will be significantly larger than when we started.
Therefore, it’s important that we try to drop data as early as possible.
One way we can achieve that is by performing as many calculations as possible within a Select statement, allowing us to drop the inputs at the same time.
Additionally, Due to Spark’s lack of native support for complex numbers, we must manually implement complex multiplication using the real and imaginary components of the input and kernel.
conv_df = cross_joined \
.select(
sf.col("batch_id"),
sf.col("id"),
(sf.col("real") * sf.col("k_real") - sf.col("imag") * sf.col("k_imag")) \
.alias("real"),
(sf.col("real") * sf.col("k_imag") + sf.col("imag") * sf.col("k_real")) \
.alias("imag"),
(sf.col("id_in_batch") + sf.col("k_id")) \
.alias("i_conv"),
)
The field i_conv
contains the destination index within the result batch, after convolution.
This will result in an output index range of $0\to((BATCH\_SIZE-1)+(len(KERN)-1))$, which is expected, and we will compensate for it later.
Summing Each Output Index Link to heading
Now that we’ve multiplied each input with each element of the kernel, we need to sum up the output value.
First we must group the sums by the batch, and the index within each batch. Then we use a simple aggregation to get the total for that cell.
summed_conv_df = conv_df \
.groupby("batch_id" ,"i_conv") \
.agg(
sf.sum("real").alias("real"),
sf.sum("imag").alias("imag"),
)
Handling The Edges Of Our Array Link to heading
Discrete convolution produces an output larger than the input, therefore, we are going to select the center of the output as our return value. This will cut off the over-run, and give us an output of equal size.
First, we need to make sure we only grab populated cells, therefore we are going to use a Window
to recalculate the sequential ID, ordered by the i_conv
.
window_spec_batch = Window.partitionBy("batch_id").orderBy("i_conv")
# Count the sums within a given batch
processed_batches = summed_conv_df.withColumn(
"row_num_in_convolved_batch", sf.row_number().over(window_spec_batch)
)
Now we will join on our batch sizes, so that we can properly remove the over-run off the end:
KERNEL_HALF_LEN = len(KERN) // 2 # Integer division
batches_with_len = processed_batches.join(input_counts, on="batch_id", how="left")
filtered = batches_with_len \
.filter(
(sf.col("row_num_in_convolved_batch") > KERNEL_HALF_LEN) &
# (input_len + KERNEL_HALF_LEN) corresponds to the end of the "valid" part
(sf.col("row_num_in_convolved_batch") <= (sf.col("input_len") + KERNEL_HALF_LEN))
)
Now we simply need to:
- Select our output
- Align the index to 0
- (Optional) Sort the data by batch and index
output = filtered \
.select(
sf.col("batch_id"),
(sf.col("row_num_in_convolved_batch") - (KERNEL_HALF_LEN + 1)).alias("i"),
sf.col("real"),
sf.col("imag")
) \
.orderBy("batch_id", "i")
If you wished, you could use Window in a similar way to change the index to be across the full dataset.
Example:
full_window = Window.orderBy("batch_id", "i")
fully_ordered = output \
.withColumn("id", sf.row_number().over(full_window)) \
.select(
sf.col("id").alias("i"),
sf.col("sum_real").alias("real"),
sf.col("sum_imag").alias("imag")
)
Putting It All Together Link to heading
Here is the full script, with in-line example data:
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, DoubleType, LongType
import pyspark.sql.functions as sf
BATCH_SIZE = 2048
spark = SparkSession \
.builder \
.master("local[*]") \
.config("spark.driver.memory", "16g") \
.config("spark.driver.maxResultSize", "0") \
.config("spark.sql.shuffle.partitions", "1000") \
.getOrCreate()
print("Spark session started.")
# Example data
data = [
(0, 1.0, 2.0), (1, 3.0, 4.0), (2, 5.0, 6.0), (3, 7.0, 8.0),
(4, 9.0, 10.0), (5, 11.0, 12.0), (6, 13.0, 14.0), (7, 15.0, 16.0),
(8, 17.0, 18.0), (9, 19.0, 20.0), (10, 21.0, 22.0), (11, 23.0, 24.0)
]
schema = StructType([
StructField("id", LongType(), False),
StructField("real", DoubleType(), False),
StructField("imag", DoubleType(), False)
])
# input_df = spark.read.parquet("data.parquet") # Load from a file
input_df = spark.createDataFrame(data, schema) # Load from a list
# Add batch_id and id_in_batch columns, then repartition by batch_id
base = input_df \
.withColumn(
"batch_id",
sf.floor((sf.col("id")) / BATCH_SIZE)
) \
.withColumn(
"id_in_batch",
(sf.col("id")) % BATCH_SIZE
) \
.repartition("batch_id")
# Calculate the actual size of each batch
input_counts = base.groupBy("batch_id").agg(sf.count("id").alias("input_len"))
# Define the kernel (convolution filter)
KERN = [
(0.229397 + 0.229397j),
(-0.080817 - 0.080817j),
(0.132327 + 0.132327j),
(-0.212555 - 0.212555j),
(0.313648 + 0.313648j),
]
# Convert kernel to a list of tuples (k_id, k_real, k_imag)
rows = [(i, v.real, v.imag) for i, v in enumerate(KERN)]
kern_schema = StructType([
StructField("k_id", LongType(), True),
StructField("k_real", DoubleType(), True),
StructField("k_imag", DoubleType(), True)
])
# Create the kernel DataFrame and broadcast it for efficient joins
kernel_df = spark.createDataFrame(rows, kern_schema)
kdf = sf.broadcast(kernel_df)
# Cross join input data with the broadcasted kernel
cross_joined = base.crossJoin(kdf)
# Calculate the complex multiplication and the convolution index within the batch
conv_df = cross_joined \
.select(
sf.col("batch_id"),
sf.col("id"),
# Manual complex multiplication
(sf.col("real") * sf.col("k_real") - sf.col("imag") * sf.col("k_imag")) \
.alias("real"),
(sf.col("real") * sf.col("k_imag") + sf.col("imag") * sf.col("k_real")) \
.alias("imag"),
# Calculate the destination index within the batch after convolution
(sf.col("id_in_batch") + sf.col("k_id")) \
.alias("i_conv"),
)
# Sum the total for each value in each batch
summed_conv_df = conv_df \
.groupby("batch_id" ,"i_conv") \
.agg(
sf.sum("real").alias("real"),
sf.sum("imag").alias("imag"),
)
# To handle any missing cells in i_conv,
# we will give each element a new sequential index
window_spec_batch = Window.partitionBy("batch_id").orderBy("i_conv")
processed_batches = summed_conv_df.withColumn(
"row_num_in_convolved_batch", sf.row_number().over(window_spec_batch)
)
KERNEL_HALF_LEN = len(KERN) // 2 # Integer division
# Retrieve the true length of each batch
batches_with_len = processed_batches.join(input_counts, on="batch_id", how="left")
# Filter to keep only the "valid" part of the convolution output
# This removes the edges where the kernel partially overlaps the input
filtered = batches_with_len \
.filter(
(sf.col("row_num_in_convolved_batch") > KERNEL_HALF_LEN) &
# (input_len + KERNEL_HALF_LEN) corresponds to the end of the "valid" part
(sf.col("row_num_in_convolved_batch") <= (sf.col("input_len") + KERNEL_HALF_LEN))
)
# Select the final output columns and calculate the final index (0-based)
output = filtered \
.select(
sf.col("batch_id"),
# We subtract (KERNEL_HALF_LEN + 1) because the
# row_num_in_convolved_batch is 1-based, and our KERNEL_HALF_LEN is 0 based.
(sf.col("row_num_in_convolved_batch") - (KERNEL_HALF_LEN + 1)).alias("i"),
sf.col("real"),
sf.col("imag")
) \
.orderBy("batch_id", "i")
print("Pipeline created.")
print("Calculating convolution...")
# Pipelines run lazily, so to trigger the calculation we can display the results.
# You could also save this to a file, or as part of another pipeline etc.
output.show(n=20)