Machine Learning and Kaggle Digit Recognizer Competition

Machine Learning and Kaggle Digit Recognizer Competition

This is a rather introductory article to Machine Learning and using one of the freely available libraries to predict a value of some entity using classification mechanism.
For those interested in Machine Learning and just starting their endeavours into that field I would like to mention a great place to start using your acquired theory by solving some of the competition problems listed on kaggle website (

Some of the competitions are prize based and some are just for learning purposes, like the one described here called Digit Recognizer.

So, straight to the point the Digit Recognizer competition is about finding out what number had been drawn on the given 28×28 pixel matrix. The training data as well as the test data come as a csv files.

28×28 grid gives us 784 pixels in total. The single record of the training data looks like the following:


where the first number contains the actual number and the other 784 columns contain the pixel info about each cell in 28×28 grid. The pixel info is a simple integer between 0 and 255 where 255 means black and 0 means white (or the opposite, depending on your imagination 😉 ).

The training data contains 42000 rows to train your machine learning program. The test data contains the same information except the first column which you need to predict. The predicted values for the test data are the content of your submission file. To submit your solution to you need to produce a csv file with the two columns in it:


where ImageId is the number of row and Label is the number you predicted to be encoded in that row.

While playing around your solution is useful to check if your algorithm is of any good. I have created a very simple program to convert the matrix data into a png file where we can see for ourselves what number should be there:

public static void main(String[] art) throws IOException {
        CSVReader reader = new CSVReader(new FileReader("/path to your csv file/test_subset.csv"));
        String [] nextLine;
        int count = 0;
        while ((nextLine = reader.readNext()) != null) {
                final int[] pixels = new int[nextLine.length];
                for (int i=0; i < nextLine.length; i++) {
                        pixels[i] = Integer.parseInt(nextLine[i]);

private static void save(BufferedImage image, String ext, int count) {
        String fileName = "image" + count;
        File file = new File(fileName + "." + ext);
        try {
                ImageIO.write(image, ext, file); // ignore returned boolean
        } catch(IOException e) {
public static BufferedImage convertRGBImage(int[]pixels){
        final int height = 28;
        final int width = 28;

        BufferedImage bufferedImage = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
        int counter = 0;
        for(int y=0; y< height; y++){
                for(int x=0; x< width; x++){
                        final Color col = convertValToColor(pixels[counter]);
return bufferedImage;

private static Color convertValToColor(final int pixel) {
        //todo: convert pixel representation (0-255) to some rgb color
        return new Color(pixel,pixel,pixel);

While running this pleas create a subset of the test csv file as you don’t want to generate 28k of images probably :).

The example generated images for a single row looks like the following:

image11 image12 image13


I believe the last one is seven, right? Or maybe 9? Or 1? As you can see this can get really tricky.

For the purposes of this competition I have decided to use a simple library called Java-ML which stands for Java Machine Learning Library and can be found here:

I guess its a good starter point for learing ML but I’m not sure if its any good for some serious processing as I found it pretty slow (especially with the KNN algorithm I was using). But its really easy to use so kudos to the creators anyway.

I won’t go into details on how to prepare the train.csv file to be used conveniently with that library as there are possibly many ways to do that, so I just describe the process of training and predicting the actual values. The nitties gritties of preparing the csv before and after applying the knn algorithm you would need to figure out yourself (no magic there)

/* Load a data set */
Dataset data = FileHandler.loadDataset(new File("path to your train.csv"), 784, ",");

/* set up the classifier */
final Classifier knn = new KNearestNeighbors(k);

After you run those two lines above your program will be trained :). Now you can use that data to predict the values from the test.csv file

/* Load a data set */
Dataset dataForClassification = FileHandler.loadDataset(new File("path to your test.csv file"), 783, ",");

/* Predict the value for each row */

for (final Instance inst : dataForClassification) {

        Object predictedClassValue = knn.classify(inst);


On my machine predicting a single value (out of 28k) takes about 10-15 seconds which is unacceptable. But who said we need to go one by one?

/* initialize executor service */
ExecutorService executorService =
        new ThreadPoolExecutor(
        maxThreads, // core thread pool size
        maxThreads, // maximum thread pool size
        1, // time to wait before resizing pool
        new ArrayBlockingQueue<Runnable>(maxThreads, true),
        new ThreadPoolExecutor.CallerRunsPolicy());

for (final Instance inst : dataForClassification) {
        final int finalK = k;
        executorService.submit(new Runnable() {
        public void run() {

                Object predictedClassValue = knn.classify(inst);

// wait for all of the executor threads to finish

where maxThreads is calculated using available cores on your machine with:

/* calculate number of threads to use */
final int cpus = Runtime.getRuntime().availableProcessors();
final int maxThreads = ((cpus*2) > 0 ? cpus*2 : 1);

(Note: I do multiply number of available cores by 2, you can tweak this number empirically for this algorithm)

This speeds up some things considerably on the other hand we need to keep track of what number do we actually process at the moment, so if you decide to write that into a csv file you would need to sort the csv by the first column before submitting it to kaggle website.

QuickTip: the easiest and fastest way to sort a csv file for me is to load it into mysql table, and then produce a new csv file using a simple select into query.

I’m lucky enough to work for the company which allows me to use their AWS account for free so I have just started cc2.8xlarge instantance, scp the jar + csv files there and all together it took 15 minutes to get the results and upload them to Kaggle. Kudos Softwaremill! 🙂

On the first submission I got 0.97114 score which placed me on the 474th position on the Leaderboard :).

Both KNN as well as Random Forest algorithms, useful for this competition, are available from Java-ML. I didn’t play too long with the number of neighbours for the KNN algorithm setting, either with the Random Forest algorithm but if you try pls let me know your setup ;).


Leave a Reply