Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                

Active Learning boost for your ML problem

You can also look at my presentation and video at FWdays conference about Active Learning and Weak Supervision usage to solve NLP problems

Mariia Havrylovych
Towards Data Science
12 min readFeb 4, 2021

--

Meme by the Author using the photo on imgflip.com

It’s hard to disagree that the most widespread and effective way to solve machine learning (ML) problems is vanilla supervised learning. During the last decade, algorithms improved a lot, especially Deep Learning models. However, model success is highly dependent on data quality and quantity, which, in turn, requires time and financial and human resources. And here the complications may appear.

If you are interested in developing a successful data science solution but having a low-resource problem or insufficient resources to get labels, then these materials are right for you! And even if you or your company has enough (financial) resources to obtain labels, this post still can be beneficial. Even in such a big company as Wix, where we have a whole in-house labeling department, we still experience issues with getting quality labeled data in a satisfactory amount.

Consider possible scenarios:

  • your problem has domain-specific data, like juridical documents or bio-data. To get the labels, you need highly qualified experts;
  • you cannot label your data with a cheap crowdsourcing platform because it has sensitive information under NDA;
  • you don’t have a lot of data — for example, imbalanced dataset and rare classes;
  • you need to train a more complex model, like neural networks, that are particularly hungry for data;
  • the labeling process is time-consuming, and you want somehow to speed it up.

So, no matter what scenario happened above, you come up with not enough labeled data.

Here Active Learning comes into the spotlight!

With this fantastic tool, you may need fewer manually labeled samples in your data science task without harming the performance or even increase it!

At Wix.com, we use it for text classification (business categorization) problems, sentiment analysis on user support requests, objects, and portraits segmentation.

Also, in addition to the post, you can review the prepared Colab notebook and get your hands dirty with coding:)

Let’s start!

Table of content

  1. Dataset and starting components
  2. Active Learning Framework
  3. Query strategies for Active Learning
    Basic query strategies
    How to choose query strategy
    How to customize query strategy
  4. Full pipeline
  5. Results
  6. Summary
  7. References

Dataset and starting components

In this post, for the sake of simplicity and comprehension, we will use the classic SMS Spam Collection dataset [1][2] to detect spam messages. The dataset contains ~5200 samples. The class proportion is shown in fig. 1 below.

We divide this dataset into three parts:

  • 20% is taken as a test set.
  • ~75% was taken as an SMS pool, where we imagine that there are no labels.
  • The rest (~200 samples) is taken as a labeled initial train set — seed.

We will mimic the label shortage case with an initial small train set and show how active learning will tackle this problem!

As a model, we would take the TF-IDF vectorizer and simple logistic regression. As a monitored metric — F1-score for the “spam” class.

Figure 1 — The class proportion in the SMS spam dataset — Figure by Author

Active Learning Framework

Having endless resources, time and money for model training — you will just collect all available data signals and assign a label to each of them. But, regrettably, that’s not feasible. Now, imagine that your data science model is so smart that it can decide what data it likes and wants to train on. Of course the model does not have emotions regarding data but it is still a good idea to train the model only on samples that are significantly impacting its performance. That’s actually what active learning does.

The main idea behind AL is to allow our model to select the data for training in the “smart” way. A similar idea is used in the AdaBoost algorithm, where it assigns a higher weight to the more informative samples, i.e. the ones the model makes more mistakes. We will feed the model with the most informative and complex examples. Therefore the training process’s costs will be optimized because of labeling fewer data points.

There are different types of active learning [3]. We will concentrate on the most popular one — the pool-based type of active learning (AL).

Let’s try to figure out where actually is the difference between regular ML and ML with active learning usage.

Figure 2 — Classic and Active Learning ML comparison — Image by the Author
Figure 3— Active learning cycle — Image by the Author, inspired by [3]

As you see from the figures above — the difference between classic supervised training and pool-based active learning is in the query strategy — how you choose the train set data for labeling.

Also, do not forget that the ML process is iterative in its nature — if model evaluation results are not good enough — you go back and decide whether add more data or play with the model. Although in regular ML you can probably end with the first iteration, in active learning you for sure will have a couple of them— you will set batch size parameter— how many data samples you will label in one iteration.

Query strategies for Active Learning

The best way to understand the algorithm is to code it. Thus we will implement selected query strategies here.

You can try various active learning software packages (like modAL or ALiPy) with already implemented different query strategies. However, if you do not need sophisticated ones or want to create a fast POC ⇒ code it yourself.

