Solved – MNIST Softmax regression hitting a wall at 70% accuracy

gradient descentjavalogisticmulti-classsoftmax

I've implemented a Softmax regression algorithm in java as part of an Android machine learning program. However, no matter how long I let it run for, the accuracy gets to about 70.5% and then plateaus indefinitely, yet the site I got my data from stated that I should be getting close to 90%. I've gone through my code over and over and have been unable to find the source, so I was hoping that you would be able to help. I don't know why it would successfully advance to about 70% and then stop. This is the Softmax formula that I'm using and this is the dataset I'm using

Here is my softmax algorithm:

public List<Double> gradient(List<Double> weights, double[] features, int type){
        int D = Constants.featureSize;
        int K = Constants.numberOfClasses;

        List<Double> grad = new ArrayList<Double>(D*K);
        for(int i = 0; i < D*K; i++){
            grad.add(i, 0.0);
        }
        //Σ(i:k) exp(Θ_i · X) 
        double dot = 0;
        double denom = 0;
        for(int i = 0; i < K; i++){
            //dot product w_i*x
            dot = 0;
            for(int j = 0; j < D; j++){
                dot += features[j] * weights.get(j + (D*i));
            }

            denom += Math.exp(dot);
        }

        //regularization constants
        double[] regular = new double[D * K];
        for(int i = 0; i < D * K; i++){
            regular[i] = 2 * weights.get(i) * Constants.regularizationConstant;
        }

        double prob;
        //prob_i = exp(Θ_i · X)/denom
        for(int i = 0; i < K; i++) {
            //dot product w_i·x
            dot = 0;
            for (int j = 0; j < D; j++) {
                dot += features[j] * weights.get(j + (D * i));
            }
            prob = Math.exp(dot) / denom;

            //∇_0_i = -X(1{i = y} - prob_i)
            int match = 0;
            if(i == type){
                match = 1;
            }
            for (int j = 0; j < D; j++) {
                grad.set(j + (D * i), -1 * features[j] * (match - prob));
            }
        }

        //apply regularization
        for(int i = 0; i < D * K; i++){
            grad.set(i, grad.get(i)+ regular[i]);
        }

        return grad;
    }

Here is where I apply my gradient to the weight (written as a javascript server):

for (i = 0; i < length; i++) { 
                adaG[i] += gradient[i]*gradient[i];
                newWeight[i] = currentWeight[i] - ((c/Math.sqrt(adaG[i]+eps)) * gradient[i]);
            }

weight = newWeight

And then finally, in case you think it might be an error in my accuracy test, here's that code:

var correct = 0;
        var error = 0;
        var labels = require('fs').readFileSync(testLabels).toString().split('\n')
        var features = require('fs').readFileSync(testFeatures).toString().split('\n')
        for(i = 0; i < N; i++){
            var classResults = [];
            line = labels[i];
            var label = parseFloat(line, 10);
            line = features[i];
            var featureStr = line.split(/,| /);
            function valid(str) {
                    return str != "";}
            var featureClean = featureStr.filter(valid);
            var featureArray = [];
            for(var j=0; j<featureClean.length; j++) { 
                featureArray[j] = parseFloat(featureClean[j], 10);}

            for(h = 0; h < K; h++){
                dot = 0;        
                for(j = 0; j < D; j++){
                    dot += featureArray[j]*testWeight[j + (h*D)];}
                classResults[h] = dot;      
            }
            var bestGuess = 0;
            for(h = 0; h < K; h++){
                if(classResults[h]>classResults[bestGuess]){
                    bestGuess = h;}
            }

            if(bestGuess == label){
                correct++;}
            //else{
                //console.log(classResults)
                //console.log('correct: ', label)}      
        }

        var accuracy = correct/N;
        console.log(accuracy)

If anyone could give me some idea of where i'm making an error, please tell me.

Best Answer

After running more tests with different parameters, I found that the issue was that the initial learning rate was set too low. If I'm perfectly honest I'm pretty frustrated with myself that the solution was that simple, but that's all it was.