"""A utility for debugging serialization issues."""
import inspect
from contextlib import contextmanager
from typing import Any, Optional, Set, Tuple
# Import ray first to use the bundled colorama
import ray # noqa: F401
import colorama
import ray.cloudpickle as cp
from ray.util.annotations import DeveloperAPI
@contextmanager
def _indent(printer):
printer.level += 1
yield
printer.level -= 1
class _Printer:
def __init__(self, print_file):
self.level = 0
self.print_file = print_file
def indent(self):
return _indent(self)
def print(self, msg):
indent = " " * self.level
print(indent + msg, file=self.print_file)
@DeveloperAPI
class FailureTuple:
"""Represents the serialization 'frame'.
Attributes:
obj: The object that fails serialization.
name: The variable name of the object.
parent: The object that references the `obj`.
"""
def __init__(self, obj: Any, name: str, parent: Any):
self.obj = obj
self.name = name
self.parent = parent
def __repr__(self):
return f"FailTuple({self.name} [obj={self.obj}, parent={self.parent}])"
def _inspect_func_serialization(base_obj, depth, parent, failure_set, printer):
"""Adds the first-found non-serializable element to the failure_set."""
assert inspect.isfunction(base_obj)
closure = inspect.getclosurevars(base_obj)
found = False
if closure.globals:
printer.print(
f"Detected {len(closure.globals)} global variables. "
"Checking serializability..."
)
with printer.indent():
for name, obj in closure.globals.items():
serializable, _ = _inspect_serializability(
obj,
name=name,
depth=depth - 1,
parent=parent,
failure_set=failure_set,
printer=printer,
)
found = found or not serializable
if found:
break
if closure.nonlocals:
printer.print(
f"Detected {len(closure.nonlocals)} nonlocal variables. "
"Checking serializability..."
)
with printer.indent():
for name, obj in closure.nonlocals.items():
serializable, _ = _inspect_serializability(
obj,
name=name,
depth=depth - 1,
parent=parent,
failure_set=failure_set,
printer=printer,
)
found = found or not serializable
if found:
break
if not found:
printer.print(
f"WARNING: Did not find non-serializable object in {base_obj}. "
"This may be an oversight."
)
return found
def _inspect_generic_serialization(base_obj, depth, parent, failure_set, printer):
"""Adds the first-found non-serializable element to the failure_set."""
assert not inspect.isfunction(base_obj)
functions = inspect.getmembers(base_obj, predicate=inspect.isfunction)
found = False
with printer.indent():
for name, obj in functions:
serializable, _ = _inspect_serializability(
obj,
name=name,
depth=depth - 1,
parent=parent,
failure_set=failure_set,
printer=printer,
)
found = found or not serializable
if found:
break
with printer.indent():
members = inspect.getmembers(base_obj)
for name, obj in members:
if name.startswith("__") and name.endswith("__") or inspect.isbuiltin(obj):
continue
serializable, _ = _inspect_serializability(
obj,
name=name,
depth=depth - 1,
parent=parent,
failure_set=failure_set,
printer=printer,
)
found = found or not serializable
if found:
break
if not found:
printer.print(
f"WARNING: Did not find non-serializable object in {base_obj}. "
"This may be an oversight."
)
return found
[docs]
@DeveloperAPI
def inspect_serializability(
base_obj: Any,
name: Optional[str] = None,
depth: int = 3,
print_file: Optional[Any] = None,
) -> Tuple[bool, Set[FailureTuple]]:
"""Identifies what objects are preventing serialization.
Args:
base_obj: Object to be serialized.
name: Optional name of string.
depth: Depth of the scope stack to walk through. Defaults to 3.
print_file: file argument that will be passed to print().
Returns:
bool: True if serializable.
set[FailureTuple]: Set of unserializable objects.
.. versionadded:: 1.1.0
"""
printer = _Printer(print_file)
return _inspect_serializability(base_obj, name, depth, None, None, printer)
def _inspect_serializability(
base_obj, name, depth, parent, failure_set, printer
) -> Tuple[bool, Set[FailureTuple]]:
colorama.init()
top_level = False
declaration = ""
found = False
if failure_set is None:
top_level = True
failure_set = set()
declaration = f"Checking Serializability of {base_obj}"
printer.print("=" * min(len(declaration), 80))
printer.print(declaration)
printer.print("=" * min(len(declaration), 80))
if name is None:
name = str(base_obj)
else:
printer.print(f"Serializing '{name}' {base_obj}...")
try:
cp.dumps(base_obj)
return True, failure_set
except Exception as e:
printer.print(
f"{colorama.Fore.RED}!!! FAIL{colorama.Fore.RESET} " f"serialization: {e}"
)
found = True
try:
if depth == 0:
failure_set.add(FailureTuple(base_obj, name, parent))
# Some objects may not be hashable, so we skip adding this to the set.
except Exception:
pass
if depth <= 0:
return False, failure_set
# TODO: we only differentiate between 'function' and 'object'
# but we should do a better job of diving into something
# more specific like a Type, Object, etc.
if inspect.isfunction(base_obj):
_inspect_func_serialization(
base_obj,
depth=depth,
parent=base_obj,
failure_set=failure_set,
printer=printer,
)
else:
_inspect_generic_serialization(
base_obj,
depth=depth,
parent=base_obj,
failure_set=failure_set,
printer=printer,
)
if not failure_set:
failure_set.add(FailureTuple(base_obj, name, parent))
if top_level:
printer.print("=" * min(len(declaration), 80))
if not failure_set:
printer.print(
"Nothing failed the inspect_serialization test, though "
"serialization did not succeed."
)
else:
fail_vars = (
f"\n\n\t{colorama.Style.BRIGHT}"
+ "\n".join(str(k) for k in failure_set)
+ f"{colorama.Style.RESET_ALL}\n\n"
)
printer.print(
f"Variable: {fail_vars}was found to be non-serializable. "
"There may be multiple other undetected variables that were "
"non-serializable. "
)
printer.print(
"Consider either removing the "
"instantiation/imports of these variables or moving the "
"instantiation into the scope of the function/class. "
)
printer.print("=" * min(len(declaration), 80))
printer.print(
"Check https://docs.ray.io/en/master/ray-core/objects/serialization.html#troubleshooting for more information." # noqa
)
printer.print(
"If you have any suggestions on how to improve "
"this error message, please reach out to the "
"Ray developers on github.com/ray-project/ray/issues/"
)
printer.print("=" * min(len(declaration), 80))
return not found, failure_set