The covid-19 pandemic is stretching hospital resources to the breaking point in many countries in the world. It’s no surprise that many people hope AI could speed up patient screening and ease the strain on clinical staff. This case study is not about covid-19 but diabetes which is a growing concern.
With 70 million people with diabetes, India has a growing concern with diabetic retinopathy. The disease creates damage or an abnormal change in the tissue in the back of the retina that can lead to total blindness, and 18 per cent of diabetic Indians already have the ailment. With 415 million diabetics at risk for blindness worldwide, the disease is a global concern.
But the good news is that permanent vision loss is not inevitable. If caught early, the disease can be treated; if not, it can lead to total blindness.
One of the common way to detect diabetic retinopathy is to have a specialist examine the pictures of the back of the eye and determine the disease presence and it’s severity. Severity is determined by the type of damage present. Specialized training is required to interpret these photographs.
Recent advances in Machine Learning and Computer Vision can improve the DR screening process. Deep Learning algorithms can interpret signs of DR in the retinal photographs, helping doctors screen more patients.
The data is obtained from the Kaggle competition APTOS 2019 Blindness Detection. The dataset contains a large set of retina images taken using fundus photography under a variety of lighting conditions. There are a total of 3662 retina images in the dataset. A clinician has rated each image on the scale of 0 to 4.
0 — No DR, 1 — Mild, 2 — Moderate, 3 — Severe, 4 — Proliferative DR
The evaluation metric for a Multi-class Classification problem could be a classification accuracy or an F-score. Kaggle competition had a defined evaluation metric — Quadratic weighted kappa.
Quadratic weighted kappa is a measurement of agreement that ranges from 0 (random) to 1 (perfect agreement). There is a better explanation available here.
The dataset is an imbalanced. There are a lot more images for healthy retina. Only 5% of total images belong to class 3 (severe DR).
In order to correct for data imbalance we will use class weighting.
Weight for class 0: 1.01
Weight for class 1: 4.95
Weight for class 2: 1.83
Weight for class 3: 9.49
Weight for class 4: 6.21
Let’s use TSNE visualization with perplexity of 40. Class 0 is separable bu the classes are not.
- We are defining the key configuration parameters.
2. Load the data
The tf.data API enables you to build complex input pipelines from simple, reusable pieces. To construct the dataset we are using tf.data.Dataset.from_tensor_slices(). We will transform this dataset into a new one by chaining methods.
Training images count: 2929
Validating images count: 733
As we have 5 labels, we will convert these into one hot tensors. For example, 2 will be converted to [0, 0, 1, 0, 0]. Also we have to map each filename to its label. We can do this using following methods.
Let’s visualize the shape of image and label.
Image shape: (320, 320, 3)
Label: [1. 0. 0. 0. 0.]
Let’s use buffered prefetching so we can get data from disk without having I/O getting blocked. We are using tf.image API for data augmentation.
Visualize the dataset after image augmentation.
- Define Callbacks
The checkpoint callback saves the best weights of the model, so that the next time we want to use the model, we do not have to train the model. The early stopping callback is used to stop training process if the model starts overfitting or becomes stagnant. Reduce LR on plateau callback is used to reduce learning rate when a metric stops improving.
We are initializing the model with pre-trained ImageNet weights.
For our use case, we have used accuracy as the metric which tells us the fraction of correct predictions. Since there are 5 classes, we are using categorical crossentropy as the loss function. We have also specified class weights as we discussed earlier.
Let’s plot the model accuracy and loss for the train and validation set. We can see that accuracy of our model is 83%. We can see our accuracy on validation data is lower than the train data which indicates overfitting.
Confusion matrix indicates classes 1, 3 and 4 are being misclassified as class 2. Maybe our model has not been able to detect the spots/hemorrhage that are present in classes 3 and 4(severe cases of DR).
HRNet is recently developed for the human pose detection but can be used to in Image Classification, Object Detection, etc. Code is provide by the researchers here. Official code is written with PyTorch. We had to rewrite the code in TensorFlow. HRNet maintains high-resolution representations through the whole process of connecting high-to-low resolution convolutions in parallel and produces strong high-resolution representations by repeatedly conducting fusions across parallel convolutions. Research paper is linked here.
For Image Classification, we need to replace the head with a softmax layer. You can find the code in my GitHub repository. The results from this model were not encouraging. We got a accuracy of 68%. We experimented with different loss functions and optimizer but we were not able to improve the performance.
Let’s see the classification report. In our case, Recall score for class 3 and 4 is very low which means that we are misidentifying these classes where cost associated is very high. We need to improve our model and recall scores for each class.
Let’s visualize the image with actual and predicted labels.
We can see the image and the model prediction with the probability score for each class.
Actual Label - Proliferative DR
Predicted Label - Moderate
The inference time is 0.45 seconds and rate is 0.01 predictions per second.
- Use data from other sources such as eyePACS/Messidor which could further improve our accuracy.
- Set up continuous integration system for our codebase, which will check functionality of code and evaluate the model about to be deployed.
- Package up the prediction system as a REST API and deploy it as a Docker container as a serverless function to Amazon Lambda.