Malaria is parasitic disease caused by plasmodium parasites and transmitted through a bite of infected mosquitoes. If left untreated, malaria is a fatal disease. In 2019, it was estimated that there were 229 million cases of malaria worldwide (WHO).
Timely diagnosis of malaria is key for effective treatment and control of malaria. Diagnosis is done either through microscopy of rapid diagnostic tests. Use of machine learning techniques in diagnosis of malaria is currently being explored and in this study, we explore the use of Convolutional Neural Networks (CNN) in detection of malaria parasites.
- To obtain and prepare malaria parasites dataset for use with CNN.
- To evaluate the optimal hyper parameters for the use of CNN in malaria parasite detection.
Malaria parasite dataset download In order to get a suitable dataset for use in this project, a search was done online then narrowed down to detailed search in Kaggle and FastAI websites. A suitable dataset was obtained from kaggle and was readily available for download from the this link. In the code section below, we use opendatasets library to download the code directly from kaggle.
For successful download of this dataset, kaggle authentication parameter was downloaded and uploaded into collab environment.
Python libraries import
In order to prepare our code environment, we import pre-requisite python libraries in the code section below.
Malaria parasite dataset exploration
After downloading the dataset, we explore the dataset in order to fully understand the type of data downloaded. We get to know the number of images available in the dataset for both infected and non-infected images. We also setup a DATA_DIR variable to contain both classes of images.
Transformation of images to PyTorch Tensor dataset
After randomly looking at sample images, it was noted that the images varied in sizes and as such, we need to perform some transformations in order to standardize the images before conversion into PyTorch tensor dataset. We resize the images to 112 X 112 pixels size by using a size of 108 and padding of 2 pixels on both sides. We then use the transformation to generate our dataset.
The resulting images shape after transformation becomes (3,112,112) and below, we can view sample image from the dataset.
Dataset preparation for training
Given that from the downloaded dataset, we do not have separate test or validation images, we will split the data into training, validation and test datasets. We set a side 5000 images to use for validation and 1000 images to use for testing. The rest of images is used to train the model
Using a random seed of 43, we use the random_split method to create the training, validation and test datasets.
We can now create data loaders to load the data in batches of 64 or 128.
We then make_grid function from torch.vision to view the first batch of images from our train_loader dataset.
Base model class and training on GPU
In building the CNN classification model for malaria parasites, we start defining a generic base model containing training_step method, validation_step method, validation_epoch_end method and lastly, epoch_end method. We will later extend this class to define the final classification model.
Below, we define MalariaCNNModel class using the nn.Sequential class to chain the layers and activation functions into a single network architecture.
Configure GPU utilities
We use GPU to run the model and first we check if we have GPU available in our environment. We then define get_default_device() method for use in running the model.
Loading data to device
We define, DeviceDataLoader class for use in loading the model for execution in the default device.
We configure helper functions for plotting the losses and accuracies
We move our training and validation data loaders to available device.
Training the model
In order to evaluate optimal settings for this CNN model, we will vary the batch_size, epoch_size, learning_rate and record the same each time the model is run.
We instantiate the model and load it on to a device. We then record initial model loss and accuracy using initial weights. Given that our output classes are only two, we expect initial accuracy to be around 50%.
We set model parameters and log the parameters using jovian.log_hyperparams. We will review the params three times in order to identify optimal settings.
Finally we train the model using the hyper_params and below we display the validation_loss and validtion_accuracy based on the used parameters.
Loss vs Number of epochs graph
Accuracy vs number of epochs graph
Testing the model
Finally, we test the model using test data and randomly select three images from the dataset for displaying the predicted results from the model. Given that from the validation phase, the model accuracy is above 95% in most cases, it is expected that the randomly selected images will be classified correctly and that was the case based on the results below.
When we evaluate the model accuracy and loss using test data, the model achieves 95.9% accuracy.
In this project, we built a CNN model to predict malaria parasites presence or absence from sample malaria dataset images. We especially tested the model using various hyper parameters in order to identify optimum performance based on the number of training epochs, batch size and learning rate.
Based on our observation, the model achieves an accuracy above 96% using a learning rate of 0.001 where as the accuracy is about 50% when using a learning rate of 0.01. This implies that learning rate selection is key determinant of the model performance and must be carefully selected in order to achieve best prediction results in a malaria parasite detection classification problem.
Using a batch size of 64, our model achieved an accuracy of 95.78% where as with a batch size of 128, the model accuracy was consistently above 96%. This shows that batch size selection has a slight impact in the overall model accuracy.
With epoch size of both 5 and 10, the model was able to achieve an ccuracy above 96% and as such we can conclude that the any epoch size above 5 will not make any significant difference in the final model accuracy results.
- WHO malaria fact sheet, available at: https://www.who.int/news-room/fact-sheets/detail/malaria
- Kaggle malaria parasite dataset available at: malaria parasite dataset
- Jovian AI Zero to GANs course materials available at https://jovian.ai/learn/deep-learning-with-pytorch-zero-to-gans