import os
import struct
from array import array

class MNIST(object):
    def __init__(self, train_data_path, train_label_path, test_data_path, test_label_path):
        self.train_data_path = train_data_path
        self.train_label_path = train_label_path
        self.test_data_path = test_data_path
        self.test_label_path = test_label_path

        self.test_images = []
        self.test_labels = []

        self.train_images = []
        self.train_labels = []

    def load_testing(self):
        ims, labels = self.load(self.test_data_path, self.test_label_path)

        self.test_images = ims
        self.test_labels = labels

        return ims, labels

    def load_training(self):
        ims, labels = self.load(self.train_data_path, self.train_label_path)

        self.train_images = ims
        self.train_labels = labels

        return ims, labels

    @classmethod
    def load(cls, path_img, path_lbl):
        with open(path_lbl, 'rb') as file:
            magic, size = struct.unpack(">II", file.read(8))
            if magic != 2049:
                raise ValueError('Magic number mismatch, expected 2049,'
                    'got %d' % magic)

            labels = array("B", file.read())

        with open(path_img, 'rb') as file:
            magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
            if magic != 2051:
                raise ValueError('Magic number mismatch, expected 2051,'
                    'got %d' % magic)

            image_data = array("B", file.read())

        images = []
        for i in xrange(size):
            images.append([0]*rows*cols)

        for i in xrange(size):
            # scale to be in [0,255] for every pixel
            images[i][:] = [x/255. for x in image_data[i*rows*cols : (i+1)*rows*cols]]

        return images, labels

    def print_img0(self):
        print "Mnist loader first image:"
        test_img, test_label = self.load_testing()
        train_img, train_label = self.load_training()
        assert len(test_img) == len(test_label)
        assert len(test_img) == 10000
        assert len(train_img) == len(train_label)
        assert len(train_img) == 60000
        print 'Showing num:%d' % train_label[0]
        render =  self.display(train_img[0])
        print render
        print
        return render 

    @classmethod
    def display(cls, img, width=28):
        render = ''
        for i in range(len(img)):
            if i % width == 0: render += '\n'
            render += "{:.3f}".format(img[i]) + " "
        return render
    
