RLlib Models, Preprocessors, and Action Distributions¶
The following diagram provides a conceptual overview of data flow between different components in RLlib. We start with an Environment
, which given an action produces an observation. The observation is preprocessed by a Preprocessor
and Filter
(e.g. for running mean normalization) before being sent to a neural network Model
. The model output is in turn interpreted by an ActionDistribution
to determine the next action.
The components highlighted in green can be replaced with custom userdefined implementations, as described in the next sections. The purple components are RLlib internal, which means they can only be modified by changing the algorithm source code.
Default Behaviours¶
Builtin Models and Preprocessors¶
RLlib picks default models based on a simple heuristic: a vision network for observations that have shape of length larger than 2 (for example, (84 x 84 x 3)), and a fully connected network for everything else. These models can be configured via the model
config key, documented in the model catalog. Note that you’ll probably have to configure conv_filters
if your environment observations have custom sizes, e.g., "model": {"dim": 42, "conv_filters": [[16, [4, 4], 2], [32, [4, 4], 2], [512, [11, 11], 1]]}
for 42x42 observations.
In addition, if you set "model": {"use_lstm": true}
, then the model output will be further processed by a LSTM cell. More generally, RLlib supports the use of recurrent models for its policy gradient algorithms (A3C, PPO, PG, IMPALA), and RNN support is built into its policy evaluation utilities.
For preprocessors, RLlib tries to pick one of its builtin preprocessor based on the environment’s observation space. Discrete observations are onehot encoded, Atari observations downscaled, and Tuple and Dict observations flattened (these are unflattened and accessible via the input_dict
parameter in custom models). Note that for Atari, RLlib defaults to using the DeepMind preprocessors, which are also used by the OpenAI baselines library.
Builtin Model Parameters¶
The following is a list of the builtin model hyperparameters:
MODEL_DEFAULTS = {
# === Builtin options ===
# Filter config. List of [out_channels, kernel, stride] for each filter
"conv_filters": None,
# Nonlinearity for builtin convnet
"conv_activation": "relu",
# Nonlinearity for fully connected net (tanh, relu)
"fcnet_activation": "tanh",
# Number of hidden layers for fully connected net
"fcnet_hiddens": [256, 256],
# For DiagGaussian action distributions, make the second half of the model
# outputs floating bias variables instead of statedependent. This only
# has an effect is using the default fully connected net.
"free_log_std": False,
# Whether to skip the final linear layer used to resize the hidden layer
# outputs to size `num_outputs`. If True, then the last hidden layer
# should already match num_outputs.
"no_final_linear": False,
# Whether layers should be shared for the value function.
"vf_share_layers": True,
# == LSTM ==
# Whether to wrap the model with an LSTM.
"use_lstm": False,
# Max seq len for training the LSTM, defaults to 20.
"max_seq_len": 20,
# Size of the LSTM cell.
"lstm_cell_size": 256,
# Whether to feed a_{t1}, r_{t1} to LSTM.
"lstm_use_prev_action_reward": False,
# When using modelv1 models with a modelv2 algorithm, you may have to
# define the state shape here (e.g., [256, 256]).
"state_shape": None,
# == Atari ==
# Whether to enable framestack for Atari envs
"framestack": True,
# Final resized frame dimension
"dim": 84,
# (deprecated) Converts ATARI frame to 1 Channel Grayscale image
"grayscale": False,
# (deprecated) Changes frame to range from [1, 1] if true
"zero_mean": True,
# === Options for custom models ===
# Name of a custom model to use
"custom_model": None,
# Extra options to pass to the custom classes.
# These will be available in the Model's
"custom_model_config": {},
# Name of a custom action distribution to use.
"custom_action_dist": None,
# Custom preprocessors are deprecated. Please use a wrapper class around
# your environment instead to preprocess observations.
"custom_preprocessor": None,
# Deprecated config keys.
"custom_options": DEPRECATED_VALUE,
}
TensorFlow Models¶
Note
TFModelV2 replaces the previous rllib.models.Model
class, which did not support Kerasstyle reuse of variables. The rllib.models.Model
class is deprecated and should not be used.
Custom TF models should subclass TFModelV2 to implement the __init__()
and forward()
methods. Forward takes in a dict of tensor inputs (the observation obs
, prev_action
, and prev_reward
, is_training
), optional RNN state, and returns the model output of size num_outputs
and the new state. You can also override extra methods of the model such as value_function
to implement a custom value branch. Additional supervised / selfsupervised losses can be added via the custom_loss
method:

