class ray.train.torch.TorchPredictor(model: torch.nn.modules.module.Module, preprocessor: Optional[Preprocessor] = None, use_gpu: bool = False)[source]#

Bases: ray.train._internal.dl_predictor.DLPredictor

A predictor for PyTorch models.

  • model – The torch module to use for predictions.

  • preprocessor – A preprocessor used to transform data batches prior to prediction.

  • use_gpu – If set, the model will be moved to GPU on instantiation and prediction happens on GPU.

PublicAPI (beta): This API is in beta and may change before becoming stable.

classmethod from_checkpoint(checkpoint: ray.air.checkpoint.Checkpoint, model: Optional[torch.nn.modules.module.Module] = None, use_gpu: bool = False) ray.train.torch.torch_predictor.TorchPredictor[source]#

Instantiate the predictor from a Checkpoint.

The checkpoint is expected to be a result of TorchTrainer.

  • checkpoint – The checkpoint to load the model and preprocessor from. It is expected to be from the result of a TorchTrainer run.

  • model – If the checkpoint contains a model state dict, and not the model itself, then the state dict will be loaded to this model. If the checkpoint already contains the model itself, this model argument will be discarded.

  • use_gpu – If set, the model will be moved to GPU on instantiation and prediction happens on GPU.

call_model(inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) Union[torch.Tensor, Dict[str, torch.Tensor]][source]#

Runs inference on a single batch of tensor data.

This method is called by TorchPredictor.predict after converting the original data batch to torch tensors.

Override this method to add custom logic for processing the model input or output.


inputs – A batch of data to predict on, represented as either a single PyTorch tensor or for multi-input models, a dictionary of tensors.


The model outputs, either as a single tensor or a dictionary of tensors.


import numpy as np
import torch
from ray.train.torch import TorchPredictor

# List outputs are not supported by default TorchPredictor.
# So let's define a custom TorchPredictor and override call_model
class MyModel(torch.nn.Module):
    def forward(self, input_tensor):
        return [input_tensor, input_tensor]

# Use a custom predictor to format model output as a dict.
class CustomPredictor(TorchPredictor):
    def call_model(self, inputs):
        model_output = super().call_model(inputs)
        return {
            str(i): model_output[i] for i in range(len(model_output))

# create our data batch
data_batch = np.array([1, 2])
# create custom predictor and predict
predictor = CustomPredictor(model=MyModel())
predictions = predictor.predict(data_batch)
print(f"Predictions: {predictions.get('0')}, {predictions.get('1')}")
Predictions: [1 2], [1 2]

DeveloperAPI: This API may change across minor Ray releases.

predict(data: Union[numpy.ndarray, pandas.DataFrame, Dict[str, numpy.ndarray]], dtype: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None) Union[numpy.ndarray, pandas.DataFrame, Dict[str, numpy.ndarray]][source]#

Run inference on data batch.

If the provided data is a single array or a dataframe/table with a single column, it will be converted into a single PyTorch tensor before being inputted to the model.

If the provided data is a multi-column table or a dict of numpy arrays, it will be converted into a dict of tensors before being inputted to the model. This is useful for multi-modal inputs (for example your model accepts both image and text).

  • data – A batch of input data of DataBatchType.

  • dtype – The dtypes to use for the tensors. Either a single dtype for all tensors or a mapping from column name to dtype.


Prediction result. The return type will be the same as the

input type.

Return type



import numpy as np
import pandas as pd
import torch
import ray
from ray.train.torch import TorchPredictor

# Define a custom PyTorch module
class CustomModule(torch.nn.Module):
    def __init__(self):
        self.linear1 = torch.nn.Linear(1, 1)
        self.linear2 = torch.nn.Linear(1, 1)

    def forward(self, input_dict: dict):
        out1 = self.linear1(input_dict["A"].unsqueeze(1))
        out2 = self.linear2(input_dict["B"].unsqueeze(1))
        return out1 + out2

# Set manul seed so we get consistent output

# Use Standard PyTorch model
model = torch.nn.Linear(2, 1)
predictor = TorchPredictor(model=model)
# Define our data
data = np.array([[1, 2], [3, 4]])
predictions = predictor.predict(data, dtype=torch.float)
print(f"Standard model predictions: {predictions}")

# Use Custom PyTorch model with TorchPredictor
predictor = TorchPredictor(model=CustomModule())
# Define our data and predict Customer model with TorchPredictor
data = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"])
predictions = predictor.predict(data, dtype=torch.float)
print(f"Custom model predictions: {predictions}")
Standard model predictions: {'predictions': array([[1.5487633],
       [3.8037925]], dtype=float32)}
Custom model predictions:     predictions
0  [0.61623406]
1    [2.857038]