Dask Performance Boosts For Model Training

Tips and tricks to speed up distributed training with Dask.


A distributed computing library like Dask is a great tool for speeding up machine learning model training. We’ll show a few methods for achieving performance boosts by knowing when and how to use techniques like batching, scattering, and tree-reduction.

Distributing Model Training

In supervised machine learning, models are often trained by defining a problem-specific loss function and trying to minimize it. Intuitively, the loss function quantifies some measure of how different model predictions are compared to the training data labels for a given set of parameter values; the best-known example of this is minimizing the sum-of-squared errors when fitting a linear regression model. In practice, this is done by using first- or second-order optimization algorithms (gradient descent, SGD, Newton’s method, etc.) to search for optimal parameters which minimize the loss function. This broad description of training applies to many popular models, including:

  • Linear regression
  • Logistic regression
  • Support Vector Machines
  • Neural Networks

The loss function in these situations can often be written as the sum of loss for the individual training examples (maybe with some additional regularization terms):

Image for post

where R(θ) is a regularization term. Linearity of the derivative then gives us the convenient formulation of the gradient with respect to the parameters :

Image for post

As such, we can easily distribute computations of the loss and its gradient (and hessian, etc.) by letting multiple threads/cores/cluster workers each compute the loss/gradient for individual training examples concurrently, then gathering and summing the results.

Dask for Distributed Computing in Python

Dask is a popular library for scalable distributed computing in Python. It allows you to distribute tasks on various clusters ranging from a few nodes running locally on your laptop up to thousands of nodes running on a server. Other libraries exist with similar functionality, but this post focuses specifically on how to speed up training in Dask.

Let’s look at a simple example. Here we have a list of examples (examples) and the parameters for the current training iteration as params:

def single_example_loss(examples,params):
   return loss, gradient  def compute_loss(examples, params, client):
   futures=client.map(single_example_loss, examples,params=params)
   return sum(r[0] for r in results), sum(r[1] for r in results)

This is the basic formula for distributing the loss computation using Dask. The performance gains from this basic implementation, however, may be lacking for reasons that will be explained shortly. The remainder of this post will discuss ways to improve results.

Transmitting Data (scatter)

In the above example, the parameters array may be very large. Furthermore, we could have a need to transmit other large pieces of data to the workers, like a Hessian product vector. Such large pieces of data which are shared between all workers can instead be scattered beforehand using client.scatter():

params_future = client.scatter(params, broadcast=True)
futures = client.map(single_example_loss, examples,

client.scatter(…, broadcast=True) efficiently sends a piece of data to all workers, which can eliminate unnecessary network time during a client.map().

Be aware of the following:

  • Scattering is usually unnecessary for small pieces of data, and in this case it may slow things down slightly because it creates some additional overhead.
  • There are some known bugs in Dask when using client.scatter(). Be aware of these, as they can cause your training routine to crash. Scattering is not always the optimal thing to do.

Submitting Tasks (batching)

Suppose there are 4000 training examples. In the above example which calls client.map(single_example_loss, examples, params=…) to compute the loss for a single iteration, we create 4000 Dask tasks/futures. We might instead choose to write a function to compute loss in batches:

def batch_loss(example_list, params):
   accum_loss, accum_grad = 0, 0
   for example in example_list:
       loss, grad = single_example_loss(example, params)
       accum_loss += loss
       accum_grad += grad
   return accum_loss, accum_grad

Now, we can do all of the computations for a single iteration with:

n_workers = len(client.has_what())
futures = client.map(
   [examples[i::n_workers] for i in range(n_workers)],

If we have 40 workers, this reduces the number of Dask tasks/futures created from 4000 to 100. Each Dask task submitted creates some non-trivial amount of overhead, so submitting in batches like this can provide nice speedups.

Gathering Results (tree-summation)

The standard way one might gather the results to compute the final loss is client.gather(), as used above. To attempt to increase the efficiency of this slightly, we might replace this with:

from dask.distributed import as_completed
for _, result in as_completed(futures=mapped, with_results=True):

It turns out that neither of these solutions work particularly well in certain situations. Dask has some latency that prevents the client from collecting many large results (remember there might be lots of parameters) from many workers quickly. That is, even if all workers finish quickly, the client might take several additional minutes to collect these results.

Image for post
Gathering results as a single output using tree-summation.

An alternative (and better!) solution is to let the workers do all necessary post processing by pairing up. We continue iterating like this until only one worker remains with the final result, and then send only this single final result to the client. We also need to add a post-processing helper-function:

def accumulate(result1, result2):
   return result1[0] + result2[0], result1[1] + result2[1]

Then, gather results using tree-summation (letting workers pair up):

mapped = client.map(...)

while len(mapped) > 1:
   new_mapped = []
   for i in range(0, len(mapped) - 1, 2):
       fut = client.submit(accumulate, mapped[i], mapped[i + 1])
   if len(mapped) % 2 == 1:
   mapped = new_mapped

return mapped[0].result()

Using this tree-summation solution will usually provide huge speed-ups over the other two result-gathering solutions.

One of our time-to-event modeling frameworks currently under development that uses deep learning to model event risk, HazardNet, benefited hugely from the combination of batching and tree summation due to the large number of parameters in the model. We saw estimated speedups of 5–10x when implementing these methods over the standard client.gather() approach.


We gave a brief overview of three methods for speeding up distributed computations with Dask:

  • Scattering: Transmits data to workers. This is useful when there are large pieces of data that all workers need access to in order to complete the task.
  • Batching: Reduces overhead caused by submitting many tasks to the Dask scheduler. This is useful when the number of tasks is very large.
  • Tree-summation: Accumulates data among workers before sending the accumulated result to the scheduler. Very effective when there are a large number of tasks and each individual piece of data being accumulated is large.

These approaches could be used for many different types of distributed computations, but are especially useful in the context of optimization. You can learn more about setting up efficient data input pipelines and distributed training in the context of deep learning with TensorFlow and much more on When Machines Learn.

Stay up to date with our latest news and research

Thanks for subscribing. You'll now receive our future email newsletters.
Oops! Something went wrong while submitting the form.