FeaturesType
- Type of input features. E.g., Vector
M
- Concrete Model typepublic abstract class ClassificationModel<FeaturesType,M extends ClassificationModel<FeaturesType,M>> extends PredictionModel<FeaturesType,M> implements ClassifierParams
Classifier
.
Classes are indexed {0, 1, ..., numClasses - 1}.
Constructor and Description |
---|
ClassificationModel() |
Modifier and Type | Method and Description |
---|---|
abstract int |
numClasses()
Number of classes (values which the label can take).
|
double |
predict(FeaturesType features)
Predict label for the given features.
|
abstract Vector |
predictRaw(FeaturesType features)
Raw prediction for each possible label.
|
Param<String> |
rawPredictionCol()
Param for raw prediction (a.k.a.
|
M |
setRawPredictionCol(String value) |
Dataset<Row> |
transform(Dataset<?> dataset)
Transforms dataset by reading from
featuresCol , and appending new columns as specified by
parameters:
- predicted labels as predictionCol of type Double
- raw predictions (confidences) as rawPredictionCol of type Vector . |
Dataset<Row> |
transformImpl(Dataset<?> dataset) |
StructType |
transformSchema(StructType schema)
Check transform validity and derive the output schema from the input schema.
|
featuresCol, labelCol, numFeatures, predictionCol, setFeaturesCol, setPredictionCol
transform, transform, transform
params
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
extractInstances, validateAndTransformSchema
extractInstances, extractInstances
getLabelCol, labelCol
featuresCol, getFeaturesCol
getPredictionCol, predictionCol
clear, copy, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
toString, uid
getRawPredictionCol
$init$, initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, initLock, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning, org$apache$spark$internal$Logging$$log__$eq, org$apache$spark$internal$Logging$$log_, uninitialize
public abstract int numClasses()
public double predict(FeaturesType features)
transform()
and output predictionCol
.
This default implementation for classification predicts the index of the maximum value
from predictRaw()
.
predict
in class PredictionModel<FeaturesType,M extends ClassificationModel<FeaturesType,M>>
features
- (undocumented)public abstract Vector predictRaw(FeaturesType features)
transform()
and output rawPredictionCol
.
features
- (undocumented)public final Param<String> rawPredictionCol()
HasRawPredictionCol
rawPredictionCol
in interface HasRawPredictionCol
public M setRawPredictionCol(String value)
public Dataset<Row> transform(Dataset<?> dataset)
featuresCol
, and appending new columns as specified by
parameters:
- predicted labels as predictionCol
of type Double
- raw predictions (confidences) as rawPredictionCol
of type Vector
.
transform
in class PredictionModel<FeaturesType,M extends ClassificationModel<FeaturesType,M>>
dataset
- input datasetpublic StructType transformSchema(StructType schema)
PipelineStage
We check validity for interactions between parameters during transformSchema
and
raise an exception if any parameter value is invalid. Parameter value checks which
do not depend on other parameters are handled by Param.validate()
.
Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks.
transformSchema
in class PredictionModel<FeaturesType,M extends ClassificationModel<FeaturesType,M>>
schema
- (undocumented)