import json
import logging
import os
import random
import shutil
import subprocess
import sys
import tempfile
import threading
import time
from typing import Any, Dict, Optional
import yaml
import ray
from ray._private.dict import deep_update
from ray.autoscaler._private.fake_multi_node.node_provider import (
FAKE_DOCKER_DEFAULT_CLIENT_PORT,
FAKE_DOCKER_DEFAULT_GCS_PORT,
)
from ray.util.queue import Empty, Queue
logger = logging.getLogger(__name__)
DEFAULT_DOCKER_IMAGE = "rayproject/ray:nightly-py{major}{minor}-cpu"
class ResourcesNotReadyError(RuntimeError):
pass
[docs]
class DockerCluster:
"""Docker cluster wrapper.
Creates a directory for starting a fake multinode docker cluster.
Includes APIs to update the cluster config as needed in tests,
and to start and connect to the cluster.
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
self._base_config_file = os.path.join(
os.path.dirname(__file__), "example_docker.yaml"
)
self._tempdir = None
self._config_file = None
self._nodes_file = None
self._nodes = {}
self._status_file = None
self._status = {}
self._partial_config = config
self._cluster_config = None
self._docker_image = None
self._monitor_script = os.path.join(
os.path.dirname(__file__), "docker_monitor.py"
)
self._monitor_process = None
self._execution_thread = None
self._execution_event = threading.Event()
self._execution_queue = None
@property
def config_file(self):
return self._config_file
@property
def cluster_config(self):
return self._cluster_config
@property
def cluster_dir(self):
return self._tempdir
@property
def gcs_port(self):
return self._cluster_config.get("provider", {}).get(
"host_gcs_port", FAKE_DOCKER_DEFAULT_GCS_PORT
)
@property
def client_port(self):
return self._cluster_config.get("provider", {}).get(
"host_client_port", FAKE_DOCKER_DEFAULT_CLIENT_PORT
)
[docs]
def connect(self, client: bool = True, timeout: int = 120, **init_kwargs):
"""Connect to the docker-compose Ray cluster.
Assumes the cluster is at RAY_TESTHOST (defaults to
``127.0.0.1``).
Args:
client: If True, uses Ray client to connect to the
cluster. If False, uses GCS to connect to the cluster.
timeout: Connection timeout in seconds.
**init_kwargs: kwargs to pass to ``ray.init()``.
"""
host = os.environ.get("RAY_TESTHOST", "127.0.0.1")
if client:
port = self.client_port
address = f"ray://{host}:{port}"
else:
port = self.gcs_port
address = f"{host}:{port}"
timeout_at = time.monotonic() + timeout
while time.monotonic() < timeout_at:
try:
ray.init(address, **init_kwargs)
self.wait_for_resources({"CPU": 1})
except ResourcesNotReadyError:
time.sleep(1)
continue
else:
break
try:
ray.cluster_resources()
except Exception as e:
raise RuntimeError(f"Timed out connecting to Ray: {e}")
[docs]
def remote_execution_api(self) -> "RemoteAPI":
"""Create an object to control cluster state from within the cluster."""
self._execution_queue = Queue(actor_options={"num_cpus": 0})
stop_event = self._execution_event
def entrypoint():
while not stop_event.is_set():
try:
cmd, kwargs = self._execution_queue.get(timeout=1)
except Empty:
continue
if cmd == "kill_node":
self.kill_node(**kwargs)
self._execution_thread = threading.Thread(target=entrypoint)
self._execution_thread.start()
return RemoteAPI(self._execution_queue)
[docs]
@staticmethod
def wait_for_resources(resources: Dict[str, float], timeout: int = 60):
"""Wait until Ray cluster resources are available
Args:
resources: Minimum resources needed before
this function returns.
timeout: Timeout in seconds.
"""
timeout = time.monotonic() + timeout
available = ray.cluster_resources()
while any(available.get(k, 0.0) < v for k, v in resources.items()):
if time.monotonic() > timeout:
raise ResourcesNotReadyError(
f"Timed out waiting for resources: {resources}"
)
time.sleep(1)
available = ray.cluster_resources()
[docs]
def update_config(self, config: Optional[Dict[str, Any]] = None):
"""Update autoscaling config.
Does a deep update of the base config with a new configuration.
This can change autoscaling behavior.
Args:
config: Partial config to update current
config with.
"""
assert self._tempdir, "Call setup() first"
config = config or {}
if config:
self._partial_config = config
if not config.get("provider", {}).get("image"):
# No image specified, trying to parse from buildkite
docker_image = os.environ.get("RAY_DOCKER_IMAGE", None)
if not docker_image:
# If still no docker image, use one according to Python version
mj = sys.version_info.major
mi = sys.version_info.minor
docker_image = DEFAULT_DOCKER_IMAGE.format(major=mj, minor=mi)
self._docker_image = docker_image
with open(self._base_config_file, "rt") as f:
cluster_config = yaml.safe_load(f)
if self._partial_config:
deep_update(cluster_config, self._partial_config, new_keys_allowed=True)
if self._docker_image:
cluster_config["provider"]["image"] = self._docker_image
cluster_config["provider"]["shared_volume_dir"] = self._tempdir
self._cluster_config = cluster_config
with open(self._config_file, "wt") as f:
yaml.safe_dump(self._cluster_config, f)
logging.info(f"Updated cluster config to: {self._cluster_config}")
def maybe_pull_image(self):
if self._docker_image:
try:
images_str = subprocess.check_output(
f"docker image inspect {self._docker_image}", shell=True
)
images = json.loads(images_str)
except Exception as e:
logger.error(f"Error inspecting image {self._docker_image}: {e}")
return
if not images:
try:
subprocess.check_call(
f"docker pull {self._docker_image}", shell=True
)
except Exception as e:
logger.error(f"Error pulling image {self._docker_image}: {e}")
[docs]
def setup(self):
"""Setup docker compose cluster environment.
Creates the temporary directory, writes the initial config file,
and pulls the docker image, if required.
"""
self._tempdir = tempfile.mkdtemp(dir=os.environ.get("RAY_TEMPDIR", None))
os.chmod(self._tempdir, 0o777)
self._config_file = os.path.join(self._tempdir, "cluster.yaml")
self._nodes_file = os.path.join(self._tempdir, "nodes.json")
self._status_file = os.path.join(self._tempdir, "status.json")
self.update_config()
self.maybe_pull_image()
[docs]
def teardown(self, keep_dir: bool = False):
"""Tear down docker compose cluster environment.
Args:
keep_dir: If True, cluster directory
will not be removed after termination.
"""
if not keep_dir:
shutil.rmtree(self._tempdir)
self._tempdir = None
self._config_file = None
def _start_monitor(self):
self._monitor_process = subprocess.Popen(
[sys.executable, self._monitor_script, self.config_file]
)
time.sleep(2)
def _stop_monitor(self):
if self._monitor_process:
self._monitor_process.wait(timeout=30)
if self._monitor_process.poll() is None:
self._monitor_process.terminate()
self._monitor_process = None
[docs]
def start(self):
"""Start docker compose cluster.
Starts the monitor process and runs ``ray up``.
"""
self._start_monitor()
subprocess.check_call(
f"RAY_FAKE_CLUSTER=1 ray up -y {self.config_file}", shell=True
)
[docs]
def stop(self):
"""Stop docker compose cluster.
Runs ``ray down`` and stops the monitor process.
"""
if ray.is_initialized:
ray.shutdown()
subprocess.check_call(
f"RAY_FAKE_CLUSTER=1 ray down -y {self.config_file}", shell=True
)
self._stop_monitor()
self._execution_event.set()
def _update_nodes(self):
with open(self._nodes_file, "rt") as f:
self._nodes = json.load(f)
def _update_status(self):
with open(self._status_file, "rt") as f:
self._status = json.load(f)
def _get_node(
self,
node_id: Optional[str] = None,
num: Optional[int] = None,
rand: Optional[str] = None,
) -> str:
self._update_nodes()
if node_id:
assert (
not num and not rand
), "Only provide either `node_id`, `num`, or `random`."
elif num:
assert (
not node_id and not rand
), "Only provide either `node_id`, `num`, or `random`."
base = "fffffffffffffffffffffffffffffffffffffffffffffffffff"
node_id = base + str(num).zfill(5)
elif rand:
assert (
not node_id and not num
), "Only provide either `node_id`, `num`, or `random`."
assert rand in [
"worker",
"any",
], "`random` must be one of ['worker', 'any']"
choices = list(self._nodes.keys())
if rand == "worker":
choices.remove(
"fffffffffffffffffffffffffffffffffffffffffffffffffff00000"
)
# Else: any
node_id = random.choice(choices)
assert node_id in self._nodes, f"Node with ID {node_id} is not in active nodes."
return node_id
def _get_docker_container(self, node_id: str) -> Optional[str]:
self._update_status()
node_status = self._status.get(node_id)
if not node_status:
return None
return node_status["Name"]
[docs]
def kill_node(
self,
node_id: Optional[str] = None,
num: Optional[int] = None,
rand: Optional[str] = None,
):
"""Kill node.
If ``node_id`` is given, kill that node.
If ``num`` is given, construct node_id from this number, and kill
that node.
If ``rand`` is given (as either ``worker`` or ``any``), kill a random
node.
"""
node_id = self._get_node(node_id=node_id, num=num, rand=rand)
container = self._get_docker_container(node_id=node_id)
subprocess.check_call(f"docker kill {container}", shell=True)
class RemoteAPI:
"""Remote API to control cluster state from within cluster tasks.
This API uses a Ray queue to interact with an execution thread on the
host machine that will execute commands passed to the queue.
Instances of this class can be serialized and passed to Ray remote actors
to interact with cluster state (but they can also be used outside actors).
The API subset is limited to specific commands.
Args:
queue: Ray queue to push command instructions to.
"""
def __init__(self, queue: Queue):
self._queue = queue
def kill_node(
self,
node_id: Optional[str] = None,
num: Optional[int] = None,
rand: Optional[str] = None,
):
self._queue.put(("kill_node", dict(node_id=node_id, num=num, rand=rand)))