Resize MNIST into 14*14 images

In [1]:
import numpy as np
import gzip
import cPickle
import time

%matplotlib inline
from matplotlib import pyplot as plt
In []:
f = gzip.open('mnist.pkl.gz','rb')
data = cPickle.load(f)
f.close()
print type(data), len(data)
train_set, valid_set, test_set = data

def resize_xs(xs):
    N,dim = xs.shape
    if dim != 28*28:
        print 'size wrong', dim
        return
    xs_new = np.zeros((N, 14*14),dtype='float32')
    for i in range(N):
        xi = xs[i].reshape(28,28)
        xs_new[i] = xi[::2,::2].reshape(14*14)
    return xs_new


#
xs_tr, ys_tr = train_set
xs_tr_new = resize_xs(xs_tr)
print xs_tr_new.shape

xs_te, ys_te = test_set
xs_te_new = resize_xs(xs_te)
print xs_te_new.shape

xs_val, ys_val = valid_set
xs_val_new = resize_xs(xs_val)
print xs_val_new.shape

# save
train_set_new = (xs_tr_new, ys_tr)
valid_set_new = (xs_val_new, ys_val)
test_set_new  = (xs_te_new, ys_te)
f2 = gzip.open('mnist_14x14.pkl.gz','wb')
cPickle.dump((train_set_new, valid_set_new, test_set_new), f2)
f2.close()
if 0:
    f2 = open('mnist_14x14.pkl','wb')
    cPickle.dump((train_set_new, valid_set_new, test_set_new), f2)
    f2.close()
print 'dump resized images finished'

Try MLP for small MNIST

In [2]:
import numpy
import theano
import theano.tensor as T

from sgd_model import *
from mlp_layers import *
In [3]:
import cPickle
import gzip

class dataset_mnist_small:
    def __init__(self):
        if 0: # no requirement for gzip, but big file
            f = open('mnist_14x14.pkl', 'rb')
            train_set, valid_set, test_set = cPickle.load(f)
            f.close()
        else: # requiring gzip           
            f = gzip.open('mnist_14x14.pkl.gz','rb')
            train_set, valid_set, test_set = cPickle.load(f)
            f.close()

        self.xs_tr, ys_tr = train_set
        self.xs_val, ys_val = valid_set
        self.xs_te, ys_te = test_set

        self.ys_tr =  ys_tr.astype(np.int32)
        self.ys_val = ys_val.astype(np.int32)
        self.ys_te =  ys_te.astype(np.int32)
        
ds =  dataset_mnist_small()  

print ds.xs_tr.shape
(50000L, 196L)

In []:
class mlp_small(tt_sgd_model):
    def __init__(self):
        #784
        dim = 196
        nc = 10

        self.layers = []
        self.layers += [InputLayer(dim)]

        nonlinearfunc = T.tanh #T.nnet.sigmoid #rectify #rectify
        print 'nonlinear activation function = ', nonlinearfunc
        self.layers += [ HiddenLayer(self.layers[-1], n_out = 500,  activation = T.tanh)]
        self.layers += [ HiddenLayer(self.layers[-1], n_out = nc,  activation = None)]
        self.layers +=  [ SoftmaxLayer(self.layers[-1])]

        target = T.ivector('target')
        output_eval = self.layers[-1].output()
        self.cost_eval = mcloss_negli(output_eval, target)
        self.err_eval = mc_error(output_eval, target)
        self.validate_model = theano.function([self.layers[0].input, target], self.err_eval)
        self.cost_model = theano.function([self.layers[0].input, target], self.cost_eval)

        output_tr = self.layers[-1].output(dropout_training=True)
        self.cost_tr = mcloss_negli(output_eval, target)

        learning_rate = 0.13
        all_para = all_parameters(self.layers[-1])
        updates = gen_updates_sgd(self.cost_tr, all_para, learning_rate)
        self.train_model = theano.function([self.layers[0].input, target], self.cost_tr, updates=updates)

model = mlp_small()
from dataset_mnist import dataset_mnist
import time
t0 = time.time()
print 'load dataset in %.1f secs' % (time.time()-t0)
if 0: #unit test
    import sys
    model._unit_test(ds.xs_tr[0:100], ds.ys_tr[0:100])
    sys.exit(0)

