Source code for ray.util.rpdb

# Some code in this file is from
# https://github.com/ionelmc/python-remote-pdb/blob/07d563331c4ab9eb45731bb272b158816d98236e/src/remote_pdb.py
# (BSD 2-Clause "Simplified" License)

import errno
import inspect
import json
import logging
import os
import re
import select
import socket
import sys
import time
import traceback
import uuid
from pdb import Pdb
from typing import Callable

import setproctitle

import ray
from ray._private import ray_constants
from ray.experimental.internal_kv import _internal_kv_del, _internal_kv_put
from ray.util.annotations import DeveloperAPI

log = logging.getLogger(__name__)


def _cry(message, stderr=sys.__stderr__):
    print(message, file=stderr)
    stderr.flush()


class _LF2CRLF_FileWrapper(object):
    def __init__(self, connection):
        self.connection = connection
        self.stream = fh = connection.makefile("rw")
        self.read = fh.read
        self.readline = fh.readline
        self.readlines = fh.readlines
        self.close = fh.close
        self.flush = fh.flush
        self.fileno = fh.fileno
        if hasattr(fh, "encoding"):
            self._send = lambda data: connection.sendall(data.encode(fh.encoding))
        else:
            self._send = connection.sendall

    @property
    def encoding(self):
        return self.stream.encoding

    def __iter__(self):
        return self.stream.__iter__()

    def write(self, data, nl_rex=re.compile("\r?\n")):
        data = nl_rex.sub("\r\n", data)
        self._send(data)

    def writelines(self, lines, nl_rex=re.compile("\r?\n")):
        for line in lines:
            self.write(line, nl_rex)


class _PdbWrap(Pdb):
    """Wrap PDB to run a custom exit hook on continue."""

    def __init__(self, exit_hook: Callable[[], None]):
        self._exit_hook = exit_hook
        Pdb.__init__(self)

    def do_continue(self, arg):
        self._exit_hook()
        return Pdb.do_continue(self, arg)

    do_c = do_cont = do_continue


class _RemotePdb(Pdb):
    """
    This will run pdb as a ephemeral telnet service. Once you connect no one
    else can connect. On construction this object will block execution till a
    client has connected.
    Based on https://github.com/tamentis/rpdb I think ...
    To use this::
        RemotePdb(host="0.0.0.0", port=4444).set_trace()
    Then run: telnet 127.0.0.1 4444
    """

    active_instance = None

    def __init__(
        self,
        breakpoint_uuid,
        host,
        port,
        ip_address,
        patch_stdstreams=False,
        quiet=False,
    ):
        self._breakpoint_uuid = breakpoint_uuid
        self._quiet = quiet
        self._patch_stdstreams = patch_stdstreams
        self._listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
        self._listen_socket.bind((host, port))
        self._ip_address = ip_address

    def listen(self):
        if not self._quiet:
            _cry(
                "RemotePdb session open at %s:%s, "
                "use 'ray debug' to connect..."
                % (self._ip_address, self._listen_socket.getsockname()[1])
            )
        self._listen_socket.listen(1)
        connection, address = self._listen_socket.accept()
        if not self._quiet:
            _cry("RemotePdb accepted connection from %s." % repr(address))
        self.handle = _LF2CRLF_FileWrapper(connection)
        Pdb.__init__(
            self,
            completekey="tab",
            stdin=self.handle,
            stdout=self.handle,
            skip=["ray.*"],
        )
        self.backup = []
        if self._patch_stdstreams:
            for name in (
                "stderr",
                "stdout",
                "__stderr__",
                "__stdout__",
                "stdin",
                "__stdin__",
            ):
                self.backup.append((name, getattr(sys, name)))
                setattr(sys, name, self.handle)
        _RemotePdb.active_instance = self

    def __restore(self):
        if self.backup and not self._quiet:
            _cry("Restoring streams: %s ..." % self.backup)
        for name, fh in self.backup:
            setattr(sys, name, fh)
        self.handle.close()
        _RemotePdb.active_instance = None

    def do_quit(self, arg):
        self.__restore()
        return Pdb.do_quit(self, arg)

    do_q = do_exit = do_quit

    def do_continue(self, arg):
        self.__restore()
        self.handle.connection.close()
        return Pdb.do_continue(self, arg)

    do_c = do_cont = do_continue

    def set_trace(self, frame=None):
        if frame is None:
            frame = sys._getframe().f_back
        try:
            Pdb.set_trace(self, frame)
        except IOError as exc:
            if exc.errno != errno.ECONNRESET:
                raise

    def post_mortem(self, traceback=None):
        # See https://github.com/python/cpython/blob/
        # 022bc7572f061e1d1132a4db9d085b29707701e7/Lib/pdb.py#L1617
        try:
            t = sys.exc_info()[2]
            self.reset()
            Pdb.interaction(self, None, t)
        except IOError as exc:
            if exc.errno != errno.ECONNRESET:
                raise

    def do_remote(self, arg):
        """remote
        Skip into the next remote call.
        """
        # Tell the next task to drop into the debugger.
        ray._private.worker.global_worker.debugger_breakpoint = self._breakpoint_uuid
        # Tell the debug loop to connect to the next task.
        data = json.dumps(
            {
                "job_id": ray.get_runtime_context().get_job_id(),
            }
        )
        _internal_kv_put(
            "RAY_PDB_CONTINUE_{}".format(self._breakpoint_uuid),
            data,
            namespace=ray_constants.KV_NAMESPACE_PDB,
        )
        self.__restore()
        self.handle.connection.close()
        return Pdb.do_continue(self, arg)

    def do_get(self, arg):
        """get
        Skip to where the current task returns to.
        """
        ray._private.worker.global_worker.debugger_get_breakpoint = (
            self._breakpoint_uuid
        )
        self.__restore()
        self.handle.connection.close()
        return Pdb.do_continue(self, arg)


