
Data scarcity is when a) there is limited amount or a complete lack of labeled training data, or b) lack of data for a given label compared to the other labels (a.k.a data imbalance). Larger technology companies tend to have access to abundant data although they could encounter data imbalance. Smaller technology companies typically suffer from the limited availability of labeled training data. In this post, I will share my notes on how I approached scenarios where the amount of training data available was low.
When training a machine learning model, the general rule of thumb is to have at least 10 examples for each degree of freedom in your dataset (this post provides a nice overview of what degrees of freedom are). As the degrees of freedom increases so does the requirement for the amount of data needed to train a reasonable model.
Furthermore, it is important to consider the amount of training data available when choosing the model type. Fig. 2 compares an offline model performance metric for traditional machine learning models (example: logistic regression) v/s deep learning methods. Deep learning performs much better when there is a lot more training data. On the other hand, there is a high likelihood of this type of model overfitting when data size is small.
Let’s walk through some scenarios and the approaches that would work for them.
Let’s consider a few examples where data availability is likely to be low: 1) a startup experimenting with the idea of adding a machine learning model to its product and 2) a team in a large company that created a fairly new product and wants to apply machine learning to optimize a certain problem. The analogy of a seedling aptly describes this scenario (Fig. 3). How does one approach building a machine learning model in such cases? Having been in similar situations, I have learnt that starting with simple heuristics work well. Heuristics have a few advantages:
- They do not require a significant amount of data and can be created with intuition or domain knowledge.
- They have a high degree of interpretability. Usually the question “why did a model predict something as spam?” can be answered by looking at the code or logging the code path that was triggered for a given query.
- Generally, they do not suffer from the problems and complexities of a typical machine learning pipeline: data skew, code skew, model staleness, sensitivity to variance in feature distribution etc.
To explain this further, let’s dive deeper into the startup example — say that the startup’s main product is a mobile application that pulls relevant local news given a device’s geo-location. Building a machine learning model here would be hard because there is no labelled data. However, a heuristic model is still possible. We can think of a few signals to determine the rank of a given news article: 1) Relevance score after the geo-location match (the score returned by the underlying IR system like ElasticSearch), 2) article recency based on publish time and 3) a pre-determined popularity score of a publisher (say the range is [1,5] where 5 is most popular and 1 is least popular). These signals can be combined using a linear function:
w1 * f1 + w2 * f2 + w3 * f3
where {w1, w2, w3} denote a numerical notion of the emphasis (a.k.a weight) of a particular signal and {f1, f2, f3} are the 3 signals mentioned above. The logic to retrieve a news article and rank it can be expressed in two phases:
- Retrieve matching news articles (along with their relevance scores) for a given geo-location. Drop articles that have a relevance score below a threshold (threshold can be picked using anecdotal examples)
- Apply the linear heuristic model on the resulting set of news articles to generate a “score”. Sort articles on this score to generate the final result.
The weights in the heuristic model can be tuned appropriately until a qualitative analysis of the news feed appears to be optimal. This approach works well to get the product off the ground. As additional data becomes available, intuition/domain knowledge can be combined with insights from the data to further tune the heuristic model. You could run logistic regression to find the right weights once you have enough user interaction data. The step after that would be to setup a model retraining pipeline. The more complex your setup gets the more attention you would need to pay to things like data quality, model performance quality etc. The key point to take away from this section is that it is absolutely acceptable to start with heuristics to optimize a certain product experience especially when data size is low.
There are scenarios that require the use of the state-of-the-art techniques for a product specific problem and where heuristics are not an option to start with. For example, object recognition in images — given an image identify people, animals, buildings, vehicles etc. In this case, there are several API providers: AWS Rekognition, Google Cloud’s Vision AI to name a few that offer APIs that can be used to detect objects in images, perform face recognition, detect text from an image etc. Breaking down your problem in a way that can utilize these APIs can be a good way to get off the ground. Fig. 4 provides an analogy for this situation: you can buy fruits from a farmer instead of growing them on your own. Let’s take an example here: you want to build an application that identifies the specific brand (let’s limit brands to Adidas, Nike and Puma) of shoes worn by people in an image. One way to design the solution for this problem is to use vision APIs to detect shoes in the image along with the text on the shoes. You can use the resulting information (text on the shoe plus the detected shoe object) to infer that a shoe belongs to a specific brand. This path is easier to implement compared to the alternative that would involve building a model that can detect specific brands of shoes directly given an input image. The cost implications are usually small given the availability of free tiers from the popular API providers.
One way to tackle the lack of labelled data is to artificially manufacture synthetic data for your specific problem. Synthetic data is used in machine learning to capture examples that are not seen in the training dataset but are theoretically possible. For example, when building an object recognition application, a user might point the camera at an object (say a deer) at an angle instead of directly pointing at it. In order to correctly recognize the object, the original object image can be rotated and inserted into the dataset as a new example with the same label. This would help the model learn that an image taken at a different angle corresponds to the same object.
Another case where synthetic data is useful is when your dataset is heavily imbalanced i.e., you have an over-representation of a particular class compared to the other class(es). An example of this is the email spam detection problem where the number of positive examples are small compared to the negative examples (spam is annoying but rare). Typically, spam ranges between 0.1% to 2% of the entire dataset. In this case, using a technique like SMOTE can be useful to generate more examples of the rare labels in your dataset. However, there are downsides to this — the generated data may not be representative of real world data and that could cause problems with model performance.
Wouldn’t it be magical if you could use abundantly available non-domain specific data and train a model that would be able to work for a domain specific task? This is possible using transfer learning. This technique deserves a whole blog post by itself but I will try to provide a brief overview here.
Consider an example: You would like to train a classifier that would classify a tweet into positive, negative or neutral sentiment. You only have a few 100 examples of positive, negative and neutral tweets that you have manually labelled (by yourself or using an external labelling service like Amazon Mechanical Turk). You are aiming to train a model that is at par with the state-of-the-art sentiment classification models.
In this case, you could utilize the transfer learning technique described in the ULMFiT paper (I have used ULMFiT in the past as well). The idea is to first learn a base language model. A language model is essentially a model that predicts the probability of the next word in the sequence given an input sequence of words. A wikipedia dataset can be leveraged to train a language model (the most recent dataset as of Jan 2021 has over 3.7 Billion words). In this first step, the resulting language model learns language representations in the general domain (wikipedia). The second step is to “fine tune” the first model on the target dataset (using your few 100 tweet examples). The fine tuning step involves the same setup (language modeling) as the first step with the difference being the model is now being trained to capture idiosyncrasies in your target dataset (tweets). The intuition here is the language in tweets is different from the language in wikipedia articles, hence there is a need to capture this difference. The third and the final step is to train the fine-tuned model to be able to classify sentiment instead of predicting the next word. One way of doing this is by removing the final classification layer of the fine tuned model and replacing it with a classification layer for the specific task — sentiment prediction that is handled by a Softmax layer for the multi-class classification problem. Fig. 7 (an image taken from the ULMFiT paper) shows the 3 steps in training the neural network. Please refer to the paper [1] for more details.
Research has shown that transfer learning helps improve target task accuracy by a substantial margin. Intuitively, this makes sense. An analogy is a person who knows how to walk/run can learn how to ice skate/ski much faster than a person who tries to learn how to ice skate/ski directly.
Finally, to end with a plant analogy, Fig. 6 shows a cherry tree being grafted. Grafting has similar advantages: a) to strengthen plant’s resistance to certain diseases b) to adapt varieties to adverse soil or climatic condition by leveraging the adapted host plant etc.
In sum, the following are the general set of steps I have followed when dealing with low amounts of training data:
- Use heuristics if possible. For example, hand tuned weights by a domain expert that optimize item ranking are a good start.
- Delegate if possible. Break down the problem in a way such that external APIs can be leveraged.
- Experiment with synthetic data especially in cases where datasets are heavily imbalanced.
- Experiment with transfer learning. Be aware that the time investment in this step can be high.
How do you deal with data scarcity? Feel free to leave a comment with your thoughts.
- Howard, J., & Ruder, S. (2018). Universal language model fine-tuning for text classification. arXiv preprint arXiv:1801.06146. https://arxiv.org/abs/1801.06146v5