Developing estimators
Based on: Developing scikit-learn estimators (opens in a new tab)
Or “How to safely interact with Pipelines and model selection”
Estimator API
Instantiation
-
__init__accepts model constraints, but must not accept training data (reserved for fitting)✅ Do:
estimator = SVC(C=1)❌ Don’t:
estimator = SVC([[0, 1], [1, 2]]) -
Model hyper parameters should have default values, so that the user can instantiate a model without passing any arguments
-
Every parameter should directly match an attribute, without additional logic to enable
model_selection.✅ Do:
def __init__(self, p_1=1, p_2=2): self.p_1 = p_1 self.p_2 = p_2❌ Don’t
def __init__(self, p_1=1, p_2=2): if p_1 > 1: p_1 += 1 self.p_1 = p_1 self.p_3 = p_2 -
No parameter validation in
__init__, only infit
Fitting
-
For supervised
estimator = estimator.fit(X, y)or for unsupervised
estimator = estimator.fit(X) -
kwargscan be added, restricted to data dependent variableexemple: a precomputed matrix is data dependent, a tolerance criterion is not
-
The estimator holds no reference to X or y (exceptions for precomputed kernel where this data must be stored for use by the predict method)
-
Use utils
check_X_yto ensure that X and y length are consistent -
Even if the fit method doesn’t imply a target
y,y=Nonemust be set to enable pipelines -
The method must return
selffor better usability with chained operationsy_pred = SVC(C=1).fit(X_train, y_train).predict(X_test) -
Fit must be idempotent, and any new call to fit overwrites the result of the previous call (exceptions when using
warm_start=Truestrategy to speed-up next fit operations) -
Names of attributes created during fit must end with a trailing underscore:
param_ -
n_features_in_keyword can be added to make input expectations explicit
Predictor
-
For supervised or unsupervised:
prediction = predictor.predict(X) -
Classification can also offer to quantify a prediction (without applying thresholding)
prediction = predictor.predict_proba(X)
Transformer
-
For transforming or filtering data, in a supervised or unsupervised way
X_transformed = transformer.transform(X) -
When fitting and transforming can be implementing together more efficiently
X_transformed = transformer.fit_transform(X)
Rolling your own estimator
Back bone
-
Test your estimator using
check_estimatorfrom sklearn.utils.estimator_checks import check_estimator from sklearn.svm import LinearSVC check_estimator(LinearSVC()) # passes -
You can leverage the project template (opens in a new tab) to get started with all the estimator required methods.
-
You can also use inheritance from
BaseEstimator,ClassifierMixinorRegressorMixinot significantly reduce the amount of boilerplate code, including:-
get_params: take no arguments, return fit parameters.use
deep=Trueto return submodel parameters -
set_params: overwrite fit parameters -
base.clonecompatibility is enable withget_params -
_estimator_typemust be"classifier","regressor"or"clusterer"This is automatic with
ClassifierMixin,RegressorMixinorClusterMixininheritance
-
Specific estimators
-
A classifier fit can accept y with integers or string values, with the following conversion:
self.classes_, y = np.unique(y, return_inverse=True) -
A classifier predict method must return arrays containing class labels from
classes_def predict(self, X): D = self.decision_function(X) return self.classes_[np.argmax(D, axis=1)] -
In linear models, coefficient are stored in
coef_and intercept inintercept_ -
The
sklearn.utils.multiclassmodule contains useful functions for working with multiclass and multilabel problems. -
Also, check tags (opens in a new tab) to define your estimator capabilities
Coding guidelines
How new code should be written for inclusion in scikit-learn and make review easier.
Style
- Format and indentation follows PEP8 (opens in a new tab)
- Use underscores to separate words in non class names:
n_samplesrather thannsamples. - Avoid multiple statements on one line. Prefer a line return after
if/for - Use relative imports for references inside scikit-learn.
- Please don’t use
import *in any case- code becomes hard to read
- no reference for static analysis tool to run
- Use the numpy docstring standard (opens in a new tab)
Check the utils module for better integration and reusability