class ray.train.sklearn.SklearnCheckpoint(local_path: Optional[Union[str, os.PathLike]] = None, data_dict: Optional[dict] = None, uri: Optional[str] = None)[source]#

Bases: ray.air.checkpoint.Checkpoint

A Checkpoint with sklearn-specific functionality.

Create this from a generic Checkpoint by calling SklearnCheckpoint.from_checkpoint(ckpt)

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

classmethod from_estimator(estimator: sklearn.base.BaseEstimator, *, path: os.PathLike, preprocessor: Optional[Preprocessor] = None) SklearnCheckpoint[source]#

Create a Checkpoint that stores an sklearn Estimator.

  • estimator – The Estimator to store in the checkpoint.

  • path – The directory where the checkpoint will be stored.

  • preprocessor – A fitted preprocessor to be applied before inference.


An SklearnCheckpoint containing the specified Estimator.


>>> from ray.train.sklearn import SklearnCheckpoint
>>> from sklearn.ensemble import RandomForestClassifier
>>> estimator = RandomForestClassifier()
>>> checkpoint = SklearnCheckpoint.from_estimator(estimator, path=".")

You can use a SklearnCheckpoint to create an SklearnPredictor and preform inference.

>>> from ray.train.sklearn import SklearnPredictor
>>> predictor = SklearnPredictor.from_checkpoint(checkpoint)
get_estimator() sklearn.base.BaseEstimator[source]#

Retrieve the Estimator stored in this checkpoint.