class
ray.rllib.models.tf.tf_modelv2.
TFModelV2
(obs_space, action_space, num_outputs, model_config, name)[source]¶ TF version of ModelV2.
Note that this class by itself is not a valid model unless you implement forward() in a subclass.

__init__
(obs_space, action_space, num_outputs, model_config, name)[source]¶ Initialize a TFModelV2.
Here is an example implementation for a subclass
MyModelClass(TFModelV2)
:def __init__(self, *args, **kwargs): super(MyModelClass, self).__init__(*args, **kwargs) input_layer = tf.keras.layers.Input(...) hidden_layer = tf.keras.layers.Dense(...)(input_layer) output_layer = tf.keras.layers.Dense(...)(hidden_layer) value_layer = tf.keras.layers.Dense(...)(hidden_layer) self.base_model = tf.keras.Model( input_layer, [output_layer, value_layer]) self.register_variables(self.base_model.variables)

forward
(input_dict, state, seq_lens)¶ Call the model with the given input tensors and state.
Any complex observations (dicts, tuples, etc.) will be unpacked by __call__ before being passed to forward(). To access the flattened observation tensor, refer to input_dict[“obs_flat”].
This method can be called any number of times. In eager execution, each call to forward() will eagerly evaluate the model. In symbolic execution, each call to forward creates a computation graph that operates over the variables of this model (i.e., shares weights).
Custom models should override this instead of __call__.
 Parameters
input_dict (dict) – dictionary of input tensors, including “obs”, “obs_flat”, “prev_action”, “prev_reward”, “is_training”
state (list) – list of state tensors with sizes matching those returned by get_initial_state + the batch dimension
seq_lens (Tensor) – 1d tensor holding input sequence lengths
 Returns
 The model output tensor of size
[BATCH, num_outputs]
 Return type
(outputs, state)
Examples
>>> def forward(self, input_dict, state, seq_lens): >>> model_out, self._value_out = self.base_model( ... input_dict["obs"]) >>> return model_out, state

value_function
()¶ Returns the value function output for the most recent forward pass.
Note that a forward call has to be performed first, before this methods can return anything and thus that calling this method does not cause an extra forward pass through the network.
 Returns
value estimate tensor of shape [BATCH].

custom_loss
(policy_loss, loss_inputs)¶ Override to customize the loss function used to optimize this model.
This can be used to incorporate selfsupervised losses (by defining a loss over existing input and output tensors of this model), and supervised losses (by defining losses over a variablesharing copy of this model’s layers).
You can find an runnable example in examples/custom_loss.py.
 Parameters
policy_loss (Union[List[Tensor],Tensor]) – List of or single policy loss(es) from the policy.
loss_inputs (dict) – map of input placeholders for rollout data.
 Returns
 List of or scalar tensor for the
customized loss(es) for this model.
 Return type
Union[List[Tensor],Tensor]

metrics
()¶ Override to return custom metrics from your model.
 The stats will be reported as part of the learner stats, i.e.,
 info:
 learner:
 model:
key1: metric1 key2: metric2
 Returns
Dict of string keys to scalar tensors.

update_ops
()[source]¶ Return the list of update ops for this model.
For example, this should include any BatchNorm update ops.

variables
(as_dict=False)[source]¶ Returns the list (or a dict) of variables for this model.
 Parameters
as_dict (bool) – Whether variables should be returned as dictvalues (using descriptive keys).
 Returns
 The list (or dict if as_dict is
True) of all variables of this ModelV2.
 Return type
Union[List[any],Dict[str,any]]

trainable_variables
(as_dict=False)[source]¶ Returns the list of trainable variables for this model.
 Parameters
as_dict (bool) – Whether variables should be returned as dictvalues (using descriptive keys).
 Returns
 The list (or dict if as_dict is
True) of all trainable (tf)/requires_grad (torch) variables of this ModelV2.
 Return type
Union[List[any],Dict[str,any]]

