The Data Science Lab
Nearest Centroid Classification for Numeric Data Using C#
Here's a complete end-to-end demo of what Dr. James McCaffrey of Microsoft Research says is arguably the simplest possible classification technique.
The goal of a machine learning classification system is to predict the value of a discrete variable. For example, you might want to predict the sex (male or female -- binary classification) of person based on their age, income, and so on. Or you might want to predict the political leaning (conservative or moderate or liberal -- multi-class classification) of a person based on their sex, age, and so on.
There are many machine learning classification techniques. Common techniques include logistic regression, neural network classification, naive Bayes classification, decision tree classification, k-nearest neighbors classification, and several more.
This article presents a complete end-to-end demo of a technique called nearest centroid classification. Briefly, in nearest centroid classification, the vector centroids (also called means or averages) in the training data are computed for each of the classes to predict. To classify a data item, the distance between the item and each centroid is computed. The predicted class is the class associated with the nearest centroid.
Nearest centroid classification is arguably the simplest possible classification technique. Compared to other techniques, four advantages of nearest centroid classification (NCC) are that NCC is easy to implement, NCC can work with very small datasets, NCC is highly interpretable, and NCC works for both binary classification and multi-class classification. Two disadvantages of NCC are that basic NCC works only with strictly numeric predictor variables (although there are new techniques, to modify NCC to work with mixed numeric and categorical predictors), and most importantly, NCC is the least powerful classification technique because it doesn't take interactions between predictor variables into account.
A good way to see where this article is headed is to take a look at the screenshot in Figure 1. The demo program loads a subset of the Penguin Dataset into memory. The goal is to predict the species of a penguin (0 = Adelie, 1 = Chinstrap, 2 = Gentoo) from its bill length, bill width, flipper length, and body mass. The demo uses just 30 data items for training and 10 items for testing. The predictor values of the first few raw training items are:
[ 0] 50.0 16.3 230.0 5700.0
[ 1] 39.1 18.7 181.0 3750.0
[ 2] 38.8 17.2 180.0 3800.0
[ 3] 39.3 20.6 190.0 3650.0
. . .
The raw penguin data is normalized so that all the predictor values are between 0 and 1. This prevents the body mass predictor variable, which has large magnitude values, from overwhelming the other predictor variables. Data normalization is necessary for nearest centroid classification (and many other classification techniques). The normalized training data predictor values look like:
[ 0] 0.851 0.257 1.000 0.941
[ 1] 0.249 0.600 0.020 0.176
[ 2] 0.232 0.386 0.000 0.196
[ 3] 0.260 0.871 0.200 0.137
. . .
Next, the centroids of each of the three classes/species are computed:
Class centroids:
[0] 0.2700 0.7786 0.2025 0.3174
[1] 0.8432 0.6418 0.2969 0.1931
[2] 0.7103 0.2095 0.6956 0.7996
The resulting model predicts the training data with 0.9333 accuracy (28 out of 30 correct) and the test data with 1.0000 accuracy (10 out of 10 correct).
The demo program concludes by predicting the species of a previously unseen penguin with raw values of bill length = 46.5, bill width = 17.9, flipper length = 192, body mass = 3500. The raw predictor values are normalized to (0.5962, 0.6316, 0.3182, -0.0182). The NCC model predicts that the penguin is class = 1 = Chinstrap because the normalized values are closest to centroid [1] which is (0.8432, 0.6418, 0.2969, 0.1931).
This article assumes you have intermediate or better programming skill but doesn't assume you know anything about nearest centroid classification. The source code is too long to be presented in its entirety in this article, but the complete code and data are in the accompanying file download. You can also find the code and data online. The demo program is implemented using the C# language but you shouldn't have too much trouble refactoring the demo to another C-family language if you wish. The demo program has all normal error checking removed to keep the main ideas as clear as possible.
Understanding Nearest Centroid Classification
Nearest centroid classification is best explained by using a concrete example. Suppose your goal is to predict the species of a toad from its height and weight. You have a small set of training data that has seven items and three classes:
label ht wt
0 56 1800
0 20 5000
1 68 1000
1 64 4200
1 80 7400
2 68 9000
2 44 7400
min 20 1000
max 80 9000
The class/labels/species are 0 = Asiatic, 1 = Bolivian, 2 = Canadian. The first step is to normalize the data so that the large weight values don't overwhelm the small height values. The demo program uses min-max normalization. When using min-max normalization, for each column, x' = (x - min) / (max - min) where x' is the normalized value, x is the raw value, min is the column minimum value and max is the column max value. For example, the first height value of 56 is normalized to x' = (56 - 20) / (80 - 20) = 36 / 60 = 0.60. The first weight value is normalized to x' = (1800 - 1000) / (9000 - 1000) = 800 / 8000 = 0.10.
The resulting min-max normalized toad data is:
label ht' wt'
0 0.60 0.10
0 0.00 0.50
1 0.80 0.00
1 0.90 0.40
1 1.00 0.80
2 0.80 1.00
2 0.40 0.80
The values in each predictor column will be between 0.0 and 1.0 where 0.0 corresponds to the smallest value in the column and 1.0 corresponds to the largest value in the column.
Next, the centroids for each of the three classes are computed from the normalized data:
centroids'
0 (0.30, 0.30)
1 (0.90, 0.40)
2 (0.60, 0.90)
Centroids are also called vector means or vector averages. To compute the centroid for a set of vectors, you compute the arithmetic average of each of the vector elements. For example, the two vectors that belong to class 0 are (0.60, 0.10) and (0.00, 0.50). Therefore, the centroid for class 0 is ((0.60 + 0.00)/2, (0.10 + 0.50)/2) = (0.30, 0.30).
Similarly, the centroid for class 1 is ((0.80 + 0.90 + 1.00)/3, (0.00 + 0.40 + 0.80)/3) = (0.90, 0.40). The centroid for class 2 is ((0.80 + 0.40)/2, (1.00 + 0.80)/2) = (0.60, 0.90).
To classify a data item, you compute the Euclidean distance between the normalized item and each of the three centroids, and return the class label that is associated with the smallest distance. The Euclidean distance between two vectors is the square root of the sum of the squared differences between vector elements.
For example, suppose the toad item to classify is x = (50, 6600). The normalized item is x' = ((50-20)/(80-20), ((6600-1000)/(9000-1000)) = (0.50, 0.70). The distances to each centroid are:
Distance from x' to centroid[0] = sqrt((0.50 - 0.30)^2 + (0.70 - 0.30)^2)
= sqrt(0.20)
= 0.45
Distance from x' to centroid[1] = sqrt((0.50 - 0.90)^2 + (0.70 - 0.40)^2)
= sqrt(0.25)
= 0.50
Distance from x' to centroid[2] = sqrt((0.50 - 0.60)^2 + (0.70 - 0.90)^2)
= sqrt(0.05)
= 0.22
Because the smallest distance is between the normalized item and centroid[2], the predicted class/label/species = 2 = Canadian.
The Demo Data
The demo program uses a subset of the Penguin Dataset. The full Penguin Dataset has 345 items, where 333 items are valid (no missing values). The demo data uses just 30 training items (10 of each species), and 10 test items (4 Adelie, 3 Chinstrap, 3 Gentoo). The file of test data is:
0, 40.6, 18.6, 183.0, 3550.0
0, 40.5, 18.9, 180.0, 3950.0
0, 37.2, 18.1, 178.0, 3900.0
0, 40.9, 18.9, 184.0, 3900.0
1, 52.0, 19.0, 197.0, 4150.0
1, 49.5, 19.0, 200.0, 3800.0
1, 52.8, 20.0, 205.0, 4550.0
2, 49.2, 15.2, 221.0, 6300.0
2, 48.7, 15.1, 222.0, 5350.0
2, 50.2, 14.3, 218.0, 5700.0
The full Penguin Dataset has additional columns. In addition to species, bill length, bill width, flipper length, and body mass, the dataset has categorical columns island (where the penguin was found and measured) and sex (male or female). The demo data does not use the island and sex columns because standard nearest centroid classification does not directly work with categorical predictor variables.
The full Penguin Dataset and its documentation can be found here.
Program Organization
The overall organization of the demo program is shown in Listing 1. All of the control logic is in the Main() method. All of the nearest centroid classification functionality is in a NearestCentroidClassifier class. The Program class has helper functions to load data, normalize data, and display data.
Listing 1: Demo Program Organization
using System;
using System.IO;
namespace NearestCentroidClassification
{
internal class NearestCentroidProgram
{
static void Main(string[] args)
{
Console.WriteLine("Begin nearest " +
"centroid classification demo ");
// load raw data into memory
// normalize training data
// load and normalize test data
// instantiate a NearestCentroidClassifier object
// train the model (compute centroids)
// evaluate model accuracy
// use model to make a prediction
Console.WriteLine("End demo ");
Console.ReadLine();
}
public static double[][] MatLoad(string fn,
int[] usecols, char sep, string comment) { . . }
public static int[] VecLoad(string fn, int usecol,
string comment) { . . }
public static double[][] MatMinMaxValues(double[][] X)
{ . . }
public static double[][] MatNormalizeUsing(double[][] X,
double[][] minsMaxs) { . . }
public static double[] VecNormalizeUsing(double[] x,
double[][] minsMaxs) { . . }
public static void MatShow(double[][] M, int dec,
int wid, int numRows, bool showIndices) { . . }
public static void VecShow(int[] vec,
int wid) { . . }
public static void VecShow(int[] vec, int wid,
int nItems) { . . }
public static void VecShow(double[] vec, int decimals,
int wid) { . . }
} // Program
public class NearestCentroidClassifier
{
public int numClasses;
public double[][] centroids; // of each class
public NearestCentroidClassifier(int numClasses) { . . }
public void Train(double[][] trainX, int[] trainY) { . . }
public int Predict(double[] x) { . . }
public double[] PredictProbs(double[] x) { . . }
public double Accuracy(double[][] dataX, int[] dataY) { . . }
private double EucDistance(double[] v1, double[] v2) { . . }
public int[][] ConfusionMatrix(double[][] dataX,
int[] dataY) { . . }
public void ShowConfusion(int[][] cm) { . . }
} // class
} // ns
The demo program begins by loading the raw training data into memory and displaying the first four rows:
string trainFile = "..\\..\\..\\Data\\penguin_train_30.txt";
double[][] trainX =
MatLoad(trainFile, new int[] { 1, 2, 3, 4 }, ',', "#");
MatShow(trainX, 1, 9, 4, true);
The raw training data file looks like:
2, 50.0, 16.3, 230.0, 5700.0
0, 39.1, 18.7, 181.0, 3750.0
1, 38.8, 17.2, 180.0, 3800.0
2, 39.3, 20.6, 190.0, 3650.0
0, 39.2, 19.6, 195.0, 4675.0
. . .
The raw data is comma-delimited and the species/class/label is in the first column, [0]. The arguments in the call to program-defined MatLoad() mean load columns [1], [2], [3], [4] of the comma-delimited data, where a line that begins with "#" indicates a comment line. The return value is an array-of-arrays style matrix of type double values.
The arguments in the call to program-defined MatShow() function mean display values using 1 decimal in a field width of 9, just the first 4 rows, and show the row indices.
Normalizing the Data
The demo program normalizes the data using these statements:
double[][] minsMaxs = MatMinMaxValues(trainX);
trainX = MatNormalizeUsing(trainX, minsMaxs);
Console.WriteLine("Done ");
Console.WriteLine("X training normalized: ");
MatShow(trainX, 4, 9, 4, true);
The program-defined MatMinMaxValues() function scan through the training data and computes the minimum value and the maximum vales for each column. For the demo data, the values stored in the minsMaxs matrix are:
34.6 14.5 180.0 3300.0
52.7 21.5 230.0 5850.0
The four predictor minimum values are in row [0] and the maximum values are in row [1]. These values are needed to normalize the test data and to normalize any new, previously unseen data that you wish to classify.
The MatNormalizeUsing() function returns a matrix of normalized data values. The demo program assigns these values to the original un-normalized matrix, destroying the un-normalized values. One of several design alternatives is to keep both versions of the data like so:
double[][] trainX_normalized = MatNormalizeUsing(trainX, minsMaxs);
The best design to use will depend on your specific problem scenario.
The demo fetches and displays the class labels/species values of the training data like so:
int[] trainY = VecLoad(trainFile, 0, "#");
Console.WriteLine("Y training: ");
VecShow(trainY, wid: 3);
The arguments passed to program-defined VecLoad() mean load the integer values in column [0] where lines that start with "#" are comments. Next, the test data is loaded and normalized:
string testFile = "..\\..\\..\\Data\\penguin_test_10.txt";
double[][] testX = MatLoad(testFile,
new int[] { 1, 2, 3, 4 }, ',', "#");
testX = MatNormalizeUsing(testX, minsMaxs);
int[] testY = VecLoad(testFile, 0, "#");
Notice that when normalizing the test data, the min and max values from the training data are used. The idea here is somewhat subtle. In machine learning classification, when you have a set of test data, you should not use the test data in any way until after training. Put another way, you operate as if the test data doesn't exist until after a prediction model has been created.
Not using test data until after model creation means that, in principle, when creating a classification prediction system, you should first split the source data into training and test sets, then normalize the training data, and then use the information from the training normalization (such as min and max for min-max normalization, or mean and standard deviation for z-score normalization) to normalize the test data. But in practice, it's not uncommon to normalize the entire source data and then split the data into training and test sets, even though this is not the theoretically correct approach.
Training the Nearest Centroid Classifier
The statements to create and train the nearest centroid classifier model are:
int numClasses = 3;
NearestCentroidClassifier ncc =
new NearestCentroidClassifier(numClasses);
ncc.Train(trainX, trainY);
Console.WriteLine("Class centroids: ");
MatShow(ncc.centroids, 4, 9, 3, true); // 4 decimals, 9 wide, 3 rows
The term "train" often infers an iterative process where a prediction model is created from the training data. In the case of nearest centroid classification, training just means computing the centroids for each class/label. In such situations, the term "fit" is sometimes used instead of the term "train." For example, the scikit-learn machine learning library uses this naming style.
The NearestCentroidClassifier object does not implement a get-property for the public scope centroids member, so the centroids are accessed directly.
Evaluating the Trained Model
After the nearest centroid classifier model has been created and trained, it is evaluated in two ways. First, the overall accuracy of the model on the training and test data is computed and displayed:
double accTrain = ncc.Accuracy(trainX, trainY);
Console.WriteLine("Accuracy on train: " + accTrain.ToString("F4"));
double accTest = ncc.Accuracy(testX, testY);
Console.WriteLine("Accuracy on test: " + accTest.ToString("F4"));
The NearestCentroidClassifier class defines an Accuracy() method. An alternative design is to define a static accuracy function that is external to the class definition. The calling code would look like double accTrain = Accuracy(ncc, testX, testY).
The trained model scores 0.9333 accuracy (28 out of 30 correct) on the training data and 1.0000 accuracy (10 out of 10) on the test data. In fact, I had to manipulate the training data by changing the class label values for two of the training items so that I could get any incorrect predictions at all.
Such high accuracy is rare for a nearest centroid classification model. The high accuracy indicates that the species of a penguin is almost completely determined by one or two of the predictor variables. If you look back at the centroid values, you can see that class 0 = Adelie is unambiguously identified by a low value for bill length and a high value for bill width. Class 1 = Chinstrap is identified by a high value for bill length and a low value for body mass. Class 2 = Gentoo is identified by a high value for flipper length and a high value for body mass.
An overall accuracy value for the training or test data is informative, but in most scenarios you want to see the accuracy for each of the class labels. The demo program computes and displays a confusion matrix using these statements:
int[][] cm = ncc.ConfusionMatrix(trainX, trainY);
ncc.ShowConfusion(cm);
The output is:
actual 0: 8 0 0 | 8 | 1.0000
actual 1: 1 12 0 | 13 | 0.9231
actual 2: 1 0 8 | 9 | 0.8889
The second row means that for actual class label = 1 = Chinstrap, 1 item was incorrectly predicted as class 0, 12 items were correctly predicted as class 1, and 0 items were incorrectly predicted as class 2. The accuracy for actual class = 1 is 12 out of 13 correct = 0.9231. Values on the diagonal of the confusion matrix are counts of correct predictions, and values off the diagonal are counts of incorrect predictions.
To keep the size of the output screenshot small, the demo does not compute a confusion matrix for the test data, as you'd want to do in a non-demo scenario.
Using the Model to Make Predictions
The demo program predicts the species for a new, previously unseen penguin:
Console.WriteLine("Predicting species for x = 46.5, 17.9, 192, 3500");
string[] speciesNames = new string[] { "Adelie", "Chinstrap", "Gentoo" };
double[] xRaw = { 46.5, 17.9, 192, 3500 };
double[] xNorm = VecNormalizeUsing(xRaw, minsMaxs);
Console.Write("Normalized x =");
VecShow(xNorm, 4, 9);
int lbl = ncc.Predict(xNorm);
Console.WriteLine("predicted label/class = " + lbl);
Console.WriteLine("predicted species = " + speciesNames[lbl]);
double[] pseudoProbs = ncc.PredictProbs(xNorm);
Console.WriteLine("prediction pseudo-probs = ");
VecShow(pseudoProbs, 4, 9);
The Predict() method compares the new x predictor values against each of the three model centroids using Euclidean distance, and returns the class associated with the centroid that is nearest to x, which is class 1 = Chinstrap. Because the model centroids were computed using normalized training data, the raw input must be normalized in the same way, in this case, by using the min and max values from the training data via the VecNormalizeUsing() function.
The Predict() method returns only the predicted class/label/species. The PredictProbs() method returns three pseudo-probabilities, in this case (0.28, 0.55, 0.17). The predicted class is the one associated with the largest pseudo-probability, again class 1. The pseudo-probabilities are computed as an array of the normalized inverse distances to each centroid. For example, the distances to the three centroids for x = (46.5, 17.9, 192, 3500) are (0.54, 0.27, 0.90). Taking 1 over each distance value gives (1.84, 3.65, 1.11). Those values sum to 6.60. Dividing each inverse distance by their sum gives (0.28, 0.55, 0.17), which sum to 1. These values can be loosely interpreted as confidence scores, or likelihoods, for each possible predicted class.
Wrapping Up
Nearest centroid classification isn't as powerful as other classification techniques because it doesn't deal with interactions between predictor variables. But nearest centroid classification is a good way to establish a baseline prediction model result. In most scenarios, models created by more powerful techniques, such as neural network classifiers and decision tree classifiers, should have better predictive accuracy. That said however, there are some situations, such as the predicting the Penguin Dataset, where nearest centroid classification is surprisingly powerful.
The code presented in this article can be used as-is for most classification problems where the raw data predictor variables are all numeric and the class labels to predict have been encoded using zero-based ordinal values. Some recent, but unpublished, research work has shown that there are effective techniques to encode non-numeric predictor values so that machine learning algorithms that use Euclidean distance can work well. That topic will be the subject of a future Data Science Lab column in Visual Studio Magazine.