ModelV2.get_initial_state() List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor][source]#

Get the initial recurrent state values for the model.


List of np.array (for tf) or Tensor (for torch) objects containing the initial hidden state of an RNN, if applicable.

import numpy as np
from ray.rllib.models.modelv2 import ModelV2
class MyModel(ModelV2):
    # ...
    def get_initial_state(self):
        return [
            np.zeros(self.cell_size, np.float32),
            np.zeros(self.cell_size, np.float32),