Source code for ray.air.checkpoint

import contextlib
import io
import logging
import os
import platform
import shutil
import tarfile
import tempfile
import traceback
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Tuple, Type, Union

import ray
from ray import cloudpickle as pickle
from ray.air._internal.checkpointing import (
    load_preprocessor_from_dir,
    save_preprocessor_to_dir,
)
from ray.air._internal.filelock import TempFileLock
from ray.air._internal.remote_storage import (
    download_from_uri,
    fs_hint,
    is_non_local_path_uri,
    read_file_from_uri,
    upload_to_uri,
)
from ray.air._internal.util import _copy_dir_ignore_conflicts
from ray.air.constants import PREPROCESSOR_KEY, CHECKPOINT_ID_ATTR
from ray.util.annotations import DeveloperAPI, PublicAPI

if TYPE_CHECKING:
    from ray.data.preprocessor import Preprocessor


_DICT_CHECKPOINT_FILE_NAME = "dict_checkpoint.pkl"
_CHECKPOINT_METADATA_FILE_NAME = ".metadata.pkl"
_DICT_CHECKPOINT_ADDITIONAL_FILE_KEY = "_ray_additional_checkpoint_files"
_METADATA_CHECKPOINT_SUFFIX = ".meta.pkl"
_FS_CHECKPOINT_KEY = "fs_checkpoint"
_BYTES_DATA_KEY = "bytes_data"
_METADATA_KEY = "_metadata"
_CHECKPOINT_DIR_PREFIX = "checkpoint_tmp_"
# The namespace is a constant UUID to prevent conflicts, as defined in RFC-4122
_CHECKPOINT_UUID_URI_NAMESPACE = uuid.UUID("627fe696-f135-436f-bc4b-bda0306e0181")

logger = logging.getLogger(__name__)


@dataclass
class _CheckpointMetadata:
    """Metadata about a checkpoint.

    Attributes:
        checkpoint_type: The checkpoint class. For example, ``TorchCheckpoint``.
        checkpoint_state: A dictionary that maps object attributes to their values. When
            you load a serialized checkpoint, restore these values.
    """

    checkpoint_type: Type["Checkpoint"]
    checkpoint_state: Dict[str, Any]