best_err, best_model,err_list_val, err_list_tr =  sgd_patience(model, ds.xs_tr, ds.ys_tr, ds.xs_te, ds.ys_te, batch_size=600, n_epochs=30)
nonlinear activation function =  Elemwise{tanh,no_inplace}
load dataset in 0.0 secs

 epoch 1, minibatch 82/83, validation error 0.133854  traing error 0.147490 
randomly shuffling...

 epoch 2, minibatch 82/83, validation error 0.115312  traing error 0.126285 
randomly shuffling...

 epoch 3, minibatch 82/83, validation error 0.104167  traing error 0.116968 
randomly shuffling...

 epoch 4, minibatch 82/83, validation error 0.100000  traing error 0.111345 
randomly shuffling...

 epoch 5, minibatch 82/83, validation error 0.098229  traing error 0.107972 
randomly shuffling...

 epoch 6, minibatch 82/83, validation error 0.095937  traing error 0.105502 
randomly shuffling...

 epoch 7, minibatch 82/83, validation error 0.092500  traing error 0.101827 
randomly shuffling...

 epoch 8, minibatch 82/83, validation error 0.091771  traing error 0.099980 
randomly shuffling...

 epoch 9, minibatch 82/83, validation error 0.091771  traing error 0.099900 
randomly shuffling...

 epoch 10, minibatch 82/83, validation error 0.090312  traing error 0.097149 
randomly shuffling...

 epoch 11, minibatch 82/83, validation error 0.089271  traing error 0.096406 
randomly shuffling...

 epoch 12, minibatch 82/83, validation error 0.089479  traing error 0.095181 
randomly shuffling...

 epoch 13, minibatch 82/83, validation error 0.088229  traing error 0.094056 
randomly shuffling...

 epoch 14, minibatch 82/83, validation error 0.087917  traing error 0.094116 
randomly shuffling...

 epoch 15, minibatch 82/83, validation error 0.087604  traing error 0.093293 
randomly shuffling...

 epoch 16, minibatch 82/83, validation error 0.086354  traing error 0.092329 
randomly shuffling...

 epoch 17, minibatch 82/83, validation error 0.086354  traing error 0.091727 
randomly shuffling...

 epoch 18, minibatch 82/83, validation error 0.085833  traing error 0.090803 
randomly shuffling...

 epoch 19, minibatch 82/83, validation error 0.084687  traing error 0.090402 
randomly shuffling...

 epoch 20, minibatch 82/83, validation error 0.085000  traing error 0.089739 
randomly shuffling...

 epoch 21, minibatch 82/83, validation error 0.084375  traing error 0.089277 
randomly shuffling...

 epoch 22, minibatch 82/83, validation error 0.084167  traing error 0.088735 
randomly shuffling...

 epoch 23, minibatch 82/83, validation error 0.085208  traing error 0.088012 
randomly shuffling...

 epoch 24, minibatch 82/83, validation error 0.083333  traing error 0.087390 
randomly shuffling...

 epoch 25, minibatch 82/83, validation error 0.082708  traing error 0.086627 
randomly shuffling...

 epoch 26, minibatch 82/83, validation error 0.082813  traing error 0.086185 
randomly shuffling...

 epoch 27, minibatch 82/83, validation error 0.081771  traing error 0.085361 
randomly shuffling...

 epoch 28, minibatch 82/83, validation error 0.083229  traing error 0.085743 
randomly shuffling...

 epoch 29, minibatch 82/83, validation error 0.082083  traing error 0.084659 
randomly shuffling...

 epoch 30, minibatch 82/83, validation error 0.081875  traing error 0.085000 
randomly shuffling...

 epoch 31, minibatch 82/83, validation error 0.080833  traing error 0.084398 
randomly shuffling...

 epoch 32, minibatch 82/83, validation error 0.081667  traing error 0.083614 
randomly shuffling...

 epoch 33, minibatch 82/83, validation error 0.080521  traing error 0.082992 
randomly shuffling...

 epoch 34, minibatch 82/83, validation error 0.080417  traing error 0.082269 
randomly shuffling...

 epoch 35, minibatch 82/83, validation error 0.080208  traing error 0.081406 
randomly shuffling...

 epoch 36, minibatch 82/83, validation error 0.080313  traing error 0.081145 
randomly shuffling...

 epoch 37, minibatch 82/83, validation error 0.080104  traing error 0.080100 
