How to quickly set up multi-GPU training for hyperparameter optimisation with PyTorch Lightning
Probably most people reading this post has at least once trained an ML model that took a considerable amount of time. This was certainly the case for me when I was fine-tuning RoBERTa for high accuracy text classification. Utilising GPU for training can shave off an order of magnitude from the training time, depending on the data and the model. This is frequently the case for convolutional neural networks (CNNs) and image or video-stream data if passing a batch to GPU takes less time than passing it through the model.
One head is good, two is better. This is certainly the case when more than one GPU is available. For me one of the most appealing features of PyTorch Lightning is a seamless multi-GPU training capability, which requires minimal code modification. PyTorch Lightning is a wrapper on top of PyTorch that aims at standardising routine sections of ML model implementation. Extra speed boost from additional GPUs comes especially handy for time-consuming task such as hyperparameter tuning. In this piece I would like to share my experience of using PyTorch Lightining and Optuna, a python library for automated hyperparameter tuning.
For this example I selected Intel Image Classification dataset, available from Kaggle. Complete code can be found in this notebook. The dataset is split into training and test subsets, 14034 images in each. There are six classes in the dataset: mountain, glacier, sea, street, buildings and forest. Here are some of the sample images:
Most of the images are 150 by 150 pixels, with a small number of outliers in width. Pytorch Lightning provides a base class LightningDataModule to download, prepare and serve a dataset of choice to a model. It defines methods that are called in during training, validation and testing. My child class looks like this:
where transforms
are imported from torchvision
. A random horizontal flip is used for augmentation only for training data;ToTensor()
in preprocessing transforms scales all channels to [0, 1] interval (by dividing over 255). This class also splits the train data into training and validation sets with 20% of images going to the latter. Another great thing about PyTorch Lightning is that there is no need to specify device(s) when initialising IntelDataModule
as everything is done automatically by the trainer later.
To start, lets put together a basic CNN. It will serve as a benchmark to compare the optimised model against. To reduce code repetitions lets introduce a basic building block for our model
This is just a convolutional layer with an optional batch norm and an activation function. The benchmark model would look like this
The model does not have any Linear
layers and is fully convolutional, a nice trick that adds a bit of speed on GPU from the book “Hands-on Machine Learning with scikit-learn, Keras and Tensorflow” by Aurelien Geron. Class methods ..._step
and ..._epoch_end
of the base class LightningModule
define actions to be taken at the end of training/validation/testing steps and epochs respectively. This is the place to pass the output to metric, log quantities of interest. I chose not to log after each step but rather at the end of an epoch. This is also a place where metrics were calculated, which can save a little of computation time. Training the model in PyTorch Lightning is performed with a trainer class and is as simple as this
Most of the arguments are self-explanatory with deterministic
parameter controlling reproducibility. For the demo purposes I chose a fixed learning rate of 0.0001 for 20 epochs. This Benchmark model achieves 0.8213 accuracy on the test dataset, here is the confusion matrix
The model seems to often confuse glaciers with mountains. Since the former often come with mountains on the background, it is hard to blame it. Overall, it is a decent, but probably not the best result. One of the ways to improve it is via a hyperparameter tuning.
Automated selection of hyperparameters, including layers types, can be simplified with nn.ModuleDict
modules from Pytorch:
where Residual
is a fully convolutional module from the “Deep Residual Learning for Image Recognition” paper. A tunable model to be optimised looks like this
In the above code ..._epoch_end
methods are identical to the Benchmark model and dropped for brevity. Some additional methods ..._step_end
need to be defined though. These methods are called after all pieces of the batch were evaluated at every devices and results are aggregated here. The backbone
is the component of the Model that is going to be optimised. It is created with a function that selects hyperparameters from pre-defined sets and passes them to conv_block
The probability distributions for each parameter are controlled by Optuna, which is a swiss-army-knife library for finetuning Pytorch, Tensorflow, Scikit-learn models among others. It is simple to set up and easy to use. It requires a user-defined objective function that takes a trial
object, responsible for hyperparameter selection. It creates a model from selected hyperparameters and uses it to return the score to be optimised. After the above function create_backbone
creates a tunable component for a Model instance, it is trained and the validation accuracy is used as an objective score:
Each model in the above function is trained on two GPUs, which is controlled by gpus=2
and accelerator='dp'
parameters. The value 'dp'
indicates that all GPUs are located on a single computer. After defining the objective
, running the tuning procedure is straightforward
Sampler object in the above equation determines the strategy to sample hyperparameters. TPESampler
samples more points from the region with hyperparameters that received higher scores and updates those regions after each trial
. Similar to Sklearn’s GridSearchCV
best parameters and score can be extracted using study.best_score
and study.best_params
. With optimised parameters, the accuracy has improved by 1.27% to 0.8340:
I certainly like PyTorch Lightning for the noble goal of simplifying and standardising ML model creation. A seamless scalability of distributed training that one gets almost for free is particularly useful. An obvious place to apply a training speed gain is hyperparameter optimisation, that Optuna helps to implement. While been one of many such libraries, Optuna is simple to set up for models from almost any framework under the sky. I hope this small example helps you to scale the mountains that stand between your good and awesome models!