def _connect_ray_pdb(
    host=None,
    port=None,
    patch_stdstreams=False,
    quiet=None,
    breakpoint_uuid=None,
    debugger_external=False,
):
    """
    Opens a remote PDB on first available port.
    """
    if debugger_external:
        assert not host, "Cannot specify both host and debugger_external"
        host = "0.0.0.0"
    elif host is None:
        host = os.environ.get("REMOTE_PDB_HOST", "127.0.0.1")
    if port is None:
        port = int(os.environ.get("REMOTE_PDB_PORT", "0"))
    if quiet is None:
        quiet = bool(os.environ.get("REMOTE_PDB_QUIET", ""))
    if not breakpoint_uuid:
        breakpoint_uuid = uuid.uuid4().hex
    if debugger_external:
        ip_address = ray._private.worker.global_worker.node_ip_address
    else:
        ip_address = "localhost"
    rdb = _RemotePdb(
        breakpoint_uuid=breakpoint_uuid,
        host=host,
        port=port,
        ip_address=ip_address,
        patch_stdstreams=patch_stdstreams,
        quiet=quiet,
    )
    sockname = rdb._listen_socket.getsockname()
    pdb_address = "{}:{}".format(ip_address, sockname[1])
    parentframeinfo = inspect.getouterframes(inspect.currentframe())[2]
    data = {
        "proctitle": setproctitle.getproctitle(),
        "pdb_address": pdb_address,
        "filename": parentframeinfo.filename,
        "lineno": parentframeinfo.lineno,
        "traceback": "\n".join(traceback.format_exception(*sys.exc_info())),
        "timestamp": time.time(),
        "job_id": ray.get_runtime_context().get_job_id(),
        "node_id": ray.get_runtime_context().get_node_id(),
        "worker_id": ray.get_runtime_context().get_worker_id(),
        "actor_id": ray.get_runtime_context().get_actor_id(),
        "task_id": ray.get_runtime_context().get_task_id(),
    }
    _internal_kv_put(
        "RAY_PDB_{}".format(breakpoint_uuid),
        json.dumps(data),
        overwrite=True,
        namespace=ray_constants.KV_NAMESPACE_PDB,
    )
    rdb.listen()
    _internal_kv_del(
        "RAY_PDB_{}".format(breakpoint_uuid), namespace=ray_constants.KV_NAMESPACE_PDB
    )

    return rdb


[docs] @DeveloperAPI def set_trace(breakpoint_uuid=None): """Interrupt the flow of the program and drop into the Ray debugger. Can be used within a Ray task or actor. """ if os.environ.get("RAY_DEBUG", "1") == "1": return ray.util.ray_debugpy.set_trace(breakpoint_uuid) if os.environ.get("RAY_DEBUG", "1") == "legacy": # If there is an active debugger already, we do not want to # start another one, so "set_trace" is just a no-op in that case. if ray._private.worker.global_worker.debugger_breakpoint == b"": frame = sys._getframe().f_back rdb = _connect_ray_pdb( host=None, port=None, patch_stdstreams=False, quiet=None, breakpoint_uuid=breakpoint_uuid.decode() if breakpoint_uuid else None, debugger_external=ray._private.worker.global_worker.ray_debugger_external, # noqa: E501 ) rdb.set_trace(frame=frame)
def _driver_set_trace(): """The breakpoint hook to use for the driver. This disables Ray driver logs temporarily so that the PDB console is not spammed: https://github.com/ray-project/ray/issues/18172 """ if os.environ.get("RAY_DEBUG", "1") == "1": return ray.util.ray_debugpy.set_trace() if os.environ.get("RAY_DEBUG", "1") == "legacy": print("*** Temporarily disabling Ray worker logs ***") ray._private.worker._worker_logs_enabled = False def enable_logging(): print("*** Re-enabling Ray worker logs ***") ray._private.worker._worker_logs_enabled = True pdb = _PdbWrap(enable_logging) frame = sys._getframe().f_back pdb.set_trace(frame) def _is_ray_debugger_post_mortem_enabled(): return os.environ.get("RAY_DEBUG_POST_MORTEM", "0") == "1" def _post_mortem(): if os.environ.get("RAY_DEBUG", "1") == "1": return ray.util.ray_debugpy._post_mortem() rdb = _connect_ray_pdb( host=None, port=None, patch_stdstreams=False, quiet=None, debugger_external=ray._private.worker.global_worker.ray_debugger_external, ) rdb.post_mortem() def _connect_pdb_client(host, port): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect((host, port)) while True: # Get the list of sockets which are readable. read_sockets, write_sockets, error_sockets = select.select( [sys.stdin, s], [], [] ) for sock in read_sockets: if sock == s: # Incoming message from remote debugger. data = sock.recv(4096) if not data: return else: sys.stdout.write(data.decode()) sys.stdout.flush() else: # User entered a message. msg = sys.stdin.readline() s.send(msg.encode())