|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.Objectorg.apache.mahout.classifier.AbstractVectorClassifier
org.apache.mahout.classifier.sgd.CrossFoldLearner
public class CrossFoldLearner
Does cross-fold validation of log-likelihood and AUC on several online logistic regression models. Each record is passed to all but one of the models for training and to the remaining model for evaluation. In order to maintain proper segregation between the different folds across training data iterations, data should either be passed to this learner in the same order each time the training data is traversed or a tracking key such as the file offset of the training record should be passed with each training example.
Constructor Summary | |
---|---|
CrossFoldLearner()
|
|
CrossFoldLearner(int folds,
int numCategories,
int numFeatures,
PriorFunction prior)
|
Method Summary | |
---|---|
void |
addModel(OnlineLogisticRegression model)
|
CrossFoldLearner |
alpha(double alpha)
|
double |
auc()
|
Vector |
classify(Vector instance)
Classify a vector returning a vector of numCategories-1 scores. |
Vector |
classifyNoLink(Vector instance)
Classify a vector, but don't apply the inverse link function. |
double |
classifyScalar(Vector instance)
Classifies a vector in the special case of a binary classifier where classify(Vector) would return a vector with only one element. |
void |
close()
Prepares the classifier for classification and deallocates any temporary data structures. |
CrossFoldLearner |
copy()
|
CrossFoldLearner |
decayExponent(double x)
|
OnlineAuc |
getAucEvaluator()
|
double |
getLogLikelihood()
|
java.util.List<OnlineLogisticRegression> |
getModels()
|
int |
getNumFeatures()
|
double[] |
getParameters()
|
PriorFunction |
getPrior()
|
int |
getRecord()
|
CrossFoldLearner |
lambda(double v)
|
CrossFoldLearner |
learningRate(double x)
|
double |
logLikelihood()
|
int |
numCategories()
Returns the number of categories for the target variable. |
double |
percentCorrect()
|
void |
resetLineCounter()
|
void |
setAucEvaluator(OnlineAuc auc)
|
void |
setLogLikelihood(double logLikelihood)
|
void |
setNumFeatures(int numFeatures)
|
void |
setParameters(double[] parameters)
|
void |
setPrior(PriorFunction prior)
|
void |
setRecord(int record)
|
void |
setWindowSize(int windowSize)
|
CrossFoldLearner |
stepOffset(int x)
|
void |
train(int actual,
Vector instance)
Updates the model using a particular target variable value and a feature vector. |
void |
train(long trackingKey,
int actual,
Vector instance)
Updates the model using a particular target variable value and a feature vector. |
void |
train(long trackingKey,
java.lang.String groupKey,
int actual,
Vector instance)
Updates the model using a particular target variable value and a feature vector. |
boolean |
validModel()
|
Methods inherited from class org.apache.mahout.classifier.AbstractVectorClassifier |
---|
classify, classifyFull, classifyFull, classifyFull, classifyScalar, logLikelihood |
Methods inherited from class java.lang.Object |
---|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
Constructor Detail |
---|
public CrossFoldLearner()
public CrossFoldLearner(int folds, int numCategories, int numFeatures, PriorFunction prior)
Method Detail |
---|
public CrossFoldLearner lambda(double v)
public CrossFoldLearner learningRate(double x)
public CrossFoldLearner stepOffset(int x)
public CrossFoldLearner decayExponent(double x)
public CrossFoldLearner alpha(double alpha)
public void train(int actual, Vector instance)
OnlineLearner
train
in interface OnlineLearner
actual
- The value of the target variable. This value should be in the half-open
interval [0..n) where n is the number of target categories.instance
- The feature vector for this example.public void train(long trackingKey, int actual, Vector instance)
OnlineLearner
train
in interface OnlineLearner
trackingKey
- The tracking key for this training example.actual
- The value of the target variable. This value should be in the half-open
interval [0..n) where n is the number of target categories.instance
- The feature vector for this example.public void train(long trackingKey, java.lang.String groupKey, int actual, Vector instance)
OnlineLearner
train
in interface OnlineLearner
trackingKey
- The tracking key for this training example.groupKey
- An optional value that allows examples to be grouped in the computation of
the update to the model.actual
- The value of the target variable. This value should be in the half-open
interval [0..n) where n is the number of target categories.instance
- The feature vector for this example.public void close()
OnlineLearner
close
in interface OnlineLearner
public void resetLineCounter()
public boolean validModel()
public Vector classify(Vector instance)
AbstractVectorClassifier
classify
in class AbstractVectorClassifier
instance
- A feature vector to be classified.
public Vector classifyNoLink(Vector instance)
AbstractVectorClassifier
classifyNoLink
in class AbstractVectorClassifier
instance
- A feature vector to be classified.
public double classifyScalar(Vector instance)
AbstractVectorClassifier
classify(Vector)
would return a vector with only one element. As such,
using this method can void the allocation of a vector.
classifyScalar
in class AbstractVectorClassifier
instance
- The feature vector to be classified.
AbstractVectorClassifier.classify(Vector)
public int numCategories()
AbstractVectorClassifier
numCategories
in class AbstractVectorClassifier
public double auc()
public double logLikelihood()
public double percentCorrect()
public CrossFoldLearner copy()
public int getRecord()
public void setRecord(int record)
public OnlineAuc getAucEvaluator()
public void setAucEvaluator(OnlineAuc auc)
public double getLogLikelihood()
public void setLogLikelihood(double logLikelihood)
public java.util.List<OnlineLogisticRegression> getModels()
public void addModel(OnlineLogisticRegression model)
public double[] getParameters()
public void setParameters(double[] parameters)
public int getNumFeatures()
public void setNumFeatures(int numFeatures)
public void setWindowSize(int windowSize)
public PriorFunction getPrior()
public void setPrior(PriorFunction prior)
|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |