3 Tips on how to train machine learning models efficiently when your data is noisy and mislabeled…
In this article, I would like to talk about 3 tricks that helped me to efficiently train models and win a silver medal in a kaggle competition where the dataset was mislabeled and contained a significant amount of noise.
Rule n° 1 in data science: Garbage In = Garbage Out.
Mislabeled data is part of real world data, not all the datasets are clean. Most datasets tend to have some amount of noise which can be challenging when training a machine learning model. The good news is that the Garbage In = Garbage Out rule can be overcome with some tricks that can help your model adapt to the mislabeled data.
Cassava leaf disease prediction: It’s a computer vision competition with a dataset of 21,367 labeled images of cassava plants. The aim of the competition was to classify each cassava image into four disease categories or a fifth category indicating a healthy leaf.
After a quick exploratory data analysis, I realized that some of the images were mislabeled, let’s have the example of the 2 images below:
We can clearly see that the 1st image contains diseased leaves while the 2nd one has healthy leaves . Well, both images were labeled as ‘healthy’ in this dataset, which makes the task of the model harder since it has to extract and learn the features of healthy and diseased leaves and assign them to the same class: Healthy.
In the following section, I would like to talk about 3 tricks I found useful to deal with noisy datasets:
Picking the right loss function is very critical in machine learning. It depends a lot on your data, task and metric. In this case, we have a multi-class classification (5 classes) with categorical accuracy as a metric. So, the first loss function that comes to mind is categorical cross-entropy.
However, we have a mislabeled dataset and the cross-entropy loss is very sensitive to outliers. Mislabeled images can stretch the decision boundaries and dominate the overall loss.
To solve this problem, Google AI researchers introduced a “bi-tempered” generalization of the logistic loss endowed with two tunable parameters that handle those situations well, which they call “temperatures” — t1, which characterizes boundedness, and t2 for tail-heaviness.
It’s a cross-entropy loss with 2 new tunable parameters t1 and t2. The standard cross-entropy can be recovered by setting both t1 and t2 equal to 1.
So, what happens when we tune t1 and t2 parameters?
Let’s understand what’s happening here:
- With small margin noise: The noise stretched the decision boundary in a heavy tailed form. This was solved with the Bi-Tempered loss by tuning the t2 parameter from t2=1 to t2=4.
- With large margin noise: The large noise stretched the decision boundary in a bounded way, covering more surface than the heavy tail in the case of small margin noise. The Bi-Tempered loss solved this by tuning the t1 parameter from t1=1 to t1=0.2.
- With random noise: Here, we can see both heavy tails and bounded decision boundaries, so both t1 and t2 are adjusted in the Bi-Tempered loss.
The best way to finetune the t1 and t2 parameters is by plotting your model’s decision boundary and checking if your decision boundary is heavy tailed, bounded or both, then tweak the t1 and t2 parameters accordingly.
If you are dealing with tabular data, you can use the Plot_decision_regions() function to visualize your model’s decision boundaries.
from mlxtend.plotting import plot_decision_regions# Plot decision boundary
plot_decision_regions(x=x_test, y=y_test, clf=model, legend=2)
You can learn more about the Bi-Tempered loss in the Google AI blog and their github repository.
If you are already familiar with knowledge distillation where knowledge transfer takes place from a teacher to a student model, self distillation is a very similar concept.
This new concept was introduced in the paper: Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self Distillation. The idea is so simple:
Self Distillation: You train your model and then you retrain it using itself as a teacher.
The paper discusses a more advanced approach that includes several loss functions and some architecture modifications (Additional bottleneck and fully connected layers). In this article. I’d like to introduce a much simpler approach.
I read about this approach in the first place solution of the plant pathology competition on kaggle, where the winner team used self-distillation to deal with the noisy dataset. You can check the code in their github repository.
Self Distillation in 3 steps:
- 1- Split your dataset to k folds cross-validation.
- 2- Train model 1 to predict the out of folds predictions.
- 3- After saving the out of folds predictions predicted by our model, we load them and blend them with the original labels. The blending coefficients are tunable, the original labels should have a higher coefficient.
The out of fold predictions here are class probabilities predicted by model 1:
- In this particular example we have a multiclass classification with 5 classes [0,1,2,3,4].
- The labels are one hot encoded. Class 2 is represented as [0,0,1,0,0].
- Model 1 predicted the class 2 correctly: [0.1, 0.1 ,0.4 ,0.1 ,0.3], giving it a probablity of 0.4, higher than the other classes. But, it also gave class 5 a high probability of 0.3.
- Model 2 will use this information to improve its predictions.
Ensemble learning is well known to improve the quality of predictions in general. In the case of noisy datasets it can be very helpful because each model has a different architecture and learns different patterns.
I was planning to try Vision Transformer models released by Google AI in the paper: An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale, and this competition was the perfect place to try and learn more about them since they introduce a new concept in computer vision that is different than convolutional neural networks that are dominating the field.
In short, the ensemble of a vision transformer model with 2 different CNN architectures improves the predictions quality of single models: