Source code for ray.rllib.utils.replay_buffers.reservoir_replay_buffer

from typing import Any, Dict
import random

# Import ray before psutil will make sure we use psutil's bundled version
import ray  # noqa F401
import psutil  # noqa E402

from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.replay_buffers.replay_buffer import (
    ReplayBuffer,
    warn_replay_capacity,
)
from ray.rllib.utils.typing import SampleBatchType


# __sphinx_doc_reservoir_buffer__begin__
[docs] @ExperimentalAPI class ReservoirReplayBuffer(ReplayBuffer): """This buffer implements reservoir sampling. The algorithm has been described by Jeffrey S. Vitter in "Random sampling with a reservoir". """
[docs] def __init__( self, capacity: int = 10000, storage_unit: str = "timesteps", **kwargs ): """Initializes a ReservoirBuffer instance. Args: capacity: Max number of timesteps to store in the FIFO buffer. After reaching this number, older samples will be dropped to make space for new ones. storage_unit: Either 'timesteps', 'sequences' or 'episodes'. Specifies how experiences are stored. """ ReplayBuffer.__init__(self, capacity, storage_unit) self._num_add_calls = 0 self._num_evicted = 0
@ExperimentalAPI @override(ReplayBuffer) def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None: """Add a SampleBatch of experiences to self._storage. An item consists of either one or more timesteps, a sequence or an episode. Differs from add() in that it does not consider the storage unit or type of batch and simply stores it. Args: item: The batch to be added. ``**kwargs``: Forward compatibility kwargs. """ self._num_timesteps_added += item.count self._num_timesteps_added_wrap += item.count # Update add counts. self._num_add_calls += 1 # Update our timesteps counts. if self._num_timesteps_added < self.capacity: self._storage.append(item) self._est_size_bytes += item.size_bytes() else: # Eviction of older samples has already started (buffer is "full") self._eviction_started = True idx = random.randint(0, self._num_add_calls - 1) if idx < len(self._storage): self._num_evicted += 1 self._evicted_hit_stats.push(self._hit_count[idx]) self._hit_count[idx] = 0 # This is a bit of a hack: ReplayBuffer always inserts at # self._next_idx self._next_idx = idx self._evicted_hit_stats.push(self._hit_count[idx]) self._hit_count[idx] = 0 item_to_be_removed = self._storage[idx] self._est_size_bytes -= item_to_be_removed.size_bytes() self._storage[idx] = item self._est_size_bytes += item.size_bytes() assert item.count > 0, item warn_replay_capacity(item=item, num_items=self.capacity / item.count)
[docs] @ExperimentalAPI @override(ReplayBuffer) def stats(self, debug: bool = False) -> dict: """Returns the stats of this buffer. Args: debug: If True, adds sample eviction statistics to the returned stats dict. Returns: A dictionary of stats about this buffer. """ data = { "num_evicted": self._num_evicted, "num_add_calls": self._num_add_calls, } parent = ReplayBuffer.stats(self, debug) parent.update(data) return parent
[docs] @ExperimentalAPI @override(ReplayBuffer) def get_state(self) -> Dict[str, Any]: """Returns all local state. Returns: The serializable local state. """ parent = ReplayBuffer.get_state(self) parent.update(self.stats()) return parent
[docs] @ExperimentalAPI @override(ReplayBuffer) def set_state(self, state: Dict[str, Any]) -> None: """Restores all local state to the provided `state`. Args: state: The new state to set this buffer. Can be obtained by calling `self.get_state()`. """ self._num_evicted = state["num_evicted"] self._num_add_calls = state["num_add_calls"] ReplayBuffer.set_state(self, state)
# __sphinx_doc_reservoir_buffer__end__