ray.util.tpu.init_jax_profiler#
- ray.util.tpu.init_jax_profiler(port: int | None = None) None[source]#
Setup JAX Profiler server for in-process JAX profiling.
This opens a background gRPC profiling port inside the current worker process and automatically registers the port to GCS internal_kv so that the Ray Dashboard can discover the profiling endpoint.
- Parameters:
port – The port where JAX profiler server should listen. If None, it reads the port from JAX_PROFILER_PORT environment variable (default: 9999).
Note
JAX profiling is inherently an in-process operation. The JAX profiler server must run inside the memory space of the target worker process executing the JAX/XLA code in order to capture trace events, Python thread stacks, and XLA execution times.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.