ray.rllib.models.tf.tf_modelv2.TFModelV2.get_initial_state#
- TFModelV2.get_initial_state() List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor] #
Get the initial recurrent state values for the model.
- Returns:
List of np.array (for tf) or Tensor (for torch) objects containing the initial hidden state of an RNN, if applicable.
import numpy as np from ray.rllib.models.modelv2 import ModelV2 class MyModel(ModelV2): # ... def get_initial_state(self): return [ np.zeros(self.cell_size, np.float32), np.zeros(self.cell_size, np.float32), ]