import glob
import json
import logging
import math
import numpy as np
import os
from pathlib import Path
import random
import re
import tree # pip install dm_tree
from typing import List, Optional, TYPE_CHECKING, Union
from urllib.parse import urlparse
import zipfile
try:
from smart_open import smart_open
except ImportError:
smart_open = None
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID,
MultiAgentBatch,
SampleBatch,
concat_samples,
convert_ma_batch_to_sample_batch,
)
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils.compression import unpack_if_needed
from ray.rllib.utils.spaces.space_utils import clip_action, normalize_action
from ray.rllib.utils.typing import Any, FileType, SampleBatchType
if TYPE_CHECKING:
from ray.rllib.evaluation import RolloutWorker
logger = logging.getLogger(__name__)
WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)]
def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict:
"""Handle nested action/observation spaces for policies.
Translates nested lists/dicts from the json into proper
np.ndarrays, according to the (nested) observation- and action-
spaces of the given policy.
Providing nested lists w/o this preprocessing step would
confuse a SampleBatch constructor.
"""
for k, v in json_data.items():
data_col = (
policy.view_requirements[k].data_col
if k in policy.view_requirements
else ""
)
# No action flattening -> Process nested (leaf) action(s).
if policy.config.get("_disable_action_flattening") and (
k == SampleBatch.ACTIONS
or data_col == SampleBatch.ACTIONS
or k == SampleBatch.PREV_ACTIONS
or data_col == SampleBatch.PREV_ACTIONS
):
json_data[k] = tree.map_structure_up_to(
policy.action_space_struct,
lambda comp: np.array(comp),
json_data[k],
check_types=False,
)
# No preprocessing -> Process nested (leaf) observation(s).
elif policy.config.get("_disable_preprocessor_api") and (
k == SampleBatch.OBS
or data_col == SampleBatch.OBS
or k == SampleBatch.NEXT_OBS
or data_col == SampleBatch.NEXT_OBS
):
json_data[k] = tree.map_structure_up_to(
policy.observation_space_struct,
lambda comp: np.array(comp),
json_data[k],
check_types=False,
)
return json_data
@DeveloperAPI
def _adjust_dones(json_data: dict) -> dict:
"""Make sure DONES in json data is properly translated into TERMINATEDS."""
new_json_data = {}
for k, v in json_data.items():
# Translate DONES into TERMINATEDS.
if k == SampleBatch.DONES:
new_json_data[SampleBatch.TERMINATEDS] = v
# Leave everything else as-is.
else:
new_json_data[k] = v
return new_json_data
@DeveloperAPI
def postprocess_actions(batch: SampleBatchType, ioctx: IOContext) -> SampleBatchType:
# Clip actions (from any values into env's bounds), if necessary.
cfg = ioctx.config
# TODO(jungong): We should not clip_action in input reader.
# Use connector to handle this.
if cfg.get("clip_actions"):
if ioctx.worker is None:
raise ValueError(
"clip_actions is True but cannot clip actions since no workers exist"
)
if isinstance(batch, SampleBatch):
policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
if policy is None:
assert len(ioctx.worker.policy_map) == 1
policy = next(iter(ioctx.worker.policy_map.values()))
batch[SampleBatch.ACTIONS] = clip_action(
batch[SampleBatch.ACTIONS], policy.action_space_struct
)
else:
for pid, b in batch.policy_batches.items():
b[SampleBatch.ACTIONS] = clip_action(
b[SampleBatch.ACTIONS],
ioctx.worker.policy_map[pid].action_space_struct,
)
# Re-normalize actions (from env's bounds to zero-centered), if
# necessary.
if (
cfg.get("actions_in_input_normalized") is False
and cfg.get("normalize_actions") is True
):
if ioctx.worker is None:
raise ValueError(
"actions_in_input_normalized is False but"
"cannot normalize actions since no workers exist"
)
# If we have a complex action space and actions were flattened
# and we have to normalize -> Error.
error_msg = (
"Normalization of offline actions that are flattened is not "
"supported! Make sure that you record actions into offline "
"file with the `_disable_action_flattening=True` flag OR "
"as already normalized (between -1.0 and 1.0) values. "
"Also, when reading already normalized action values from "
"offline files, make sure to set "
"`actions_in_input_normalized=True` so that RLlib will not "
"perform normalization on top."
)
if isinstance(batch, SampleBatch):
policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
if policy is None:
assert len(ioctx.worker.policy_map) == 1
policy = next(iter(ioctx.worker.policy_map.values()))
if isinstance(
policy.action_space_struct, (tuple, dict)
) and not policy.config.get("_disable_action_flattening"):
raise ValueError(error_msg)
batch[SampleBatch.ACTIONS] = normalize_action(
batch[SampleBatch.ACTIONS], policy.action_space_struct
)
else:
for pid, b in batch.policy_batches.items():
policy = ioctx.worker.policy_map[pid]
if isinstance(
policy.action_space_struct, (tuple, dict)
) and not policy.config.get("_disable_action_flattening"):
raise ValueError(error_msg)
b[SampleBatch.ACTIONS] = normalize_action(
b[SampleBatch.ACTIONS],
ioctx.worker.policy_map[pid].action_space_struct,
)
return batch
@DeveloperAPI
def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]):
# Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch).
if "type" in json_data:
data_type = json_data.pop("type")
else:
raise ValueError("JSON record missing 'type' field")
if data_type == "SampleBatch":
if worker is not None and len(worker.policy_map) != 1:
raise ValueError(
"Found single-agent SampleBatch in input file, but our "
"PolicyMap contains more than 1 policy!"
)
for k, v in json_data.items():
json_data[k] = unpack_if_needed(v)
if worker is not None:
policy = next(iter(worker.policy_map.values()))
json_data = _adjust_obs_actions_for_policy(json_data, policy)
json_data = _adjust_dones(json_data)
return SampleBatch(json_data)
elif data_type == "MultiAgentBatch":
policy_batches = {}
for policy_id, policy_batch in json_data["policy_batches"].items():
inner = {}
for k, v in policy_batch.items():
# Translate DONES into TERMINATEDS.
if k == SampleBatch.DONES:
k = SampleBatch.TERMINATEDS
inner[k] = unpack_if_needed(v)
if worker is not None:
policy = worker.policy_map[policy_id]
inner = _adjust_obs_actions_for_policy(inner, policy)
inner = _adjust_dones(inner)
policy_batches[policy_id] = SampleBatch(inner)
return MultiAgentBatch(policy_batches, json_data["count"])
else:
raise ValueError(
"Type field must be one of ['SampleBatch', 'MultiAgentBatch']", data_type
)
# TODO(jungong) : use DatasetReader to back JsonReader, so we reduce
# codebase complexity without losing existing functionality.
[docs]
@PublicAPI
class JsonReader(InputReader):
"""Reader object that loads experiences from JSON file chunks.
The input files will be read from in random order.
"""
[docs]
@PublicAPI
def __init__(
self, inputs: Union[str, List[str]], ioctx: Optional[IOContext] = None
):
"""Initializes a JsonReader instance.
Args:
inputs: Either a glob expression for files, e.g. `/tmp/**/*.json`,
or a list of single file paths or URIs, e.g.,
["s3://bucket/file.json", "s3://bucket/file2.json"].
ioctx: Current IO context object or None.
"""
logger.info(
"You are using JSONReader. It is recommended to use "
+ "DatasetReader instead for better sharding support."
)
self.ioctx = ioctx or IOContext()
self.default_policy = self.policy_map = None
self.batch_size = 1
if self.ioctx:
self.batch_size = self.ioctx.config.get("train_batch_size", 1)
num_workers = self.ioctx.config.get("num_env_runners", 0)
if num_workers:
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
if self.ioctx.worker is not None:
self.policy_map = self.ioctx.worker.policy_map
self.default_policy = self.policy_map.get(DEFAULT_POLICY_ID)
if self.default_policy is None:
assert len(self.policy_map) == 1
self.default_policy = next(iter(self.policy_map.values()))
if isinstance(inputs, str):
inputs = os.path.abspath(os.path.expanduser(inputs))
if os.path.isdir(inputs):
inputs = [os.path.join(inputs, "*.json"), os.path.join(inputs, "*.zip")]
logger.warning(f"Treating input directory as glob patterns: {inputs}")
else:
inputs = [inputs]
if any(urlparse(i).scheme not in [""] + WINDOWS_DRIVES for i in inputs):
raise ValueError(
"Don't know how to glob over `{}`, ".format(inputs)
+ "please specify a list of files to read instead."
)
else:
self.files = []
for i in inputs:
self.files.extend(glob.glob(i))
elif isinstance(inputs, (list, tuple)):
self.files = list(inputs)
else:
raise ValueError(
"type of inputs must be list or str, not {}".format(inputs)
)
if self.files:
logger.info("Found {} input files.".format(len(self.files)))
else:
raise ValueError("No files found matching {}".format(inputs))
self.cur_file = None
@override(InputReader)
def next(self) -> SampleBatchType:
ret = []
count = 0
while count < self.batch_size:
batch = self._try_parse(self._next_line())
tries = 0
while not batch and tries < 100:
tries += 1
logger.debug("Skipping empty line in {}".format(self.cur_file))
batch = self._try_parse(self._next_line())
if not batch:
raise ValueError(
"Failed to read valid experience batch from file: {}".format(
self.cur_file
)
)
batch = self._postprocess_if_needed(batch)
count += batch.count
ret.append(batch)
ret = concat_samples(ret)
return ret
[docs]
def read_all_files(self) -> SampleBatchType:
"""Reads through all files and yields one SampleBatchType per line.
When reaching the end of the last file, will start from the beginning
again.
Yields:
One SampleBatch or MultiAgentBatch per line in all input files.
"""
for path in self.files:
file = self._try_open_file(path)
while True:
line = file.readline()
if not line:
break
batch = self._try_parse(line)
if batch is None:
break
yield batch
def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
if not self.ioctx.config.get("postprocess_inputs"):
return batch
batch = convert_ma_batch_to_sample_batch(batch)
if isinstance(batch, SampleBatch):
out = []
for sub_batch in batch.split_by_episode():
out.append(self.default_policy.postprocess_trajectory(sub_batch))
return concat_samples(out)
else:
# TODO(ekl) this is trickier since the alignments between agent
# trajectories in the episode are not available any more.
raise NotImplementedError(
"Postprocessing of multi-agent data not implemented yet."
)
def _try_open_file(self, path):
if urlparse(path).scheme not in [""] + WINDOWS_DRIVES:
if smart_open is None:
raise ValueError(
"You must install the `smart_open` module to read "
"from URIs like {}".format(path)
)
ctx = smart_open
else:
# Allow shortcut for home directory ("~/" -> env[HOME]).
if path.startswith("~/"):
path = os.path.join(os.environ.get("HOME", ""), path[2:])
# If path doesn't exist, try to interpret is as relative to the
# rllib directory (located ../../ from this very module).
path_orig = path
if not os.path.exists(path):
path = os.path.join(Path(__file__).parent.parent, path)
if not os.path.exists(path):
raise FileNotFoundError(f"Offline file {path_orig} not found!")
# Unzip files, if necessary and re-point to extracted json file.
if re.search("\\.zip$", path):
with zipfile.ZipFile(path, "r") as zip_ref:
zip_ref.extractall(Path(path).parent)
path = re.sub("\\.zip$", ".json", path)
assert os.path.exists(path)
ctx = open
file = ctx(path, "r")
return file
def _try_parse(self, line: str) -> Optional[SampleBatchType]:
line = line.strip()
if not line:
return None
try:
batch = self._from_json(line)
except Exception:
logger.exception(
"Ignoring corrupt json record in {}: {}".format(self.cur_file, line)
)
return None
batch = postprocess_actions(batch, self.ioctx)
return batch
def _next_line(self) -> str:
if not self.cur_file:
self.cur_file = self._next_file()
line = self.cur_file.readline()
tries = 0
while not line and tries < 100:
tries += 1
if hasattr(self.cur_file, "close"): # legacy smart_open impls
self.cur_file.close()
self.cur_file = self._next_file()
line = self.cur_file.readline()
if not line:
logger.debug("Ignoring empty file {}".format(self.cur_file))
if not line:
raise ValueError(
"Failed to read next line from files: {}".format(self.files)
)
return line
def _next_file(self) -> FileType:
# If this is the first time, we open a file, make sure all workers
# start with a different one if possible.
if self.cur_file is None and self.ioctx.worker is not None:
idx = self.ioctx.worker.worker_index
total = self.ioctx.worker.num_workers or 1
path = self.files[round((len(self.files) - 1) * (idx / total))]
# After the first file, pick all others randomly.
else:
path = random.choice(self.files)
return self._try_open_file(path)
def _from_json(self, data: str) -> SampleBatchType:
if isinstance(data, bytes): # smart_open S3 doesn't respect "r"
data = data.decode("utf-8")
json_data = json.loads(data)
return from_json_data(json_data, self.ioctx.worker)