ray.train.batch_predictor.BatchPredictor.from_checkpoint#

classmethod BatchPredictor.from_checkpoint(checkpoint: ray.air.checkpoint.Checkpoint, predictor_cls: Type[ray.train.predictor.Predictor], **kwargs) ray.train.batch_predictor.BatchPredictor[source]#

Create a BatchPredictor from a Checkpoint.

Example

from torchvision import models

from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchCheckpoint, TorchPredictor

model = models.resnet50(pretrained=True)
checkpoint = TorchCheckpoint.from_model(model)
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchPredictor)
Parameters
  • checkpoint – A Checkpoint containing model state and optionally a preprocessor.

  • predictor_cls – The type of predictor to use.

  • **kwargs – Optional arguments to pass the predictor_cls constructor.