A Guide to Distributed TensorFlow: Part 1

How to set up efficient input data pipelines for deep learning using TFRecord and tf.data.Dataset API


This post is the first in a two-part series on large-scale, distributed training of TensorFlow models using Kubeflow. In this blog series, we will discuss the foundational concepts of a distribution strategy that supports data-parallelism, the tools and technologies involved in setting up the distributed computations and walk you through a concrete example highlighting the overall workflow. This post will provide the implementation details of an input data pipeline using TFRecord and tf.data.Dataset API for distributed model training.

Deep Learning

Deep learning has gained a lot of attention in recent years. The field has seen many exciting research papers coming out to open forums at a tremendous pace, some of which are setting incredible performance benchmarks on problems that were previously considered to be close to impossible. Industry practitioners are rapidly adopting the technology during this age of implementation¹ and some even argue that deep learning models have the capabilities to exhibit human-level reasoning. This particular technology, it’s safe to say, is at the forefront of major groundbreaking AI developments.

If you’re a data scientist, you’ve probably had some exposure to the basics of how neural networks work. However, you might not have the background of an ML or DevOps engineer who typically oversees end-to-end machine learning pipelines and large-scale, cluster-wide computations. If you’re an independent data scientist, you might not have access to that kind of expert advice or resources.

Wouldn’t it be nice if you could just take your experimental neural network models and scale them up with massive amounts of data, without having to rely on anyone to do it for you?

This post is intended to save you effort, cost, and development time. We’ll describe a common use case and then, in as much detail as possible, talk about each of the necessary ingredients.

Problem setup

Let’s say we have raw sensor measurements from a fleet of thousands of machines or assets. Let’s assume that all this data sits in cloud storage and that we have all the necessary credentials for read/write access to this cloud storage. Specifically, we will consider each asset’s data to be a parquet or a csv file of several hundred megabytes. Lastly, let’s also assume that we have access to an operational cluster with private nodes, powered by Google Kubernetes Engine.

Our goal is to train a recurrent autoencoder for unsupervised latent state representation. This is a neural network that takes advantage of the temporal dynamics of our data to engineer useful features that can be used for other complicated learning tasks. For this article, it will be helpful to have some familiarity with Dask, since we will not cover any of the instructions for setting up a Dask scheduler and workers on our on-premise cluster.

Image for post
The workflow we describe here involves training a model on data gathered from cloud storage. Training is distributed using the Google Kubernetes Engine.


Google’s open-source Deep Learning library, TensorFlow, is a go-to for many industry practitioners who want to build and deploy production-grade neural networks. We will use many of the excellent (but experimental) features of this library to scale up our model.

We assume a sufficient background in developing neural networks, therefore we use the notion of an abstract Model class that encapsulates the necessary functionality of a recurrent autoencoder. This model has an encoder that takes in an input of shape (batch_size, num_timesteps, num_dimensions) and compresses it to a latent state of shape (batch_size, num_latent_state_dimensions) by sequentially processing the data along the time dimension. There’s a decoder that tries to reconstruct the input by processing this latent state. As a result of minimizing the reconstruction error, the encoder is forced to come up with rich representations of our input data which can be used as features for other learning algorithms. We will use a subclassed Keras Model in this workflow.

Input pipeline

For distributed training with large amounts of data, we need efficient input pipelines that don’t need to read all of the data into memory. The recommended way to do this is to take advantage of the tf.data API. TensorFlow has the capability to treat the input function itself as a collection of nodes in a computational graph and applies all of the subsequent data transformations to symbolic tensors. The input function is invoked only during the model training iterations to return real-valued tensors.

Since we have to read and process data for each asset, it would be much faster to do this concurrently. There are some limitations to reading and processing data online that can lead to performance bottlenecks during training — our GPUs that do the actual computations end up suffering downtime if the preprocessing functions are computationally intensive. Also, we won’t be able to fully exploit our resources to parallelize this task. A more efficient alternative that allows for maximum GPU utilization during training is to apply the preprocessing function and save the input data offline. That way, the actual input function during model training would only involve reading in the preprocessed data. There are trade-offs between these two approaches. In this article, we focus on the second approach. We can also use open source scalers like Dask to parallelize this (read -> preprocess -> save) workflow.

Now there’s one last detail before we can look at code snippets that do this. TensorFlow also has the capability to read data in the TFRecord format very efficiently. A TFRecord file is essentially a collection of binary records wherein each record represents a single serialized instance of our input data.

Now, the overall pipeline becomes:

  1. Create batches of asset data records
  2. For every batch, preprocess and serialize asset data in a distributed fashion using Dask (or other scalers)
  3. Save each batch that contains serialized binary records to a TFRecord file.
Image for post
Minimal example of a data preprocessing function.
Image for post
This function will distribute`fn(..)` across cluster workers. The resulting data instances per asset are serialized in a distributed way. We’re using Dask’s vanilla map and gather strategy.

Image for post
Create batches of assets, distribute the workload via `map_fn(..)` and save results using TFRecordWriter.
Image for post
Array serializing functions.

In the above (minimal example) code snippets, we can see that the preprocess_fn(..) that accepts a single asset identifier and the cloud storage location, reads the asset-level data and performs some simple data preprocessing steps — standardizing and reformatting. We’re distributing this function at the batch-level using map_fn(..). We will further distribute serialize_fn(..) at the asset level. The distributed_preprocess_fn(..) simply orchestrates the process and repeatedly applies these functions to all batches of assets. The serialize_fn(..) converts arrays to bytes and then to tf.Features in order to serialize. Follow this minimal example workflow to efficiently preprocess and serialize data for training.

During the actual model training, we need to create a tf.data.Dataset and apply additional (but very simple) transformations to make it ready for training.

Image for post
Load previously saved TFRecords

Here, we’re simply loading and parsing the (previously saved) records from a location that contains multiple TFRecord files. We’re further repeating this for a specified number of training epochs (for data availability throughout the training routine), shuffling the instances and batching them.

Now, our input data in the tf.data.Dataset format is ready to be passed into our model and we’re all set with our data pipeline to begin model training.

The Next Step

In this post we’ve explored some basic tools and techniques to prepare data for a distributed computation. In the forthcoming Part 2 of this series we’ll learn how to use Kubeflow to distribute our computational workload across multiple cluster nodes, and we’ll provide concrete examples in a sample implementation. You can find this and more on When Machines Learn.

¹ The phrase “age of implementation” was coined by Kai-Fu Lee in his NY Times Bestseller, AI Superpowers. The book emphasizes the broad applications of deep learning to solve many practical problems today.