Source code for ray.rllib.env.external.rllink
from enum import Enum
from packaging.version import Version
from ray.rllib.utils.checkpoints import try_import_msgpack
from ray.util.annotations import DeveloperAPI
msgpack = None
[docs]
@DeveloperAPI
class RLlink(Enum):
PROTOCOL_VERSION = Version("0.0.1")
# Requests: Client (external env) -> Server (RLlib).
# ----
# Ping command (initial handshake).
PING = "PING"
# List of episodes (similar to what an EnvRunner.sample() call would return).
EPISODES = "EPISODES"
# Request state (e.g. model weights).
GET_STATE = "GET_STATE"
# Request Algorithm config.
GET_CONFIG = "GET_CONFIG"
# Send episodes and request the next state update right after that.
# Clients sending this message should wait for a SET_STATE message as an immediate
# response. Useful for external samplers that must collect on-policy data.
EPISODES_AND_GET_STATE = "EPISODES_AND_GET_STATE"
# Responses: Server (RLlib) -> Client (external env).
# ----
# Pong response (initial handshake).
PONG = "PONG"
# Set state (e.g. model weights).
SET_STATE = "SET_STATE"
# Set Algorithm config.
SET_CONFIG = "SET_CONFIG"
# @OldAPIStack (to be deprecated soon).
ACTION_SPACE = "ACTION_SPACE"
OBSERVATION_SPACE = "OBSERVATION_SPACE"
GET_WORKER_ARGS = "GET_WORKER_ARGS"
GET_WEIGHTS = "GET_WEIGHTS"
REPORT_SAMPLES = "REPORT_SAMPLES"
START_EPISODE = "START_EPISODE"
GET_ACTION = "GET_ACTION"
LOG_ACTION = "LOG_ACTION"
LOG_RETURNS = "LOG_RETURNS"
END_EPISODE = "END_EPISODE"
def __str__(self):
return self.name
[docs]
@DeveloperAPI
def send_rllink_message(sock_, message: dict):
"""Sends a message to the client with a length header."""
global msgpack
if msgpack is None:
msgpack = try_import_msgpack(error=True)
body = msgpack.packb(message, use_bin_type=True) # .encode("utf-8")
header = str(len(body)).zfill(8).encode("utf-8")
try:
sock_.sendall(header + body)
except Exception as e:
raise ConnectionError(
f"Error sending message {message} to server on socket {sock_}! "
f"Original error was: {e}"
)
[docs]
@DeveloperAPI
def get_rllink_message(sock_):
"""Receives a message from the client following the length-header protocol."""
global msgpack
if msgpack is None:
msgpack = try_import_msgpack(error=True)
try:
# Read the length header (8 bytes)
header = _get_num_bytes(sock_, 8)
msg_length = int(header.decode("utf-8"))
# Read the message body
body = _get_num_bytes(sock_, msg_length)
# Decode JSON.
message = msgpack.unpackb(body, raw=False) # .loads(body.decode("utf-8"))
# Check for proper protocol.
if "type" not in message:
raise ConnectionError(
"Protocol Error! Message from peer does not contain `type` field."
)
return RLlink(message.pop("type")), message
except Exception as e:
raise ConnectionError(
f"Error receiving message from peer on socket {sock_}! "
f"Original error was: {e}"
)
def _get_num_bytes(sock_, num_bytes):
"""Helper function to receive a specific number of bytes."""
data = b""
while len(data) < num_bytes:
packet = sock_.recv(num_bytes - len(data))
if not packet:
raise ConnectionError(f"No data received from socket {sock_}!")
data += packet
return data