ray.train.torch.TorchCheckpoint
ray.train.torch.TorchCheckpoint#
- class ray.train.torch.TorchCheckpoint(local_path: Optional[Union[str, os.PathLike]] = None, data_dict: Optional[dict] = None, uri: Optional[str] = None)[source]#
Bases:
ray.air.checkpoint.Checkpoint
A
Checkpoint
with Torch-specific functionality.Create this from a generic
Checkpoint
by callingTorchCheckpoint.from_checkpoint(ckpt)
.PublicAPI (beta): This API is in beta and may change before becoming stable.
- classmethod from_state_dict(state_dict: Dict[str, Any], *, preprocessor: Optional[Preprocessor] = None) TorchCheckpoint [source]#
Create a
Checkpoint
that stores a model state dictionary.Tip
This is the recommended method for creating
TorchCheckpoints
.- Parameters
state_dict – The model state dictionary to store in the checkpoint.
preprocessor – A fitted preprocessor to be applied before inference.
- Returns
A
TorchCheckpoint
containing the specified state dictionary.
Examples
import torch import torch.nn as nn from ray.train.torch import TorchCheckpoint # Set manual seed torch.manual_seed(42) # Function to create a NN model def create_model() -> nn.Module: model = nn.Sequential(nn.Linear(1, 10), nn.ReLU(), nn.Linear(10,1)) return model # Create a TorchCheckpoint from our model's state_dict model = create_model() checkpoint = TorchCheckpoint.from_state_dict(model.state_dict()) # Now load the model from the TorchCheckpoint by providing the # model architecture model_from_chkpt = checkpoint.get_model(create_model()) # Assert they have the same state dict assert str(model.state_dict()) == str(model_from_chkpt.state_dict()) print("worked")
- classmethod from_model(model: torch.nn.modules.module.Module, *, preprocessor: Optional[Preprocessor] = None) TorchCheckpoint [source]#
Create a
Checkpoint
that stores a Torch model.Note
PyTorch recommends storing state dictionaries. To create a
TorchCheckpoint
from a state dictionary, callfrom_state_dict()
. To learn more about state dictionaries, read Saving and Loading Models. # noqa: E501- Parameters
model – The Torch model to store in the checkpoint.
preprocessor – A fitted preprocessor to be applied before inference.
- Returns
A
TorchCheckpoint
containing the specified model.
Examples
from ray.train.torch import TorchCheckpoint from ray.train.torch import TorchPredictor import torch # Set manual seed torch.manual_seed(42) # Create model identity and send a random tensor to it model = torch.nn.Identity() input = torch.randn(2, 2) output = model(input) # Create a checkpoint checkpoint = TorchCheckpoint.from_model(model) # You can use a class TorchCheckpoint to create an # a class ray.train.torch.TorchPredictor and perform inference. predictor = TorchPredictor.from_checkpoint(checkpoint) pred = predictor.predict(input.numpy()) # Convert prediction dictionary value into a tensor pred = torch.tensor(pred['predictions']) # Assert the output from the original and checkoint model are the same assert torch.equal(output, pred) print("worked")
- get_model(model: Optional[torch.nn.modules.module.Module] = None) torch.nn.modules.module.Module [source]#
Retrieve the model stored in this checkpoint.
- Parameters
model – If the checkpoint contains a model state dict, and not the model itself, then the state dict will be loaded to this
model
. Otherwise, the model will be discarded.