#!python

from optparse import OptionParser
import pickle

from lightning.classification import CDClassifier
from lightning.classification import FistaClassifier
from lightning.classification import SGDClassifier


try:
    from svmlight_loader import load_svmlight_file
except ImportError:
    from sklearn.datasets import load_svmlight_file

parser = OptionParser("lightning_train [options] train_file model_file")

parser.add_option("", "--solver",
                  action="store", type="str", dest="solver", default="bcd-ls",
                  help="Solver to use (bcd-ls, bcd-cst, fista, sgd). "
                  "Default: bcd-ls.")

parser.add_option("", "--loss",
                  action="store", type="str", dest="loss",
                  default="squared_hinge",
                  help="Loss function to use (squared_hinge, log). "
                  "Default: squared_hinge.")

parser.add_option("", "--penalty",
                  action="store", type="str", dest="penalty",
                  default="l1/l2",
                  help="Penalty to use (l2, l1, l1/l2). Default: l1/l2.")

parser.add_option("", "--alpha",
                  action="store", type="float", dest="alpha",
                  default=1e-4,
                  help="Strength of the penalty. Default: 1e-4.")

parser.add_option("", "--ovr",
                  action="store_true", dest="ovr",
                  help="Whether to activate one-vs-rest instead of direct "
                  "multiclass loss.")

parser.add_option("", "--max-iter",
                  action="store", type="int", dest="max_iter", default=10,
                  help="Maximum number of outer iterations. Default: 10.")

parser.add_option("", "--tol",
                  action="store", type="float", dest="tol", default=0.05,
                  help="Tolerance of the stopping criterion. Default: 0.05.")

parser.add_option("", "--eta0",
                  action="store", type="float", dest="eta0", default=1e-3,
                  help="Step size (SGD only). Default: 1e-3.")

(opts, args) = parser.parse_args()

if len(args) != 2:
    parser.print_help()
    exit()

X, y = load_svmlight_file(args[0])

multiclass = not opts.ovr

if opts.solver.startswith("bcd"):
    if opts.solver == "bcd-cst":
        max_steps = 0
    else:
        max_steps = 30

    clf = CDClassifier(penalty=opts.penalty,
                       loss=opts.loss,
                       multiclass=multiclass,
                       max_iter=opts.max_iter,
                       max_steps=max_steps,
                       alpha=opts.alpha,
                       C=1.0 / X.shape[0],
                       tol=opts.tol)
elif opts.solver == "fista":
    clf = FistaClassifier(penalty=opts.penalty,
                          loss=opts.loss,
                          multiclass=multiclass,
                          max_iter=opts.max_iter,
                          alpha=opts.alpha,
                          C=1.0 / X.shape[0])
elif opts.solver == "sgd":
    clf = SGDClassifier(penalty=opts.penalty,
                        loss=opts.loss,
                        multiclass=multiclass,
                        max_iter=opts.max_iter,
                        alpha=opts.alpha,
                        learning_rate="constant",
                        eta0=opts.eta0)
else:
    print "Invalid solver"
    exit()

clf.fit(X, y)

print "Percentage of non-zero rows:"
print clf.n_nonzero(percentage=True)

pickle.dump(clf, file(args[1], "w"), pickle.HIGHEST_PROTOCOL)