I used ideas for implementation from this tutorial [5].

First, we need to define a meta-class with only one abstract method — “selection”.

Basic query strategies

Random selection

The naive way we used to select data for labeling is a random sampling — just randomly query samples from the unlabeled pool — which is, in fact, passive learning. Compared to other non-random selections, it has a significant advantage because your train set distribution will reflect reality, although it has other apparent problems like high cost.

Below look at how the “selection” method will be implemented (pool_len is the number of data samples in the unlabeled pool and num_samples — our batch size, how many samples we will label in one iteration):

Uncertainty sampling: Margin selection

Another most popular group of query strategies is uncertainty sampling — selecting samples where the model is uncertain. It is a pretty straightforward concept for models supporting probabilistic output, like Logistic Regression, and less for models that just output some score in the [0,1] range.

In the case of margin selection one of the uncertainty sampling query strategy, the most uncertain sample will be the ones with the lowest difference between the two topmost confident predictions (with the highest probability).

The implementation of margin selection step by step:

  1. For each sample in the unlabelled pool, sort its predicted probabilities from largest to smallest:

2. Calculate the difference between 2 maximum probabilities:

3. Take the samples (in num_samples amount) with the smallest values — the most uncertain ones:

To read more about query strategies — read this Active Learning handbook [3].

To know more about active learning python packages — look at this survey [4].

How to choose a query strategy

Different query strategies will select different samples; therefore, there is a need to choose from which your model will benefit the most.

The model does not support probabilistic output. In such case you may consider query strategies that do not use probabilities, like query by committee. Also, keep in mind that: ‘the least certain instance lies on the classification boundary, but is not “representative” of other instances in the distribution, so knowing its label is unlikely to improve accuracy on the data as a whole’ [3]. That’s why even if your model has probabilistic output — you may want to use not only uncertainty based query strategies.

Imbalanced datasets. When you have an anomaly/fraud detection task or imbalanced dataset — the “minor” class in such cases is often more significant. In such a situation, similarity- or density-based query strategies [6] can be more useful.

Scattered or high-density features distribution. If the query strategy tries to increase the data variance and takes the most dissimilar samples, it can query outliers, which will not improve model performance.

With every iteration, model performance does not change. Some query strategies are greedy — in one iteration, they will select data points similar in their informativeness. Consequently, we need to choose a query strategy that will select not only the most informative data but diverse in its informativeness. For example, you can combine a couple of query strategies or use batch-aware strategies [7].

Do not forget about the cold start. In the beginning, try to play with a couple of query strategies and not concentrate on some specific ones.

Fortunately, you are not limited to existing query strategies. Moreover, you can create query strategies that are suitable for your specific problem. Do not be afraid to use customized query strategies, for example, based on some rules or heuristics. You will see: it is easy to do!

How to customize query strategy

As you already know, the central point is to give the model the most valuable data instances. It is for you to decide what samples could be informative in a specific problem. You may be interested in a particular class, which your model poorly detects. Alternatively, after error analysis, you already know model weak spots and are going to work with them.

The classic uncertainty-based query strategy selects samples near 0.5 predicted probability value. However, what if you change your decision threshold from default 0.5 to another? Now you may be interested in probability values that fall in areas near your threshold. Then you can create the strategy which will select samples close to your threshold.

Maximum errors in the interval

Another way to understand where the model performs worse is to look at the error distribution histogram per different predicted probability value and find the most problematic probabilities range with the maximum misclassification amount.

Figure 4 — Example of error distribution due to different probability range — Figure by Author

For example, on the plot above we can say that the (0.3, 0.4] probabilities values interval is the one with the biggest amount of errors.

Based on this, we will implement a query strategy, that selects the samples from the probability range with the highest errors numbers:

  1. Split model-predicted probabilities into 10 (feel free to set another number) equal-sized intervals.

2. On the train set, make the predictions and calculate the number of errors in each interval.

3. Select the interval which has the highest amount of errors.

4. Take from the unlabelled pool samples, which predicted probabilities fall in the selected interval from 3 step.

This “maximum error” query strategy can be used for any classification problem.

Text similarity query strategy

At Wix, for the website categorization task — we met the problem that there are billions of Wix sites — and to find e.g. Yoga instructors sites you will need to randomly select and label a huge amount of data. Therefore, having the small seed of the Yoga sites examples — we select for labeling only the sites which have similar text to seed. As follows, we were able to label and construct the datasets for a lot of categories at a much lower cost!

