import asyncio
from functools import singledispatch
import importlib
from itertools import groupby
import json
import logging
import random
import string
import time
from typing import Iterable, List, Dict, Tuple
import os
from ray.serve.exceptions import RayServeException
from collections import UserDict
from pathlib import Path
import starlette.requests
import requests
import numpy as np
import pydantic
import ray
from ray.serve.constants import HTTP_PROXY_TIMEOUT
from ray.ray_constants import MEMORY_RESOURCE_UNIT_BYTES
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
class ServeMultiDict(UserDict):
"""Compatible data structure to simulate Starlette Request query_args."""
def getlist(self, key):
"""Return the list of items for a given key."""
return self.data.get(key, [])
[docs]class ServeRequest:
"""The request object used when passing arguments via ServeHandle.
ServeRequest partially implements the API of Starlette Request. You only
need to write your model serving code once; it can be queried by both HTTP
and Python.
To use the full Starlette Request interface with ServeHandle, you may
instead directly pass in a Starlette Request object to the ServeHandle.
"""
def __init__(self, data, kwargs, headers, method):
self._data = data
self._kwargs = ServeMultiDict(kwargs)
self._headers = headers
self._method = method
@property
def headers(self):
"""The HTTP headers from ``handle.option(http_headers=...)``."""
return self._headers
@property
def method(self):
"""The HTTP method data from ``handle.option(http_method=...)``."""
return self._method
@property
def query_params(self):
"""The keyword arguments from ``handle.remote(**kwargs)``."""
return self._kwargs
[docs] async def json(self):
"""The request dictionary, from ``handle.remote(dict)``."""
if not isinstance(self._data, dict):
raise RayServeException("Request data is not a dictionary. "
f"It is {type(self._data)}.")
return self._data
[docs] async def body(self):
"""The request data from ``handle.remote(obj)``."""
return self._data
def parse_request_item(request_item):
arg = request_item.args[0] if len(request_item.args) == 1 else None
# If the input data from handle is web request, we don't need to wrap
# it in ServeRequest.
if isinstance(arg, starlette.requests.Request):
return arg
return ServeRequest(
arg,
request_item.kwargs,
headers=request_item.metadata.http_headers,
method=request_item.metadata.http_method,
)
def _get_logger():
logger = logging.getLogger("ray.serve")
# TODO(simon): Make logging level configurable.
log_level = os.environ.get("SERVE_LOG_DEBUG")
if log_level and int(log_level):
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.INFO)
return logger
logger = _get_logger()
class ServeEncoder(json.JSONEncoder):
"""Ray.Serve's utility JSON encoder. Adds support for:
- bytes
- Pydantic types
- Exceptions
- numpy.ndarray
"""
def default(self, o): # pylint: disable=E0202
if isinstance(o, bytes):
return o.decode("utf-8")
if isinstance(o, pydantic.BaseModel):
return o.dict()
if isinstance(o, Exception):
return str(o)
if isinstance(o, np.ndarray):
if o.dtype.kind == "f": # floats
o = o.astype(float)
if o.dtype.kind in {"i", "u"}: # signed and unsigned integers.
o = o.astype(int)
return o.tolist()
return super().default(o)
@ray.remote(num_cpus=0)
def block_until_http_ready(http_endpoint,
backoff_time_s=1,
check_ready=None,
timeout=HTTP_PROXY_TIMEOUT):
http_is_ready = False
start_time = time.time()
while not http_is_ready:
try:
resp = requests.get(http_endpoint)
assert resp.status_code == 200
if check_ready is None:
http_is_ready = True
else:
http_is_ready = check_ready(resp)
except Exception:
pass
if 0 < timeout < time.time() - start_time:
raise TimeoutError(
"HTTP proxy not ready after {} seconds.".format(timeout))
time.sleep(backoff_time_s)
def get_random_letters(length=6):
return "".join(random.choices(string.ascii_letters, k=length))
def format_actor_name(actor_name, controller_name=None, *modifiers):
if controller_name is None:
name = actor_name
else:
name = "{}:{}".format(controller_name, actor_name)
for modifier in modifiers:
name += "-{}".format(modifier)
return name
def get_conda_env_dir(env_name):
"""Given a environment name like `tf1`, find and validate the
corresponding conda directory. Untested on Windows.
"""
conda_prefix = os.environ.get("CONDA_PREFIX")
if conda_prefix is None:
# The caller is neither in a conda env or in (base). This is rare
# because by default, new terminals start in (base), but we can still
# support this case.
conda_exe = os.environ.get("CONDA_EXE")
if conda_exe is None:
raise RayServeException(
"Ray Serve cannot find environment variables set by conda. "
"Please verify conda is installed.")
# Example: CONDA_EXE=$HOME/anaconda3/bin/python
# Strip out the /bin/python by going up two parent directories.
conda_prefix = str(Path(conda_exe).parent.parent)
# There are two cases:
# 1. We are in conda base env: CONDA_DEFAULT_ENV=base and
# CONDA_PREFIX=$HOME/anaconda3
# 2. We are in user created conda env: CONDA_DEFAULT_ENV=$env_name and
# CONDA_PREFIX=$HOME/anaconda3/envs/$env_name
if os.environ.get("CONDA_DEFAULT_ENV") == "base":
# Caller is running in base conda env.
# Not recommended by conda, but we can still try to support it.
env_dir = os.path.join(conda_prefix, "envs", env_name)
else:
# Now `conda_prefix` should be something like
# $HOME/anaconda3/envs/$env_name
# We want to strip the $env_name component.
conda_envs_dir = os.path.split(conda_prefix)[0]
env_dir = os.path.join(conda_envs_dir, env_name)
if not os.path.isdir(env_dir):
raise ValueError(
"conda env " + env_name +
" not found in conda envs directory. Run `conda env list` to " +
"verify the name is correct.")
return env_dir
@singledispatch
def chain_future(src, dst):
"""Base method for chaining futures together.
Chaining futures means the output from source future(s) are written as the
results of the destination future(s). This method can work with the
following inputs:
- src: Future, dst: Future
- src: List[Future], dst: List[Future]
"""
raise NotImplementedError()
@chain_future.register(asyncio.Future)
def _chain_future_single(src: asyncio.Future, dst: asyncio.Future):
asyncio.futures._chain_future(src, dst)
@chain_future.register(list)
def _chain_future_list(src: List[asyncio.Future], dst: List[asyncio.Future]):
if len(src) != len(dst):
raise ValueError(
"Source and destination list doesn't have the same length. "
"Source: {}. Destination: {}.".foramt(len(src), len(dst)))
for s, d in zip(src, dst):
chain_future(s, d)
def unpack_future(src: asyncio.Future, num_items: int) -> List[asyncio.Future]:
"""Unpack the result of source future to num_items futures.
This function takes in a Future and splits its result into many futures. If
the result of the source future is an exception, then all destination
futures will have the same exception.
"""
dest_futures = [
asyncio.get_event_loop().create_future() for _ in range(num_items)
]
def unwrap_callback(fut: asyncio.Future):
exception = fut.exception()
if exception is not None:
[f.set_exception(exception) for f in dest_futures]
return
result = fut.result()
assert len(result) == num_items
for item, future in zip(result, dest_futures):
future.set_result(item)
src.add_done_callback(unwrap_callback)
return dest_futures
def try_schedule_resources_on_nodes(
requirements: List[dict],
ray_resource: Dict[str, Dict] = None,
) -> List[bool]:
"""Test given resource requirements can be scheduled on ray nodes.
Args:
requirements(List[dict]): The list of resource requirements.
ray_nodes(Optional[Dict[str, Dict]]): The resource dictionary keyed by
node id. By default it reads from
``ray.state.state._available_resources_per_node()``.
Returns:
successfully_scheduled(List[bool]): A list with the same length as
requirements. Each element indicates whether or not the requirement
can be satisied.
"""
if ray_resource is None:
ray_resource = ray.state.state._available_resources_per_node()
successfully_scheduled = []
for resource_dict in requirements:
# Filter out zero value
resource_dict = {k: v for k, v in resource_dict.items() if v > 0}
for node_id, node_resource in ray_resource.items():
# Check if we can schedule on this node
feasible = True
for key, count in resource_dict.items():
# Fix legacy behaviour in all memory objects
if "memory" in key:
memory_resource = node_resource.get(key, 0)
if memory_resource > 0:
# Convert from chunks to bytes
memory_resource *= MEMORY_RESOURCE_UNIT_BYTES
if memory_resource - count < 0:
feasible = False
elif node_resource.get(key, 0) - count < 0:
feasible = False
# If we can, schedule it on this node
if feasible:
for key, count in resource_dict.items():
node_resource[key] -= count
successfully_scheduled.append(True)
break
else:
successfully_scheduled.append(False)
return successfully_scheduled
def get_all_node_ids():
"""Get IDs for all nodes in the cluster.
Handles multiple nodes on the same IP by appending an index to the
node_id, e.g., 'node_id-index'.
Returns a list of ('node_id-index', 'node_id') tuples (the latter can be
used as a resource requirement for actor placements).
"""
node_ids = []
# We need to use the node_id and index here because we could
# have multiple virtual nodes on the same host. In that case
# they will have the same IP and therefore node_id.
for _, node_id_group in groupby(sorted(ray.state.node_ids())):
for index, node_id in enumerate(node_id_group):
node_ids.append(("{}-{}".format(node_id, index), node_id))
return node_ids
def get_node_id_for_actor(actor_handle):
"""Given an actor handle, return the node id it's placed on."""
return ray.actors()[actor_handle._actor_id.hex()]["Address"]["NodeID"]
def import_attr(full_path: str):
"""Given a full import path to a module attr, return the imported attr.
For example, the following are equivalent:
MyClass = import_attr("module.submodule.MyClass")
from module.submodule import MyClass
Returns:
Imported attr
"""
last_period_idx = full_path.rfind(".")
attr_name = full_path[last_period_idx + 1:]
module_name = full_path[:last_period_idx]
module = importlib.import_module(module_name)
return getattr(module, attr_name)
async def mock_imported_function(batch):
result = []
for request in batch:
result.append(await request.body())
return result
class MockImportedBackend:
"""Used for testing backends.ImportedBackend.
This is necessary because we need the class to be installed in the worker
processes. We could instead mock out importlib but doing so is messier and
reduces confidence in the test (it isn't truly end-to-end).
"""
def __init__(self, arg):
self.arg = arg
self.config = None
def reconfigure(self, config):
self.config = config
def __call__(self, batch):
return [{
"arg": self.arg,
"config": self.config
} for _ in range(len(batch))]
async def other_method(self, batch):
responses = []
for request in batch:
responses.append(await request.body())
return responses
def compute_iterable_delta(old: Iterable,
new: Iterable) -> Tuple[set, set, set]:
"""Given two iterables, return the entries that's (added, removed, updated).
Usage:
>>> old = {"a", "b"}
>>> new = {"a", "d"}
>>> compute_iterable_delta(old, new)
({"d"}, {"b"}, {"a"})
"""
old_keys, new_keys = set(old), set(new)
added_keys = new_keys - old_keys
removed_keys = old_keys - new_keys
updated_keys = old_keys.intersection(new_keys)
return added_keys, removed_keys, updated_keys
def compute_dict_delta(old_dict, new_dict) -> Tuple[dict, dict, dict]:
"""Given two dicts, return the entries that's (added, removed, updated).
Usage:
>>> old = {"a": 1, "b": 2}
>>> new = {"a": 3, "d": 4}
>>> compute_dict_delta(old, new)
({"d": 4}, {"b": 2}, {"a": 3})
"""
added_keys, removed_keys, updated_keys = compute_iterable_delta(
old_dict.keys(), new_dict.keys())
return (
{k: new_dict[k]
for k in added_keys},
{k: old_dict[k]
for k in removed_keys},
{k: new_dict[k]
for k in updated_keys},
)
def get_current_node_resource_key() -> str:
"""Get the Ray resource key for current node.
It can be used for actor placement.
"""
current_node_id = ray.get_runtime_context().node_id.hex()
for node in ray.nodes():
if node["NodeID"] == current_node_id:
# Found the node.
for key in node["Resources"].keys():
if key.startswith("node:"):
return key
else:
raise ValueError("Cannot found the node dictionary for current node.")