We’re introducing the BERT deep learning architecture for text data to Azure Automated ML. This model usually performs much better than older machine learning techniques that rely on bag of words-style features for text classification. BERT, which is both a neural net architecture and a particular transfer learning technique, has had a huge impact on large and small companies (example use cases include Microsoft, Google, Stitch Fix). Since Automated ML uses a BERT model that has already been pretrained on a large corpus of text, the user of Automated ML doesn't need very much training data to see a lot of benefit (even ~ 100s of rows are okay in some circumstances), which can be very valuable if labeled data is hard or expensive to acquire. We’ve implemented BERT in automated ML in such a way that it is first fine-tuned on the dataset the user provides, and then automated ML uses the embeddings from the fine-tuned BERT model as text features for other ML algorithms like logistic regression or LightGBM. Our implementation of BERT uses the popular transformers repository (paper, github).
What does deep learning for text data do that older techniques don’t do?
Deep learning for text data gives you more accurate models compared to bag of words-based approaches to handling text data. Bag of words-style features used for text in Automated ML include unigrams, bigrams, and tri-character grams. Another text feature sometimes used in automated ML are static pretrained word embeddings. It’s not big news anymore that some deep learning architecture outperforms older “shallower” learning techniques, but what really grabbed our attention was seeing how much BERT tends to outperform bag of words type approaches on small training data (100s of rows) as compared to larger data of 10000 rows or more. In general, we’ve found that BERT can often get the same accuracy as the bag of words approach with only 1/10 of the data! This ability to learn from small data can really benefit your product if, as is often the case, labeled data is difficult or expensive to get. To illustrate this, we ran Automated ML model with and without BERT on a four class news dataset (called AG News) and plot the learning curve in Figure 1. To illustrate the value of pretraining (both through BERT and pretrained word embeddings), we also trained a logistic regression model with unigram and bigram features as a simple baseline. Notably, automated ML with BERT achieves 94.7% accuracy on AG News when trained with 120k rows, which would put it at 4th place on this leaderboard for ag news as of this writing. To ensure that training does not take too long and to avoid GPU memory issues, automated ML uses a smaller BERT model (called bert-base-uncased) that will run on any Azure GPU VM.
What's so special about BERT that makes it so much better than bag of words?
Theoretically BERT has a big advantage over bag of words methods in that BERT is sensitive to word order and long range word interdependencies (e.g. the meaning of "it" might depend on a particular word 10s of words to the left of "it"), but we were curious how this difference plays out with real world data. So we examined how well BERT vs bag of words does in our hold out set for this news classification dataset. After looking at many examples and the predictions from BERT versus simpler methods, we did notice a pattern: When a news article is misclassified by bag of words methods, it often contains one to a few words that shift the meaning of the entire document. To illustrate this, here’s two examples that we constructed and fed into the models trained on the ag news dataset.
1. “The two players were evenly matched, but the first player's skill with the joystick pushed her over the top.”
2. "The two players were evenly matched, but the first player's skill with the hockey stick pushed her over the top. "
The sentence with "hockey stick" is easy to classify as being about "sports", and indeed the bag of words approach and BERT correctly classify it. The sentence with "joystick" is harder to classify because it has a lot of words ordinarily associated with sports, but it's actually in the "science & tech" category. For the "joystick" sentence, BERT correctly predicts that it's in the "science & tech" category, while the bag of words-based model incorrectly predicts the "joystick" sentence as being about sports. What’s particularly interesting about this example is how just one word, “joystick”, or phrase, “hockey stick”, dramatically changes the meaning of the document from being about video games to being about physical sports. This is why the bag of words approach fails for the "joystick" sentence, the preponderance of strongly “sports”-like words pushes the prediction to sports. BERT, on the other hand, doesn’t just model words as a static collection of distinct things (aka “bag of words”), but rather contains sophisticated mechanisms that ensure the features of one word both depend on that word’s position in a document and also depend on the other words in the document around it. This way it can know that a “player” in this document is actually a video game “player” by virtue of the presence of “joystick” elsewhere in the document. This sophistication in BERT, we hypothesize, is the reason BERT isn’t fooled into thinking a video game competition text snippet like example 1 is about physical sports.
How Is BERT integrated into Automated ML?
BERT is used in the featurization layer of Automated ML. In this layer we detect if a column contains free text or other types of data like timestamps or simple numbers and we featurize accordingly. For BERT we fine-tune/train the model by utilizing the user-provided labels, then we output document embeddings (for BERT these are the final hidden state associated with the special [CLS] token) as features alongside other features like timestamp-based features (e.g. day of week) or numbers that many typical datasets have. Please see Figure 2 for a schematic of this.
It’s worth noting that with BERT we don’t technically need to train it as it’s pretrained on a large corpus of text. However, there’s no good way to get document embeddings unless BERT is fine tuned. One way to see what fine-tuning does to embeddings is to visualize BERT generated document embeddings using a t-SNE plot (this visualization method places points close together in 2 dimensions according to the probability they are close together in the original ~ few hundred dimensional embeddings space). We created two 2D t-SNE plots: one where BERT has been trained on 1% of a dataset vs another BERT model that was trained on the full dataset. Each point represents a document, and its color is the ground-truth class label of that document. Both of these models use the same four class text dataset. In the 1% case you can see that the embeddings don’t display much class structure since most points belong to one blob. On the right where BERT was trained on the full dataset the class structure is much more obvious. So it’s now apparent that fine-tuning BERT is quite important if you want BERT to generate embeddings that “know” about the particularities of a dataset!
With the gains of BERT for training data both big and small in mind, we conclude with a recommendation for when you should use BERT vs when you might want to use a bag of words-based model. If you need predictions to be very fast (like < few milliseconds per prediction) and/or you want to perform predictions on a CPU, then you should not use BERT and should stick with bag of words-based models. In the end, you will need to make a choice regarding this trade-off between the fast inference time of bag of words-type models and the high accuracy of BERT.
So how can I try BERT in Azure Automated ML?
To get started with Azure automated machine learning, you can read our docs here. If you're comfortable with python, you can jump right into our Jupyter notebook that illustrates BERT. The main thing to keep in mind is that to benefit from BERT, you need to
Once your automated ML run is complete, you can use your trained automl model to do inference on a GPU or a CPU Azure VM, just note that performing inference on a GPU will be much faster.
Contributors (alphabetical order):
Eric Clausen-Brown – Senior Data & Applied Scientist
Zubin Pahuja – Software Engineer
Anup Shirgaonkar – Principal Data & Applied Scientist
Arjun Singh – Intern Data & Applied Scientist
You must be a registered user to add a comment. If you've already registered, sign in. Otherwise, register and sign in.