ray.train.huggingface.HuggingFaceCheckpoint
ray.train.huggingface.HuggingFaceCheckpoint#
- class ray.train.huggingface.HuggingFaceCheckpoint(local_path: Optional[Union[str, os.PathLike]] = None, data_dict: Optional[dict] = None, uri: Optional[str] = None)[source]#
Bases:
ray.air.checkpoint.Checkpoint
A
Checkpoint
with HuggingFace-specific functionality.Use
HuggingFaceCheckpoint.from_model
to create this type of checkpoint.PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- classmethod from_model(model: Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module], tokenizer: Optional[transformers.tokenization_utils.PreTrainedTokenizer] = None, *, path: os.PathLike, preprocessor: Optional[Preprocessor] = None) HuggingFaceCheckpoint [source]#
Create a
Checkpoint
that stores a HuggingFace model.- Parameters
model – The pretrained transformer or Torch model to store in the checkpoint.
tokenizer – The Tokenizer to use in the Transformers pipeline for inference.
path – The directory where the checkpoint will be stored.
preprocessor – A fitted preprocessor to be applied before inference.
- Returns
A
HuggingFaceCheckpoint
containing the specified model.
- get_model(model: Union[Type[transformers.modeling_utils.PreTrainedModel], torch.nn.modules.module.Module], **pretrained_model_kwargs) Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] [source]#
Retrieve the model stored in this checkpoint.