|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.Objectorg.apache.mahout.classifier.AbstractVectorClassifier
org.apache.mahout.classifier.sgd.AbstractOnlineLogisticRegression
public abstract class AbstractOnlineLogisticRegression
Generic definition of a 1 of n logistic regression classifier that returns probabilities in response to a feature vector. This classifier uses 1 of n-1 coding where the 0-th category is not stored explicitly.
Provides the based SGD based algorithm for learning a logistic regression, but omits all annealing of learning rates. Any extension of this abstract class must define the overall and per-term annealing for themselves.
Field Summary | |
---|---|
protected Matrix |
beta
|
protected int |
numCategories
|
protected PriorFunction |
prior
|
protected Vector |
updateCounts
|
protected Vector |
updateSteps
|
Constructor Summary | |
---|---|
AbstractOnlineLogisticRegression()
|
Method Summary | |
---|---|
Vector |
classify(Vector instance)
Returns n-1 probabilities, one for each category but the 0-th. |
Vector |
classifyNoLink(Vector instance)
Classify a vector, but don't apply the inverse link function. |
double |
classifyScalar(Vector instance)
Returns a single scalar probability in the case where we have two categories. |
double |
classifyScalarNoLink(Vector instance)
|
void |
close()
Prepares the classifier for classification and deallocates any temporary data structures. |
void |
copyFrom(AbstractOnlineLogisticRegression other)
|
abstract double |
currentLearningRate()
|
Matrix |
getBeta()
|
double |
getLambda()
|
PriorFunction |
getPrior()
|
int |
getStep()
|
boolean |
isSealed()
|
AbstractOnlineLogisticRegression |
lambda(double lambda)
Chainable configuration option. |
double |
link(double r)
Computes the binomial logistic inverse link function. |
Vector |
link(Vector v)
Computes the inverse link function, by default the logistic link function. |
protected void |
nextStep()
|
int |
numCategories()
Returns the number of categories for the target variable. |
int |
numFeatures()
|
abstract double |
perTermLearningRate(int j)
|
void |
regularize(Vector instance)
|
void |
setBeta(int i,
int j,
double betaIJ)
|
void |
setGradient(Gradient gradient)
|
void |
setPrior(PriorFunction prior)
|
void |
train(int actual,
Vector instance)
Updates the model using a particular target variable value and a feature vector. |
void |
train(long trackingKey,
int actual,
Vector instance)
Updates the model using a particular target variable value and a feature vector. |
void |
train(long trackingKey,
java.lang.String groupKey,
int actual,
Vector instance)
Updates the model using a particular target variable value and a feature vector. |
protected void |
unseal()
|
boolean |
validModel()
|
Methods inherited from class org.apache.mahout.classifier.AbstractVectorClassifier |
---|
classify, classifyFull, classifyFull, classifyFull, classifyScalar, logLikelihood |
Methods inherited from class java.lang.Object |
---|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
Field Detail |
---|
protected Matrix beta
protected int numCategories
protected Vector updateSteps
protected Vector updateCounts
protected PriorFunction prior
Constructor Detail |
---|
public AbstractOnlineLogisticRegression()
Method Detail |
---|
public AbstractOnlineLogisticRegression lambda(double lambda)
lambda
- New value of lambda, the weighting factor for the prior distribution.
public Vector link(Vector v)
v
- The output of the linear combination in a GLM. Note that the value
of v is disturbed.
public double link(double r)
r
- The value to transform.
public Vector classifyNoLink(Vector instance)
AbstractVectorClassifier
classifyNoLink
in class AbstractVectorClassifier
instance
- A feature vector to be classified.
public double classifyScalarNoLink(Vector instance)
public Vector classify(Vector instance)
classify
in class AbstractVectorClassifier
instance
- A vector of features to be classified.
public double classifyScalar(Vector instance)
classifyScalar
in class AbstractVectorClassifier
instance
- The vector of features to be classified.
java.lang.IllegalArgumentException
- If the classifier doesn't have two categories.AbstractVectorClassifier.classify(Vector)
public void train(long trackingKey, java.lang.String groupKey, int actual, Vector instance)
OnlineLearner
train
in interface OnlineLearner
trackingKey
- The tracking key for this training example.groupKey
- An optional value that allows examples to be grouped in the computation of
the update to the model.actual
- The value of the target variable. This value should be in the half-open
interval [0..n) where n is the number of target categories.instance
- The feature vector for this example.public void train(long trackingKey, int actual, Vector instance)
OnlineLearner
train
in interface OnlineLearner
trackingKey
- The tracking key for this training example.actual
- The value of the target variable. This value should be in the half-open
interval [0..n) where n is the number of target categories.instance
- The feature vector for this example.public void train(int actual, Vector instance)
OnlineLearner
train
in interface OnlineLearner
actual
- The value of the target variable. This value should be in the half-open
interval [0..n) where n is the number of target categories.instance
- The feature vector for this example.public void regularize(Vector instance)
public abstract double perTermLearningRate(int j)
public abstract double currentLearningRate()
public void setPrior(PriorFunction prior)
public void setGradient(Gradient gradient)
public PriorFunction getPrior()
public Matrix getBeta()
public void setBeta(int i, int j, double betaIJ)
public int numCategories()
AbstractVectorClassifier
numCategories
in class AbstractVectorClassifier
public int numFeatures()
public double getLambda()
public int getStep()
protected void nextStep()
public boolean isSealed()
protected void unseal()
public void close()
OnlineLearner
close
in interface OnlineLearner
public void copyFrom(AbstractOnlineLogisticRegression other)
public boolean validModel()
|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |