ray.train.torch.prepare_model#

ray.train.torch.prepare_model(model: torch.nn.Module, move_to_device: bool | torch.device = True, parallel_strategy: str | None = 'ddp', parallel_strategy_kwargs: Dict[str, Any] | None = None) torch.nn.Module[source]#

Prepares the model for distributed execution.

This allows you to use the same exact code regardless of number of workers or the device type being used (CPU, GPU).

Parameters:
  • model – A torch model to prepare.

  • move_to_device – Either a boolean indiciating whether to move the model to the correct device or an actual device to move the model to. If set to False, the model needs to manually be moved to the correct device.

  • parallel_strategy – Whether to wrap models in DistributedDataParallel, FullyShardedDataParallel, or neither. Must be one of "ddp", "fsdp", or None.

  • parallel_strategy_kwargs – Args to pass into DistributedDataParallel or FullyShardedDataParallel initialization if parallel_strategy is set to “ddp” or “fsdp”, respectively.

Returns:

The prepared model, wrapped according to parallel_strategy.