メインコンテンツに移動

What’s Wrong with My Machine Learning Model?

画像
Stuart Feffer
Stuart Feffer
掲載: 2022年8月17日

At Renesas, we make software that customers use to build machine learning models on sensor and signal data – and it usually works pretty well. But many customers hit a point somewhere in the process where they're not getting the results they're looking for, and they call up our Customer Success line and ask some variation on "What's wrong with my machine learning model? What do I need to do to improve the accuracy"?

Here's the process we go through to find the best path forward.

1. Does it fit?

The first thing to check is whether the various parts of the model are right for the job – looking at the parsing, features and algorithm selected.

To diagnose, the first thing to check is "training separation accuracy".

This is the accuracy of a model when it predicts on its own training set. It reflects how well the training data could be separated by the machine learning engine.

This accuracy should be very high – after all, the machine learning model has seen every one of these observations before. If we don't get very high accuracy, beating chance by 2+ sigma, there's a wrong-tool-for-the-job kind of problem in some aspect of the model.

It might not be the algorithm though – far more common are problems in parsing and feature selection.

Reality AI is used with real-time streaming sensor data, and we have often seen accuracy depend quite sensitively on the length of the decision window, for example.

This makes sense, right?

If the phenomena you want to detect occurs on a time scale much longer or much shorter than the window considered by your algorithm, that target is going to be much harder to spot. So we suggest that our customers create alternate training sets with their streaming data parsed into different length decision windows.

Then we tell them to explore for features (Reality AI explores features automatically) and look at how marginal class accuracy varies by window length. They can then zero in on the optimal window length, and if needed combine sub-models using different window lengths into an ensemble.

Another thing to look at in parsing streaming sensor data is window offset in the training set.
Offset (sometimes also called "stride" or "overlap") refers to how much of the previous window is included in the next window.

In 60Hz data, for example, it's possible to create one-second windows starting at each successive data point, with 59/60ths overlap between them. A little bit of overlap is a good thing, as it helps to make sure we have a diversity of windows and do not miss transients in the signal that span windows.

Too much overlap, on the other hand, may needlessly create very large data sets and, worse, create a situation in which the redundancy between consecutive samples distorts the statistical metrics used by the training algorithms. We usually recommend training initially on a balanced dataset (see "Is it balanced?" below), parsed at 50% window overlap.

For testing a trained model, however, it may be useful to evaluate sample lists with tighter offset.

2. Will the machine learning model generalize?

Generalization refers to the ability of a machine learning model to predict accurately in circumstances not represented in its training data. Generalization is a good thing – and is one of the main goals of the model building process. But it's hard to judge early on when you're using every scrap of data available for training.

To assess generalization in circumstances with limited data, we suggest looking at "k-fold accuracy." K-fold is a cross-validation method in which accuracy is computed by dividing the data into folds – say, 10 different randomly selected subsets.

We hold out one of the 10 folds, train on the remaining 9, predict on the holdout and then repeat until we've held out all 10. The prediction accuracy on those 10 holdout sets is "k-fold accuracy" and is meant to be a statistically valid estimate of how that model will perform on data it's never seen – at least to the degree that the training set is representative of broader reality.

If training separation accuracy is high and k-fold is low, the model can't generalize yet and needs more data. Chances are there is a wider range of variation in the training data than is adequately covered by the number of examples you provide. You'll need to add more data to the mix, that better spans the diversity of each of the target classes.

Generalization refers to the ability of a machine learning model to predict accurately in circumstances not represented in its training data. Generalization is a good thing – and is one of the main goals of the model building process. But it's hard to judge early on when you're using every scrap of data available for training.

3. Is it overtrained?

Overtraining refers to a model that has learned its training set "too well" – predicting it almost perfectly but failing on anything new.

If training separation accuracy is high and k-fold accuracy is also very high (we are always suspicious of 100% k-folds), but performance on holdouts or new data is very low, your model is probably overtrained. This is another version of a failure to generalize – the training set did not include enough data representative of variation in the target or background, but in this case the machine learning model overcompensated.

You'll need to add more data to the mix, with more diversity as representative as possible of the circumstances in which it will be expected to perform. Forcing the model to span the real-world variations will loosen the fit, and accuracy on new data will improve.

4. Is it balanced?

Unbalanced training data is also a frequent cause of machine learning performance issues. Balance refers to whether the total number of observations in the training set are allocated more or less equally across classes or predicted values. If we have metadata that address other components of variation (environments or circumstances in which data is collected), we'd like to see a balance there as well.

Lack of balance can bias feature selection, and also tends to bias many types of learning algorithms to favor overrepresented classes. Models will naturally be better at detecting what they see the most of. Our feature discovery environment automatically creates a balanced subset of the data thru random subsampling before running the discovery algorithm. We don't always force balance during training, because sometimes the resulting bias is useful – it may even be representative of real world expectations. But its always one of the first things we suggest looking at when there is a problem of many different class samples being incorrectly classified into the same class.

Here's a quick test to see whether the imbalance is a significant factor in model accuracy: Create a balanced subset from your training data and retrain the model.

If it's a classification model, compare the marginal accuracies of each prediction class using the original vs balanced set.

If it's a regression model, look at the distribution of error for different predicted values.

If the balanced subset makes the marginal accuracies converge, even if overall accuracy is lower, then balance is a big part of the problem.

To fix an unbalanced dataset:

  1. If your training data is unbalanced and you have plenty of data, create a balanced subsample.
  2. If you don't have plenty of data, collect more observations from the underrepresented classes or ranges.

There is no number 3.

5. What else could it be?

Once you've ruled out basic fitness-for-purpose issues with the algorithms, features and parsing methods, after you've added data but it's still not improving its ability to generalize, once your dataset is well balanced, and it doesn't appear to be overtrained, it's time for a deeper dive.

Try to identify common factors in data where the model fails. Take a closer look at feature selection. You might explore k-fold validation using hold-outs grouped by metadata to determine if there is an explainable factor. Perhaps do some clustering or anomaly detection on the data set to see if there are inherent differences between types of data that might benefit from different sub-models. (Reality AI supports this.)

But this stuff is wasted effort until you've gone thru 1-4 – in our experience, these address most of our customers' problems. You don't need a PhD in machine learning to deal with and diagnose most machine learning model problems. Just a little detective work, some common sense about data, and these four steps can get you a very long way.

Have you encountered other issues with your machine learning model? Do you want to share a different diagnosis approach? Let us know in the comment section below.

この記事をシェアする