Now that we have learned several methods and explored them with illustrative examples, we are going to try them out on a real example: the MNIST digits.
We can load this data using the following dslabs package:
The dataset includes two components, a training set and test set:
Each of these components includes a matrix with features in the columns:
and vector with the classes as integers:
Because we want this example to run on a small laptop and in less than one hour, we will consider a subset of the dataset. We will sample 10,000 random rows from the training set and 1,000 random rows from the test set:
In machine learning, we often transform predictors before running the machine algorithm. We also remove predictors that are clearly not useful. We call these steps preprocessing.
Examples of preprocessing include standardizing the predictors, taking the log transform of some predictors, removing predictors that are highly correlated with others, and removing predictors with very few non-unique values or close to zero variation. We show an example below.
We can run the
nearZero function from the caret package to see that several features do not vary much from observation to observation. We can see that there is a large number of features with 0 variability:
This is expected because there are parts of the image that rarely contain writing (dark pixels).
The caret packages includes a function that recommends features to be removed due to near zero variance:
We can see the columns recommended for removal:
So we end up keeping this number of columns:
Now we are ready to fit some models. Before we start, we need to add column names to the feature matrices as these are required by caret:
Let’s start with kNN. The first step is to optimize for \(k\). Keep in mind that when we run the algorithm, we will have to compute a distance between each observation in the test set and each observation in the training set. There are a lot of computations. We will therefore use k-fold cross validation to improve speed.
If we run the following code, the computing time on a standard laptop will be several minutes.
In general, it is a good idea to try a test run with a subset of the data to get an idea of timing before we start running code that might take hours to complete. We can do this as follows:
We can then increase
b and try to establish a pattern of how they affect computing time
to get an idea of how long the fitting process will take for larger values of
b. You want to know if a function is going to take hours, or even days, before you run it.
Once we optimize our algorithm, we can fit it to the entire dataset:
We now achieve a high accuracy:
From the specificity and sensitivity, we also see that 8s are the hardest to detect and the most commonly incorrectly predicted digit is 7.
Now let’s see if we can do even better with the random forest algorithm.
With random forest, computation time is a challenge. For each forest, we need to build hundreds of trees. We also have several parameters we can tune.
Because with random forest the fitting is the slowest part of the procedure rather than the predicting (as with kNN), we will use only five-fold cross validation. We will also reduce the number of trees that are fit since we are not yet building our final model.
Finally, to compute on a smaller dataset, we will take a random sample of the observations when constructing each tree. We can change this number with the
Now that we have optimized our algorithm, we are ready to fit our final model:
To check that we ran enough trees we can use the plot function:
We see that we achieve high accuracy:
With some further tuning, we can get even higher accuracy.
The following function computes the importance of each feature:
We can see which features are being used most by plotting an image:
An important part of data analysis is visualizing results to determine why we are failing. How we do this depends on the application. Below we show the images of digits for which we made an incorrect prediction. We can compare what we get with kNN to random forest.
Here are some errors for the random forest:
By examining errors like this we often find specific weaknesses to algorithms or parameter choices and can try to correct them.
The idea of an ensemble is similar to the idea of combining data from different pollsters to obtain a better estimate of the true support for each candidate.
In machine learning, one can usually greatly improve the final results by combining the results of different algorithms.
Here is a simple example where we compute new class probabilities by taking the average of random forest and kNN. We can see that the accuracy improves to 0.96:
In the exercises we are going to build several machine learning models for the
mnist_27 dataset and then build an ensemble.
1. Use the
mnist_27 training set to build a model with several of the models available from the caret package. For example, you can try these:
We have not explained many of these, but apply them anyway using
train with all the default parameters. Keep the results in a list. You might need to install some packages. Keep in mind that you will likely get some warnings.
2. Now that you have all the trained models in a list, use
map to create a matrix of predictions for the test set. You should end up with a matrix with
length(mnist_27$test$y) rows and
3. Now compute accuracy for each model on the test set.
4. Now build an ensemble prediction by majority vote and compute its accuracy.
5. Earlier we computed the accuracy of each method on the training set and noticed they varied. Which individual methods do better than the ensemble?
6. It is tempting to remove the methods that do not perform well and re-do the ensemble. The problem with this approach is that we are using the test data to make a decision. However, we could use the accuracy estimates obtained from cross validation with the training data. Obtain these estimates and save them in an object.
7. Now let’s only consider the methods with an estimated accuracy of 0.8 when constructing the ensemble. What is the accuracy now?
8. Advanced: If two methods give results that are the same, ensembling them will not change the results at all. For each pair of metrics compare the percent of time they call the same thing. Then use the
heatmap function to visualize the results. Hint: use the
method = "binary" argument in the
9. Advanced: Note that each method can also produce an estimated conditional probability. Instead of majority vote we can take the average of these estimated conditional probabilities. For most methods, we can the use the
type = "prob" in the train function. However, some of the methods require you to use the argument
trControl=trainControl(classProbs=TRUE) when calling train. Also these methods do not work if classes have numbers as names. Hint: change the levels like this:
10. In this chapter, we illustrated a couple of machine learning algorithms on a subset of the MNIST dataset. Try fitting a model to the entire dataset.