Based on: Developing scikit-learn estimators
Or “How to safely interact with Pipelines and model selection”
Estimator API
Instantiation
__init__ accepts model constraints, but must not accept training data (reserved for fitting)
Do:
Python
Copy
estimator = SVC(C=1)
Don’t:
Python
Copy
estimator = SVC([[0, 1], [1, 2]])
Model hyper parameters should have default values, so that the user can instantiate a model without passing any arguments
Every parameter should directly match an attribute, without additional logic to enable model_selection.
Do:
Python
Copy
def __init__(self, p_1=1, p_2=2):
self.p_1 = p_1
self.p_2 = p_2
Don’t
Python
Copy
def __init__(self, p_1=1, p_2=2):
if p_1 > 1:
p_1 += 1
self.p_1 = p_1
self.p_3 = p_2
No parameter validation in __init__, only in fit
Fitting
For supervised
Python
Copy
estimator = estimator.fit(X, y)
or for unsupervised
Python
Copy
estimator = estimator.fit(X)
kwargs can be added, restricted to data dependent variable
exemple: a precomputed matrix is data dependent, a tolerance criterion is not
The estimator holds no reference to X or y (exceptions for precomputed kernel where this data must be stored for use by the predict method)
Use utils check_X_y to ensure that X and y length are consistent
Even if the fit method doesn’t imply a target y, y=None must be set to enable pipelines
The method must return self for better usability with chained operations
Python
Copy
y_pred = SVC(C=1).fit(X_train, y_train).predict(X_test)
Fit must be idempotent, and any new call to fit overwrites the result of the previous call (exceptions when using warm_start=True strategy to speed-up next fit operations)
Names of attributes created during fit must end with a trailing underscore: param_
n_features_in_ keyword can be added to make input expectations explicit
Predictor
For supervised or unsupervised:
Python
Copy
prediction = predictor.predict(X)
Classification can also offer to quantify a prediction (without applying thresholding)
Python
Copy
prediction = predictor.predict_proba(X)
Transformer
For transforming or filtering data, in a supervised or unsupervised way
Python
Copy
X_transformed = transformer.transform(X)
When fitting and transforming can be implementing together more efficiently
Python
Copy
X_transformed = transformer.fit_transform(X)
Rolling your own estimator
Back bone
Test your estimator using check_estimator
Python
Copy
from sklearn.utils.estimator_checks import check_estimator
from sklearn.svm import LinearSVC
check_estimator(LinearSVC()) # passes
You can leverage the project template to get started with all the estimator required methods.
You can also use inheritance from BaseEstimator, ClassifierMixin or RegressorMixin ot significantly reduce the amount of boilerplate code, including:
get_params: take no arguments, return fit parameters.
use deep=True to return submodel parameters
set_params: overwrite fit parameters
base.clone compatibility is enable with get_params
_estimator_type must be "classifier" , "regressor" or "clusterer"
This is automatic with ClassifierMixin, RegressorMixin or ClusterMixin inheritance
Specific estimators
A classifier fit can accept y with integers or string values, with the following conversion:
Python
Copy
self.classes_, y = np.unique(y, return_inverse=True)
A classifier predict method must return arrays containing class labels from classes_
Python
Copy
def predict(self, X):
D = self.decision_function(X)
return self.classes_[np.argmax(D, axis=1)]
In linear models, coefficient are stored in coef_ and intercept in intercept_
The sklearn.utils.multiclass module contains useful functions for working with multiclass and multilabel problems.
Also, check tags to define your estimator capabilities
Coding guidelines
How new code should be written for inclusion in scikit-learn and make review easier.
Style
Format and indentation follows PEP8
Use underscores to separate words in non class names: n_samples rather than nsamples.
Avoid multiple statements on one line. Prefer a line return after if/for
Use relative imports for references inside scikit-learn.
Please don’t use import * in any case
code becomes hard to read
no reference for static analysis tool to run
Use the numpy docstring standard
Check the utils module for better integration and reusability