ray.rllib.models.tf.tf_modelv2.TFModelV2.get_initial_state#

TFModelV2.get_initial_state() List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]#

Get the initial recurrent state values for the model.

Returns:

List of np.array (for tf) or Tensor (for torch) objects containing the initial hidden state of an RNN, if applicable.

import numpy as np
from ray.rllib.models.modelv2 import ModelV2
class MyModel(ModelV2):
    # ...
    def get_initial_state(self):
        return [
            np.zeros(self.cell_size, np.float32),
            np.zeros(self.cell_size, np.float32),
        ]