Developing estimators
Based on: Developing scikit-learn estimators (opens in a new tab)
Or “How to safely interact with Pipelines and model selection”
Estimator API
Instantiation
-
__init__
accepts model constraints, but must not accept training data (reserved for fitting)✅ Do:
estimator = SVC(C=1)
❌ Don’t:
estimator = SVC([[0, 1], [1, 2]])
-
Model hyper parameters should have default values, so that the user can instantiate a model without passing any arguments
-
Every parameter should directly match an attribute, without additional logic to enable
model_selection
.✅ Do:
def __init__(self, p_1=1, p_2=2): self.p_1 = p_1 self.p_2 = p_2
❌ Don’t
def __init__(self, p_1=1, p_2=2): if p_1 > 1: p_1 += 1 self.p_1 = p_1 self.p_3 = p_2
-
No parameter validation in
__init__
, only infit
Fitting
-
For supervised
estimator = estimator.fit(X, y)
or for unsupervised
estimator = estimator.fit(X)
-
kwargs
can be added, restricted to data dependent variableexemple: a precomputed matrix is data dependent, a tolerance criterion is not
-
The estimator holds no reference to X or y (exceptions for precomputed kernel where this data must be stored for use by the predict method)
-
Use utils
check_X_y
to ensure that X and y length are consistent -
Even if the fit method doesn’t imply a target
y
,y=None
must be set to enable pipelines -
The method must return
self
for better usability with chained operationsy_pred = SVC(C=1).fit(X_train, y_train).predict(X_test)
-
Fit must be idempotent, and any new call to fit overwrites the result of the previous call (exceptions when using
warm_start=True
strategy to speed-up next fit operations) -
Names of attributes created during fit must end with a trailing underscore:
param_
-
n_features_in_
keyword can be added to make input expectations explicit
Predictor
-
For supervised or unsupervised:
prediction = predictor.predict(X)
-
Classification can also offer to quantify a prediction (without applying thresholding)
prediction = predictor.predict_proba(X)
Transformer
-
For transforming or filtering data, in a supervised or unsupervised way
X_transformed = transformer.transform(X)
-
When fitting and transforming can be implementing together more efficiently
X_transformed = transformer.fit_transform(X)
Rolling your own estimator
Back bone
-
Test your estimator using
check_estimator
from sklearn.utils.estimator_checks import check_estimator from sklearn.svm import LinearSVC check_estimator(LinearSVC()) # passes
-
You can leverage the project template (opens in a new tab) to get started with all the estimator required methods.
-
You can also use inheritance from
BaseEstimator
,ClassifierMixin
orRegressorMixin
ot significantly reduce the amount of boilerplate code, including:-
get_params
: take no arguments, return fit parameters.use
deep=True
to return submodel parameters -
set_params
: overwrite fit parameters -
base.clone
compatibility is enable withget_params
-
_estimator_type
must be"classifier"
,"regressor"
or"clusterer"
This is automatic with
ClassifierMixin
,RegressorMixin
orClusterMixin
inheritance
-
Specific estimators
-
A classifier fit can accept y with integers or string values, with the following conversion:
self.classes_, y = np.unique(y, return_inverse=True)
-
A classifier predict method must return arrays containing class labels from
classes_
def predict(self, X): D = self.decision_function(X) return self.classes_[np.argmax(D, axis=1)]
-
In linear models, coefficient are stored in
coef_
and intercept inintercept_
-
The
sklearn.utils.multiclass
module contains useful functions for working with multiclass and multilabel problems. -
Also, check tags (opens in a new tab) to define your estimator capabilities
Coding guidelines
How new code should be written for inclusion in scikit-learn and make review easier.
Style
- Format and indentation follows PEP8 (opens in a new tab)
- Use underscores to separate words in non class names:
n_samples
rather thannsamples
. - Avoid multiple statements on one line. Prefer a line return after
if/for
- Use relative imports for references inside scikit-learn.
- Please don’t use
import *
in any case- code becomes hard to read
- no reference for static analysis tool to run
- Use the numpy docstring standard (opens in a new tab)
Check the utils module for better integration and reusability