ray.train.torch.prepare_model
ray.train.torch.prepare_model#
- ray.train.torch.prepare_model(model: torch.nn.modules.module.Module, move_to_device: bool = True, parallel_strategy: Optional[str] = 'ddp', parallel_strategy_kwargs: Optional[Dict[str, Any]] = None) torch.nn.modules.module.Module [source]#
Prepares the model for distributed execution.
This allows you to use the same exact code regardless of number of workers or the device type being used (CPU, GPU).
- Parameters
model (torch.nn.Module) – A torch model to prepare.
move_to_device – Whether to move the model to the correct device. If set to False, the model needs to manually be moved to the correct device.
parallel_strategy ("ddp", "fsdp", or None) – Whether to wrap models in
DistributedDataParallel
,FullyShardedDataParallel
, or neither.parallel_strategy_kwargs (Dict[str, Any]) – Args to pass into
DistributedDataParallel
orFullyShardedDataParallel
initialization ifparallel_strategy
is set to “ddp” or “fsdp”, respectively.
PublicAPI (beta): This API is in beta and may change before becoming stable.