Source code for ray.train.tensorflow.config
import json
import logging
import os
from dataclasses import dataclass
from typing import List
import ray
from ray.train._internal.utils import get_address_and_port
from ray.train._internal.worker_group import WorkerGroup
from ray.train.backend import Backend, BackendConfig
from ray.util import PublicAPI
logger = logging.getLogger(__name__)
[docs]
@PublicAPI(stability="beta")
@dataclass
class TensorflowConfig(BackendConfig):
@property
def backend_cls(self):
return _TensorflowBackend
def _setup_tensorflow_environment(worker_addresses: List[str], index: int):
"""Set up distributed Tensorflow training information.
This function should be called on each worker.
Args:
worker_addresses: Addresses of all the workers.
index: Index (i.e. world rank) of the current worker.
"""
tf_config = {
"cluster": {"worker": worker_addresses},
"task": {"type": "worker", "index": index},
}
os.environ["TF_CONFIG"] = json.dumps(tf_config)
class _TensorflowBackend(Backend):
def on_start(self, worker_group: WorkerGroup, backend_config: TensorflowConfig):
# Compute URL for initializing distributed setup.
def get_url():
address, port = get_address_and_port()
return f"{address}:{port}"
urls = worker_group.execute(get_url)
# Get setup tasks in order to throw errors on failure.
setup_futures = []
for i in range(len(worker_group)):
setup_futures.append(
worker_group.execute_single_async(
i,
_setup_tensorflow_environment,
worker_addresses=urls,
index=i,
)
)
ray.get(setup_futures)