[docs]@PublicAPI(stability="beta") class Checkpoint: """Ray AIR Checkpoint. An AIR Checkpoint are a common interface for accessing models across different AIR components and libraries. A Checkpoint can have its data represented in one of three ways: - as a directory on local (on-disk) storage - as a directory on an external storage (e.g., cloud storage) - as an in-memory dictionary The Checkpoint object also has methods to translate between different checkpoint storage locations. These storage representations provide flexibility in distributed environments, where you may want to recreate an instance of the same model on multiple nodes or across different Ray clusters. Example: .. code-block:: python from ray.air.checkpoint import Checkpoint # Create checkpoint data dict checkpoint_data = {"data": 123} # Create checkpoint object from data checkpoint = Checkpoint.from_dict(checkpoint_data) # Save checkpoint to a directory on the file system. path = checkpoint.to_directory() # This path can then be passed around, # # e.g. to a different function or a different script. # You can also use `checkpoint.to_uri/from_uri` to # read from/write to cloud storage # In another function or script, recover Checkpoint object from path checkpoint = Checkpoint.from_directory(path) # Convert into dictionary again recovered_data = checkpoint.to_dict() # It is guaranteed that the original data has been recovered assert recovered_data == checkpoint_data Checkpoints can be used to instantiate a :class:`Predictor`, :class:`BatchPredictor`, or :class:`PredictorDeployment` class. The constructor is a private API, instead the ``from_`` methods should be used to create checkpoint objects (e.g. ``Checkpoint.from_directory()``). **Other implementation notes:** When converting between different checkpoint formats, it is guaranteed that a full round trip of conversions (e.g. directory --> dict --> --> directory) will recover the original checkpoint data. There are no guarantees made about compatibility of intermediate representations. New data can be added to a Checkpoint during conversion. Consider the following conversion: directory --> dict (adding dict["foo"] = "bar") --> directory --> dict (expect to see dict["foo"] = "bar"). Note that the second directory will contain pickle files with the serialized additional field data in them. Similarly with a dict as a source: dict --> directory (add file "foo.txt") --> dict --> directory (will have "foo.txt" in it again). Note that the second dict representation will contain an extra field with the serialized additional files in it. Checkpoints can be pickled and sent to remote processes. Please note that checkpoints pointing to local directories will be pickled as data representations, so the full checkpoint data will be contained in the checkpoint object. If you want to avoid this, consider passing only the checkpoint directory to the remote task and re-construct your checkpoint object in that function. Note that this will only work if the "remote" task is scheduled on the same node or a node that also has access to the local data path (e.g. on a shared file system like NFS). If you need persistence across clusters, use the ``to_uri()`` or ``to_directory()`` methods to persist your checkpoints to disk. """ # A list of object attributes to persist. For example, if # `_SERIALIZED_ATTRS = ("spam",)`, then the value of `"spam"` is serialized with # the checkpoint data. When a user constructs a checkpoint from the serialized data, # the value of `"spam"` is restored. # In subclasses, do _SERIALIZED_ATTRS = Checkpoint._SERIALIZED_ATTRS + ("spam",) # `"_checkpoint_id` is used to track checkpoints in TuneCheckpointManager # and is unset otherwise _SERIALIZED_ATTRS = (CHECKPOINT_ID_ATTR,)
[docs] @DeveloperAPI def __init__( self, local_path: Optional[Union[str, os.PathLike]] = None, data_dict: Optional[dict] = None, uri: Optional[str] = None, ): # First, resolve file:// URIs to local paths if uri: local_path = _get_local_path(uri) if local_path: uri = None # Only one data type can be set at any time if local_path: assert not data_dict and not uri if not isinstance(local_path, (str, os.PathLike)) or not os.path.exists( local_path ): raise RuntimeError( f"Cannot create checkpoint from path as it does " f"not exist on local node: {local_path}" ) elif not os.path.isdir(local_path): raise RuntimeError( f"Cannot create checkpoint from path as it does " f"not point to a directory: {local_path}. If your checkpoint " f"is a single file, consider passing the enclosing directory " f"instead." ) elif data_dict: assert not local_path and not uri if not isinstance(data_dict, dict): raise RuntimeError( f"Cannot create checkpoint from dict as no " f"dict was passed: {data_dict}" ) elif uri: assert not local_path and not data_dict resolved = _get_external_path(uri) if not resolved: raise RuntimeError( f"Cannot create checkpoint from URI as it is not " f"supported: {resolved}" ) uri = resolved else: raise ValueError("Cannot create checkpoint without data.") self._local_path: Optional[str] = ( str(Path(local_path).resolve()) if local_path else local_path ) self._data_dict: Optional[Dict[str, Any]] = data_dict self._uri: Optional[str] = uri self._override_preprocessor: Optional["Preprocessor"] = None self._override_preprocessor_set = False # When using a cloud URI, we make sure that the uuid is constant. # This ensures we do not download the data multiple times on one node. # Note that this is not a caching mechanism - instead, this # only ensures that if there are several processes downloading # from the same URI, only one process does the actual work # while the rest waits (FileLock). This also means data will not be duplicated. self._uuid = ( uuid.uuid4() if not self._uri else uuid.uuid5(_CHECKPOINT_UUID_URI_NAMESPACE, self._uri) )
def __repr__(self): parameter, argument = self.get_internal_representation() return f"{self.__class__.__name__}({parameter}={argument})" @property def _metadata(self) -> _CheckpointMetadata: return _CheckpointMetadata( checkpoint_type=self.__class__, checkpoint_state={ attr: getattr(self, attr) for attr in self._SERIALIZED_ATTRS if hasattr(self, attr) }, ) def _copy_metadata_attrs_from(self, source: "Checkpoint") -> None: """Copy in-place metadata attributes from ``source`` to self.""" for attr, value in source._metadata.checkpoint_state.items(): if attr in self._SERIALIZED_ATTRS: setattr(self, attr, value) @_metadata.setter def _metadata(self, metadata: _CheckpointMetadata): if metadata.checkpoint_type is not self.__class__: raise ValueError( f"Checkpoint type in metadata must match {self.__class__}, " f"got {metadata.checkpoint_type}" ) for attr, value in metadata.checkpoint_state.items(): setattr(self, attr, value) @property def path(self) -> Optional[str]: """Return path to checkpoint, if available. This will return a URI to cloud storage if this checkpoint is persisted on cloud, or a local path if this checkpoint is persisted on local disk and available on the current node. In all other cases, this will return None. Example: >>> from ray.air import Checkpoint >>> checkpoint = Checkpoint.from_uri("s3://some-bucket/some-location") >>> assert checkpoint.path == "s3://some-bucket/some-location" >>> checkpoint = Checkpoint.from_dict({"data": 1}) >>> assert checkpoint.path == None Returns: Checkpoint path if this checkpoint is reachable from the current node (e.g. cloud storage or locally available directory). """ if self._uri: return self._uri if self._local_path: return self._local_path return None @property def uri(self) -> Optional[str]: """Return checkpoint URI, if available. This will return a URI to cloud storage if this checkpoint is persisted on cloud, or a local ``file://`` URI if this checkpoint is persisted on local disk and available on the current node. In all other cases, this will return None. Users can then choose to persist to cloud with :meth:`Checkpoint.to_uri() <ray.air.Checkpoint.to_uri>`. Example: >>> from ray.air import Checkpoint >>> checkpoint = Checkpoint.from_uri("s3://some-bucket/some-location") >>> assert checkpoint.uri == "s3://some-bucket/some-location" >>> checkpoint = Checkpoint.from_dict({"data": 1}) >>> assert checkpoint.uri == None Returns: Checkpoint URI if this URI is reachable from the current node (e.g. cloud storage or locally available file URI). """ if self._uri: return self._uri if self._local_path and Path(self._local_path).exists(): return f"file://{self._local_path}" return None
[docs] @classmethod def from_bytes(cls, data: bytes) -> "Checkpoint": """Create a checkpoint from the given byte string. Args: data: Data object containing pickled checkpoint data. Returns: Checkpoint: checkpoint object. """ bytes_data = pickle.loads(data) if isinstance(bytes_data, dict): data_dict = bytes_data else: data_dict = {_BYTES_DATA_KEY: bytes_data} return cls.from_dict(data_dict)
[docs] def to_bytes(self) -> bytes: """Return Checkpoint serialized as bytes object. Returns: bytes: Bytes object containing checkpoint data. """ # Todo: Add support for stream in the future (to_bytes(file_like)) data_dict = self.to_dict() if _BYTES_DATA_KEY in data_dict: return data_dict[_BYTES_DATA_KEY] return pickle.dumps(data_dict)
[docs] @classmethod def from_dict(cls, data: dict) -> "Checkpoint": """Create checkpoint object from dictionary. Args: data: Dictionary containing checkpoint data. Returns: Checkpoint: checkpoint object. """ state = {} if _METADATA_KEY in data: metadata = data[_METADATA_KEY] cls = cls._get_checkpoint_type(metadata.checkpoint_type) state = metadata.checkpoint_state checkpoint = cls(data_dict=data) checkpoint.__dict__.update(state) return checkpoint
[docs] def to_dict(self) -> dict: """Return checkpoint data as dictionary. Returns: dict: Dictionary containing checkpoint data. """ if self._data_dict: # If the checkpoint data is already a dict, return checkpoint_data = self._data_dict elif self._local_path or self._uri: # Else, checkpoint is either on FS or external storage with self.as_directory() as local_path: checkpoint_data_path = os.path.join( local_path, _DICT_CHECKPOINT_FILE_NAME ) if os.path.exists(checkpoint_data_path): # If we are restoring a dict checkpoint, load the dict # from the checkpoint file. with open(checkpoint_data_path, "rb") as f: checkpoint_data = pickle.load(f) # If there are additional files in the directory, add them as # _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY additional_files = {} for file_or_dir in os.listdir(local_path): if file_or_dir in [ ".", "..", _DICT_CHECKPOINT_FILE_NAME, _CHECKPOINT_METADATA_FILE_NAME, ]: continue additional_files[file_or_dir] = _pack( os.path.join(local_path, file_or_dir) ) if additional_files: checkpoint_data[ _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY ] = additional_files else: files = [ f for f in os.listdir(local_path) if os.path.isfile(os.path.join(local_path, f)) and f.endswith(_METADATA_CHECKPOINT_SUFFIX) ] metadata = {} for file in files: with open(os.path.join(local_path, file), "rb") as f: key = file[: -len(_METADATA_CHECKPOINT_SUFFIX)] value = pickle.load(f) metadata[key] = value data = _pack(local_path) checkpoint_data = { _FS_CHECKPOINT_KEY: data, } checkpoint_data.update(metadata) else: raise RuntimeError(f"Empty data for checkpoint {self}") checkpoint_data[_METADATA_KEY] = self._metadata # If override_preprocessor is specified, then set that in the output dict. if self._override_preprocessor_set: checkpoint_data[PREPROCESSOR_KEY] = self._override_preprocessor return checkpoint_data
[docs] @classmethod def from_directory(cls, path: Union[str, os.PathLike]) -> "Checkpoint": """Create checkpoint object from directory. Args: path: Directory containing checkpoint data. The caller promises to not delete the directory (gifts ownership of the directory to this Checkpoint). Returns: Checkpoint: checkpoint object. """ state = {} checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME) if os.path.exists(checkpoint_metadata_path): with open(checkpoint_metadata_path, "rb") as file: metadata = pickle.load(file) cls = cls._get_checkpoint_type(metadata.checkpoint_type) state = metadata.checkpoint_state checkpoint = cls(local_path=path) checkpoint.__dict__.update(state) return checkpoint
[docs] @classmethod @DeveloperAPI def from_checkpoint(cls, other: "Checkpoint") -> "Checkpoint": """Create a checkpoint from a generic :class:`Checkpoint`. This method can be used to create a framework-specific checkpoint from a generic :class:`Checkpoint` object. Examples: >>> result = TorchTrainer.fit(...) # doctest: +SKIP >>> checkpoint = TorchCheckpoint.from_checkpoint(result.checkpoint) # doctest: +SKIP >>> model = checkpoint.get_model() # doctest: +SKIP Linear(in_features=1, out_features=1, bias=True) """ # noqa: E501 if type(other) is cls: return other new_checkpoint = cls( local_path=other._local_path, data_dict=other._data_dict, uri=other._uri, ) new_checkpoint._copy_metadata_attrs_from(other) return new_checkpoint
def _get_temporary_checkpoint_dir(self) -> str: """Return the name for the temporary checkpoint dir.""" tmp_dir_path = tempfile.gettempdir() checkpoint_dir_name = _CHECKPOINT_DIR_PREFIX + self._uuid.hex if platform.system() == "Windows": # Max path on Windows is 260 chars, -1 for joining \ # Also leave a little for the del lock del_lock_name = _get_del_lock_path("") checkpoint_dir_name = ( _CHECKPOINT_DIR_PREFIX + self._uuid.hex[ -259 + len(_CHECKPOINT_DIR_PREFIX) + len(tmp_dir_path) + len(del_lock_name) : ] ) if not checkpoint_dir_name.startswith(_CHECKPOINT_DIR_PREFIX): raise RuntimeError( "Couldn't create checkpoint directory due to length " "constraints. Try specifing a shorter checkpoint path." ) return os.path.join(tmp_dir_path, checkpoint_dir_name) def _save_checkpoint_metadata_in_directory(self, path: str) -> None: checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME) with open(checkpoint_metadata_path, "wb") as file: pickle.dump(self._metadata, file) def _to_directory(self, path: str, move_instead_of_copy: bool = False) -> None: if self._data_dict: data_dict = self.to_dict() if _FS_CHECKPOINT_KEY in data_dict: for key in data_dict.keys(): if key == _FS_CHECKPOINT_KEY: continue metadata_path = os.path.join( path, f"{key}{_METADATA_CHECKPOINT_SUFFIX}" ) with open(metadata_path, "wb") as f: pickle.dump(data_dict[key], f) # This used to be a true fs checkpoint, so restore _unpack(data_dict[_FS_CHECKPOINT_KEY], path) else: # This is a dict checkpoint. # First, restore any additional files additional_files = data_dict.pop( _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY, {} ) for file, content in additional_files.items(): _unpack(stream=content, path=os.path.join(path, file)) # Then dump data into checkpoint.pkl checkpoint_data_path = os.path.join(path, _DICT_CHECKPOINT_FILE_NAME) with open(checkpoint_data_path, "wb") as f: pickle.dump(data_dict, f) else: # This is either a local fs, remote node fs, or external fs local_path = self._local_path path_pathlib = Path(path).resolve() external_path = _get_external_path(self._uri) if local_path: local_path_pathlib = Path(local_path).resolve() if local_path_pathlib != path_pathlib: # If this exists on the local path, just copy over if move_instead_of_copy: os.makedirs(str(path_pathlib.absolute()), exist_ok=True) self._local_path = str(path_pathlib.absolute()) for inner in local_path_pathlib.iterdir(): dest = path_pathlib / inner.name if dest.exists(): # Ignore files that already exist. # For example, checkpoints from every rank may all have # a same .is_checkpoint file. continue shutil.move( str(inner.absolute()), str(path_pathlib.absolute()) ) else: _copy_dir_ignore_conflicts(local_path_pathlib, path_pathlib) elif external_path: logger.info( f"Downloading checkpoint from {external_path} to {path} ..." ) # If this exists on external storage (e.g. cloud), download download_from_uri(uri=external_path, local_path=path, filelock=False) else: raise RuntimeError( f"No valid location found for checkpoint {self}: {self._uri}" ) self._save_checkpoint_metadata_in_directory(path) if self._override_preprocessor_set and self._override_preprocessor: save_preprocessor_to_dir(self._override_preprocessor, path) def _to_directory_safe(self, path: str, move_instead_of_copy: bool = False) -> None: try: # Timeout 0 means there will be only one attempt to acquire # the file lock. If it cannot be aquired, a TimeoutError # will be thrown. with TempFileLock(f"{path}.lock", timeout=0): self._to_directory(path, move_instead_of_copy=move_instead_of_copy) except TimeoutError: # if the directory is already locked, then wait but do not do anything. with TempFileLock(f"{path}.lock", timeout=-1): pass if not os.path.exists(path): raise RuntimeError( f"Checkpoint directory {path} does not exist, " "even though it should have been created by " "another process. Please raise an issue on GitHub: " "https://github.com/ray-project/ray/issues" ) return path def _move_directory(self, path: str) -> str: """Move to a new directory, changing state. Only for local directory backed checkpoints.""" if not self._local_path: raise RuntimeError( "_move_directory requires the checkpoint to be backed by" " a local directory" ) path = os.path.normpath(str(path)) _make_dir(path) self._local_path = self._to_directory_safe(path, move_instead_of_copy=True) return self._local_path
[docs] def to_directory(self, path: Optional[str] = None) -> str: """Write checkpoint data to directory. Args: path: Target directory to restore data in. If not specified, will create a temporary directory. Returns: str: Directory containing checkpoint data. """ user_provided_path = path is not None path = path if user_provided_path else self._get_temporary_checkpoint_dir() path = os.path.normpath(str(path)) _make_dir(path, acquire_del_lock=not user_provided_path) return self._to_directory_safe(path)
[docs] @contextlib.contextmanager def as_directory(self) -> Iterator[str]: """Return checkpoint directory path in a context. This function makes checkpoint data available as a directory while avoiding unnecessary copies and left-over temporary data. If the checkpoint is already a directory checkpoint, it will return the existing path. If it is not, it will create a temporary directory, which will be deleted after the context is exited. Users should treat the returned checkpoint directory as read-only and avoid changing any data within it, as it might get deleted when exiting the context. Example: .. code-block:: python with checkpoint.as_directory() as checkpoint_dir: # Do some read-only processing of files within checkpoint_dir pass # At this point, if a temporary directory was created, it will have # been deleted. """ if self._local_path: yield self._local_path else: temp_dir = self.to_directory() del_lock_path = _get_del_lock_path(temp_dir) yield temp_dir # Cleanup try: os.remove(del_lock_path) except Exception: logger.warning( f"Could not remove {del_lock_path} deletion file lock. " f"Traceback:\n{traceback.format_exc()}" ) # In the edge case (process crash before del lock file is removed), # we do not remove the directory at all. # Since it's in /tmp, this is not that big of a deal. # check if any lock files are remaining temp_dir_base_name = Path(temp_dir).name if not list( Path(temp_dir).parent.glob(_get_del_lock_path(temp_dir_base_name, "*")) ): try: # Timeout 0 means there will be only one attempt to acquire # the file lock. If it cannot be aquired, a TimeoutError # will be thrown. with TempFileLock(f"{temp_dir}.lock", timeout=0): shutil.rmtree(temp_dir, ignore_errors=True) except TimeoutError: pass
[docs] @classmethod def from_uri(cls, uri: str) -> "Checkpoint": """Create checkpoint object from location URI (e.g. cloud storage). Valid locations currently include AWS S3 (``s3://``), Google cloud storage (``gs://``), HDFS (``hdfs://``), and local files (``file://``). Args: uri: Source location URI to read data from. Returns: Checkpoint: checkpoint object. """ state = {} try: checkpoint_metadata_uri = os.path.join(uri, _CHECKPOINT_METADATA_FILE_NAME) metadata = pickle.loads(read_file_from_uri(checkpoint_metadata_uri)) except Exception: pass else: cls = cls._get_checkpoint_type(metadata.checkpoint_type) state = metadata.checkpoint_state checkpoint = cls(uri=uri) checkpoint.__dict__.update(state) return checkpoint
[docs] def to_uri(self, uri: str) -> str: """Write checkpoint data to location URI (e.g. cloud storage). Args: uri: Target location URI to write data to. Returns: str: Cloud location containing checkpoint data. """ if uri.startswith("file://"): local_path = uri[7:] return self.to_directory(local_path) if not is_non_local_path_uri(uri): raise RuntimeError( f"Cannot upload checkpoint to URI: Provided URI " f"does not belong to a registered storage provider: `{uri}`. " f"Hint: {fs_hint(uri)}" ) with self.as_directory() as local_path: checkpoint_metadata_path = os.path.join( local_path, _CHECKPOINT_METADATA_FILE_NAME ) with open(checkpoint_metadata_path, "wb") as file: pickle.dump(self._metadata, file) upload_to_uri(local_path=local_path, uri=uri) return uri
[docs] @DeveloperAPI def get_internal_representation( self, ) -> Tuple[str, Union[dict, str, ray.ObjectRef]]: """Return tuple of (type, data) for the internal representation. The internal representation can be used e.g. to compare checkpoint objects for equality or to access the underlying data storage. The returned type is a string and one of ``["local_path", "data_dict", "uri"]``. The data is the respective data value. Note that paths converted from ``file://...`` will be returned as ``local_path`` (without the ``file://`` prefix) and not as ``uri``. Returns: Tuple of type and data. """ if self._local_path: return "local_path", self._local_path elif self._data_dict: return "data_dict", self._data_dict elif self._uri: return "uri", self._uri else: raise RuntimeError( "Cannot get internal representation of empty checkpoint." )
def __getstate__(self): if self._local_path: blob = self.to_bytes() return self.__class__.from_bytes(blob).__getstate__() return self.__dict__ def __setstate__(self, state): self.__dict__.update(state) def __fspath__(self): raise TypeError( "You cannot use `air.Checkpoint` objects directly as paths. " "Use `Checkpoint.to_directory()` or `Checkpoint.as_directory()` instead." )
[docs] def get_preprocessor(self) -> Optional["Preprocessor"]: """Return the saved preprocessor, if one exists.""" if self._override_preprocessor_set: return self._override_preprocessor # The preprocessor will either be stored in an in-memory dict or # written to storage. In either case, it will use the PREPROCESSOR_KEY key. # If this is a pure directory checkpoint (not a dict checkpoint saved to dir), # then do not convert to dictionary as that takes a lot of time and memory. if self.uri: with self.as_directory() as checkpoint_path: if _is_persisted_directory_checkpoint(checkpoint_path): # If this is a persisted directory checkpoint, then we load the # files from the temp directory created by the context. # That way we avoid having to download the files twice. loaded_checkpoint = self.from_directory(checkpoint_path) preprocessor = _get_preprocessor(loaded_checkpoint) else: preprocessor = load_preprocessor_from_dir(checkpoint_path) else: preprocessor = _get_preprocessor(self) return preprocessor
[docs] def set_preprocessor(self, preprocessor: Optional["Preprocessor"]): """Saves the provided preprocessor to this Checkpoint.""" self._override_preprocessor = preprocessor self._override_preprocessor_set = True
@classmethod def _get_checkpoint_type( cls, serialized_cls: Type["Checkpoint"] ) -> Type["Checkpoint"]: """Return the class that's a subclass of the other class, if one exists. Raises: ValueError: If neither class is a subclass of the other. """ if issubclass(cls, serialized_cls): return cls if issubclass(serialized_cls, cls): return serialized_cls raise ValueError( f"You're trying to load a serialized `{serialized_cls.__name__}`, but " f"`{serialized_cls.__name__}` isn't compatible with `{cls.__name__}`." )
def _get_local_path(path: Optional[str]) -> Optional[str]: """Check if path is a local path. Otherwise return None.""" if path is None or is_non_local_path_uri(path): return None if path.startswith("file://"): path = path[7:] if os.path.exists(path): return path return None def _get_external_path(path: Optional[str]) -> Optional[str]: """Check if path is an external path. Otherwise return None.""" if not isinstance(path, str) or not is_non_local_path_uri(path): return None return path def _pack(path: str) -> bytes: """Pack directory in ``path`` into an archive, return as bytes string.""" stream = io.BytesIO() def filter_function(tarinfo): if tarinfo.name.endswith(_METADATA_CHECKPOINT_SUFFIX): return None else: return tarinfo with tarfile.open(fileobj=stream, mode="w", format=tarfile.PAX_FORMAT) as tar: tar.add(path, arcname="", filter=filter_function) return stream.getvalue() def _unpack(stream: bytes, path: str) -> str: """Unpack archive in bytes string into directory in ``path``.""" with tarfile.open(fileobj=io.BytesIO(stream)) as tar: tar.extractall(path) return path def _get_del_lock_path(path: str, pid: str = None) -> str: """Get the path to the deletion lock file.""" pid = pid if pid is not None else os.getpid() return f"{path}.del_lock_{pid}" def _make_dir(path: str, acquire_del_lock: bool = False) -> None: """Create the temporary checkpoint dir in ``path``.""" if acquire_del_lock: # Each process drops a deletion lock file it then cleans up. # If there are no lock files left, the last process # will remove the entire directory. del_lock_path = _get_del_lock_path(path) open(del_lock_path, "a").close() os.makedirs(path, exist_ok=True) def _is_persisted_directory_checkpoint(path: str) -> bool: return Path(path, _DICT_CHECKPOINT_FILE_NAME).exists() def _get_preprocessor(checkpoint: "Checkpoint") -> Optional["Preprocessor"]: # First try converting to dictionary. checkpoint_dict = checkpoint.to_dict() preprocessor = checkpoint_dict.get(PREPROCESSOR_KEY, None) if preprocessor is None: # Fallback to reading from directory. with checkpoint.as_directory() as checkpoint_path: preprocessor = load_preprocessor_from_dir(checkpoint_path) return preprocessor