Once implemented, the model can then be registered and used in place of a builtin model:
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
class MyModelClass(TFModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config, name): ...
def forward(self, input_dict, state, seq_lens): ...
def value_function(self): ...
ModelCatalog.register_custom_model("my_model", MyModelClass)
ray.init()
trainer = ppo.PPOTrainer(env="CartPolev0", config={
"model": {
"custom_model": "my_model",
# Extra kwargs to be passed to your model's c'tor.
"custom_model_config": {},
},
})
For a full example of a custom model in code, see the keras model example. You can also reference the unit tests for Tuple and Dict spaces, which show how to access nested observation fields.
Recurrent Models¶
Instead of using the use_lstm: True
option, it can be preferable use a custom recurrent model. This provides more control over postprocessing of the LSTM output and can also allow the use of multiple LSTM cells to process different portions of the input. For an RNN model it is preferred to subclass RecurrentNetwork
to implement __init__()
, get_initial_state()
, and forward_rnn()
. You can check out the custom_rnn_model.py model as an example to implement your own model:

class
ray.rllib.models.tf.recurrent_net.
RecurrentNetwork
(obs_space, action_space, num_outputs, model_config, name)[source]¶ Helper class to simplify implementing RNN models with TFModelV2.
Instead of implementing forward(), you can implement forward_rnn() which takes batches with the time dimension added already.
Here is an example implementation for a subclass
MyRNNClass(RecurrentNetwork)
:def __init__(self, *args, **kwargs): super(MyModelClass, self).__init__(*args, **kwargs) cell_size = 256 # Define input layers input_layer = tf.keras.layers.Input( shape=(None, obs_space.shape[0])) state_in_h = tf.keras.layers.Input(shape=(256, )) state_in_c = tf.keras.layers.Input(shape=(256, )) seq_in = tf.keras.layers.Input(shape=(), dtype=tf.int32) # Send to LSTM cell lstm_out, state_h, state_c = tf.keras.layers.LSTM( cell_size, return_sequences=True, return_state=True, name="lstm")( inputs=input_layer, mask=tf.sequence_mask(seq_in), initial_state=[state_in_h, state_in_c]) output_layer = tf.keras.layers.Dense(...)(lstm_out) # Create the RNN model self.rnn_model = tf.keras.Model( inputs=[input_layer, seq_in, state_in_h, state_in_c], outputs=[output_layer, state_h, state_c]) self.register_variables(self.rnn_model.variables) self.rnn_model.summary()

__init__
(obs_space, action_space, num_outputs, model_config, name)¶ Initialize a TFModelV2.
Here is an example implementation for a subclass
MyModelClass(TFModelV2)
:def __init__(self, *args, **kwargs): super(MyModelClass, self).__init__(*args, **kwargs) input_layer = tf.keras.layers.Input(...) hidden_layer = tf.keras.layers.Dense(...)(input_layer) output_layer = tf.keras.layers.Dense(...)(hidden_layer) value_layer = tf.keras.layers.Dense(...)(hidden_layer) self.base_model = tf.keras.Model( input_layer, [output_layer, value_layer]) self.register_variables(self.base_model.variables)

forward_rnn
(inputs, state, seq_lens)[source]¶ Call the model with the given input tensors and state.
 Parameters
inputs (dict) – observation tensor with shape [B, T, obs_size].
state (list) – list of state tensors, each with shape [B, T, size].
seq_lens (Tensor) – 1d tensor holding input sequence lengths.
 Returns
 The model output tensor of shape
[B, T, num_outputs] and the list of new state tensors each with shape [B, size].
 Return type
(outputs, new_state)
Sample implementation for the
MyRNNClass
example:def forward_rnn(self, inputs, state, seq_lens): model_out, h, c = self.rnn_model([inputs, seq_lens] + state) return model_out, [h, c]

Batch Normalization¶
You can use tf.layers.batch_normalization(x, training=input_dict["is_training"])
to add batch norm layers to your custom model: code example. RLlib will automatically run the update ops for the batch norm layers during optimization (see tf_policy.py and multi_gpu_impl.py for the exact handling of these updates).
In case RLlib does not properly detect the update ops for your custom model, you can override the update_ops()
method to return the list of ops to run for updates.
PyTorch Models¶
Similarly, you can create and register custom PyTorch models for use with PyTorchbased algorithms (e.g., A2C, PG, QMIX). See these examples of fully connected, convolutional, and recurrent torch models.

class
ray.rllib.models.torch.torch_modelv2.
TorchModelV2
(obs_space, action_space, num_outputs, model_config, name)[source]¶ Torch version of ModelV2.
Note that this class by itself is not a valid model unless you inherit from nn.Module and implement forward() in a subclass.

__init__
(obs_space, action_space, num_outputs, model_config, name)[source]¶ Initialize a TorchModelV2.
Here is an example implementation for a subclass
MyModelClass(TorchModelV2, nn.Module)
:def __init__(self, *args, **kwargs): TorchModelV2.__init__(self, *args, **kwargs) nn.Module.__init__(self) self._hidden_layers = nn.Sequential(...) self._logits = ... self._value_branch = ...

forward
(input_dict, state, seq_lens)¶ Call the model with the given input tensors and state.
Any complex observations (dicts, tuples, etc.) will be unpacked by __call__ before being passed to forward(). To access the flattened observation tensor, refer to input_dict[“obs_flat”].
This method can be called any number of times. In eager execution, each call to forward() will eagerly evaluate the model. In symbolic execution, each call to forward creates a computation graph that operates over the variables of this model (i.e., shares weights).
Custom models should override this instead of __call__.
 Parameters
input_dict (dict) – dictionary of input tensors, including “obs”, “obs_flat”, “prev_action”, “prev_reward”, “is_training”
state (list) – list of state tensors with sizes matching those returned by get_initial_state + the batch dimension
seq_lens (Tensor) – 1d tensor holding input sequence lengths
 Returns
 The model output tensor of size
[BATCH, num_outputs]
 Return type
(outputs, state)
Examples
>>> def forward(self, input_dict, state, seq_lens): >>> model_out, self._value_out = self.base_model( ... input_dict["obs"]) >>> return model_out, state

value_function
()¶ Returns the value function output for the most recent forward pass.
Note that a forward call has to be performed first, before this methods can return anything and thus that calling this method does not cause an extra forward pass through the network.
 Returns
value estimate tensor of shape [BATCH].

custom_loss
(policy_loss, loss_inputs)¶ Override to customize the loss function used to optimize this model.
This can be used to incorporate selfsupervised losses (by defining a loss over existing input and output tensors of this model), and supervised losses (by defining losses over a variablesharing copy of this model’s layers).
You can find an runnable example in examples/custom_loss.py.
 Parameters
policy_loss (Union[List[Tensor],Tensor]) – List of or single policy loss(es) from the policy.
loss_inputs (dict) – map of input placeholders for rollout data.
 Returns
 List of or scalar tensor for the
customized loss(es) for this model.
 Return type
Union[List[Tensor],Tensor]

metrics
()¶ Override to return custom metrics from your model.
 The stats will be reported as part of the learner stats, i.e.,
 info:
 learner:
 model:
key1: metric1 key2: metric2
 Returns
Dict of string keys to scalar tensors.

get_initial_state
()¶ Get the initial recurrent state values for the model.
 Returns
 List of np.array objects containing the initial
hidden state of an RNN, if applicable.
 Return type
List[np.ndarray]
Examples
>>> def get_initial_state(self): >>> return [ >>> np.zeros(self.cell_size, np.float32), >>> np.zeros(self.cell_size, np.float32), >>> ]

