Data Loaders#
This tutorial explores different data loading strategies for using JAX. While JAX doesn’t include a built-in data loader, it seamlessly integrates with popular data loading libraries, including:
In this tutorial, you’ll learn how to efficiently load data using these libraries for a simple image classification task based on the MNIST dataset.
You should be familiar with how to write a training loop from the MNIST Example. For this tutorial, we’ll use a dummy training step that takes in 4-D image arrays and 1-D label vectors. Our goal in data loading will be to create these tensors, implementing the get_batches generator below.
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Float, Int, Array
from flax import nnx
batch_size = 32
def train(model: nnx.Module, images: Float[Array, "batch channels height width"], labels: Int[Array, "batch"]):
pass
def train_loop(model):
for images, labels in get_batches(train_ds):
train(model, images, labels)
Collecting jaxtyping
Downloading jaxtyping-0.3.9-py3-none-any.whl.metadata (7.4 kB)
Collecting wadler-lindig>=0.1.3 (from jaxtyping)
Downloading wadler_lindig-0.1.7-py3-none-any.whl.metadata (17 kB)
Downloading jaxtyping-0.3.9-py3-none-any.whl (56 kB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/56.3 kB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 kB 6.4 MB/s eta 0:00:00
?25hDownloading wadler_lindig-0.1.7-py3-none-any.whl (20 kB)
Installing collected packages: wadler-lindig, jaxtyping
Successfully installed jaxtyping-0.3.9 wadler-lindig-0.1.7
Loading Hugging Face Datasets#
In the previous MNIST Example, we saw how to use Hugging Face’s datasets library with jax. Specifically, we downloaded the ‘mnist’ dataset and used different subsets of the data from the ‘train’ and ‘test’ splits. The splits object we get from load_dataset is just a dict mapping subset names to Dataset objects. Each Dataset is cached to an Arrow file for fast, efficient loading.
from datasets import load_dataset
splits = load_dataset('mnist')
train_ds = splits['train'].shuffle(seed=0)
test_ds = splits['test']
isinstance(splits, dict)
True
When you take slices of these Dataset objects, you get dictionaries mapping feature names to lists of observations.
jax.tree.map(get_list_type, train_ds[1:32], is_leaf=lambda x: type(x) is list)
{'image': "list[<class 'PIL.PngImagePlugin.PngImageFile'>]",
'label': "list[<class 'int'>]"}
To convert images to jax Arrays, we can use jnp.array. This will materialize the array on the default device (which will be a GPU if you have one available).
img_array = jnp.array(train_ds[1]['image'], dtype=jnp.float32)
img_array.shape, img_array.max()
((28, 28), Array(255., dtype=float32))
We can see that these arrays don’t yet have a channel dimension, and that the values are between 0 and 255. We need to add a channel dimension and rescale them before giving them to the training loop. This gives us the batch iterator we saw in the MNIST tutorial.
def get_hf_batches(ds):
"""Yield batches of normalized (image, label) numpy arrays."""
for i in range(0, len(ds), batch_size):
batch = ds[i : i + batch_size]
if len(batch['label']) < batch_size: # drop incomplete final batch
break
images = jnp.stack([
jnp.array(img, dtype=jnp.float32)[None] / 255.0
for img in batch['image']
])
yield [images, jnp.array(batch['label'])]
Loading Data with PyTorch DataLoaders#
If you’re coming to Jax from PyTorch, you might want to use PyTorch’s data utilities instead. The process is pretty similar! This time, the “image to normalized array” transformation is already written for is: it’s called ToTensor.
# !pip install torch torchvision
from torch.utils import data
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
mnist_dataset = MNIST("data", download=True, transform=ToTensor())
Pytorch’s dataset doesn’t come pre-split into train and test sets, so we’ll have to do the splitting ourselves.
train_ds, test_ds = data.random_split(mnist_dataset, [0.8, 0.2])
To package each dataset into batches, we’ll use a DataLoader. Setting num_workers > 0 enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload.
Note: When setting num_workers > 0, you may see the following RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. This warning can be safely ignored since data loaders do not use JAX within the forked processes.
train_dataloader = data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
Iterating over a DataLoader yields batches of Pytorch tensors. We’ll need to convert them to Jax arrays before passing them to the training step.
jax.tree.map(lambda a: a.shape, next(iter(train_dataloader)))
[torch.Size([32, 1, 28, 28]), torch.Size([32])]
def get_pt_batches(ds):
for image, label in train_dataloader:
yield jax.dlpack.from_dlpack(image), jax.dlpack.from_dlpack(label)
Loading Data with TensorFlow Datasets (TFDS)#
This section demonstrates how to load the MNIST dataset using TFDS. Currently, while TFDS does not require TensorFlow to load datasets, it does require Tensorflow to download datasets. By default, TensorFlow will try to hog the GPU when it loads, preventing Jax from allocating arrays. To stop this, we have to explicitly tell TensorFlow to knock it off.
Once you’ve downloaded the datasets with an initial call to tfds.data_source, you no longer need TensorFlow. The exposed API looks almost identical to Hugging Face’s. Specifically, TFDS gives us a dictionary mapping from split names to datasets.
import tensorflow_datasets as tfds
import tensorflow as tf
from itertools import batched
# Ensuring CPU-Only Execution, disable any GPU usage(if applicable) for TF
tf.config.set_visible_devices([], device_type='GPU')
splits = tfds.data_source('mnist')
splits
WARNING:absl:Variant folder /root/tensorflow_datasets/mnist/3.0.1 has no dataset_info.json
Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
{Split('train'): ArrayRecordDataSource(name=mnist, split=Split('train'), decoders=None),
Split('test'): ArrayRecordDataSource(name=mnist, split=Split('test'), decoders=None)}
Indexing each split gives you a dictionary with separate keys for each feature (in this case, ‘image’ and ‘label’). For now, we’ll normalize and aggregate these into batches with pure python, but in the next section we’ll see how the grain data loader can make this process faster.
def get_tfds_batches():
for batch in batched(splits['train'], batch_size):
images = jnp.array([a['image'] for a in batch], dtype=jnp.float32) / 255
labels = jnp.array([a['label'] for a in batch])
yield images, labels
Loading Data with Grain#
In the Hugging Face and TFDS examples above, we’ve done our data processing (datatype conversion, batching and normalization) in pure Python. Due to the GIL, this means that these processing steps are always performed sequentially. The grain library allows you to do this loading and processing in parallel. You can use grain to accelerate Hugging Face datasets or TFDS, but it also works fine with raw ArrayRecord or Parquet files.
To start, we need to tell grain what order to iterate over the dataset using an IndexSampler.
import grain
splits = tfds.data_source('mnist')
sampler = grain.samplers.IndexSampler(
num_records=len(splits['train']),
num_epochs=2,
shuffle=True,
seed=0)
We describe our data transformations by subclassing the grain.transforms.Map class.
class ScalePixelVals(grain.transforms.Map):
def map(self, x: int) -> int:
x['image'] = x['image'].astype(jnp.float32) / 255
return x
Finally, we package the results together with a grain.DataLoader.
data_loader = grain.DataLoader(
data_source=splits['train'],
operations=[
ScalePixelVals(),
grain.transforms.Batch(batch_size)],
sampler=sampler,
worker_count=0)
def get_grain_batches():
for elt in data_loader:
yield elt['image'], elt['label']
Summary#
This notebook has introduced efficient strategies for data loading on a CPU with JAX, demonstrating how to integrate popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. Each library offers distinct advantages, enabling you to streamline the data loading process for machine learning tasks. By understanding the strengths of these methods, you can select the approach that best suits your project’s specific requirements.