Problem summary
I am fitting a brms::brm_multiple()
model to a large dataset where missing data has been imputed using the mice
package. The size of the dataset makes the use of parallel processing very desirable. However, it isn't clear to me how to best use the compute resources because I am unclear about how brms
divides sampling on the imputed dataset among cores.
How can I choose the following to maximize efficient use of compute resources?
- number of imputations (
m
) - number of chains (
chains
) - number of cores (
cores
)
Conceptual example
Let's say that I naively (or deliberately foolishly for sake of example) choose m = 5
, chains = 10
, cores = 24
. There are thus 5 x 10 = 50 chains to be allocated among 24 cores reserved on the HPC. Without parallel processing, this would take ~50 time units (excluding compiling time).
I can imagine three parallelization strategies for brms_multiple()
, but there may be others:
Scenario 1: Imputed datasets in parallel, associated chains in serial
Here, each of the 5 imputations is allocated to it's own processor which runs through the 10 chains in serial. The processing time is 10 units (a 5x speed improvement vs. non-parallel processing), but poor planning has wasted 19 cores x 10 time units = 190 core time units (ctu; =80% of the reserved compute resources). The efficient solution would be to set cores
= m
.
Scenario 2: Imputed datasets in serial, associated chains in parallel
Here, the sampling begins by taking the first imputed dataset and running one of the chains for that dataset on each of 10 different cores. This is then repeated for the remaining four imputed datasets. The processing takes 5 time units (a 10x speed improvement over serial processing & a 2x improvement over Scenario 1). However, here too compute resources are wasted: 14 cores x 5 time units = 70 ctu. The efficient solution would be to set cores
= chains
Scenario 3: Free-for-all, wherein each core takes on a pending imputation/chain combination when it becomes available until all are processed.
Here, the sampling begins by allocating all 24 cores, each one to one of the 50 pending chains. After they finish their iterations, a second batch of 24 chains is processed, bringing the total chains processed to 48. But now there are only two chains pending and 22 cores sit idle for 1 time unit. The total processing time is 3 time units, and the wasted compute resource is 22 ctu. The efficient solution would be to set cores
to a multiple of m
x chains
.
Minimal reproducible example
This code compares the compute time using an example modified from a brms vignette. Here we'll set m
= 10, chains
= 6, and cores
= 4. This makes for a total of 60 chains to be processed. Under these conditions, I would expect speed improvement (vs. serial processing) is as follows*:
- Scenario 1: 60/(6 chains x ceiling(10 m / 4 cores)) = 3.3x
- Scenario 2: 60/(ceiling(6 chains / 4 cores) x 10 m) = 3.0x
- Scenario 3: 60/ceiling((6 chains x 10 m) / 4 cores) = 4.0x
*(ceiling/rounding up is used because chains cannot be subdivided among cores)
library(brms)
library(mice)
library(tictoc) # convenience functions for timing
# Load data
data("nhanes", package = "mice")
# There are 10 imputations x 6 chains = 60 total chains to be processed
imp <- mice(nhanes, m = 10, print = FALSE, seed = 234023)
# Fit the model first to get compilation out of the way
fit_base <- brm_multiple(bmi ~ age*chl, data = imp, chains = 6,
iter = 10000, warmup = 2000)
# Use update() function to avoid re-compiling time
# Serial processing (127 sec on my machine)
tic() # start timing
fit_serial <- update(fit_base, .~., cores = 1L)
t_serial <- toc() # stop timing
t_serial <- diff(unlist(t_serial)[1:2]) # calculate seconds elapsed
# Parallel processing with 3 cores (82 sec)
tic()
fit_parallel <- update(fit_base, .~., cores = 4L)
t_parallel <- toc()
t_parallel <- diff(unlist(t_parallel)[1:2]) # calculate seconds elapsed
# Calculate speed up ratio
t_serial/t_parallel # 1.5x
Clearly I am missing something. I can't distinguish between the scenarios with this approach.