|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.Objectorg.apache.mahout.classifier.sgd.AdaptiveLogisticRegression
public class AdaptiveLogisticRegression
This is a meta-learner that maintains a pool of ordinary OnlineLogisticRegression learners. Each member of the pool has different learning rates. Whichever of the learners in the pool falls behind in terms of average log-likelihood will be tossed out and replaced with variants of the survivors. This will let us automatically derive an annealing schedule that optimizes learning speed. Since on-line learners tend to be IO bound anyway, it doesn't cost as much as it might seem that it would to maintain multiple learners in memory. Doing this adaptation on-line as we learn also decreases the number of learning rate parameters required and replaces the normal hyper-parameter search.
One wrinkle is that the pool of learners that we maintain is actually a pool of CrossFoldLearners which themselves contain several OnlineLogisticRegression objects. These pools allow estimation of performance on the fly even if we make many passes through the data. This does, however, increase the cost of training since if we are using 5-fold cross-validation, each vector is used 4 times for training and once for classification. If this becomes a problem, then we should probably use a 2-way unbalanced train/test split rather than full cross validation. With the current default settings, we have 100 learners running. This is better than the alternative of running hundreds of training passes to find good hyper-parameters because we only have to parse and feature-ize our inputs once. If you already have good hyper-parameters, then you might prefer to just run one CrossFoldLearner with those settings. The fitness used here is AUC. Another alternative would be to try log-likelihood, but it is much easier to get bogus values of log-likelihood than with AUC and the results seem to accord pretty well. It would be nice to allow the fitness function to be pluggable. This use of AUC means that AdaptiveLogisticRegression is mostly suited for binary target variables. This will be fixed before long by extending OnlineAuc to handle non-binary cases or by using a different fitness value in non-binary cases.
Nested Class Summary | |
---|---|
static class |
AdaptiveLogisticRegression.TrainingExample
|
static class |
AdaptiveLogisticRegression.Wrapper
Provides a shim between the EP optimization stuff and the CrossFoldLearner. |
Constructor Summary | |
---|---|
AdaptiveLogisticRegression(int numCategories,
int numFeatures,
PriorFunction prior)
|
Method Summary | |
---|---|
double |
auc()
What is the AUC for the current best member of the population. |
void |
close()
Prepares the classifier for classification and deallocates any temporary data structures. |
State<AdaptiveLogisticRegression.Wrapper> |
getBest()
|
java.util.List<AdaptiveLogisticRegression.TrainingExample> |
getBuffer()
|
EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper> |
getEp()
|
int |
getMaxInterval()
|
int |
getMinInterval()
|
int |
getNumCategories()
|
int |
getNumFeatures()
|
PriorFunction |
getPrior()
|
int |
getRecord()
|
State<AdaptiveLogisticRegression.Wrapper> |
getSeed()
|
int |
nextStep(int recordNumber)
|
int |
numFeatures()
Returns the size of the internal feature vector. |
void |
setAucEvaluator(OnlineAuc auc)
|
void |
setAveragingWindow(int averagingWindow)
|
void |
setBest(State<AdaptiveLogisticRegression.Wrapper> best)
|
void |
setBuffer(java.util.List<AdaptiveLogisticRegression.TrainingExample> buffer)
|
void |
setEp(EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper> ep)
|
void |
setFreezeSurvivors(boolean freezeSurvivors)
|
void |
setInterval(int interval)
How often should the evolutionary optimization of learning parameters occur? |
void |
setInterval(int minInterval,
int maxInterval)
Starts optimization using the shorter interval and progresses to the longer using the specified number of steps per decade. |
void |
setPoolSize(int poolSize)
|
void |
setRecord(int record)
|
void |
setSeed(State<AdaptiveLogisticRegression.Wrapper> seed)
|
void |
setThreadCount(int threadCount)
|
static int |
stepSize(int recordNumber,
double multiplier)
|
void |
train(int actual,
Vector instance)
Updates the model using a particular target variable value and a feature vector. |
void |
train(long trackingKey,
int actual,
Vector instance)
Updates the model using a particular target variable value and a feature vector. |
void |
train(long trackingKey,
java.lang.String groupKey,
int actual,
Vector instance)
Updates the model using a particular target variable value and a feature vector. |
Methods inherited from class java.lang.Object |
---|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
Constructor Detail |
---|
public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior)
Method Detail |
---|
public void train(int actual, Vector instance)
OnlineLearner
train
in interface OnlineLearner
actual
- The value of the target variable. This value should be in the half-open
interval [0..n) where n is the number of target categories.instance
- The feature vector for this example.public void train(long trackingKey, int actual, Vector instance)
OnlineLearner
train
in interface OnlineLearner
trackingKey
- The tracking key for this training example.actual
- The value of the target variable. This value should be in the half-open
interval [0..n) where n is the number of target categories.instance
- The feature vector for this example.public void train(long trackingKey, java.lang.String groupKey, int actual, Vector instance)
OnlineLearner
train
in interface OnlineLearner
trackingKey
- The tracking key for this training example.groupKey
- An optional value that allows examples to be grouped in the computation of
the update to the model.actual
- The value of the target variable. This value should be in the half-open
interval [0..n) where n is the number of target categories.instance
- The feature vector for this example.public int nextStep(int recordNumber)
public static int stepSize(int recordNumber, double multiplier)
public void close()
OnlineLearner
close
in interface OnlineLearner
public void setInterval(int interval)
interval
- Number of training examples to use in each epoch of optimization.public void setInterval(int minInterval, int maxInterval)
minInterval
- The minimum epoch length for the evolutionary optimizationmaxInterval
- The maximum epoch lengthpublic void setPoolSize(int poolSize)
public void setThreadCount(int threadCount)
public void setAucEvaluator(OnlineAuc auc)
public int numFeatures()
public double auc()
public State<AdaptiveLogisticRegression.Wrapper> getBest()
public void setBest(State<AdaptiveLogisticRegression.Wrapper> best)
public int getRecord()
public void setRecord(int record)
public int getMinInterval()
public int getMaxInterval()
public int getNumCategories()
public PriorFunction getPrior()
public void setBuffer(java.util.List<AdaptiveLogisticRegression.TrainingExample> buffer)
public java.util.List<AdaptiveLogisticRegression.TrainingExample> getBuffer()
public EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper> getEp()
public void setEp(EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper> ep)
public State<AdaptiveLogisticRegression.Wrapper> getSeed()
public void setSeed(State<AdaptiveLogisticRegression.Wrapper> seed)
public int getNumFeatures()
public void setAveragingWindow(int averagingWindow)
public void setFreezeSurvivors(boolean freezeSurvivors)
|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |