org.apache.mahout.classifier.discriminative
Class LinearTrainer

java.lang.Object
  extended by org.apache.mahout.classifier.discriminative.LinearTrainer
Direct Known Subclasses:
PerceptronTrainer, WinnowTrainer

public abstract class LinearTrainer
extends Object

Implementors of this class need to provide a way to train linear discriminative classifiers. As this is just the reference implementation we assume that the dataset fits into main memory - this should be the first thing to change when switching to Hadoop.


Constructor Summary
protected LinearTrainer(int dimension, double threshold, double init, double initBias)
          Initialize the trainer.
 
Method Summary
 LinearModel getModel()
          Retrieves the trained model if called after train, otherwise the raw model.
 void train(Vector labelset, Matrix dataset)
          Initializes training.
protected abstract  void update(double label, Vector dataPoint, LinearModel model)
          Implement this method to match your training strategy.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Constructor Detail

LinearTrainer

protected LinearTrainer(int dimension,
                        double threshold,
                        double init,
                        double initBias)
Initialize the trainer. Distance is initialized to cosine distance, all weights are represented through a dense vector.

Parameters:
dimension - number of expected features.
threshold - threshold to use for classification.
init - initial value of weight vector.
initBias - initial classification bias.
Method Detail

train

public void train(Vector labelset,
                  Matrix dataset)
           throws TrainingException
Initializes training. Runs through all data points in the training set and updates the weight vector whenever a classification error occurs. Can be called multiple times.

Parameters:
dataset - the dataset to train on. Each column is treated as point.
labelset - the set of labels, one for each data point. If the cardinalities of data- and labelset do not match, a CardinalityException is thrown
Throws:
TrainingException

getModel

public LinearModel getModel()
Retrieves the trained model if called after train, otherwise the raw model.


update

protected abstract void update(double label,
                               Vector dataPoint,
                               LinearModel model)
Implement this method to match your training strategy.

Parameters:
model - the model to update.
label - the target label of the wrongly classified data point.
dataPoint - the data point that was classified incorrectly.


Copyright © 2008-2012 The Apache Software Foundation. All Rights Reserved.