Once implemented, the model can then be registered and used in place of a builtin model:
import torch.nn as nn
import ray
from ray.rllib.agents import a3c
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
class CustomTorchModel(TorchModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config, name): ...
def forward(self, input_dict, state, seq_lens): ...
def value_function(self): ...
ModelCatalog.register_custom_model("my_model", CustomTorchModel)
ray.init()
trainer = a3c.A2CTrainer(env="CartPolev0", config={
"framework": "torch",
"model": {
"custom_model": "my_model",
# Extra kwargs to be passed to your model's c'tor.
"custom_model_config": {},
},
})
Custom Preprocessors¶
Warning
Custom preprocessors are deprecated, since they sometimes conflict with the builtin preprocessors for handling complex observation spaces. Please use wrapper classes around your environment instead of preprocessors.
Custom preprocessors should subclass the RLlib preprocessor class and be registered in the model catalog:
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.preprocessors import Preprocessor
class MyPreprocessorClass(Preprocessor):
def _init_shape(self, obs_space, options):
return new_shape # can vary depending on inputs
def transform(self, observation):
return ... # return the preprocessed observation
ModelCatalog.register_custom_preprocessor("my_prep", MyPreprocessorClass)
ray.init()
trainer = ppo.PPOTrainer(env="CartPolev0", config={
"model": {
"custom_preprocessor": "my_prep",
# Extra kwargs to be passed to your model's c'tor.
"custom_model_config": {},
},
})
Custom Action Distributions¶
Similar to custom models and preprocessors, you can also specify a custom action distribution class as follows. The action dist class is passed a reference to the model
, which you can use to access model.model_config
or other attributes of the model. This is commonly used to implement autoregressive action outputs.
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.preprocessors import Preprocessor
class MyActionDist(ActionDistribution):
@staticmethod
def required_model_output_shape(action_space, model_config):
return 7 # controls model output feature vector size
def __init__(self, inputs, model):
super(MyActionDist, self).__init__(inputs, model)
assert model.num_outputs == 7
def sample(self): ...
def logp(self, actions): ...
def entropy(self): ...
ModelCatalog.register_custom_action_dist("my_dist", MyActionDist)
ray.init()
trainer = ppo.PPOTrainer(env="CartPolev0", config={
"model": {
"custom_action_dist": "my_dist",
},
})
Supervised Model Losses¶
You can mix supervised losses into any RLlib algorithm through custom models. For example, you can add an imitation learning loss on expert experiences, or a selfsupervised autoencoder loss within the model. These losses can be defined over either policy evaluation inputs, or data read from offline storage.
TensorFlow: To add a supervised loss to a custom TF model, you need to override the custom_loss()
method. This method takes in the existing policy loss for the algorithm, which you can add your own supervised loss to before returning. For debugging, you can also return a dictionary of scalar tensors in the metrics()
method. Here is a runnable example of adding an imitation loss to CartPole training that is defined over a offline dataset.
PyTorch: There is no explicit API for adding losses to custom torch models. However, you can modify the loss in the policy definition directly. Like for TF models, offline datasets can be incorporated by creating an input reader and calling reader.next()
in the loss forward pass.
SelfSupervised Model Losses¶
You can also use the custom_loss()
API to add in selfsupervised losses such as VAE reconstruction loss and L2regularization.
Variablelength / Complex Observation Spaces¶
RLlib supports complex and variablelength observation spaces, including gym.spaces.Tuple
, gym.spaces.Dict
, and rllib.utils.spaces.Repeated
. The handling of these spaces is transparent to the user. RLlib internally will insert preprocessors to insert padding for repeated elements, flatten complex observations into a fixedsize vector during transit, and unpack the vector into the structured tensor before sending it to the model. The flattened observation is available to the model as input_dict["obs_flat"]
, and the unpacked observation as input_dict["obs"]
.
To enable batching of struct observations, RLlib unpacks them in a StructTensorlike format. In summary, repeated fields are “pushed down” and become the outer dimensions of tensor batches, as illustrated in this figure from the StructTensor RFC.
 For further information about complex observation spaces, see:
