ray.train.torch.prepare_model(model: torch.nn.Module, move_to_device: bool | torch.device = True, parallel_strategy: str | None = 'ddp', parallel_strategy_kwargs: Dict[str, Any] | None = None) torch.nn.Module[source]#

Prepares the model for distributed execution.

This allows you to use the same exact code regardless of number of workers or the device type being used (CPU, GPU).

  • model (torch.nn.Module) – A torch model to prepare.

  • move_to_device – Either a boolean indiciating whether to move the model to the correct device or an actual device to move the model to. If set to False, the model needs to manually be moved to the correct device.

  • parallel_strategy ("ddp", "fsdp", or None) – Whether to wrap models in DistributedDataParallel, FullyShardedDataParallel, or neither.

  • parallel_strategy_kwargs (Dict[str, Any]) – Args to pass into DistributedDataParallel or FullyShardedDataParallel initialization if parallel_strategy is set to “ddp” or “fsdp”, respectively.