
Getting Started Guide¶
This tutorial will give you a quick tour of Ray’s features. To get started, we’ll start by installing Ray. Most of the examples in this guide are based on Python, but we’ll also show you how to user Ray Core in Java.
Python
To use Ray in Python, install it with
pip install ray
Java
To use Ray in Java, first add the ray-api and ray-runtime dependencies in your project.
Want to build Ray from source or with docker? Need more details? Check out our detailed installation guide.
Ray ML Quick Start¶
Ray has a rich ecosystem of libraries and frameworks built on top of it. Simply click on the dropdowns below to see examples of our most popular libraries.
Data: Creating and Transforming Datasets
Ray Datasets are the standard way to load and exchange data in Ray libraries and applications.
Datasets provide basic distributed data transformations such as map
, filter
, and repartition
.
They are compatible with a variety of file formats, datasources, and distributed frameworks.
Note
To get started with this example install Ray Data as follows.
pip install "ray[data]" dask
Get started by creating Datasets from synthetic data using ray.data.range()
and ray.data.from_items()
.
Datasets can hold either plain Python objects (schema is a Python type), or Arrow records (schema is Arrow).
import ray
# Create a Dataset of Python objects.
ds = ray.data.range(10000)
# -> Dataset(num_blocks=200, num_rows=10000, schema=<class 'int'>)
ds.take(5)
# -> [0, 1, 2, 3, 4]
ds.count()
# -> 10000
# Create a Dataset of Arrow records.
ds = ray.data.from_items([{"col1": i, "col2": str(i)} for i in range(10000)])
# -> Dataset(num_blocks=200, num_rows=10000, schema={col1: int64, col2: string})
ds.show(5)
# -> {'col1': 0, 'col2': '0'}
# -> {'col1': 1, 'col2': '1'}
# -> {'col1': 2, 'col2': '2'}
# -> {'col1': 3, 'col2': '3'}
# -> {'col1': 4, 'col2': '4'}
ds.schema()
# -> col1: int64
# -> col2: string
Datasets can be created from files on local disk or remote datasources such as S3. Any filesystem
supported by pyarrow can be used to specify file locations.
You can also create a Dataset
from existing data in the Ray object store or Ray-compatible distributed DataFrames:
import pandas as pd
import dask.dataframe as dd
# Create a Dataset from a list of Pandas DataFrame objects.
pdf = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
ds = ray.data.from_pandas([pdf])
# Create a Dataset from a Dask-on-Ray DataFrame.
dask_df = dd.from_pandas(pdf, npartitions=10)
ds = ray.data.from_dask(dask_df)
Datasets can be transformed in parallel using .map()
.
Transformations are executed eagerly and block until the operation is finished.
Datasets also supports .filter()
and .flat_map()
.
ds = ray.data.range(10000)
ds = ds.map(lambda x: x * 2)
# -> Map Progress: 100%|████████████████████| 200/200 [00:00<00:00, 1123.54it/s]
# -> Dataset(num_blocks=200, num_rows=10000, schema=<class 'int'>)
ds.take(5)
# -> [0, 2, 4, 6, 8]
ds.filter(lambda x: x > 5).take(5)
# -> Map Progress: 100%|████████████████████| 200/200 [00:00<00:00, 1859.63it/s]
# -> [6, 8, 10, 12, 14]
ds.flat_map(lambda x: [x, -x]).take(5)
# -> Map Progress: 100%|████████████████████| 200/200 [00:00<00:00, 1568.10it/s]
# -> [0, 0, 2, -2, 4]
Train: Distributed Model Training
Ray Train abstracts away the complexity of setting up a distributed training system. Let’s take following simple examples:
This example shows how you can use Ray Train with PyTorch.
First, set up your dataset and model.
import torch
import torch.nn as nn
num_samples = 20
input_size = 10
layer_size = 15
output_size = 5
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.layer1 = nn.Linear(input_size, layer_size)
self.relu = nn.ReLU()
self.layer2 = nn.Linear(layer_size, output_size)
def forward(self, input):
return self.layer2(self.relu(self.layer1(input)))
# In this example we use a randomly generated dataset.
input = torch.randn(num_samples, input_size)
labels = torch.randn(num_samples, output_size)
Now define your single-worker PyTorch training function.
import torch.optim as optim
def train_func():
num_epochs = 3
model = NeuralNetwork()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
for epoch in range(num_epochs):
output = model(input)
loss = loss_fn(output, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"epoch: {epoch}, loss: {loss.item()}")
This training function can be executed with:
train_func()
Now let’s convert this to a distributed multi-worker training function!
All you have to do is use the ray.train.torch.prepare_model
and
ray.train.torch.prepare_data_loader
utility functions to
easily setup your model & data for distributed training.
This will automatically wrap your model with DistributedDataParallel
and place it on the right device, and add DistributedSampler
to your DataLoaders.
from ray import train
import ray.train.torch
def train_func_distributed():
num_epochs = 3
model = NeuralNetwork()
model = train.torch.prepare_model(model)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
for epoch in range(num_epochs):
output = model(input)
loss = loss_fn(output, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"epoch: {epoch}, loss: {loss.item()}")
Then, instantiate a Trainer
that uses a "torch"
backend
with 4 workers, and use it to run the new training function!
from ray.train import Trainer
trainer = Trainer(backend="torch", num_workers=4)
# For GPU Training, set `use_gpu` to True.
# trainer = Trainer(backend="torch", num_workers=4, use_gpu=True)
trainer.start()
results = trainer.run(train_func_distributed)
trainer.shutdown()
This example shows how you can use Ray Train to set up Multi-worker training with Keras <https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras>
_.
First, set up your dataset and model.
import numpy as np
import tensorflow as tf
def mnist_dataset(batch_size):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
# The `x` arrays are in uint8 and have values in the [0, 255] range.
# You need to convert them to float32 with values in the [0, 1] range.
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
return train_dataset
def build_and_compile_cnn_model():
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=['accuracy'])
return model
Now define your single-worker TensorFlow training function.
def train_func():
batch_size = 64
single_worker_dataset = mnist_dataset(batch_size)
single_worker_model = build_and_compile_cnn_model()
single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)
This training function can be executed with:
train_func()
Now let’s convert this to a distributed multi-worker training function! All you need to do is:
Set the global batch size - each worker will process the same size batch as in the single-worker code.
Choose your TensorFlow distributed training strategy. In this example we use the
MultiWorkerMirroredStrategy
.
import json
import os
def train_func_distributed():
per_worker_batch_size = 64
# This environment variable will be set by Ray Train.
tf_config = json.loads(os.environ['TF_CONFIG'])
num_workers = len(tf_config['cluster']['worker'])
strategy = tf.distribute.MultiWorkerMirroredStrategy()
global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_dataset(global_batch_size)
with strategy.scope():
# Model building/compiling need to be within `strategy.scope()`.
multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
Then, instantiate a Trainer
that uses a "tensorflow"
backend
with 4 workers, and use it to run the new training function!
from ray.train import Trainer
trainer = Trainer(backend="tensorflow", num_workers=4)
# For GPU Training, set `use_gpu` to True.
# trainer = Trainer(backend="tensorflow", num_workers=4, use_gpu=True)
trainer.start()
results = trainer.run(train_func_distributed)
trainer.shutdown()
Tune: Hyperparameter Tuning at Scale
Tune is a library for hyperparameter tuning at any scale. With Tune, you can launch a multi-node distributed hyperparameter sweep in less than 10 lines of code. Tune supports any deep learning framework, including PyTorch, TensorFlow, and Keras.
Note
To run this example, you will need to install the following:
pip install "ray[tune]"
This example runs a small grid search with an iterative training function.
from ray import tune
# 1. Define an objective function.
def objective(config):
score = config["a"] ** 2 + config["b"]
return {"score": score}
# 2. Define a search space.
search_space = {
"a": tune.grid_search([0.001, 0.01, 0.1, 1.0]),
"b": tune.choice([1, 2, 3]),
}
# 3. Start a Tune run and print the best result.
analysis = tune.run(objective, config=search_space)
print(analysis.get_best_config(metric="score", mode="min"))
If TensorBoard is installed, automatically visualize all trial results:
tensorboard --logdir ~/ray_results
Serve: Scalable Model Serving
Ray Serve is a scalable model-serving library built on Ray.
Note
To run this example, you will need to install the following libraries.
pip install "ray[serve]" scikit-learn
This example runs serves a scikit-learn gradient boosting classifier.
import requests
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier
from ray import serve
serve.start()
# Train model.
iris_dataset = load_iris()
model = GradientBoostingClassifier()
model.fit(iris_dataset["data"], iris_dataset["target"])
@serve.deployment(route_prefix="/iris")
class BoostingModel:
def __init__(self, model):
self.model = model
self.label_list = iris_dataset["target_names"].tolist()
async def __call__(self, request):
payload = (await request.json())["vector"]
print(f"Received http request with data {payload}")
prediction = self.model.predict([payload])[0]
human_name = self.label_list[prediction]
return {"result": human_name}
# Deploy model.
BoostingModel.deploy(model)
# Query it!
sample_request_input = {"vector": [1.2, 1.0, 1.1, 0.9]}
response = requests.get(
"http://localhost:8000/iris", json=sample_request_input)
print(response.text)
As a result you will see {"result": "versicolor"}
.
RLlib: Industry-Grade Reinforcement Learning
RLlib is an industry-grade library for reinforcement learning (RL) built on top of Ray. RLlib offers high scalability and unified APIs for a variety of industry- and research applications.
Note
To run this example, you will need to install rllib
and either tensorflow
or pytorch
.
pip install "ray[rllib]" tensorflow # or torch
import gym
from ray.rllib.agents.ppo import PPOTrainer
# Define your problem using python and openAI's gym API:
class SimpleCorridor(gym.Env):
"""Corridor in which an agent must learn to move right to reach the exit.
---------------------
| S | 1 | 2 | 3 | G | S=start; G=goal; corridor_length=5
---------------------
Possible actions to chose from are: 0=left; 1=right
Observations are floats indicating the current field index, e.g. 0.0 for
starting position, 1.0 for the field next to the starting position, etc..
Rewards are -0.1 for all steps, except when reaching the goal (+1.0).
"""
def __init__(self, config):
self.end_pos = config["corridor_length"]
self.cur_pos = 0
self.action_space = gym.spaces.Discrete(2) # left and right
self.observation_space = gym.spaces.Box(0.0, self.end_pos, shape=(1,))
def reset(self):
"""Resets the episode and returns the initial observation of the new one."""
self.cur_pos = 0
# Return initial observation.
return [self.cur_pos]
def step(self, action):
"""Takes a single step in the episode given `action`
Returns:
New observation, reward, done-flag, info-dict (empty).
"""
# Walk left.
if action == 0 and self.cur_pos > 0:
self.cur_pos -= 1
# Walk right.
elif action == 1:
self.cur_pos += 1
# Set `done` flag when end of corridor (goal) reached.
done = self.cur_pos >= self.end_pos
# +1 when goal reached, otherwise -1.
reward = 1.0 if done else -0.1
return [self.cur_pos], reward, done, {}
# Create an RLlib Trainer instance.
trainer = PPOTrainer(
config={
# Env class to use (here: our gym.Env sub-class from above).
"env": SimpleCorridor,
# Config dict to be passed to our custom env's constructor.
"env_config": {
# Use corridor with 20 fields (including S and G).
"corridor_length": 20
},
# Parallelize environment rollouts.
"num_workers": 3,
}
)
# Train for n iterations and report results (mean episode rewards).
# Since we have to move at least 19 times in the env to reach the goal and
# each move gives us -0.1 reward (except the last move at the end: +1.0),
# we can expect to reach an optimal episode reward of -0.1*18 + 1.0 = -0.8
for i in range(5):
results = trainer.train()
print(f"Iter: {i}; avg. reward={results['episode_reward_mean']}")
# Perform inference (action computations) based on given env observations.
# Note that we are using a slightly different env here (len 10 instead of 20),
# however, this should still work as the agent has (hopefully) learned
# to "just always walk right!"
env = SimpleCorridor({"corridor_length": 10})
# Get the initial observation (should be: [0.0] for the starting position).
obs = env.reset()
done = False
total_reward = 0.0
# Play one episode.
while not done:
# Compute a single action, given the current observation
# from the environment.
action = trainer.compute_single_action(obs)
# Apply the computed action in the environment.
obs, reward, done, info = env.step(action)
# Sum up rewards for reporting purposes.
total_reward += reward
# Report results.
print(f"Played 1 episode; total-reward={total_reward}")
Ray Core Quick Start¶
Ray Core provides simple primitives for building and running distributed applications. Below you find examples that show you how to turn your functions and classes easily into Ray tasks and actors, for both Python and Java.
Core: Parallelizing Functions with Ray Tasks
First, you import Ray and and initialize it with ray.init()
.
Then you decorate your function with @ray.remote
to declare that you want to run this function remotely.
Lastly, you call that function with .remote()
instead of calling it normally.
This remote call yields a future, a so-called Ray object reference, that you can then fetch with ray.get
.
import ray
ray.init()
@ray.remote
def f(x):
return x * x
futures = [f.remote(i) for i in range(4)]
print(ray.get(futures)) # [0, 1, 4, 9]
First, use Ray.init
to initialize Ray runtime.
Then you can use Ray.task(...).remote()
to convert any Java static method into a Ray task.
The task will run asynchronously in a remote worker process. The remote
method will return an ObjectRef
,
and you can then fetch the actual result with get
.
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import java.util.ArrayList;
import java.util.List;
public class RayDemo {
public static int square(int x) {
return x * x;
}
public static void main(String[] args) {
// Intialize Ray runtime.
Ray.init();
List<ObjectRef<Integer>> objectRefList = new ArrayList<>();
// Invoke the `square` method 4 times remotely as Ray tasks.
// The tasks will run in parallel in the background.
for (int i = 0; i < 4; i++) {
objectRefList.add(Ray.task(RayDemo::square, i).remote());
}
// Get the actual results of the tasks.
System.out.println(Ray.get(objectRefList)); // [0, 1, 4, 9]
}
}
In the above code block we defined some Ray Tasks. While these are great for stateless operations, sometimes you must maintain the state of your application. You can do that with Ray Actors.
Core: Parallelizing Classes with Ray Actors
Ray provides actors to allow you to parallelize an instance of a class in Python or Java. When you instantiate a class that is a Ray actor, Ray will start a remote instance of that class in the cluster. This actor can then execute remote method calls and maintain its own internal state.
import ray
ray.init() # Only call this once.
@ray.remote
class Counter(object):
def __init__(self):
self.n = 0
def increment(self):
self.n += 1
def read(self):
return self.n
counters = [Counter.remote() for i in range(4)]
[c.increment.remote() for c in counters]
futures = [c.read.remote() for c in counters]
print(ray.get(futures)) # [1, 1, 1, 1]
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
public class RayDemo {
public static class Counter {
private int value = 0;
public void increment() {
this.value += 1;
}
public int read() {
return this.value;
}
}
public static void main(String[] args) {
// Intialize Ray runtime.
Ray.init();
List<ActorHandle<Counter>> counters = new ArrayList<>();
// Create 4 actors from the `Counter` class.
// They will run in remote worker processes.
for (int i = 0; i < 4; i++) {
counters.add(Ray.actor(Counter::new).remote());
}
// Invoke the `increment` method on each actor.
// This will send an actor task to each remote actor.
for (ActorHandle<Counter> counter : counters) {
counter.task(Counter::increment).remote();
}
// Invoke the `read` method on each actor, and print the results.
List<ObjectRef<Integer>> objectRefList = counters.stream()
.map(counter -> counter.task(Counter::read).remote())
.collect(Collectors.toList());
System.out.println(Ray.get(objectRefList)); // [1, 1, 1, 1]
}
}
Ray Cluster Quick Start¶
You can deploy your applications on Ray clusters, often with minimal code changes to your existing code. See an example of this below.
Clusters: Launching a Ray Cluster on AWS
Ray programs can run on a single machine, or seamlessly scale to large clusters. Take this simple example that waits for individual nodes to join the cluster.
example.py
from collections import Counter
import sys
import time
import ray
@ray.remote
def get_host_name(x):
import platform
import time
time.sleep(0.01)
return x + (platform.node(),)
def wait_for_nodes(expected):
# Wait for all nodes to join the cluster.
while True:
num_nodes = len(ray.nodes())
if num_nodes < expected:
print(
"{} nodes have joined so far, waiting for {} more.".format(
num_nodes, expected - num_nodes
)
)
sys.stdout.flush()
time.sleep(1)
else:
break
def main():
wait_for_nodes(4)
# Check that objects can be transferred from each node to each other node.
for i in range(10):
print("Iteration {}".format(i))
results = [get_host_name.remote(get_host_name.remote(())) for _ in range(100)]
print(Counter(ray.get(results)))
sys.stdout.flush()
print("Success!")
sys.stdout.flush()
time.sleep(20)
if __name__ == "__main__":
ray.init(address="localhost:6379")
main()
You can also download this example from our GitHub repository.
Go ahead and store it locally in a file called example.py
.
To execute this script in the cloud, just download this configuration file, or copy it here:
cluster.yaml
# An unique identifier for the head node and workers of this cluster.
cluster_name: default
# The maximum number of workers nodes to launch in addition to the head
# node.
max_workers: 2
# The autoscaler will scale up the cluster faster with higher upscaling speed.
# E.g., if the task requires adding more nodes then autoscaler will gradually
# scale up the cluster in chunks of upscaling_speed*currently_running_nodes.
# This number should be > 0.
upscaling_speed: 1.0
# This executes all commands on all nodes in the docker container,
# and opens all the necessary ports to support the Ray cluster.
# Empty string means disabled.
docker:
image: "rayproject/ray-ml:latest-gpu" # You can change this to latest-cpu if you don't need GPU support and want a faster startup
# image: rayproject/ray:latest-gpu # use this one if you don't need ML dependencies, it's faster to pull
container_name: "ray_container"
# If true, pulls latest version of image. Otherwise, `docker run` will only pull the image
# if no cached version is present.
pull_before_run: True
run_options: # Extra options to pass into "docker run"
- --ulimit nofile=65536:65536
# Example of running a GPU head with CPU workers
# head_image: "rayproject/ray-ml:latest-gpu"
# Allow Ray to automatically detect GPUs
# worker_image: "rayproject/ray-ml:latest-cpu"
# worker_run_options: []
# If a node is idle for this many minutes, it will be removed.
idle_timeout_minutes: 5
# Cloud-provider specific configuration.
provider:
type: aws
region: us-west-2
# Availability zone(s), comma-separated, that nodes may be launched in.
# Nodes will be launched in the first listed availability zone and will
# be tried in the subsequent availability zones if launching fails.
availability_zone: us-west-2a,us-west-2b
# Whether to allow node reuse. If set to False, nodes will be terminated
# instead of stopped.
cache_stopped_nodes: True # If not present, the default is True.
# How Ray will authenticate with newly launched nodes.
auth:
ssh_user: ubuntu
# By default Ray creates a new private keypair, but you can also use your own.
# If you do so, make sure to also set "KeyName" in the head and worker node
# configurations below.
# ssh_private_key: /path/to/your/key.pem
# Tell the autoscaler the allowed node types and the resources they provide.
# The key is the name of the node type, which is just for debugging purposes.
# The node config specifies the launch config and physical instance type.
available_node_types:
ray.head.default:
# The node type's CPU and GPU resources are auto-detected based on AWS instance type.
# If desired, you can override the autodetected CPU and GPU resources advertised to the autoscaler.
# You can also set custom resources.
# For example, to mark a node type as having 1 CPU, 1 GPU, and 5 units of a resource called "custom", set
# resources: {"CPU": 1, "GPU": 1, "custom": 5}
resources: {}
# Provider-specific config for this node type, e.g. instance type. By default
# Ray will auto-configure unspecified fields such as SubnetId and KeyName.
# For more documentation on available fields, see:
# http://boto3.readthedocs.io/en/latest/reference/services/ec2.html#EC2.ServiceResource.create_instances
node_config:
InstanceType: m5.large
ImageId: ami-0a2363a9cff180a64 # Deep Learning AMI (Ubuntu) Version 30
# You can provision additional disk space with a conf as follows
BlockDeviceMappings:
- DeviceName: /dev/sda1
Ebs:
VolumeSize: 100
# Additional options in the boto docs.
ray.worker.default:
# The minimum number of worker nodes of this type to launch.
# This number should be >= 0.
min_workers: 0
# The maximum number of worker nodes of this type to launch.
# This takes precedence over min_workers.
max_workers: 2
# The node type's CPU and GPU resources are auto-detected based on AWS instance type.
# If desired, you can override the autodetected CPU and GPU resources advertised to the autoscaler.
# You can also set custom resources.
# For example, to mark a node type as having 1 CPU, 1 GPU, and 5 units of a resource called "custom", set
# resources: {"CPU": 1, "GPU": 1, "custom": 5}
resources: {}
# Provider-specific config for this node type, e.g. instance type. By default
# Ray will auto-configure unspecified fields such as SubnetId and KeyName.
# For more documentation on available fields, see:
# http://boto3.readthedocs.io/en/latest/reference/services/ec2.html#EC2.ServiceResource.create_instances
node_config:
InstanceType: m5.large
ImageId: ami-0a2363a9cff180a64 # Deep Learning AMI (Ubuntu) Version 30
# Run workers on spot by default. Comment this out to use on-demand.
# NOTE: If relying on spot instances, it is best to specify multiple different instance
# types to avoid interruption when one instance type is experiencing heightened demand.
# Demand information can be found at https://aws.amazon.com/ec2/spot/instance-advisor/
InstanceMarketOptions:
MarketType: spot
# Additional options can be found in the boto docs, e.g.
# SpotOptions:
# MaxPrice: MAX_HOURLY_PRICE
# Additional options in the boto docs.
# Specify the node type of the head node (as configured above).
head_node_type: ray.head.default
# Files or directories to copy to the head and worker nodes. The format is a
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
file_mounts: {
# "/path1/on/remote/machine": "/path1/on/local/machine",
# "/path2/on/remote/machine": "/path2/on/local/machine",
}
# Files or directories to copy from the head node to the worker nodes. The format is a
# list of paths. The same path on the head node will be copied to the worker node.
# This behavior is a subset of the file_mounts behavior. In the vast majority of cases
# you should just use file_mounts. Only use this if you know what you're doing!
cluster_synced_files: []
# Whether changes to directories in file_mounts or cluster_synced_files in the head node
# should sync to the worker node continuously
file_mounts_sync_continuously: False
# Patterns for files to exclude when running rsync up or rsync down
rsync_exclude:
- "**/.git"
- "**/.git/**"
# Pattern files to use for filtering out files when running rsync up or rsync down. The file is searched for
# in the source directory and recursively through all subdirectories. For example, if .gitignore is provided
# as a value, the behavior will match git's behavior for finding and using .gitignore files.
rsync_filter:
- ".gitignore"
# List of commands that will be run before `setup_commands`. If docker is
# enabled, these commands will run outside the container and before docker
# is setup.
initialization_commands: []
# List of shell commands to run to set up nodes.
setup_commands: []
# Note: if you're developing Ray, you probably want to create a Docker image that
# has your Ray repo pre-cloned. Then, you can replace the pip installs
# below with a git checkout <your_sha> (and possibly a recompile).
# To run the nightly version of ray (as opposed to the latest), either use a rayproject docker image
# that has the "nightly" (e.g. "rayproject/ray-ml:nightly-gpu") or uncomment the following line:
# - pip install -U "ray[default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp37-cp37m-manylinux2014_x86_64.whl"
# Custom commands that will be run on the head node after common setup.
head_setup_commands: []
# Custom commands that will be run on worker nodes after common setup.
worker_setup_commands: []
# Command to start ray on the head node. You don't need to change this.
head_start_ray_commands:
- ray stop
- ray start --head --port=6379 --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml
# Command to start ray on worker nodes. You don't need to change this.
worker_start_ray_commands:
- ray stop
- ray start --address=$RAY_HEAD_IP:6379 --object-manager-port=8076
head_node: {}
worker_nodes: {}
Assuming you have stored this configuration in a file called cluster.yaml
, you can now launch an AWS cluster as follows:
ray submit cluster.yaml example.py --start
Learn More¶
Here are some talks, papers, and press coverage involving Ray and its libraries. Please raise an issue if any of the below links are broken, or if you’d like to add your own talk!
Blog and Press¶
Talks (Videos)¶
Unifying Large Scale Data Preprocessing and Machine Learning Pipelines with Ray Datasets | PyData 2021 (slides)
Programming at any Scale with Ray | SF Python Meetup Sept 2019
Ray: A Cluster Computing Engine for Reinforcement Learning Applications | Spark Summit
Enabling Composition in Distributed Reinforcement Learning | Spark Summit 2018