Distributed XGBoost on Ray¶
XGBoost-Ray is a distributed backend for XGBoost, built on top of distributed computing framework Ray.
XGBoost-Ray
enables multi-node and multi-GPU training
integrates seamlessly with distributed hyperparameter optimization library Ray Tune
comes with advanced fault tolerance handling mechanisms, and
supports distributed dataframes and distributed data loading
All releases are tested on large clusters and workloads.
Installation¶
You can install the latest XGBoost-Ray release from PIP:
pip install "xgboost_ray"
If you’d like to install the latest master, use this command instead:
pip install "git+https://github.com/ray-project/xgboost_ray.git#egg=xgboost_ray"
Usage¶
XGBoost-Ray provides a drop-in replacement for XGBoost’s train
function. To pass data, instead of using xgb.DMatrix
you will
have to use xgboost_ray.RayDMatrix
. You can also use a scikit-learn
interface - see next section.
Just as in original xgb.train()
function, the
training parameters
are passed as the params
dictionary.
Ray-specific distributed training parameters are configured with a
xgboost_ray.RayParams
object. For instance, you can set
the num_actors
property to specify how many distributed actors
you would like to use.
Here is a simplified example (which requires sklearn
):
Training:
from xgboost_ray import RayDMatrix, RayParams, train
from sklearn.datasets import load_breast_cancer
train_x, train_y = load_breast_cancer(return_X_y=True)
train_set = RayDMatrix(train_x, train_y)
evals_result = {}
bst = train(
{
"objective": "binary:logistic",
"eval_metric": ["logloss", "error"],
},
train_set,
evals_result=evals_result,
evals=[(train_set, "train")],
verbose_eval=False,
ray_params=RayParams(
num_actors=2, # Number of remote actors
cpus_per_actor=1))
bst.save_model("model.xgb")
print("Final training error: {:.4f}".format(
evals_result["train"]["error"][-1]))
Prediction:
from xgboost_ray import RayDMatrix, RayParams, predict
from sklearn.datasets import load_breast_cancer
import xgboost as xgb
data, labels = load_breast_cancer(return_X_y=True)
dpred = RayDMatrix(data, labels)
bst = xgb.Booster(model_file="model.xgb")
pred_ray = predict(bst, dpred, ray_params=RayParams(num_actors=2))
print(pred_ray)
scikit-learn API¶
XGBoost-Ray also features a scikit-learn API fully mirroring pure XGBoost scikit-learn API, providing a completely drop-in replacement. The following estimators are available:
RayXGBClassifier
RayXGRegressor
RayXGBRFClassifier
RayXGBRFRegressor
RayXGBRanker
Example usage of RayXGBClassifier
:
from xgboost_ray import RayXGBClassifier, RayParams
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
seed = 42
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
X, y, train_size=0.25, random_state=42
)
clf = RayXGBClassifier(
n_jobs=4, # In XGBoost-Ray, n_jobs sets the number of actors
random_state=seed
)
# scikit-learn API will automatically convert the data
# to RayDMatrix format as needed.
# You can also pass X as a RayDMatrix, in which case
# y will be ignored.
clf.fit(X_train, y_train)
pred_ray = clf.predict(X_test)
print(pred_ray)
pred_proba_ray = clf.predict_proba(X_test)
print(pred_proba_ray)
# It is also possible to pass a RayParams object
# to fit/predict/predict_proba methods - will override
# n_jobs set during initialization
clf.fit(X_train, y_train, ray_params=RayParams(num_actors=2))
pred_ray = clf.predict(X_test, ray_params=RayParams(num_actors=2))
print(pred_ray)
Things to keep in mind:
n_jobs
parameter controls the number of actors spawned. You can pass aRayParams
object to thefit
/predict
/predict_proba
methods as theray_params
argument for greater control over resource allocation. Doing so will override the value ofn_jobs
with the value ofray_params.num_actors
attribute. For more information, refer to the Resources section below.By default
n_jobs
is set to1
, which means the training will not be distributed. Make sure to either setn_jobs
to a higher value or pass aRayParams
object as outlined above in order to take advantage of XGBoost-Ray’s functionality.After calling
fit
, additional evaluation results (e.g. training time, number of rows, callback results) will be available underadditional_results_
attribute.XGBoost-Ray’s scikit-learn API is based on XGBoost 1.4. While we try to support older XGBoost versions, please note that this library is only fully tested and supported for XGBoost >= 1.4.
For more information on the scikit-learn API, refer to the XGBoost documentation.
Data loading¶
Data is passed to XGBoost-Ray via a RayDMatrix
object.
The RayDMatrix
lazy loads data and stores it sharded in the
Ray object store. The Ray XGBoost actors then access these
shards to run their training on.
A RayDMatrix
support various data and file types, like
Pandas DataFrames, Numpy Arrays, CSV files and Parquet files.
Example loading multiple parquet files:
import glob
from xgboost_ray import RayDMatrix, RayFileType
# We can also pass a list of files
path = list(sorted(glob.glob("/data/nyc-taxi/*/*/*.parquet")))
# This argument will be passed to `pd.read_parquet()`
columns = [
"passenger_count",
"trip_distance", "pickup_longitude", "pickup_latitude",
"dropoff_longitude", "dropoff_latitude",
"fare_amount", "extra", "mta_tax", "tip_amount",
"tolls_amount", "total_amount"
]
dtrain = RayDMatrix(
path,
label="passenger_count", # Will select this column as the label
columns=columns,
# ignore=["total_amount"], # Optional list of columns to ignore
filetype=RayFileType.PARQUET)
Hyperparameter Tuning¶
XGBoost-Ray integrates with Ray Tune to provide distributed hyperparameter tuning for your
distributed XGBoost models. You can run multiple XGBoost-Ray training runs in parallel, each with a different
hyperparameter configuration, and each training run parallelized by itself. All you have to do is move your training
code to a function, and pass the function to tune.run
. Internally, train
will detect if tune
is being used and will
automatically report results to tune.
Example using XGBoost-Ray with Ray Tune:
from xgboost_ray import RayDMatrix, RayParams, train
from sklearn.datasets import load_breast_cancer
num_actors = 4
num_cpus_per_actor = 1
ray_params = RayParams(
num_actors=num_actors,
cpus_per_actor=num_cpus_per_actor)
def train_model(config):
train_x, train_y = load_breast_cancer(return_X_y=True)
train_set = RayDMatrix(train_x, train_y)
evals_result = {}
bst = train(
params=config,
dtrain=train_set,
evals_result=evals_result,
evals=[(train_set, "train")],
verbose_eval=False,
ray_params=ray_params)
bst.save_model("model.xgb")
from ray import tune
# Specify the hyperparameter search space.
config = {
"tree_method": "approx",
"objective": "binary:logistic",
"eval_metric": ["logloss", "error"],
"eta": tune.loguniform(1e-4, 1e-1),
"subsample": tune.uniform(0.5, 1.0),
"max_depth": tune.randint(1, 9)
}
# Make sure to use the `get_tune_resources` method to set the `resources_per_trial`
analysis = tune.run(
train_model,
config=config,
metric="train-error",
mode="min",
num_samples=4,
resources_per_trial=ray_params.get_tune_resources())
print("Best hyperparameters", analysis.best_config)
Also see examples/simple_tune.py for another example.
Fault tolerance¶
XGBoost-Ray leverages the stateful Ray actor model to enable fault tolerant training. There are currently two modes implemented.
Non-elastic training (warm restart)¶
When an actor or node dies, XGBoost-Ray will retain the state of the remaining actors. In non-elastic training, the failed actors will be replaced as soon as resources are available again. Only these actors will reload their parts of the data. Training will resume once all actors are ready for training again.
You can set this mode in the RayParams
:
from xgboost_ray import RayParams
ray_params = RayParams(
elastic_training=False, # Use non-elastic training
max_actor_restarts=2, # How often are actors allowed to fail
)
Elastic training¶
In elastic training, XGBoost-Ray will continue training with fewer actors (and on fewer data) when a node or actor dies. The missing actors are staged in the background, and are reintegrated into training once they are back and loaded their data.
This mode will train on fewer data for a period of time, which can impact accuracy. In practice, we found these effects to be minor, especially for large shuffled datasets. The immediate benefit is that training time is reduced significantly to almost the same level as if no actors died. Thus, especially when data loading takes a large part of the total training time, this setting can dramatically speed up training times for large distributed jobs.
You can configure this mode in the RayParams
:
from xgboost_ray import RayParams
ray_params = RayParams(
elastic_training=True, # Use elastic training
max_failed_actors=3, # Only allow at most 3 actors to die at the same time
max_actor_restarts=2, # How often are actors allowed to fail
)
Resources¶
By default, XGBoost-Ray tries to determine the number of CPUs available and distributes them evenly across actors.
In the case of very large clusters or clusters with many different
machine sizes, it makes sense to limit the number of CPUs per actor
by setting the cpus_per_actor
argument. Consider always
setting this explicitly.
The number of XGBoost actors always has to be set manually with
the num_actors
argument.
Multi GPU training¶
XGBoost-Ray enables multi GPU training. The XGBoost core backend
will automatically leverage NCCL2 for cross-device communication.
All you have to do is to start one actor per GPU and set XGBoost’s
tree_method
to a GPU-compatible option, eg. gpu_hist
(see XGBoost
documentation for more details.)
For instance, if you have 2 machines with 4 GPUs each, you will want
to start 8 remote actors, and set gpus_per_actor=1
. There is usually
no benefit in allocating less (e.g. 0.5) or more than one GPU per actor.
You should divide the CPUs evenly across actors per machine, so if your machines have 16 CPUs in addition to the 4 GPUs, each actor should have 4 CPUs to use.
from xgboost_ray import RayParams
ray_params = RayParams(
num_actors=8,
gpus_per_actor=1,
cpus_per_actor=4, # Divide evenly across actors per machine
)
How many remote actors should I use?¶
This depends on your workload and your cluster setup. Generally there is no inherent benefit of running more than one remote actor per node for CPU-only training. This is because XGBoost core can already leverage multiple CPUs via threading.
However, there are some cases when you should consider starting more than one actor per node:
For multi GPU training, each GPU should have a separate remote actor. Thus, if your machine has 24 CPUs and 4 GPUs, you will want to start 4 remote actors with 6 CPUs and 1 GPU each
In a heterogeneous cluster, you might want to find the greatest common divisor for the number of CPUs. E.g. for a cluster with three nodes of 4, 8, and 12 CPUs, respectively, you should set the number of actors to 6 and the CPUs per actor to 4.
Distributed data loading¶
XGBoost-Ray can leverage both centralized and distributed data loading.
In centralized data loading, the data is partitioned by the head node and stored in the object store. Each remote actor then retrieves their partitions by querying the Ray object store. Centralized loading is used when you pass centralized in-memory dataframes, such as Pandas dataframes or Numpy arrays, or when you pass a single source file, such as a single CSV or Parquet file.
from xgboost_ray import RayDMatrix
# This will use centralized data loading, as only one source file is specified
# `label_col` is a column in the CSV, used as the target label
ray_params = RayDMatrix("./source_file.csv", label="label_col")
In distributed data loading, each remote actor loads their data directly from the source (e.g. local hard disk, NFS, HDFS, S3), without a central bottleneck. The data is still stored in the object store, but locally to each actor. This mode is used automatically when loading data from multiple CSV or Parquet files. Please note that we do not check or enforce partition sizes in this case - it is your job to make sure the data is evenly distributed across the source files.
from xgboost_ray import RayDMatrix
# This will use distributed data loading, as four source files are specified
# Please note that you cannot schedule more than four actors in this case.
# `label_col` is a column in the Parquet files, used as the target label
ray_params = RayDMatrix([
"hdfs:///tmp/part1.parquet",
"hdfs:///tmp/part2.parquet",
"hdfs:///tmp/part3.parquet",
"hdfs:///tmp/part4.parquet",
], label="label_col")
Lastly, XGBoost-Ray supports distributed dataframe representations, such as Ray Datasets, Modin and Dask dataframes (used with Dask on Ray). Here, XGBoost-Ray will check on which nodes the distributed partitions are currently located, and will assign partitions to actors in order to minimize cross-node data transfer. Please note that we also assume here that partition sizes are uniform.
from xgboost_ray import RayDMatrix
# This will try to allocate the existing Modin partitions
# to co-located Ray actors. If this is not possible, data will
# be transferred across nodes
ray_params = RayDMatrix(existing_modin_df)
Data sources¶
The following data sources can be used with a RayDMatrix
object.
Type |
Centralized loading |
Distributed loading |
---|---|---|
Numpy array |
Yes |
No |
Pandas dataframe |
Yes |
No |
Single CSV |
Yes |
No |
Multi CSV |
Yes |
Yes |
Single Parquet |
Yes |
No |
Multi Parquet |
Yes |
Yes |
Yes |
Yes |
|
Yes |
Yes |
|
Yes |
Yes |
|
Yes |
Yes |
Memory usage¶
XGBoost uses a compute-optimized datastructure, the DMatrix
,
to hold training data. When converting a dataset to a DMatrix
,
XGBoost creates intermediate copies and ends up
holding a complete copy of the full data. The data will be converted
into the local dataformat (on a 64 bit system these are 64 bit floats.)
Depending on the system and original dataset dtype, this matrix can
thus occupy more memory than the original dataset.
The peak memory usage for CPU-based training is at least
3x the dataset size (assuming dtype float32
on a 64bit system)
plus about 400,000 KiB for other resources,
like operating system requirements and storing of intermediate
results.
Example
Machine type: AWS m5.xlarge (4 vCPUs, 16 GiB RAM)
Usable RAM: ~15,350,000 KiB
Dataset: 1,250,000 rows with 1024 features, dtype float32. Total size: 5,000,000 KiB
XGBoost DMatrix size: ~10,000,000 KiB
This dataset will fit exactly on this node for training.
Note that the DMatrix size might be lower on a 32 bit system.
GPUs
Generally, the same memory requirements exist for GPU-based training. Additionally, the GPU must have enough memory to hold the dataset.
In the example above, the GPU must have at least
10,000,000 KiB (about 9.6 GiB) memory. However,
empirically we found that using a DeviceQuantileDMatrix
seems to show more peak GPU memory usage, possibly
for intermediate storage when loading data (about 10%).
Best practices
In order to reduce peak memory usage, consider the following suggestions:
Store data as
float32
or less. More precision is often not needed, and keeping data in a smaller format will help reduce peak memory usage for initial data loading.Pass the
dtype
when loading data from CSV. Otherwise, floating point values will be loaded asnp.float64
per default, increasing peak memory usage by 33%.
Placement Strategies¶
XGBoost-Ray leverages Ray’s Placement Group API (https://docs.ray.io/en/master/placement-group.html) to implement placement strategies for better fault tolerance.
By default, a SPREAD strategy is used for training, which attempts to spread all of the training workers
across the nodes in a cluster on a best-effort basis. This improves fault tolerance since it minimizes the
number of worker failures when a node goes down, but comes at a cost of increased inter-node communication
To disable this strategy, set the RXGB_USE_SPREAD_STRATEGY
environment variable to 0. If disabled, no
particular placement strategy will be used.
Note that this strategy is used only when elastic_training
is not used. If elastic_training
is set to True
,
no placement strategy is used.
When XGBoost-Ray is used with Ray Tune for hyperparameter tuning, a PACK strategy is used. This strategy attempts to place all workers for each trial on the same node on a best-effort basis. This means that if a node goes down, it will be less likely to impact multiple trials.
When placement strategies are used, XGBoost-Ray will wait for 100 seconds for the required resources
to become available, and will fail if the required resources cannot be reserved and the cluster cannot autoscale
to increase the number of resources. You can change the RXGB_PLACEMENT_GROUP_TIMEOUT_S
environment variable to modify
how long this timeout should be.
More examples¶
For complete end to end examples, please have a look at the examples folder:
Simple sklearn breastcancer dataset example (requires
sklearn
)HIGGS classification example with Parquet (uses the same dataset)
Test data classification (uses a self-generated dataset)
API reference¶
- class xgboost_ray.RayParams(num_actors: int = 0, cpus_per_actor: int = 0, gpus_per_actor: int = - 1, resources_per_actor: Optional[Dict] = None, elastic_training: bool = False, max_failed_actors: int = 0, max_actor_restarts: int = 0, checkpoint_frequency: int = 5, distributed_callbacks: Optional[List[xgboost_ray.callback.DistributedCallback]] = None)[source]¶
Parameters to configure Ray-specific behavior.
- Parameters
num_actors (int) – Number of parallel Ray actors.
cpus_per_actor (int) – Number of CPUs to be used per Ray actor.
gpus_per_actor (int) – Number of GPUs to be used per Ray actor.
resources_per_actor (Optional[Dict]) – Dict of additional resources required per Ray actor.
elastic_training (bool) – If True, training will continue with fewer actors if an actor fails. Default False.
max_failed_actors (int) – If elastic_training is True, this specifies the maximum number of failed actors with which we still continue training.
max_actor_restarts (int) – Number of retries when Ray actors fail. Defaults to 0 (no retries). Set to -1 for unlimited retries.
checkpoint_frequency (int) – How often to save checkpoints. Defaults to
5
(every 5th iteration).
PublicAPI (beta): This API is in beta and may change before becoming stable.
- class xgboost_ray.RayDMatrix(data: Union[str, List[str], numpy.ndarray, pandas.core.frame.DataFrame, pandas.core.series.Series], label: Optional[Union[str, List[str], numpy.ndarray, pandas.core.frame.DataFrame, pandas.core.series.Series]] = None, weight: Optional[Union[str, List[str], numpy.ndarray, pandas.core.frame.DataFrame, pandas.core.series.Series]] = None, base_margin: Optional[Union[str, List[str], numpy.ndarray, pandas.core.frame.DataFrame, pandas.core.series.Series]] = None, missing: Optional[float] = None, label_lower_bound: Optional[Union[str, List[str], numpy.ndarray, pandas.core.frame.DataFrame, pandas.core.series.Series]] = None, label_upper_bound: Optional[Union[str, List[str], numpy.ndarray, pandas.core.frame.DataFrame, pandas.core.series.Series]] = None, feature_names: Optional[List[str]] = None, feature_types: Optional[List[numpy.dtype]] = None, qid: Optional[Union[str, List[str], numpy.ndarray, pandas.core.frame.DataFrame, pandas.core.series.Series]] = None, num_actors: Optional[int] = None, filetype: Optional[xgboost_ray.data_sources.data_source.RayFileType] = None, ignore: Optional[List[str]] = None, distributed: Optional[bool] = None, sharding: xgboost_ray.matrix.RayShardingMode = RayShardingMode.INTERLEAVED, lazy: bool = False, **kwargs)[source]¶
XGBoost on Ray DMatrix class.
This is the data object that the training and prediction functions expect. This wrapper manages distributed data by sharding the data for the workers and storing the shards in the object store.
If this class is called without the
num_actors
argument, it will be lazy loaded. Thus, it will return immediately and only load the data and store it in the Ray object store afterload_data(num_actors)
orget_data(rank, num_actors)
is called.If this class is instantiated with the
num_actors
argument, it will directly load the data and store them in the object store. If this should be deferred, passlazy=True
as an argument.Loading the data will store it in the Ray object store. This object then stores references to the data shards in the Ray object store. Actors can request these shards with the
get_data(rank)
method, returning dataframes according to the actor rank.The total number of actors has to remain constant and cannot be changed once it has been set.
- Parameters
data – Data object. Can be a pandas dataframe, pandas series, numpy array, modin dataframe, string pointing to a csv or parquet file, or list of strings pointing to csv or parquet files.
label – Optional label object. Can be a pandas series, numpy array, modin series, string pointing to a csv or parquet file, or a string indicating the column of the data dataframe that contains the label. If this is not a string it must be of the same type as the data argument.
num_actors – Number of actors to shard this data for. If this is not None, data will be loaded and stored into the object store after initialization. If this is None, it will be set by the
xgboost_ray.train()
function, and it will be loaded and stored in the object store then. Defaults to None (filetype (Optional[RayFileType]) – Type of data to read. This is disregarded if a data object like a pandas dataframe is passed as the
data
argument. For filenames, the filetype is automaticlly detected via the file name (e.g..csv
will be detected asRayFileType.CSV
). Passing this argument will overwrite the detected filename. If the filename cannot be determined from thedata
object, passing this is mandatory. Defaults toNone
(auto detection).ignore (Optional[List[str]]) – Exclude these columns from the dataframe after loading the data.
distributed (Optional[bool]) – If True, use distributed loading (each worker loads a share of the dataset). If False, use central loading (the head node loads the whole dataset and distributed it). If None, auto-detect and default to distributed loading, if possible.
sharding (RayShardingMode) – How to shard the data for different workers.
RayShardingMode.INTERLEAVED
will divide the data per row, i.e. every i-th row will be passed to the first worker, every (i+1)th row to the second worker, etc.RayShardingMode.BATCH
will divide the data in batches, i.e. the first 0-(m-1) rows will be passed to the first worker, the m-(2m-1) rows to the second worker, etc. Defaults toRayShardingMode.INTERLEAVED
. If using distributed data loading, sharding happens on a per-file basis, and not on a per-row basis, i.e. For interleaved every ith file will be passed into the first worker, etc.lazy (bool) – If
num_actors
is passed, setting this toTrue
will defer data loading and storing untilload_data()
orget_data()
is called. Defaults toFalse
.**kwargs – Keyword arguments will be passed to the data loading function. For instance, with
RayFileType.PARQUET
, these arguments will be passed topandas.read_parquet()
.
from xgboost_ray import RayDMatrix, RayFileType files = ["data_one.parquet", "data_two.parquet"] columns = ["feature_1", "feature_2", "label_column"] dtrain = RayDMatrix( files, num_actors=4, # Will shard the data for four workers label="label_column", # Will select this column as the label columns=columns, # Will be passed to `pandas.read_parquet()` filetype=RayFileType.PARQUET)
PublicAPI (beta): This API is in beta and may change before becoming stable.
- load_data(num_actors: Optional[int] = None, rank: Optional[int] = None)[source]¶
Load data, putting it into the Ray object store.
If a rank is given, only data for this rank is loaded (for distributed data sources only).
- get_data(rank: int, num_actors: Optional[int] = None) Dict[str, Union[None, pandas.core.frame.DataFrame, List[Optional[pandas.core.frame.DataFrame]]]] [source]¶
Get data, i.e. return dataframe for a specific actor.
This method is called from an actor, given its rank and the total number of actors. If the data is not yet loaded, loading is triggered.
- xgboost_ray.train(params: Dict, dtrain: xgboost_ray.matrix.RayDMatrix, num_boost_round: int = 10, *args, evals: Union[List[Tuple[xgboost_ray.matrix.RayDMatrix, str]], Tuple[xgboost_ray.matrix.RayDMatrix, str]] = (), evals_result: Optional[Dict] = None, additional_results: Optional[Dict] = None, ray_params: Union[None, xgboost_ray.main.RayParams, Dict] = None, _remote: Optional[bool] = None, **kwargs) xgboost.core.Booster [source]¶
Distributed XGBoost training via Ray.
This function will connect to a Ray cluster, create
num_actors
remote actors, send data shards to them, and have them train an XGBoost classifier. The XGBoost parameters will be shared and combined via Rabit’s all-reduce protocol.If running inside a Ray Tune session, this function will automatically handle results to tune for hyperparameter search.
Failure handling:
XGBoost on Ray supports automatic failure handling that can be configured with the
ray_params
argument. If an actor or local training task dies, the Ray actor is marked as dead, and there are three options on how to proceed.First, if
ray_params.elastic_training
isTrue
and the number of dead actors is belowray_params.max_failed_actors
, training will continue right away with fewer actors. No data will be loaded again and the latest available checkpoint will be used. A maximum ofray_params.max_actor_restarts
restarts will be tried before exiting.Second, if
ray_params.elastic_training
isFalse
and the number of restarts is belowray_params.max_actor_restarts
, Ray will try to schedule the dead actor again, load the data shard on this actor, and then continue training from the latest checkpoint.Third, if none of the above is the case, training is aborted.
- Parameters
params (Dict) – parameter dict passed to
xgboost.train()
dtrain (RayDMatrix) – Data object containing the training data.
evals (Union[List[Tuple[RayDMatrix, str]], Tuple[RayDMatrix, str]]) –
evals
tuple passed toxgboost.train()
.evals_result (Optional[Dict]) – Dict to store evaluation results in.
additional_results (Optional[Dict]) – Dict to store additional results.
ray_params (Union[None, xgboost_ray.RayParams, Dict]) – Parameters to configure Ray-specific behavior. See
xgboost_ray.RayParams
for a list of valid configuration parameters._remote (bool) – Whether to run the driver process in a remote function. This is enabled by default in Ray client mode.
**kwargs – Keyword arguments will be passed to the local xgb.train() calls.
Returns: An
xgboost.Booster
object.PublicAPI (beta): This API is in beta and may change before becoming stable.
- xgboost_ray.predict(model: xgboost.core.Booster, data: xgboost_ray.matrix.RayDMatrix, ray_params: Union[None, xgboost_ray.main.RayParams, Dict] = None, _remote: Optional[bool] = None, **kwargs) Optional[numpy.ndarray] [source]¶
Distributed XGBoost predict via Ray.
This function will connect to a Ray cluster, create
num_actors
remote actors, send data shards to them, and have them predict labels using an XGBoost booster model. The results are then combined and returned.- Parameters
model (xgb.Booster) – Booster object to call for prediction.
data (RayDMatrix) – Data object containing the prediction data.
ray_params (Union[None, xgboost_ray.RayParams, Dict]) – Parameters to configure Ray-specific behavior. See
xgboost_ray.RayParams
for a list of valid configuration parameters._remote (bool) – Whether to run the driver process in a remote function. This is enabled by default in Ray client mode.
**kwargs – Keyword arguments will be passed to the local xgb.predict() calls.
Returns:
np.ndarray
containing the predicted labels.PublicAPI (beta): This API is in beta and may change before becoming stable.
scikit-learn API¶
- class xgboost_ray.RayXGBClassifier(*, objective: Optional[Union[str, Callable[[numpy.ndarray, numpy.ndarray], Tuple[numpy.ndarray, numpy.ndarray]]]] = 'binary:logistic', use_label_encoder: bool = False, **kwargs: Any)[source]¶
Implementation of the scikit-learn API for Ray-distributed XGBoost classification.
- Parameters
n_estimators (int) – Number of boosting rounds.
max_depth (Optional[int]) – Maximum tree depth for base learners.
max_leaves – Maximum number of leaves; 0 indicates no limit.
max_bin – If using histogram-based algorithm, maximum number of bins per feature
grow_policy – Tree growing policy. 0: favor splitting at nodes closest to the node, i.e. grow depth-wise. 1: favor splitting at nodes with highest loss change.
learning_rate (Optional[float]) – Boosting learning rate (xgb’s “eta”)
verbosity (Optional[int]) – The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective (typing.Union[str, typing.Callable[[numpy.ndarray, numpy.ndarray], typing.Tuple[numpy.ndarray, numpy.ndarray]], NoneType]) – Specify the learning task and the corresponding learning objective or a custom objective function to be used (see note below).
booster (Optional[str]) – Specify which booster to use: gbtree, gblinear or dart.
tree_method (Optional[str]) – Specify which tree method to use. Default to auto. If this parameter is set to default, XGBoost will choose the most conservative option available. It’s recommended to study this option from the parameters document tree method
n_jobs (Optional[int]) – Number of parallel threads used to run xgboost. When used with other Scikit-Learn algorithms like grid search, you may choose which algorithm to parallelize and balance the threads. Creating thread contention will significantly slow down both algorithms.
gamma (Optional[float]) – (min_split_loss) Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight (Optional[float]) – Minimum sum of instance weight(hessian) needed in a child.
max_delta_step (Optional[float]) – Maximum delta step we allow each tree’s weight estimation to be.
subsample (Optional[float]) – Subsample ratio of the training instance.
sampling_method –
- Sampling method. Used only by gpu_hist tree method.
uniform: select random training instances uniformly.
gradient_based select random training instances with higher probability when the gradient and hessian are larger. (cf. CatBoost)
colsample_bytree (Optional[float]) – Subsample ratio of columns when constructing each tree.
colsample_bylevel (Optional[float]) – Subsample ratio of columns for each level.
colsample_bynode (Optional[float]) – Subsample ratio of columns for each split.
reg_alpha (Optional[float]) – L1 regularization term on weights (xgb’s alpha).
reg_lambda (Optional[float]) – L2 regularization term on weights (xgb’s lambda).
scale_pos_weight (Optional[float]) – Balancing of positive and negative weights.
base_score (Optional[float]) – The initial prediction score of all instances, global bias.
random_state (Optional[Union[numpy.random.RandomState, int]]) –
Random number seed.
Note
Using gblinear booster with shotgun updater is nondeterministic as it uses Hogwild algorithm.
missing (float, default np.nan) – Value in the data which needs to be present as a missing value.
num_parallel_tree (Optional[int]) – Used for boosting random forest.
monotone_constraints (Optional[Union[Dict[str, int], str]]) – Constraint of variable monotonicity. See tutorial for more information.
interaction_constraints (Optional[Union[str, List[Tuple[str]]]]) – Constraints for interaction representing permitted interactions. The constraints must be specified in the form of a nested list, e.g.
[[0, 1], [2, 3, 4]]
, where each inner list is a group of indices of features that are allowed to interact with each other. See tutorial for more informationimportance_type (Optional[str]) –
The feature importance type for the feature_importances_ property:
For tree model, it’s either “gain”, “weight”, “cover”, “total_gain” or “total_cover”.
For linear model, only “weight” is defined and it’s the normalized coefficients without bias.
gpu_id (Optional[int]) – Device ordinal.
validate_parameters (Optional[bool]) – Give warnings for unknown parameter.
predictor (Optional[str]) – Force XGBoost to use specific predictor, available choices are [cpu_predictor, gpu_predictor].
enable_categorical (bool) –
New in version 1.5.0.
Note
This parameter is experimental
Experimental support for categorical data. When enabled, cudf/pandas.DataFrame should be used to specify categorical data type. Also, JSON/UBJSON serialization format is required.
max_cat_to_onehot (Optional[int]) –
New in version 1.6.0.
Note
This parameter is experimental
A threshold for deciding whether XGBoost should use one-hot encoding based split for categorical data. When number of categories is lesser than the threshold then one-hot encoding is chosen, otherwise the categories will be partitioned into children nodes. Only relevant for regression and binary classification. See Categorical Data for details.
eval_metric (Optional[Union[str, List[str], Callable]]) –
New in version 1.6.0.
Metric used for monitoring the training result and early stopping. It can be a string or list of strings as names of predefined metric in XGBoost (See doc/parameter.rst), one of the metrics in
sklearn.metrics
, or any other user defined metric that looks like sklearn.metrics.If custom objective is also provided, then custom metric should implement the corresponding reverse link function.
Unlike the scoring parameter commonly used in scikit-learn, when a callable object is provided, it’s assumed to be a cost function and by default XGBoost will minimize the result during early stopping.
For advanced usage on Early stopping like directly choosing to maximize instead of minimize, see
xgboost.callback.EarlyStopping
.See Custom Objective and Evaluation Metric for more.
Note
This parameter replaces eval_metric in
fit()
method. The old one receives un-transformed prediction regardless of whether custom objective is being used.from sklearn.datasets import load_diabetes from sklearn.metrics import mean_absolute_error X, y = load_diabetes(return_X_y=True) reg = xgb.XGBRegressor( tree_method="hist", eval_metric=mean_absolute_error, ) reg.fit(X, y, eval_set=[(X, y)])
early_stopping_rounds (Optional[int]) –
New in version 1.6.0.
Activates early stopping. Validation metric needs to improve at least once in every early_stopping_rounds round(s) to continue training. Requires at least one item in eval_set in
fit()
.The method returns the model from the last iteration (not the best one). If there’s more than one item in eval_set, the last entry will be used for early stopping. If there’s more than one metric in eval_metric, the last metric will be used for early stopping.
If early stopping occurs, the model will have three additional fields:
best_score
,best_iteration
andbest_ntree_limit
.Note
This parameter replaces early_stopping_rounds in
fit()
method.callbacks (Optional[List[TrainingCallback]]) –
List of callback functions that are applied at end of each iteration. It is possible to use predefined callbacks by using Callback API.
Note
States in callback are not preserved during training, which means callback objects can not be reused for multiple training sessions without reinitialization or deepcopy.
for params in parameters_grid: # be sure to (re)initialize the callbacks before each run callbacks = [xgb.callback.LearningRateScheduler(custom_rates)] xgboost.train(params, Xy, callbacks=callbacks)
kwargs (dict, optional) –
Keyword arguments for XGBoost Booster object. Full documentation of parameters can be found here. Attempting to set a parameter via the constructor args and **kwargs dict simultaneously will result in a TypeError.
Note
**kwargs unsupported by scikit-learn
**kwargs is unsupported by scikit-learn. We do not guarantee that parameters passed via this argument will interact properly with scikit-learn.
Note
Custom objective function
A custom objective function can be provided for the
objective
parameter. In this case, it should have the signatureobjective(y_true, y_pred) -> grad, hess
:- y_true: array_like of shape [n_samples]
The target values
- y_pred: array_like of shape [n_samples]
The predicted values
- grad: array_like of shape [n_samples]
The value of the gradient for each sample point.
- hess: array_like of shape [n_samples]
The value of the second derivative for each sample point
- fit(X, y, *, sample_weight=None, base_margin=None, eval_set=None, eval_metric=None, early_stopping_rounds=None, verbose=True, xgb_model=None, sample_weight_eval_set=None, base_margin_eval_set=None, feature_weights=None, callbacks=None, ray_params: Union[None, xgboost_ray.main.RayParams, Dict] = None, _remote: Optional[bool] = None, ray_dmatrix_params: Optional[Dict] = None)[source]¶
Fit gradient boosting classifier.
Note that calling
fit()
multiple times will cause the model object to be re-fit from scratch. To resume training from a previous checkpoint, explicitly passxgb_model
argument.- Parameters
X – Feature matrix. Can also be a
RayDMatrix
.y – Labels
sample_weight – instance weights
base_margin – global bias for each instance.
eval_set – A list of (X, y) tuple pairs to use as validation sets, for which metrics will be computed. Validation metrics will help us track the performance of the model.
eval_metric (str, list of str, or callable, optional) –
Deprecated since version 1.6.0: Use eval_metric in
__init__()
orset_params()
instead.early_stopping_rounds (int) –
Deprecated since version 1.6.0: Use early_stopping_rounds in
__init__()
orset_params()
instead.verbose – If verbose and an evaluation set is used, writes the evaluation metric measured on the validation set to stderr.
xgb_model – file name of stored XGBoost model or ‘Booster’ instance XGBoost model to be loaded before training (allows training continuation).
sample_weight_eval_set – A list of the form [L_1, L_2, …, L_n], where each L_i is an array like object storing instance weights for the i-th validation set.
base_margin_eval_set – A list of the form [M_1, M_2, …, M_n], where each M_i is an array like object storing base margin for the i-th validation set.
feature_weights – Weight for each feature, defines the probability of each feature being selected when colsample is being used. All values must be greater than 0, otherwise a ValueError is thrown.
callbacks –
Deprecated since version 1.6.0: Use callbacks in
__init__()
orset_params()
instead.ray_params (None or xgboost_ray.RayParams or Dict) – Parameters to configure Ray-specific behavior. See
xgboost_ray.RayParams
for a list of valid configuration parameters. Will overriden_jobs
attribute with ownnum_actors
parameter._remote (bool) – Whether to run the driver process in a remote function. This is enabled by default in Ray client mode.
ray_dmatrix_params (dict) – Dict of parameters (such as sharding mode) passed to the internal RayDMatrix initialization.
- predict(X, output_margin=False, ntree_limit=None, validate_features=True, base_margin=None, iteration_range: Optional[Tuple[int, int]] = None, ray_params: Union[None, xgboost_ray.main.RayParams, Dict] = None, _remote: Optional[bool] = None, ray_dmatrix_params: Optional[Dict] = None)[source]¶
Predict with X. If the model is trained with early stopping, then best_iteration is used automatically. For tree models, when data is on GPU, like cupy array or cuDF dataframe and predictor is not specified, the prediction is run on GPU automatically, otherwise it will run on CPU.
Note
This function is only thread safe for gbtree and dart.
- Parameters
X – Data to predict with. Can also be a
RayDMatrix
.output_margin – Whether to output the raw untransformed margin value.
ntree_limit – Deprecated, use iteration_range instead.
validate_features – When this is True, validate that the Booster’s and data’s feature_names are identical. Otherwise, it is assumed that the feature_names are the same.
base_margin – Margin added to prediction.
iteration_range –
Specifies which layer of trees are used in prediction. For example, if a random forest is trained with 100 rounds. Specifying
iteration_range=(10, 20)
, then only the forests built during [10, 20) (half open set) rounds are used in this prediction.New in version 1.4.0.
- Returns
prediction
ray_params (None or xgboost_ray.RayParams or Dict) – Parameters to configure Ray-specific behavior. See
xgboost_ray.RayParams
for a list of valid configuration parameters. Will overriden_jobs
attribute with ownnum_actors
parameter._remote (bool) – Whether to run the driver process in a remote function. This is enabled by default in Ray client mode.
ray_dmatrix_params (dict) – Dict of parameters (such as sharding mode) passed to the internal RayDMatrix initialization.
- predict_proba(X, ntree_limit=None, validate_features=False, base_margin=None, iteration_range: Optional[Tuple[int, int]] = None, ray_params: Union[None, xgboost_ray.main.RayParams, Dict] = None, _remote: Optional[bool] = None, ray_dmatrix_params: Optional[Dict] = None) numpy.ndarray [source]¶
Predict the probability of each X example being of a given class.
Note
This function is only thread safe for gbtree and dart.
- Parameters
X (array_like) – Feature matrix. Can also be a
RayDMatrix
.. Can also be aRayDMatrix
.ntree_limit (int) – Deprecated, use iteration_range instead.
validate_features (bool) – When this is True, validate that the Booster’s and data’s feature_names are identical. Otherwise, it is assumed that the feature_names are the same.
base_margin (array_like) – Margin added to prediction.
iteration_range – Specifies which layer of trees are used in prediction. For example, if a random forest is trained with 100 rounds. Specifying iteration_range=(10, 20), then only the forests built during [10, 20) (half open set) rounds are used in this prediction.
- Returns
prediction – a numpy array of shape array-like of shape (n_samples, n_classes) with the probability of each data example being of a given class.
ray_params (None or xgboost_ray.RayParams or Dict) – Parameters to configure Ray-specific behavior. See
xgboost_ray.RayParams
for a list of valid configuration parameters. Will overriden_jobs
attribute with ownnum_actors
parameter._remote (bool) – Whether to run the driver process in a remote function. This is enabled by default in Ray client mode.
ray_dmatrix_params (dict) – Dict of parameters (such as sharding mode) passed to the internal RayDMatrix initialization.
- load_model(fname)[source]¶
Load the model from a file or bytearray. Path to file can be local or as an URI.
The model is loaded from XGBoost format which is universal among the various XGBoost interfaces. Auxiliary attributes of the Python Booster object (such as feature_names) will not be loaded when using binary format. To save those attributes, use JSON/UBJ instead. See Model IO for more info.
model.load_model("model.json") # or model.load_model("model.ubj")
- Parameters
fname – Input file name or memory buffer(see also save_raw)
- class xgboost_ray.RayXGBRegressor(*, objective: Optional[Union[str, Callable[[numpy.ndarray, numpy.ndarray], Tuple[numpy.ndarray, numpy.ndarray]]]] = 'reg:squarederror', **kwargs: Any)[source]¶
Implementation of the scikit-learn API for Ray-distributed XGBoost regression.
- Parameters
n_estimators (int) – Number of gradient boosted trees. Equivalent to number of boosting rounds.
max_depth (Optional[int]) – Maximum tree depth for base learners.
max_leaves – Maximum number of leaves; 0 indicates no limit.
max_bin – If using histogram-based algorithm, maximum number of bins per feature
grow_policy – Tree growing policy. 0: favor splitting at nodes closest to the node, i.e. grow depth-wise. 1: favor splitting at nodes with highest loss change.
learning_rate (Optional[float]) – Boosting learning rate (xgb’s “eta”)
verbosity (Optional[int]) – The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective (typing.Union[str, typing.Callable[[numpy.ndarray, numpy.ndarray], typing.Tuple[numpy.ndarray, numpy.ndarray]], NoneType]) – Specify the learning task and the corresponding learning objective or a custom objective function to be used (see note below).
booster (Optional[str]) – Specify which booster to use: gbtree, gblinear or dart.
tree_method (Optional[str]) – Specify which tree method to use. Default to auto. If this parameter is set to default, XGBoost will choose the most conservative option available. It’s recommended to study this option from the parameters document tree method
n_jobs (Optional[int]) – Number of parallel threads used to run xgboost. When used with other Scikit-Learn algorithms like grid search, you may choose which algorithm to parallelize and balance the threads. Creating thread contention will significantly slow down both algorithms.
gamma (Optional[float]) – (min_split_loss) Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight (Optional[float]) – Minimum sum of instance weight(hessian) needed in a child.
max_delta_step (Optional[float]) – Maximum delta step we allow each tree’s weight estimation to be.
subsample (Optional[float]) – Subsample ratio of the training instance.
sampling_method –
- Sampling method. Used only by gpu_hist tree method.
uniform: select random training instances uniformly.
gradient_based select random training instances with higher probability when the gradient and hessian are larger. (cf. CatBoost)
colsample_bytree (Optional[float]) – Subsample ratio of columns when constructing each tree.
colsample_bylevel (Optional[float]) – Subsample ratio of columns for each level.
colsample_bynode (Optional[float]) – Subsample ratio of columns for each split.
reg_alpha (Optional[float]) – L1 regularization term on weights (xgb’s alpha).
reg_lambda (Optional[float]) – L2 regularization term on weights (xgb’s lambda).
scale_pos_weight (Optional[float]) – Balancing of positive and negative weights.
base_score (Optional[float]) – The initial prediction score of all instances, global bias.
random_state (Optional[Union[numpy.random.RandomState, int]]) –
Random number seed.
Note
Using gblinear booster with shotgun updater is nondeterministic as it uses Hogwild algorithm.
missing (float, default np.nan) – Value in the data which needs to be present as a missing value.
num_parallel_tree (Optional[int]) – Used for boosting random forest.
monotone_constraints (Optional[Union[Dict[str, int], str]]) – Constraint of variable monotonicity. See tutorial for more information.
interaction_constraints (Optional[Union[str, List[Tuple[str]]]]) – Constraints for interaction representing permitted interactions. The constraints must be specified in the form of a nested list, e.g.
[[0, 1], [2, 3, 4]]
, where each inner list is a group of indices of features that are allowed to interact with each other. See tutorial for more informationimportance_type (Optional[str]) –
The feature importance type for the feature_importances_ property:
For tree model, it’s either “gain”, “weight”, “cover”, “total_gain” or “total_cover”.
For linear model, only “weight” is defined and it’s the normalized coefficients without bias.
gpu_id (Optional[int]) – Device ordinal.
validate_parameters (Optional[bool]) – Give warnings for unknown parameter.
predictor (Optional[str]) – Force XGBoost to use specific predictor, available choices are [cpu_predictor, gpu_predictor].
enable_categorical (bool) –
New in version 1.5.0.
Note
This parameter is experimental
Experimental support for categorical data. When enabled, cudf/pandas.DataFrame should be used to specify categorical data type. Also, JSON/UBJSON serialization format is required.
max_cat_to_onehot (Optional[int]) –
New in version 1.6.0.
Note
This parameter is experimental
A threshold for deciding whether XGBoost should use one-hot encoding based split for categorical data. When number of categories is lesser than the threshold then one-hot encoding is chosen, otherwise the categories will be partitioned into children nodes. Only relevant for regression and binary classification. See Categorical Data for details.
eval_metric (Optional[Union[str, List[str], Callable]]) –
New in version 1.6.0.
Metric used for monitoring the training result and early stopping. It can be a string or list of strings as names of predefined metric in XGBoost (See doc/parameter.rst), one of the metrics in
sklearn.metrics
, or any other user defined metric that looks like sklearn.metrics.If custom objective is also provided, then custom metric should implement the corresponding reverse link function.
Unlike the scoring parameter commonly used in scikit-learn, when a callable object is provided, it’s assumed to be a cost function and by default XGBoost will minimize the result during early stopping.
For advanced usage on Early stopping like directly choosing to maximize instead of minimize, see
xgboost.callback.EarlyStopping
.See Custom Objective and Evaluation Metric for more.
Note
This parameter replaces eval_metric in
fit()
method. The old one receives un-transformed prediction regardless of whether custom objective is being used.from sklearn.datasets import load_diabetes from sklearn.metrics import mean_absolute_error X, y = load_diabetes(return_X_y=True) reg = xgb.XGBRegressor( tree_method="hist", eval_metric=mean_absolute_error, ) reg.fit(X, y, eval_set=[(X, y)])
early_stopping_rounds (Optional[int]) –
New in version 1.6.0.
Activates early stopping. Validation metric needs to improve at least once in every early_stopping_rounds round(s) to continue training. Requires at least one item in eval_set in
fit()
.The method returns the model from the last iteration (not the best one). If there’s more than one item in eval_set, the last entry will be used for early stopping. If there’s more than one metric in eval_metric, the last metric will be used for early stopping.
If early stopping occurs, the model will have three additional fields:
best_score
,best_iteration
andbest_ntree_limit
.Note
This parameter replaces early_stopping_rounds in
fit()
method.callbacks (Optional[List[TrainingCallback]]) –
List of callback functions that are applied at end of each iteration. It is possible to use predefined callbacks by using Callback API.
Note
States in callback are not preserved during training, which means callback objects can not be reused for multiple training sessions without reinitialization or deepcopy.
for params in parameters_grid: # be sure to (re)initialize the callbacks before each run callbacks = [xgb.callback.LearningRateScheduler(custom_rates)] xgboost.train(params, Xy, callbacks=callbacks)
kwargs (dict, optional) –
Keyword arguments for XGBoost Booster object. Full documentation of parameters can be found here. Attempting to set a parameter via the constructor args and **kwargs dict simultaneously will result in a TypeError.
Note
**kwargs unsupported by scikit-learn
**kwargs is unsupported by scikit-learn. We do not guarantee that parameters passed via this argument will interact properly with scikit-learn.
Note
Custom objective function
A custom objective function can be provided for the
objective
parameter. In this case, it should have the signatureobjective(y_true, y_pred) -> grad, hess
:- y_true: array_like of shape [n_samples]
The target values
- y_pred: array_like of shape [n_samples]
The predicted values
- grad: array_like of shape [n_samples]
The value of the gradient for each sample point.
- hess: array_like of shape [n_samples]
The value of the second derivative for each sample point
- fit(X, y, *, sample_weight=None, base_margin=None, eval_set=None, eval_metric=None, early_stopping_rounds=None, verbose=True, xgb_model: Optional[Union[xgboost.core.Booster, xgboost.sklearn.XGBModel, str]] = None, sample_weight_eval_set=None, base_margin_eval_set=None, feature_weights=None, callbacks=None, ray_params: Union[None, xgboost_ray.main.RayParams, Dict] = None, _remote: Optional[bool] = None, ray_dmatrix_params: Optional[Dict] = None)[source]¶
Fit gradient boosting model.
Note that calling
fit()
multiple times will cause the model object to be re-fit from scratch. To resume training from a previous checkpoint, explicitly passxgb_model
argument.- Parameters
X – Feature matrix. Can also be a
RayDMatrix
.y – Labels
sample_weight – instance weights
base_margin – global bias for each instance.
eval_set – A list of (X, y) tuple pairs to use as validation sets, for which metrics will be computed. Validation metrics will help us track the performance of the model.
eval_metric (str, list of str, or callable, optional) –
Deprecated since version 1.6.0: Use eval_metric in
__init__()
orset_params()
instead.early_stopping_rounds (int) –
Deprecated since version 1.6.0: Use early_stopping_rounds in
__init__()
orset_params()
instead.verbose – If verbose and an evaluation set is used, writes the evaluation metric measured on the validation set to stderr.
xgb_model – file name of stored XGBoost model or ‘Booster’ instance XGBoost model to be loaded before training (allows training continuation).
sample_weight_eval_set – A list of the form [L_1, L_2, …, L_n], where each L_i is an array like object storing instance weights for the i-th validation set.
base_margin_eval_set – A list of the form [M_1, M_2, …, M_n], where each M_i is an array like object storing base margin for the i-th validation set.
feature_weights – Weight for each feature, defines the probability of each feature being selected when colsample is being used. All values must be greater than 0, otherwise a ValueError is thrown.
callbacks –
Deprecated since version 1.6.0: Use callbacks in
__init__()
orset_params()
instead.ray_params (None or xgboost_ray.RayParams or Dict) – Parameters to configure Ray-specific behavior. See
xgboost_ray.RayParams
for a list of valid configuration parameters. Will overriden_jobs
attribute with ownnum_actors
parameter._remote (bool) – Whether to run the driver process in a remote function. This is enabled by default in Ray client mode.
ray_dmatrix_params (dict) – Dict of parameters (such as sharding mode) passed to the internal RayDMatrix initialization.
- predict(X, output_margin=False, ntree_limit=None, validate_features=True, base_margin=None, iteration_range=None, ray_params: Union[None, xgboost_ray.main.RayParams, Dict] = None, _remote: Optional[bool] = None, ray_dmatrix_params: Optional[Dict] = None)[source]¶
Predict with X. If the model is trained with early stopping, then best_iteration is used automatically. For tree models, when data is on GPU, like cupy array or cuDF dataframe and predictor is not specified, the prediction is run on GPU automatically, otherwise it will run on CPU.
Note
This function is only thread safe for gbtree and dart.
- Parameters
X – Data to predict with. Can also be a
RayDMatrix
.output_margin – Whether to output the raw untransformed margin value.
ntree_limit – Deprecated, use iteration_range instead.
validate_features – When this is True, validate that the Booster’s and data’s feature_names are identical. Otherwise, it is assumed that the feature_names are the same.
base_margin – Margin added to prediction.
iteration_range –
Specifies which layer of trees are used in prediction. For example, if a random forest is trained with 100 rounds. Specifying
iteration_range=(10, 20)
, then only the forests built during [10, 20) (half open set) rounds are used in this prediction.New in version 1.4.0.
- Returns
prediction
ray_params (None or xgboost_ray.RayParams or Dict) – Parameters to configure Ray-specific behavior. See
xgboost_ray.RayParams
for a list of valid configuration parameters. Will overriden_jobs
attribute with ownnum_actors
parameter._remote (bool) – Whether to run the driver process in a remote function. This is enabled by default in Ray client mode.
ray_dmatrix_params (dict) – Dict of parameters (such as sharding mode) passed to the internal RayDMatrix initialization.
- load_model(fname)[source]¶
Load the model from a file or bytearray. Path to file can be local or as an URI.
The model is loaded from XGBoost format which is universal among the various XGBoost interfaces. Auxiliary attributes of the Python Booster object (such as feature_names) will not be loaded when using binary format. To save those attributes, use JSON/UBJ instead. See Model IO for more info.
model.load_model("model.json") # or model.load_model("model.ubj")
- Parameters
fname – Input file name or memory buffer(see also save_raw)
- class xgboost_ray.RayXGBRFClassifier(*, learning_rate=1, subsample=0.8, colsample_bynode=0.8, reg_lambda=1e-05, **kwargs)[source]¶
scikit-learn API for Ray-distributed XGBoost random forest classification.
- Parameters
n_estimators (int) – Number of trees in random forest to fit.
max_depth (Optional[int]) – Maximum tree depth for base learners.
max_leaves – Maximum number of leaves; 0 indicates no limit.
max_bin – If using histogram-based algorithm, maximum number of bins per feature
grow_policy – Tree growing policy. 0: favor splitting at nodes closest to the node, i.e. grow depth-wise. 1: favor splitting at nodes with highest loss change.
learning_rate (Optional[float]) – Boosting learning rate (xgb’s “eta”)
verbosity (Optional[int]) – The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective (typing.Union[str, typing.Callable[[numpy.ndarray, numpy.ndarray], typing.Tuple[numpy.ndarray, numpy.ndarray]], NoneType]) – Specify the learning task and the corresponding learning objective or a custom objective function to be used (see note below).
booster (Optional[str]) – Specify which booster to use: gbtree, gblinear or dart.
tree_method (Optional[str]) – Specify which tree method to use. Default to auto. If this parameter is set to default, XGBoost will choose the most conservative option available. It’s recommended to study this option from the parameters document tree method
n_jobs (Optional[int]) – Number of parallel threads used to run xgboost. When used with other Scikit-Learn algorithms like grid search, you may choose which algorithm to parallelize and balance the threads. Creating thread contention will significantly slow down both algorithms.
gamma (Optional[float]) – (min_split_loss) Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight (Optional[float]) – Minimum sum of instance weight(hessian) needed in a child.
max_delta_step (Optional[float]) – Maximum delta step we allow each tree’s weight estimation to be.
subsample (Optional[float]) – Subsample ratio of the training instance.
sampling_method –
- Sampling method. Used only by gpu_hist tree method.
uniform: select random training instances uniformly.
gradient_based select random training instances with higher probability when the gradient and hessian are larger. (cf. CatBoost)
colsample_bytree (Optional[float]) – Subsample ratio of columns when constructing each tree.
colsample_bylevel (Optional[float]) – Subsample ratio of columns for each level.
colsample_bynode (Optional[float]) – Subsample ratio of columns for each split.
reg_alpha (Optional[float]) – L1 regularization term on weights (xgb’s alpha).
reg_lambda (Optional[float]) – L2 regularization term on weights (xgb’s lambda).
scale_pos_weight (Optional[float]) – Balancing of positive and negative weights.
base_score (Optional[float]) – The initial prediction score of all instances, global bias.
random_state (Optional[Union[numpy.random.RandomState, int]]) –
Random number seed.
Note
Using gblinear booster with shotgun updater is nondeterministic as it uses Hogwild algorithm.
missing (float, default np.nan) – Value in the data which needs to be present as a missing value.
num_parallel_tree (Optional[int]) – Used for boosting random forest.
monotone_constraints (Optional[Union[Dict[str, int], str]]) – Constraint of variable monotonicity. See tutorial for more information.
interaction_constraints (Optional[Union[str, List[Tuple[str]]]]) – Constraints for interaction representing permitted interactions. The constraints must be specified in the form of a nested list, e.g.
[[0, 1], [2, 3, 4]]
, where each inner list is a group of indices of features that are allowed to interact with each other. See tutorial for more informationimportance_type (Optional[str]) –
The feature importance type for the feature_importances_ property:
For tree model, it’s either “gain”, “weight”, “cover”, “total_gain” or “total_cover”.
For linear model, only “weight” is defined and it’s the normalized coefficients without bias.
gpu_id (Optional[int]) – Device ordinal.
validate_parameters (Optional[bool]) – Give warnings for unknown parameter.
predictor (Optional[str]) – Force XGBoost to use specific predictor, available choices are [cpu_predictor, gpu_predictor].
enable_categorical (bool) –
New in version 1.5.0.
Note
This parameter is experimental
Experimental support for categorical data. When enabled, cudf/pandas.DataFrame should be used to specify categorical data type. Also, JSON/UBJSON serialization format is required.
max_cat_to_onehot (Optional[int]) –
New in version 1.6.0.
Note
This parameter is experimental
A threshold for deciding whether XGBoost should use one-hot encoding based split for categorical data. When number of categories is lesser than the threshold then one-hot encoding is chosen, otherwise the categories will be partitioned into children nodes. Only relevant for regression and binary classification. See Categorical Data for details.
eval_metric (Optional[Union[str, List[str], Callable]]) –
New in version 1.6.0.
Metric used for monitoring the training result and early stopping. It can be a string or list of strings as names of predefined metric in XGBoost (See doc/parameter.rst), one of the metrics in
sklearn.metrics
, or any other user defined metric that looks like sklearn.metrics.If custom objective is also provided, then custom metric should implement the corresponding reverse link function.
Unlike the scoring parameter commonly used in scikit-learn, when a callable object is provided, it’s assumed to be a cost function and by default XGBoost will minimize the result during early stopping.
For advanced usage on Early stopping like directly choosing to maximize instead of minimize, see
xgboost.callback.EarlyStopping
.See Custom Objective and Evaluation Metric for more.
Note
This parameter replaces eval_metric in
fit()
method. The old one receives un-transformed prediction regardless of whether custom objective is being used.from sklearn.datasets import load_diabetes from sklearn.metrics import mean_absolute_error X, y = load_diabetes(return_X_y=True) reg = xgb.XGBRegressor( tree_method="hist", eval_metric=mean_absolute_error, ) reg.fit(X, y, eval_set=[(X, y)])
early_stopping_rounds (Optional[int]) –
New in version 1.6.0.
Activates early stopping. Validation metric needs to improve at least once in every early_stopping_rounds round(s) to continue training. Requires at least one item in eval_set in
fit()
.The method returns the model from the last iteration (not the best one). If there’s more than one item in eval_set, the last entry will be used for early stopping. If there’s more than one metric in eval_metric, the last metric will be used for early stopping.
If early stopping occurs, the model will have three additional fields:
best_score
,best_iteration
andbest_ntree_limit
.Note
This parameter replaces early_stopping_rounds in
fit()
method.callbacks (Optional[List[TrainingCallback]]) –
List of callback functions that are applied at end of each iteration. It is possible to use predefined callbacks by using Callback API.
Note
States in callback are not preserved during training, which means callback objects can not be reused for multiple training sessions without reinitialization or deepcopy.
for params in parameters_grid: # be sure to (re)initialize the callbacks before each run callbacks = [xgb.callback.LearningRateScheduler(custom_rates)] xgboost.train(params, Xy, callbacks=callbacks)
kwargs (dict, optional) –
Keyword arguments for XGBoost Booster object. Full documentation of parameters can be found here. Attempting to set a parameter via the constructor args and **kwargs dict simultaneously will result in a TypeError.
Note
**kwargs unsupported by scikit-learn
**kwargs is unsupported by scikit-learn. We do not guarantee that parameters passed via this argument will interact properly with scikit-learn.
Note
Custom objective function
A custom objective function can be provided for the
objective
parameter. In this case, it should have the signatureobjective(y_true, y_pred) -> grad, hess
:- y_true: array_like of shape [n_samples]
The target values
- y_pred: array_like of shape [n_samples]
The predicted values
- grad: array_like of shape [n_samples]
The value of the gradient for each sample point.
- hess: array_like of shape [n_samples]
The value of the second derivative for each sample point
- class xgboost_ray.RayXGBRFRegressor(*, learning_rate=1, subsample=0.8, colsample_bynode=0.8, reg_lambda=1e-05, **kwargs)[source]¶
scikit-learn API for Ray-distributed XGBoost random forest regression.
- Parameters
n_estimators (int) – Number of trees in random forest to fit.
max_depth (Optional[int]) – Maximum tree depth for base learners.
max_leaves – Maximum number of leaves; 0 indicates no limit.
max_bin – If using histogram-based algorithm, maximum number of bins per feature
grow_policy – Tree growing policy. 0: favor splitting at nodes closest to the node, i.e. grow depth-wise. 1: favor splitting at nodes with highest loss change.
learning_rate (Optional[float]) – Boosting learning rate (xgb’s “eta”)
verbosity (Optional[int]) – The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective (typing.Union[str, typing.Callable[[numpy.ndarray, numpy.ndarray], typing.Tuple[numpy.ndarray, numpy.ndarray]], NoneType]) – Specify the learning task and the corresponding learning objective or a custom objective function to be used (see note below).
booster (Optional[str]) – Specify which booster to use: gbtree, gblinear or dart.
tree_method (Optional[str]) – Specify which tree method to use. Default to auto. If this parameter is set to default, XGBoost will choose the most conservative option available. It’s recommended to study this option from the parameters document tree method
n_jobs (Optional[int]) – Number of parallel threads used to run xgboost. When used with other Scikit-Learn algorithms like grid search, you may choose which algorithm to parallelize and balance the threads. Creating thread contention will significantly slow down both algorithms.
gamma (Optional[float]) – (min_split_loss) Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight (Optional[float]) – Minimum sum of instance weight(hessian) needed in a child.
max_delta_step (Optional[float]) – Maximum delta step we allow each tree’s weight estimation to be.
subsample (Optional[float]) – Subsample ratio of the training instance.
sampling_method –
- Sampling method. Used only by gpu_hist tree method.
uniform: select random training instances uniformly.
gradient_based select random training instances with higher probability when the gradient and hessian are larger. (cf. CatBoost)
colsample_bytree (Optional[float]) – Subsample ratio of columns when constructing each tree.
colsample_bylevel (Optional[float]) – Subsample ratio of columns for each level.
colsample_bynode (Optional[float]) – Subsample ratio of columns for each split.
reg_alpha (Optional[float]) – L1 regularization term on weights (xgb’s alpha).
reg_lambda (Optional[float]) – L2 regularization term on weights (xgb’s lambda).
scale_pos_weight (Optional[float]) – Balancing of positive and negative weights.
base_score (Optional[float]) – The initial prediction score of all instances, global bias.
random_state (Optional[Union[numpy.random.RandomState, int]]) –
Random number seed.
Note
Using gblinear booster with shotgun updater is nondeterministic as it uses Hogwild algorithm.
missing (float, default np.nan) – Value in the data which needs to be present as a missing value.
num_parallel_tree (Optional[int]) – Used for boosting random forest.
monotone_constraints (Optional[Union[Dict[str, int], str]]) – Constraint of variable monotonicity. See tutorial for more information.
interaction_constraints (Optional[Union[str, List[Tuple[str]]]]) – Constraints for interaction representing permitted interactions. The constraints must be specified in the form of a nested list, e.g.
[[0, 1], [2, 3, 4]]
, where each inner list is a group of indices of features that are allowed to interact with each other. See tutorial for more informationimportance_type (Optional[str]) –
The feature importance type for the feature_importances_ property:
For tree model, it’s either “gain”, “weight”, “cover”, “total_gain” or “total_cover”.
For linear model, only “weight” is defined and it’s the normalized coefficients without bias.
gpu_id (Optional[int]) – Device ordinal.
validate_parameters (Optional[bool]) – Give warnings for unknown parameter.
predictor (Optional[str]) – Force XGBoost to use specific predictor, available choices are [cpu_predictor, gpu_predictor].
enable_categorical (bool) –
New in version 1.5.0.
Note
This parameter is experimental
Experimental support for categorical data. When enabled, cudf/pandas.DataFrame should be used to specify categorical data type. Also, JSON/UBJSON serialization format is required.
max_cat_to_onehot (Optional[int]) –
New in version 1.6.0.
Note
This parameter is experimental
A threshold for deciding whether XGBoost should use one-hot encoding based split for categorical data. When number of categories is lesser than the threshold then one-hot encoding is chosen, otherwise the categories will be partitioned into children nodes. Only relevant for regression and binary classification. See Categorical Data for details.
eval_metric (Optional[Union[str, List[str], Callable]]) –
New in version 1.6.0.
Metric used for monitoring the training result and early stopping. It can be a string or list of strings as names of predefined metric in XGBoost (See doc/parameter.rst), one of the metrics in
sklearn.metrics
, or any other user defined metric that looks like sklearn.metrics.If custom objective is also provided, then custom metric should implement the corresponding reverse link function.
Unlike the scoring parameter commonly used in scikit-learn, when a callable object is provided, it’s assumed to be a cost function and by default XGBoost will minimize the result during early stopping.
For advanced usage on Early stopping like directly choosing to maximize instead of minimize, see
xgboost.callback.EarlyStopping
.See Custom Objective and Evaluation Metric for more.
Note
This parameter replaces eval_metric in
fit()
method. The old one receives un-transformed prediction regardless of whether custom objective is being used.from sklearn.datasets import load_diabetes from sklearn.metrics import mean_absolute_error X, y = load_diabetes(return_X_y=True) reg = xgb.XGBRegressor( tree_method="hist", eval_metric=mean_absolute_error, ) reg.fit(X, y, eval_set=[(X, y)])
early_stopping_rounds (Optional[int]) –
New in version 1.6.0.
Activates early stopping. Validation metric needs to improve at least once in every early_stopping_rounds round(s) to continue training. Requires at least one item in eval_set in
fit()
.The method returns the model from the last iteration (not the best one). If there’s more than one item in eval_set, the last entry will be used for early stopping. If there’s more than one metric in eval_metric, the last metric will be used for early stopping.
If early stopping occurs, the model will have three additional fields:
best_score
,best_iteration
andbest_ntree_limit
.Note
This parameter replaces early_stopping_rounds in
fit()
method.callbacks (Optional[List[TrainingCallback]]) –
List of callback functions that are applied at end of each iteration. It is possible to use predefined callbacks by using Callback API.
Note
States in callback are not preserved during training, which means callback objects can not be reused for multiple training sessions without reinitialization or deepcopy.
for params in parameters_grid: # be sure to (re)initialize the callbacks before each run callbacks = [xgb.callback.LearningRateScheduler(custom_rates)] xgboost.train(params, Xy, callbacks=callbacks)
kwargs (dict, optional) –
Keyword arguments for XGBoost Booster object. Full documentation of parameters can be found here. Attempting to set a parameter via the constructor args and **kwargs dict simultaneously will result in a TypeError.
Note
**kwargs unsupported by scikit-learn
**kwargs is unsupported by scikit-learn. We do not guarantee that parameters passed via this argument will interact properly with scikit-learn.
Note
Custom objective function
A custom objective function can be provided for the
objective
parameter. In this case, it should have the signatureobjective(y_true, y_pred) -> grad, hess
:- y_true: array_like of shape [n_samples]
The target values
- y_pred: array_like of shape [n_samples]
The predicted values
- grad: array_like of shape [n_samples]
The value of the gradient for each sample point.
- hess: array_like of shape [n_samples]
The value of the second derivative for each sample point