ray.util.tpu.init_jax_profiler#

ray.util.tpu.init_jax_profiler(port: int | None = None) None[source]#

Setup JAX Profiler server for in-process JAX profiling.

This opens a background gRPC profiling port inside the current worker process and automatically registers the port to GCS internal_kv so that the Ray Dashboard can discover the profiling endpoint.

Parameters:

port – The port where JAX profiler server should listen. If None, it reads the port from JAX_PROFILER_PORT environment variable (default: 9999).

Note

JAX profiling is inherently an in-process operation. The JAX profiler server must run inside the memory space of the target worker process executing the JAX/XLA code in order to capture trace events, Python thread stacks, and XLA execution times.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.