randomly shuffling...

 epoch 38, minibatch 82/83, validation error 0.078437  traing error 0.080020 
randomly shuffling...

 epoch 39, minibatch 82/83, validation error 0.079375  traing error 0.079618 
randomly shuffling...

 epoch 40, minibatch 82/83, validation error 0.078542  traing error 0.078755 
randomly shuffling...

 epoch 41, minibatch 82/83, validation error 0.077708  traing error 0.077570 
randomly shuffling...

 epoch 42, minibatch 82/83, validation error 0.077187  traing error 0.077631 
randomly shuffling...

 epoch 43, minibatch 82/83, validation error 0.076354  traing error 0.076566 
randomly shuffling...

 epoch 44, minibatch 82/83, validation error 0.076667  traing error 0.075904 
randomly shuffling...

 epoch 45, minibatch 82/83, validation error 0.075833  traing error 0.075020 
randomly shuffling...

 epoch 46, minibatch 82/83, validation error 0.075625  traing error 0.074819 
randomly shuffling...

 epoch 47, minibatch 82/83, validation error 0.074792  traing error 0.074177 
randomly shuffling...

 epoch 48, minibatch 82/83, validation error 0.074063  traing error 0.073293 
randomly shuffling...

 epoch 49, minibatch 82/83, validation error 0.071875  traing error 0.072932 
randomly shuffling...

 epoch 50, minibatch 82/83, validation error 0.071354  traing error 0.072008 
randomly shuffling...

 epoch 51, minibatch 82/83, validation error 0.071354  traing error 0.071285 
randomly shuffling...

 epoch 52, minibatch 82/83, validation error 0.071146  traing error 0.070703 
randomly shuffling...

 epoch 53, minibatch 82/83, validation error 0.070521  traing error 0.070261 
randomly shuffling...

 epoch 54, minibatch 82/83, validation error 0.069687  traing error 0.069257 
randomly shuffling...

 epoch 55, minibatch 82/83, validation error 0.068958  traing error 0.068333 
randomly shuffling...

 epoch 56, minibatch 82/83, validation error 0.067500  traing error 0.068514 
randomly shuffling...

 epoch 57, minibatch 82/83, validation error 0.067187  traing error 0.067129 
randomly shuffling...

 epoch 58, minibatch 82/83, validation error 0.065833  traing error 0.067048 
randomly shuffling...

 epoch 59, minibatch 82/83, validation error 0.066979  traing error 0.065863 
randomly shuffling...

 epoch 60, minibatch 82/83, validation error 0.065000  traing error 0.064920 
randomly shuffling...

 epoch 61, minibatch 82/83, validation error 0.063125  traing error 0.063835 
randomly shuffling...

 epoch 62, minibatch 82/83, validation error 0.063229  traing error 0.063614 
randomly shuffling...

 epoch 63, minibatch 82/83, validation error 0.061979  traing error 0.062329 
randomly shuffling...

 epoch 64, minibatch 82/83, validation error 0.061042  traing error 0.062490 
randomly shuffling...

 epoch 65, minibatch 82/83, validation error 0.060833  traing error 0.061707 
randomly shuffling...

 epoch 66, minibatch 82/83, validation error 0.060313  traing error 0.060542 
randomly shuffling...

 epoch 67, minibatch 82/83, validation error 0.058750  traing error 0.059659 
randomly shuffling...

 epoch 68, minibatch 82/83, validation error 0.058542  traing error 0.058835 
randomly shuffling...

 epoch 69, minibatch 82/83, validation error 0.059896  traing error 0.058755 
randomly shuffling...

 epoch 70, minibatch 82/83, validation error 0.058125  traing error 0.058273 
randomly shuffling...

 epoch 71, minibatch 82/83, validation error 0.058229  traing error 0.058193 
randomly shuffling...

 epoch 72, minibatch 82/83, validation error 0.056250  traing error 0.056466 
randomly shuffling...

 epoch 73, minibatch 82/83, validation error 0.055417  traing error 0.056165 
randomly shuffling...

 epoch 74, minibatch 82/83, validation error 0.054792  traing error 0.056084 
randomly shuffling...

 epoch 75, minibatch 82/83, validation error 0.055521  traing error 0.055221 
randomly shuffling...

 epoch 76, minibatch 82/83, validation error 0.055313  traing error 0.055060 
