ray.train.torch.prepare_model(model: torch.nn.modules.module.Module, move_to_device: bool = True, parallel_strategy: Optional[str] = 'ddp', parallel_strategy_kwargs: Optional[Dict[str, Any]] = None) torch.nn.modules.module.Module[source]#

Prepares the model for distributed execution.

This allows you to use the same exact code regardless of number of workers or the device type being used (CPU, GPU).

  • model (torch.nn.Module) – A torch model to prepare.

  • move_to_device – Whether to move the model to the correct device. If set to False, the model needs to manually be moved to the correct device.

  • parallel_strategy ("ddp", "fsdp", or None) – Whether to wrap models in DistributedDataParallel, FullyShardedDataParallel, or neither.

  • parallel_strategy_kwargs (Dict[str, Any]) – Args to pass into DistributedDataParallel or FullyShardedDataParallel initialization if parallel_strategy is set to “ddp” or “fsdp”, respectively.

PublicAPI (beta): This API is in beta and may change before becoming stable.