A custom environment and model that uses repeated struct fields.
The pydoc of the Repeated space.
The pydoc of the batched repeated values tensor.
The unit tests for Tuple and Dict spaces.
Variablelength / Parametric Action Spaces¶
Custom models can be used to work with environments where (1) the set of valid actions varies per step, and/or (2) the number of valid actions is very large. The general idea is that the meaning of actions can be completely conditioned on the observation, i.e., the a
in Q(s, a)
becomes just a token in [0, MAX_AVAIL_ACTIONS)
that only has meaning in the context of s
. This works with algorithms in the DQN and policygradient families and can be implemented as follows:
The environment should return a mask and/or list of valid action embeddings as part of the observation for each step. To enable batching, the number of actions can be allowed to vary from 1 to some max number:
class MyParamActionEnv(gym.Env):
def __init__(self, max_avail_actions):
self.action_space = Discrete(max_avail_actions)
self.observation_space = Dict({
"action_mask": Box(0, 1, shape=(max_avail_actions, )),
"avail_actions": Box(1, 1, shape=(max_avail_actions, action_embedding_sz)),
"real_obs": ...,
})
A custom model can be defined that can interpret the
action_mask
andavail_actions
portions of the observation. Here the model computes the action logits via the dot product of some network output and each action embedding. Invalid actions can be masked out of the softmax by scaling the probability to zero:
class ParametricActionsModel(TFModelV2):
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
true_obs_shape=(4,),
action_embed_size=2):
super(ParametricActionsModel, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
self.action_embed_model = FullyConnectedNetwork(...)
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
avail_actions = input_dict["obs"]["avail_actions"]
action_mask = input_dict["obs"]["action_mask"]
# Compute the predicted action embedding
action_embed, _ = self.action_embed_model({
"obs": input_dict["obs"]["cart"]
})
# Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
# avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
intent_vector = tf.expand_dims(action_embed, 1)
# Batch dot product => shape of logits is [BATCH, MAX_ACTIONS].
action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2)
# Mask out invalid actions (use tf.float32.min for stability)
inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)
return action_logits + inf_mask, state
Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out parametric_actions_cartpole.py. Note that since masking introduces tf.float32.min
values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the tf.float32.min
values. The cartpole example has working configurations for DQN (must set hiddens=[]
), PPO (must disable running mean and set vf_share_layers=True
), and several other algorithms. Not all algorithms support parametric actions; see the feature compatibility matrix.
Autoregressive Action Distributions¶
In an action space with multiple components (e.g., Tuple(a1, a2)
), you might want a2
to be conditioned on the sampled value of a1
, i.e., a2_sampled ~ P(a2  a1_sampled, obs)
. Normally, a1
and a2
would be sampled independently, reducing the expressivity of the policy.
To do this, you need both a custom model that implements the autoregressive pattern, and a custom action distribution class that leverages that model. The autoregressive_action_dist.py example shows how this can be implemented for a simple binary action space. For a more complex space, a more efficient architecture such as a MADE is recommended. Note that sampling a Npart action requires N forward passes through the model, however computing the log probability of an action can be done in one pass:
class BinaryAutoregressiveOutput(ActionDistribution):
"""Action distribution P(a1, a2) = P(a1) * P(a2  a1)"""
@staticmethod
def required_model_output_shape(self, model_config):
return 16 # controls model output feature vector size
def sample(self):
# first, sample a1
a1_dist = self._a1_distribution()
a1 = a1_dist.sample()
# sample a2 conditioned on a1
a2_dist = self._a2_distribution(a1)
a2 = a2_dist.sample()
# return the action tuple
return TupleActions([a1, a2])
def logp(self, actions):
a1, a2 = actions[:, 0], actions[:, 1]
a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1)
a1_logits, a2_logits = self.model.action_model([self.inputs, a1_vec])
return (Categorical(a1_logits, None).logp(a1) + Categorical(
a2_logits, None).logp(a2))
def _a1_distribution(self):
BATCH = tf.shape(self.inputs)[0]
a1_logits, _ = self.model.action_model(
[self.inputs, tf.zeros((BATCH, 1))])
a1_dist = Categorical(a1_logits, None)
return a1_dist
def _a2_distribution(self, a1):
a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1)
_, a2_logits = self.model.action_model([self.inputs, a1_vec])
a2_dist = Categorical(a2_logits, None)
return a2_dist
class AutoregressiveActionsModel(TFModelV2):
"""Implements the `.action_model` branch required above."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(AutoregressiveActionsModel, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
if action_space != Tuple([Discrete(2), Discrete(2)]):
raise ValueError(
"This model only supports the [2, 2] action space")
# Inputs
obs_input = tf.keras.layers.Input(
shape=obs_space.shape, name="obs_input")
a1_input = tf.keras.layers.Input(shape=(1, ), name="a1_input")
ctx_input = tf.keras.layers.Input(
shape=(num_outputs, ), name="ctx_input")
# Output of the model (normally 'logits', but for an autoregressive
# dist this is more like a context/feature layer encoding the obs)
context = tf.keras.layers.Dense(
num_outputs,
name="hidden",
activation=tf.nn.tanh,
kernel_initializer=normc_initializer(1.0))(obs_input)
# P(a1  obs)
a1_logits = tf.keras.layers.Dense(
2,
name="a1_logits",
activation=None,
kernel_initializer=normc_initializer(0.01))(ctx_input)
# P(a2  a1)
# note: typically you'd want to implement P(a2  a1, obs) as follows:
# a2_context = tf.keras.layers.Concatenate(axis=1)(
# [ctx_input, a1_input])
a2_context = a1_input
a2_hidden = tf.keras.layers.Dense(
16,
name="a2_hidden",
activation=tf.nn.tanh,
kernel_initializer=normc_initializer(1.0))(a2_context)
a2_logits = tf.keras.layers.Dense(
2,
name="a2_logits",
activation=None,
kernel_initializer=normc_initializer(0.01))(a2_hidden)
# Base layers
self.base_model = tf.keras.Model(obs_input, context)
self.register_variables(self.base_model.variables)
self.base_model.summary()
# Autoregressive action sampler
self.action_model = tf.keras.Model([ctx_input, a1_input],
[a1_logits, a2_logits])
self.action_model.summary()
self.register_variables(self.action_model.variables)
Note
Not all algorithms support autoregressive action distributions; see the feature compatibility matrix.