ray.rllib.core.learner.learner_group.LearnerGroup#

class ray.rllib.core.learner.learner_group.LearnerGroup(*, config: AlgorithmConfig, module_spec: RLModuleSpec | MultiRLModuleSpec | None = None)[source]#

Bases: Checkpointable

Coordinator of n (possibly remote) Learner workers.

Each Learner worker has a copy of the RLModule, the loss function(s), and one or more optimizers.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

Methods

__init__

Initializes a LearnerGroup instance.

add_module

Adds a module to the underlying MultiRLModule.

foreach_learner

Calls the given function on each Learner L with the args: (L, **kwargs).

from_checkpoint

Creates a new Checkpointable instance from the given location and returns it.

get_metadata

Returns JSON writable metadata further describing the implementing class.

get_stats

Returns the current stats for the input queue for this learner group.

get_weights

Convenience method instead of self.get_state(components=...).

remove_module

Removes a module from the Learner.

restore_from_path

Restores the state of the implementing class from the given path.

save_to_path

Saves the state of the implementing class (or state) to path.

set_weights

Convenience method instead of self.set_state({'learner': {'rl_module': ..}}).

shutdown

Shuts down the LearnerGroup.

update_from_batch

Performs gradient based update(s) on the Learner(s), based on given batch.

update_from_episodes

Performs gradient based update(s) on the Learner(s), based on given episodes.

Attributes

CLASS_AND_CTOR_ARGS_FILE_NAME

METADATA_FILE_NAME

STATE_FILE_NAME

is_local

is_remote