org.apache.mahout.classifier.sgd
Class LogisticModelParameters

java.lang.Object
  extended by org.apache.mahout.classifier.sgd.LogisticModelParameters

public class LogisticModelParameters
extends Object

Encapsulates everything we need to know about a model and how it reads and vectorizes its input. This encapsulation allows us to coherently save and restore a model from a file. This also allows us to keep command line arguments that affect learning in a coherent way.


Nested Class Summary
static class LogisticModelParameters.MatrixTypeAdapter
          Tells GSON how to (de)serialize a Mahout matrix.
 
Constructor Summary
LogisticModelParameters()
           
 
Method Summary
 OnlineLogisticRegression createRegression()
          Creates a logistic regression trainer using the parameters collected here.
 CsvRecordFactory getCsvRecordFactory()
          Returns a CsvRecordFactory compatible with this logistic model.
 double getLambda()
           
 double getLearningRate()
           
 int getMaxTargetCategories()
           
 int getNumFeatures()
           
 String getTargetVariable()
           
 Map<String,String> getTypeMap()
           
static LogisticModelParameters loadFrom(File in)
          Reads a model in JSON format from a File.
static LogisticModelParameters loadFrom(Reader in)
          Reads a model in JSON format.
static void saveModel(Writer out, OnlineLogisticRegression model, List<String> targetCategories)
           
 void saveTo(Writer out)
          Saves a model in JSON format.
 void setLambda(double lambda)
           
 void setLearningRate(double learningRate)
           
 void setMaxTargetCategories(int maxTargetCategories)
          Sets the number of target categories to be considered.
 void setNumFeatures(int numFeatures)
           
 void setTargetCategories(List<String> targetCategories)
           
 void setTargetVariable(String targetVariable)
          Sets the target variable.
 void setTypeMap(Iterable<String> predictorList, List<String> typeList)
          Sets the types of the predictors.
 void setUseBias(boolean useBias)
           
 boolean useBias()
           
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Constructor Detail

LogisticModelParameters

public LogisticModelParameters()
Method Detail

getCsvRecordFactory

public CsvRecordFactory getCsvRecordFactory()
Returns a CsvRecordFactory compatible with this logistic model. The reason that this is tied in here is so that we have access to the list of target categories when it comes time to save the model. If the input isn't CSV, then calling setTargetCategories before calling saveTo will suffice.

Returns:
The CsvRecordFactory.

createRegression

public OnlineLogisticRegression createRegression()
Creates a logistic regression trainer using the parameters collected here.

Returns:
The newly allocated OnlineLogisticRegression object

saveModel

public static void saveModel(Writer out,
                             OnlineLogisticRegression model,
                             List<String> targetCategories)
                      throws IOException
Throws:
IOException

saveTo

public void saveTo(Writer out)
            throws IOException
Saves a model in JSON format. This includes the current state of the logistic regression trainer and the dictionary for the target categories.

Parameters:
out - Where to write the model.
Throws:
IOException - If we can't write the model.

loadFrom

public static LogisticModelParameters loadFrom(Reader in)
Reads a model in JSON format.

Parameters:
in - Where to read the model from.
Returns:
The LogisticModelParameters object that we read.

loadFrom

public static LogisticModelParameters loadFrom(File in)
                                        throws IOException
Reads a model in JSON format from a File.

Parameters:
in - Where to read the model from.
Returns:
The LogisticModelParameters object that we read.
Throws:
IOException - If there is an error opening or closing the file.

setTypeMap

public void setTypeMap(Iterable<String> predictorList,
                       List<String> typeList)
Sets the types of the predictors. This will later be used when reading CSV data. If you don't use the CSV data and convert to vectors on your own, you don't need to call this.

Parameters:
predictorList - The list of variable names.
typeList - The list of types in the format preferred by CsvRecordFactory.

setTargetVariable

public void setTargetVariable(String targetVariable)
Sets the target variable. If you don't use the CSV record factory, then this is irrelevant.

Parameters:
targetVariable - The name of the target variable.

setMaxTargetCategories

public void setMaxTargetCategories(int maxTargetCategories)
Sets the number of target categories to be considered.

Parameters:
maxTargetCategories - The number of target categories.

setNumFeatures

public void setNumFeatures(int numFeatures)

setTargetCategories

public void setTargetCategories(List<String> targetCategories)

setUseBias

public void setUseBias(boolean useBias)

useBias

public boolean useBias()

getTargetVariable

public String getTargetVariable()

getTypeMap

public Map<String,String> getTypeMap()

getNumFeatures

public int getNumFeatures()

getMaxTargetCategories

public int getMaxTargetCategories()

getLambda

public double getLambda()

setLambda

public void setLambda(double lambda)

getLearningRate

public double getLearningRate()

setLearningRate

public void setLearningRate(double learningRate)


Copyright © 2008-2010 The Apache Software Foundation. All Rights Reserved.