First, update your training code to support distributed training. Begin by wrapping your code in a training function:
def train_func():
# Your model training code here.
...
Each distributed training worker executes this function.
You can also specify the input argument for train_func as a dictionary via the Trainer’s train_loop_config. For example:
def train_func(config):
lr = config["lr"]
num_epochs = config["num_epochs"]
config = {"lr": 1e-4, "num_epochs": 10}
trainer = ray.train.torch.TorchTrainer(train_func, train_loop_config=config, ...)
Warning
Avoid passing large data objects through train_loop_config to reduce the
serialization and deserialization overhead. Instead, it’s preferred to
initialize large objects (e.g. datasets, models) directly in train_func.
def load_dataset():
# Return a large in-memory dataset
...
def load_model():
# Return a large in-memory model instance
...
-config = {"data": load_dataset(), "model": load_model()}
def train_func(config):
- data = config["data"]
- model = config["model"]
+ data = load_dataset()
+ model = load_model()
...
trainer = ray.train.torch.TorchTrainer(train_func, train_loop_config=config, ...)