from collections import deque
import threading
from typing import Dict, Set
import logging
import ray
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import OldAPIStack, override
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.typing import PolicyID
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
[docs]
@OldAPIStack
class PolicyMap(dict):
"""Maps policy IDs to Policy objects.
Thereby, keeps n policies in memory and - when capacity is reached -
writes the least recently used to disk. This allows adding 100s of
policies to a Algorithm for league-based setups w/o running out of memory.
"""
[docs]
def __init__(
self,
*,
capacity: int = 100,
policy_states_are_swappable: bool = False,
# Deprecated args.
worker_index=None,
num_workers=None,
policy_config=None,
session_creator=None,
seed=None,
):
"""Initializes a PolicyMap instance.
Args:
capacity: The size of the Policy object cache. This is the maximum number
of policies that are held in RAM memory. When reaching this capacity,
the least recently used Policy's state will be stored in the Ray object
store and recovered from there when being accessed again.
policy_states_are_swappable: Whether all Policy objects in this map can be
"swapped out" via a simple `state = A.get_state(); B.set_state(state)`,
where `A` and `B` are policy instances in this map. You should set
this to True for significantly speeding up the PolicyMap's cache lookup
times, iff your policies all share the same neural network
architecture and optimizer types. If True, the PolicyMap will not
have to garbage collect old, least recently used policies, but instead
keep them in memory and simply override their state with the state of
the most recently accessed one.
For example, in a league-based training setup, you might have 100s of
the same policies in your map (playing against each other in various
combinations), but all of them share the same state structure
(are "swappable").
"""
if policy_config is not None:
deprecation_warning(
old="PolicyMap(policy_config=..)",
error=True,
)
super().__init__()
self.capacity = capacity
if any(
i is not None
for i in [policy_config, worker_index, num_workers, session_creator, seed]
):
deprecation_warning(
old="PolicyMap([deprecated args]...)",
new="PolicyMap(capacity=..., policy_states_are_swappable=...)",
error=False,
)
self.policy_states_are_swappable = policy_states_are_swappable
# The actual cache with the in-memory policy objects.
self.cache: Dict[str, Policy] = {}
# Set of keys that may be looked up (cached or not).
self._valid_keys: Set[str] = set()
# The doubly-linked list holding the currently in-memory objects.
self._deque = deque()
# Ray object store references to the stashed Policy states.
self._policy_state_refs = {}
# Lock used for locking some methods on the object-level.
# This prevents possible race conditions when accessing the map
# and the underlying structures, like self._deque and others.
self._lock = threading.RLock()
@with_lock
@override(dict)
def __getitem__(self, item: PolicyID):
# Never seen this key -> Error.
if item not in self._valid_keys:
raise KeyError(
f"PolicyID '{item}' not found in this PolicyMap! "
f"IDs stored in this map: {self._valid_keys}."
)
# Item already in cache -> Rearrange deque (promote `item` to
# "most recently used") and return it.
if item in self.cache:
self._deque.remove(item)
self._deque.append(item)
return self.cache[item]
# Item not currently in cache -> Get from stash and - if at capacity -
# remove leftmost one.
if item not in self._policy_state_refs:
raise AssertionError(
f"PolicyID {item} not found in internal Ray object store cache!"
)
policy_state = ray.get(self._policy_state_refs[item])
policy = None
# We are at capacity: Remove the oldest policy from deque as well as the
# cache and return it.
if len(self._deque) == self.capacity:
policy = self._stash_least_used_policy()
# All our policies have same NN-architecture (are "swappable").
# -> Load new policy's state into the one that just got removed from the cache.
# This way, we save the costly re-creation step.
if policy is not None and self.policy_states_are_swappable:
logger.debug(f"restoring policy: {item}")
policy.set_state(policy_state)
else:
logger.debug(f"creating new policy: {item}")
policy = Policy.from_state(policy_state)
self.cache[item] = policy
# Promote the item to most recently one.
self._deque.append(item)
return policy
@with_lock
@override(dict)
def __setitem__(self, key: PolicyID, value: Policy):
# Item already in cache -> Rearrange deque.
if key in self.cache:
self._deque.remove(key)
# Item not currently in cache -> store new value and - if at capacity -
# remove leftmost one.
else:
# Cache at capacity -> Drop leftmost item.
if len(self._deque) == self.capacity:
self._stash_least_used_policy()
# Promote `key` to "most recently used".
self._deque.append(key)
# Update our cache.
self.cache[key] = value
self._valid_keys.add(key)
@with_lock
@override(dict)
def __delitem__(self, key: PolicyID):
# Make key invalid.
self._valid_keys.remove(key)
# Remove policy from deque if contained
if key in self._deque:
self._deque.remove(key)
# Remove policy from memory if currently cached.
if key in self.cache:
policy = self.cache[key]
self._close_session(policy)
del self.cache[key]
# Remove Ray object store reference (if this ID has already been stored
# there), so the item gets garbage collected.
if key in self._policy_state_refs:
del self._policy_state_refs[key]
@override(dict)
def __iter__(self):
return iter(self.keys())
[docs]
@override(dict)
def items(self):
"""Iterates over all policies, even the stashed ones."""
def gen():
for key in self._valid_keys:
yield (key, self[key])
return gen()
[docs]
@override(dict)
def keys(self):
"""Returns all valid keys, even the stashed ones."""
self._lock.acquire()
ks = list(self._valid_keys)
self._lock.release()
def gen():
for key in ks:
yield key
return gen()
[docs]
@override(dict)
def values(self):
"""Returns all valid values, even the stashed ones."""
self._lock.acquire()
vs = [self[k] for k in self._valid_keys]
self._lock.release()
def gen():
for value in vs:
yield value
return gen()
@with_lock
@override(dict)
def update(self, __m, **kwargs):
"""Updates the map with the given dict and/or kwargs."""
for k, v in __m.items():
self[k] = v
for k, v in kwargs.items():
self[k] = v
@with_lock
@override(dict)
def get(self, key: PolicyID):
"""Returns the value for the given key or None if not found."""
if key not in self._valid_keys:
return None
return self[key]
@with_lock
@override(dict)
def __len__(self) -> int:
"""Returns number of all policies, including the stashed-to-disk ones."""
return len(self._valid_keys)
@with_lock
@override(dict)
def __contains__(self, item: PolicyID):
return item in self._valid_keys
@override(dict)
def __str__(self) -> str:
# Only print out our keys (policy IDs), not values as this could trigger
# the LRU caching.
return (
f"<PolicyMap lru-caching-capacity={self.capacity} policy-IDs="
f"{list(self.keys())}>"
)
def _stash_least_used_policy(self) -> Policy:
"""Writes the least-recently used policy's state to the Ray object store.
Also closes the session - if applicable - of the stashed policy.
Returns:
The least-recently used policy, that just got removed from the cache.
"""
# Get policy's state for writing to object store.
dropped_policy_id = self._deque.popleft()
assert dropped_policy_id in self.cache
policy = self.cache[dropped_policy_id]
policy_state = policy.get_state()
# If we don't simply swap out vs an existing policy:
# Close the tf session, if any.
if not self.policy_states_are_swappable:
self._close_session(policy)
# Remove from memory. This will clear the tf Graph as well.
del self.cache[dropped_policy_id]
# Store state in Ray object store.
self._policy_state_refs[dropped_policy_id] = ray.put(policy_state)
# Return the just removed policy, in case it's needed by the caller.
return policy
@staticmethod
def _close_session(policy: Policy):
sess = policy.get_session()
# Closes the tf session, if any.
if sess is not None:
sess.close()