public class LogisticRegressionModel extends ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
LogisticRegression.| Modifier and Type | Method and Description |
|---|---|
void |
checkThresholdConsistency()
If
threshold and thresholds are both set, ensures they are consistent. |
LogisticRegressionModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
double |
getThreshold()
Get threshold for binary classification.
|
double[] |
getThresholds()
Get thresholds for binary or multiclass classification.
|
boolean |
hasSummary()
Indicates whether a training summary exists for this model instance.
|
double |
intercept() |
int |
numClasses()
Number of classes (values which the label can take).
|
protected double |
predict(Vector features)
Predict label for the given feature vector.
|
protected Vector |
predictRaw(Vector features)
Raw prediction for each possible label.
|
protected double |
probability2prediction(Vector probability)
Given a vector of class conditional probabilities, select the predicted label.
|
protected double |
raw2prediction(Vector rawPrediction)
Given a vector of raw predictions, select the predicted label.
|
protected Vector |
raw2probabilityInPlace(Vector rawPrediction)
Estimate the probability of each class given the raw prediction,
doing the computation in-place.
|
LogisticRegressionModel |
setThreshold(double value)
Set threshold in binary classification, in range [0, 1].
|
LogisticRegressionModel |
setThresholds(double[] value)
Set thresholds in multiclass (or binary) classification to adjust the probability of
predicting each class.
|
LogisticRegressionTrainingSummary |
summary()
Gets summary of model on training set.
|
java.lang.String |
uid()
An immutable unique ID for the object and its derivatives.
|
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
void |
validateParams() |
Vector |
weights() |
normalizeToProbabilitiesInPlace, predictProbability, raw2probability, setProbabilityCol, transformsetRawPredictionColfeaturesDataType, setFeaturesCol, setPredictionCol, transformImpl, transformSchematransform, transform, transformtransformSchemaclone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitclear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn, validateParamstoStringinitializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarningpublic java.lang.String uid()
Identifiableuid in interface Identifiablepublic Vector weights()
public double intercept()
public LogisticRegressionModel setThreshold(double value)
If the estimated probability of class label 1 is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 more often; a low threshold encourages the model to predict 1 more often.
Note: Calling this with threshold p is equivalent to calling setThresholds(Array(1-p, p)).
When setThreshold() is called, any user-set value for thresholds will be cleared.
If both threshold and thresholds are set in a ParamMap, then they must be
equivalent.
Default is 0.5.
value - (undocumented)public double getThreshold()
If threshold is set, returns that value.
Otherwise, if thresholds is set with length 2 (i.e., binary classification),
this returns the equivalent threshold:
1 / (1 + thresholds(0) / thresholds(1)).
Otherwise, returns {@link threshold} default value.
@group getParam
@throws IllegalArgumentException if {@link thresholds} is set to an array of length other than 2.public LogisticRegressionModel setThresholds(double[] value)
Note: When setThresholds() is called, any user-set value for threshold will be cleared.
If both threshold and thresholds are set in a ParamMap, then they must be
equivalent.
setThresholds in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>value - (undocumented)public double[] getThresholds()
If thresholds is set, return its value.
Otherwise, if threshold is set, return the equivalent thresholds for binary
classification: (1-threshold, threshold).
If neither are set, throw an exception.
public int numClasses()
ClassificationModelnumClasses in class ClassificationModel<Vector,LogisticRegressionModel>public LogisticRegressionTrainingSummary summary()
trainingSummary == None.public boolean hasSummary()
protected double predict(Vector features)
thresholds.predict in class ClassificationModel<Vector,LogisticRegressionModel>features - (undocumented)protected Vector raw2probabilityInPlace(Vector rawPrediction)
ProbabilisticClassificationModel
This internal method is used to implement transform() and output probabilityCol.
raw2probabilityInPlace in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>rawPrediction - (undocumented)protected Vector predictRaw(Vector features)
ClassificationModeltransform() and output rawPredictionCol.
predictRaw in class ClassificationModel<Vector,LogisticRegressionModel>features - (undocumented)public LogisticRegressionModel copy(ParamMap extra)
Paramscopy in interface Paramscopy in class Model<LogisticRegressionModel>extra - (undocumented)defaultCopy()protected double raw2prediction(Vector rawPrediction)
ClassificationModelraw2prediction in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>rawPrediction - (undocumented)protected double probability2prediction(Vector probability)
ProbabilisticClassificationModelprobability2prediction in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>probability - (undocumented)public void checkThresholdConsistency()
threshold and thresholds are both set, ensures they are consistent.java.lang.IllegalArgumentException - if threshold and thresholds are not equivalentpublic void validateParams()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema - input schemafitting - whether this is in fittingfeaturesDataType - SQL DataType for FeaturesType.
E.g., VectorUDT for vector features.