Here we use a deep neural network to classify the famous Iris flower data set collected by Edgar Anderson and made famous by Ronald Fisher. This data set includes 150 distinct observations of iris flowers, each of which consists of four empirical observations (sepal length, sepal width, petal length, and petal width) along with a classification into one of three known species (I. setosa, I. versicolor, and I. virginica).
We will repeat here the classical task for which this data set is used: attempting prediction of the species based on the four measured quantities
Training and Test Data
We have divided the data in into training and test data: the former is used to build the model, the latter is used to test its predictive accuracy.
| (2.2.1) |
We see that this data set has 150 samples (120 for training and 30 for testing) and that the Species column has three distinct species:
| (2.2.3) |
To simplify things we will replace the strings designating the species classification with the numbers 0,1,2 (corresponding to setosa, versicolor, and virginica, respectively):
| (2.2.4) |
| (2.2.5) |
Training the Deep Neural Network Model
With our data prepared, we can now actually define and train the model.
Our first step is to define a feature for each of the four observed quantities in the test data minus the final one (species) which we aim to predict:
| (2.2.6) |
| (2.2.7) |
We can now define a deep neural network classifier with these features. It has 3 classes because there are 3 species of iris in the dataset.
| (2.2.8) |
We are now ready to train the model.
| (2.2.9) |
Now trained, we can evaluate the classifier on the test set, and we see that we have achieved 96.7% predictive accuracy:
We can now build a predictor function that takes an arbitrary set of measurements as a DataSeries and returns a prediction:
Using a Trained Model
With this we can take an arbitrary new point, and generate a prediction from the trained model:
| (2.2.10) |
| (2.2.11) |
The probabilities field in the above result records the estimated probabilities for each class.
In this case, the model predicts with high probability that this particular sample is class 2, and therefore I. virginica.