The best machine learning algorithm for classification and regression

[Please note that I’m still very much a novice in this field – and that I change my mind about things often]

I had a hard time naming this post. Here are a few other titles I could have used:

Why I think Random/Decision Forests are the best machine learning algorithm.

I know there’s exceptions to this – there exists scenarios where this title is not true. But rather than giving the vague unhelpful answer of “it depends“, here’s why I think that Random Forests should be your first and default choice when choosing a machine learning algorithm to use for classification and/or regression.

Here’s a working list (in no particular order) of why I really like working with Random/Decision forests:

1. Probabilistic output: Random forests can output a probability, or a confidence in the prediction. If you have an answer, how much nicer is it to associate a confidence with that answer!
2. Intuitive approach: Random forests are essentially composed of a bunch of decision trees, where each tree splits random features – how much more intuitive can that be?
3. Built in multi-class support: You are not limited at all to a yes/no or 2 class problem. Just by adding other labels to your training data, you can do predict for > 2 class labels.
4. Examine the decision making and see what features are important: You can actually examine how the Random Forest makes its decisions by looking at what features are split in each tree. This helps you understand what features are useful in splitting your data.
5. Simple to switch between classification and regression: In MATLAB, this is almost as simple as the changing a single parameter to the TreeBagger class.
6. Robustly represents spaces: See this really nice post that gives an intuitive visualization of how well different machine learning approaches can split the data. See how well random forests do compared with other popular approaches.
7. Not much tuning of parameters: Sure there’s a few key parameters you need to set. For example, you need to set the number of trees, but this is fairly simple to set by visually inspecting when the out-of-bag error stabilizes.
8. Works fine with categorical/nominal features/variables: If you have features/variables that are categories (labels without meaning in how they are ordered) you can use them directly. There is no need to drastically increase your feature vector by doing a 1-of-k coding scheme.
9. No requirement to scale different features: If you have features with very large values, and different features with very small values, that is okay! Plug them in. (SVMs generally require scaling first)
10. Can be used for soooooo many things: See this work by Criminisi, J. Shotton and E. Konukoglu

Nice rant Jer, but do you have anything supporting your “best” claim?

Sort of… “best” is a hard thing to claim. It turns out that there exists some peer-reviewed work supporting my unsupported claims that Random Forests are good (and very close to the best). See this work by Caruana and Niculescu-Mizil entitled, “An Empirical Comparison of Supervised Learning Algorithms“.

The authors’ conclusions are a bit hidden, but if you read carefully you can find statements like:

… calibrated boosted trees were the best learning algorithm overall. Random forests are close second …

and since calibration can be a bit of a pain, it’s nice to see that…

… If not calibrated, the best models overall are bagged trees, random forests, and neural nets…

and they summarize that among the tested supervised machine learning algorithms…

…it appears to be a clean sweep for ensembles of trees…

Interesting, so what are your practical experiences with random forests?

So far, I’ve used Random Forests for delineating the spinal cord in MRIs, and predicting the disability level of multiple sclerosis patients based on spinal cord features.

Alright, you have me somewhat curious… how do I get started with random forests?

Now I’m sure you are excited. So, to help you get started predicting things, here’s a simple MATLAB example that using Random Forests (called TreeBagger in MATLAB) for classification.