randomly shuffling...

 epoch 77, minibatch 82/83, validation error 0.054479  traing error 0.054116 
randomly shuffling...

 epoch 78, minibatch 82/83, validation error 0.053021  traing error 0.054217 
randomly shuffling...

 epoch 79, minibatch 82/83, validation error 0.053750  traing error 0.053133 
randomly shuffling...

 epoch 80, minibatch 82/83, validation error 0.052396  traing error 0.052349 
randomly shuffling...

 epoch 81, minibatch 82/83, validation error 0.053229  traing error 0.052831 
randomly shuffling...

 epoch 82, minibatch 82/83, validation error 0.052292  traing error 0.051526 
randomly shuffling...

 epoch 83, minibatch 82/83, validation error 0.052083  traing error 0.051084 
randomly shuffling...

 epoch 84, minibatch 82/83, validation error 0.051563  traing error 0.050823 
randomly shuffling...

 epoch 85, minibatch 82/83, validation error 0.050729  traing error 0.050301 
randomly shuffling...

 epoch 86, minibatch 82/83, validation error 0.051771  traing error 0.049699 
randomly shuffling...

 epoch 87, minibatch 82/83, validation error 0.051250  traing error 0.048956 
randomly shuffling...

 epoch 88, minibatch 82/83, validation error 0.049688  traing error 0.048835 
randomly shuffling...

 epoch 89, minibatch 82/83, validation error 0.050625  traing error 0.048173 
randomly shuffling...

 epoch 90, minibatch 82/83, validation error 0.050521  traing error 0.047751 
randomly shuffling...

 epoch 91, minibatch 82/83, validation error 0.049583  traing error 0.047108 
randomly shuffling...

 epoch 92, minibatch 82/83, validation error 0.048646  traing error 0.046747 
randomly shuffling...

 epoch 93, minibatch 82/83, validation error 0.048750  traing error 0.046426 
randomly shuffling...

 epoch 94, minibatch 82/83, validation error 0.048333  traing error 0.045723 
randomly shuffling...

 epoch 95, minibatch 82/83, validation error 0.047917  traing error 0.045683 
randomly shuffling...

 epoch 96, minibatch 82/83, validation error 0.047396  traing error 0.045924 
randomly shuffling...

 epoch 97, minibatch 82/83, validation error 0.046667  traing error 0.044960 
randomly shuffling...

 epoch 98, minibatch 82/83, validation error 0.047292  traing error 0.044518 
randomly shuffling...

 epoch 99, minibatch 82/83, validation error 0.046146  traing error 0.043695 
randomly shuffling...

 epoch 100, minibatch 82/83, validation error 0.045729  traing error 0.043414 
randomly shuffling...

 epoch 101, minibatch 82/83, validation error 0.045833  traing error 0.042912 
randomly shuffling...

 epoch 102, minibatch 82/83, validation error 0.045000  traing error 0.042892 
randomly shuffling...

 epoch 103, minibatch 82/83, validation error 0.045833  traing error 0.042550 
randomly shuffling...

 epoch 104, minibatch 82/83, validation error 0.045729  traing error 0.042209 
randomly shuffling...

 epoch 105, minibatch 82/83, validation error 0.045208  traing error 0.042149 
randomly shuffling...

 epoch 106, minibatch 82/83, validation error 0.044167  traing error 0.042008 
randomly shuffling...

 epoch 107, minibatch 82/83, validation error 0.043750  traing error 0.041064 
randomly shuffling...

 epoch 108, minibatch 82/83, validation error 0.044167  traing error 0.041084 
randomly shuffling...

 epoch 109, minibatch 82/83, validation error 0.044271  traing error 0.040361 
randomly shuffling...

 epoch 110, minibatch 82/83, validation error 0.044167  traing error 0.040562 
randomly shuffling...

 epoch 111, minibatch 82/83, validation error 0.043646  traing error 0.039679 
randomly shuffling...

 epoch 112, minibatch 82/83, validation error 0.043333  traing error 0.039297 
randomly shuffling...

 epoch 113, minibatch 82/83, validation error 0.042292  traing error 0.039016 
randomly shuffling...

 epoch 114, minibatch 82/83, validation error 0.042188  traing error 0.039157 
randomly shuffling...

 epoch 115, minibatch 82/83, validation error 0.042292 

Feel free to adjust learning rate, number of layers, size of batch, or try dropout!

In []: