|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.Objectorg.apache.mahout.classifier.sgd.LogisticModelParameters
public class LogisticModelParameters
Encapsulates everything we need to know about a model and how it reads and vectorizes its input. This encapsulation allows us to coherently save and restore a model from a file. This also allows us to keep command line arguments that affect learning in a coherent way.
Nested Class Summary | |
---|---|
static class |
LogisticModelParameters.MatrixTypeAdapter
Tells GSON how to (de)serialize a Mahout matrix. |
Constructor Summary | |
---|---|
LogisticModelParameters()
|
Method Summary | |
---|---|
OnlineLogisticRegression |
createRegression()
Creates a logistic regression trainer using the parameters collected here. |
CsvRecordFactory |
getCsvRecordFactory()
Returns a CsvRecordFactory compatible with this logistic model. |
double |
getLambda()
|
double |
getLearningRate()
|
int |
getMaxTargetCategories()
|
int |
getNumFeatures()
|
String |
getTargetVariable()
|
Map<String,String> |
getTypeMap()
|
static LogisticModelParameters |
loadFrom(File in)
Reads a model in JSON format from a File. |
static LogisticModelParameters |
loadFrom(Reader in)
Reads a model in JSON format. |
static void |
saveModel(Writer out,
OnlineLogisticRegression model,
List<String> targetCategories)
|
void |
saveTo(Writer out)
Saves a model in JSON format. |
void |
setLambda(double lambda)
|
void |
setLearningRate(double learningRate)
|
void |
setMaxTargetCategories(int maxTargetCategories)
Sets the number of target categories to be considered. |
void |
setNumFeatures(int numFeatures)
|
void |
setTargetCategories(List<String> targetCategories)
|
void |
setTargetVariable(String targetVariable)
Sets the target variable. |
void |
setTypeMap(Iterable<String> predictorList,
List<String> typeList)
Sets the types of the predictors. |
void |
setUseBias(boolean useBias)
|
boolean |
useBias()
|
Methods inherited from class java.lang.Object |
---|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
Constructor Detail |
---|
public LogisticModelParameters()
Method Detail |
---|
public CsvRecordFactory getCsvRecordFactory()
public OnlineLogisticRegression createRegression()
public static void saveModel(Writer out, OnlineLogisticRegression model, List<String> targetCategories) throws IOException
IOException
public void saveTo(Writer out) throws IOException
out
- Where to write the model.
IOException
- If we can't write the model.public static LogisticModelParameters loadFrom(Reader in)
in
- Where to read the model from.
public static LogisticModelParameters loadFrom(File in) throws IOException
in
- Where to read the model from.
IOException
- If there is an error opening or closing the file.public void setTypeMap(Iterable<String> predictorList, List<String> typeList)
predictorList
- The list of variable names.typeList
- The list of types in the format preferred by CsvRecordFactory.public void setTargetVariable(String targetVariable)
targetVariable
- The name of the target variable.public void setMaxTargetCategories(int maxTargetCategories)
maxTargetCategories
- The number of target categories.public void setNumFeatures(int numFeatures)
public void setTargetCategories(List<String> targetCategories)
public void setUseBias(boolean useBias)
public boolean useBias()
public String getTargetVariable()
public Map<String,String> getTypeMap()
public int getNumFeatures()
public int getMaxTargetCategories()
public double getLambda()
public void setLambda(double lambda)
public double getLearningRate()
public void setLearningRate(double learningRate)
|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |