ray.train.torch.TorchCheckpoint#

class ray.train.torch.TorchCheckpoint(local_path: Optional[Union[str, os.PathLike]] = None, data_dict: Optional[dict] = None, uri: Optional[str] = None)[source]#

Bases: ray.air.checkpoint.Checkpoint

A Checkpoint with Torch-specific functionality.

Create this from a generic Checkpoint by calling TorchCheckpoint.from_checkpoint(ckpt).

PublicAPI (beta): This API is in beta and may change before becoming stable.

classmethod from_state_dict(state_dict: Dict[str, Any], *, preprocessor: Optional[Preprocessor] = None) TorchCheckpoint[source]#

Create a Checkpoint that stores a model state dictionary.

Tip

This is the recommended method for creating TorchCheckpoints.

Parameters
  • state_dict – The model state dictionary to store in the checkpoint.

  • preprocessor – A fitted preprocessor to be applied before inference.

Returns

A TorchCheckpoint containing the specified state dictionary.

Examples

import torch
import torch.nn as nn
from ray.train.torch import TorchCheckpoint

# Set manual seed
torch.manual_seed(42)

# Function to create a NN model
def create_model() -> nn.Module:
    model = nn.Sequential(nn.Linear(1, 10),
            nn.ReLU(),
            nn.Linear(10,1))
    return model

# Create a TorchCheckpoint from our model's state_dict
model = create_model()
checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())

# Now load the model from the TorchCheckpoint by providing the
# model architecture
model_from_chkpt = checkpoint.get_model(create_model())

# Assert they have the same state dict
assert str(model.state_dict()) == str(model_from_chkpt.state_dict())
print("worked")
classmethod from_model(model: torch.nn.modules.module.Module, *, preprocessor: Optional[Preprocessor] = None) TorchCheckpoint[source]#

Create a Checkpoint that stores a Torch model.

Note

PyTorch recommends storing state dictionaries. To create a TorchCheckpoint from a state dictionary, call from_state_dict(). To learn more about state dictionaries, read Saving and Loading Models. # noqa: E501

Parameters
  • model – The Torch model to store in the checkpoint.

  • preprocessor – A fitted preprocessor to be applied before inference.

Returns

A TorchCheckpoint containing the specified model.

Examples

from ray.train.torch import TorchCheckpoint
from ray.train.torch import TorchPredictor
import torch

# Set manual seed
torch.manual_seed(42)

# Create model identity and send a random tensor to it
model = torch.nn.Identity()
input = torch.randn(2, 2)
output = model(input)

# Create a checkpoint
checkpoint = TorchCheckpoint.from_model(model)

# You can use a class TorchCheckpoint to create an
# a class ray.train.torch.TorchPredictor and perform inference.
predictor = TorchPredictor.from_checkpoint(checkpoint)
pred = predictor.predict(input.numpy())

# Convert prediction dictionary value into a tensor
pred = torch.tensor(pred['predictions'])

# Assert the output from the original and checkoint model are the same
assert torch.equal(output, pred)
print("worked")
get_model(model: Optional[torch.nn.modules.module.Module] = None) torch.nn.modules.module.Module[source]#

Retrieve the model stored in this checkpoint.

Parameters

model – If the checkpoint contains a model state dict, and not the model itself, then the state dict will be loaded to this model. Otherwise, the model will be discarded.