Now, try to answer the question: which samples would be the most informative specifically for SMS spam detection?

In most Spam detection applications, proper “spam” detection is crucial, and because we have only 13% of spam samples — we can say that the “spam” class is much more critical than regular non-spam samples.

So here is a similarity-based query strategy that uses text embedding to calculate the similarity of data points and select specific samples due to your needs.

We can use simple TF-IDF or BOW here, pre-trained word embeddings, like Universal Sentence Encoder from TensorFlow or any other state of the art pre-trained embeddings (BERT, LaBSE, etc.). To get the most similar samples for our “spam” class using text embeddings, we also have to define a similarity metric. We decide to go with Universal Sentence Encoder embeddings and cosine similarity metric.

Now, let’s code this query strategy step by step:

0. First, prepare the embedding function.

  1. Get the “spam” samples from our already trained data and embed them.

2. Calculate embeddings for our unlabeled pool.

3. Calculate cosine similarity between “spam” samples from 1 step and pool embeddings from a 2 step.

4. Get the samples from the unlabeled pool with the highest similarity.

That’s all! You can modify this query strategy by selecting the most dissimilar samples to the whole train set (to increase dataset variance) or select samples that are similar to all classes (if consider that such samples are the most tricky).

As you see, customizing query strategy is not rocket science, but the area where you can use your creativity and have fun!

Full pipeline

Now that we have implemented query strategies, it is time to put them in the active learning cycle and train the model.

  1. Define a classifier:

2. Train the classifier on our initial small labeled set.

3. Do the predictions on the unlabeled pool:

4. Make the selection with the chosen query strategy (e.g. margin selection):

5. Add selected data samples with labels to the already existed initial train set:

6. Delete selected items from the pool (because we already add these data points to the train set):

7. Train the classifier on the updated train set and calculate updated metrics.

8. Repeat steps 2–7 until you will achieve satisfactory results or run out of resources.

Results

Finally, let’s review the model performance on a test set with different query strategies.

Figure 5 — Active learning results — Figure by Author

In fig. 5, we can see that if we label the whole unlabeled pool, we get an 87% F1-score. So, 87% F1 will be the score baseline — because it is the performance you get if you label all data.

With margin selection, the same performance is achieved on 14 iterations (140 data samples). With our initial test in 207 rows, it will be 347 samples in total. That’s more than ten times fewer data!

In the table below, you can see how many labels will be needed in order to get 87% baseline F1 score with different query strategies. The margin selection performance is almost the same as maximum errors in the interval selection, while text similarity selection is slightly worse — but still very good.

Table 1: Active learning experiments results

To conclude, no matter which query strategy we use, we get good performance with much fewer data labels.

In the SMS case, the best was the margin selection strategy, but keep in mind — we conduct an experiment on a pretty simple classification problem and use a Logistic Regression model, which gives the well-calibrated probabilistic output. In a real case scenario, probably you will not get such a good performance only with an uncertainty-based query strategy.

Summary

Let’s sum up all that we have done above:

  • Active learning may significantly reduce the amount of labeling data points (almost ten times fewer data on our example!);
  • Active learning query strategies can be easily customized and optimized for specific tasks.

There are a lot of ways to overcome data bottlenecks in your data science task. You just need to select one that is most suitable for your problem. In this post, we successfully tackle it with active learning tools.

Acknowledgments
Many thanks to Gilad Barkan, Olga Diadenko, and Lior Sidi for feedback on earlier versions of this post!

References

[1] UCI Machine Learning Repository: SMS Spam Collection Data Set
[2] Almeida, T.A., GÃmez Hidalgo, J.M., Yamakami, A. Contributions to the Study of SMS Spam Filtering: New Collection and Results (2011), Proceedings of the 2011 ACM Symposium on Document Engineering (DOCENG’11).
[3] Burr Settles. Active Learning Literature Survey (2010), Computer Sciences Technical Report 1648, University of Wisconsin–Madison.
[4] Alexandre Abraham. A Proactive Look at Active Learning Packages — data from the trenches (2020).
[5] Ori Kohen. Active Learning Tutorial(2018).
[6] Yukun Chen, Thomas A.Lasko, Qiaozhu Mei, Joshua C.Denny, Hua Xu. A study of active learning methods for named entity recognition in clinical text (2015), Journal of Biomedical Informatics, Volume 58, pp.11-18.
[7] Daniel Gissin. Batch Active Learning | Discriminative Active Learning (2018).

--

--