Learner (Alpha)#

Learner allows you to abstract the training logic of RLModules. It supports both gradient-based and non-gradient-based updates (e.g. polyak averaging, etc.) The API enables you to distribute the Learner using data- distributed parallel (DDP). The Learner achieves the following:

  1. Facilitates gradient-based updates on RLModule.

  2. Provides abstractions for non-gradient based updates such as polyak averaging, etc.

  3. Reporting training statistics.

  4. Checkpoints the modules and optimizer states for durable training.

The Learner class supports data-distributed- parallel style training using the LearnerGroup API. Under this paradigm, the LearnerGroup maintains multiple copies of the same Learner with identical parameters and hyperparameters. Each of these Learner instances computes the loss and gradients on a shard of a sample batch and then accumulates the gradients across the Learner instances. Learn more about data-distributed parallel learning in this article.

LearnerGroup also allows for asynchronous training and (distributed) checkpointing for durability during training.

Enabling Learner API in RLlib experiments#

Adjust the amount of resources for training using the num_gpus_per_learner_worker, num_cpus_per_learner_worker, and num_learner_workers arguments in the AlgorithmConfig.

from ray.rllib.algorithms.ppo.ppo import PPOConfig
config = (
    PPOConfig()
    .experimental(_enable_new_api_stack=True)
    .resources(
        num_gpus_per_learner_worker=0,  # Set this to 1 to enable GPU training.
        num_cpus_per_learner_worker=1,
        num_learner_workers=0  # Set this to greater than 1 to allow for DDP style
                           # updates.
    )
)

Note

This features is in alpha. If you migrate to this algorithm, enable the feature by via AlgorithmConfig.experimental(_enable_new_api_stack=True).

The following algorithms support Learner out of the box. Implement an algorithm with a custom Learner to leverage this API for other algorithms.

Algorithm

Supported Framework

PPO

pytorch tensorflow

Impala

pytorch tensorflow

APPO

pytorch tensorflow

Basic usage#

Use the LearnerGroup utility to interact with multiple learners.

Construction#

If you enable the RLModule and Learner APIs via the AlgorithmConfig, then calling build() constructs a LearnerGroup for you, but if you’re using these APIs standalone, you can construct the LearnerGroup as follows.

env = gym.make("CartPole-v1")

# Create an AlgorithmConfig object from which we can build the
# LearnerGroup.
config = (
    PPOConfig()
    # Number of Learner workers (ray actors).
    # Use 0 for no actors, only create a local Learner.
    # Use >=1 to create n DDP-style Learner workers (ray actors).
    .resources(num_learner_workers=1)
    # Specify the learner's hyperparameters.
    .training(
        use_kl_loss=True,
        kl_coeff=0.01,
        kl_target=0.05,
        clip_param=0.2,
        vf_clip_param=0.2,
        entropy_coeff=0.05,
        vf_loss_coeff=0.5
    )
)

# Construct a new LearnerGroup using our config object.
learner_group = config.build_learner_group(env=env)
env = gym.make("CartPole-v1")

# Create an AlgorithmConfig object from which we can build the
# Learner.
config = (
    PPOConfig()
    # Specify the Learner's hyperparameters.
    .training(
        use_kl_loss=True,
        kl_coeff=0.01,
        kl_target=0.05,
        clip_param=0.2,
        vf_clip_param=0.2,
        entropy_coeff=0.05,
        vf_loss_coeff=0.5
    )
)
# Construct a new Learner using our config object.
learner = config.build_learner(env=env)

Updates#

# This is a blocking update.
results = learner_group.update_from_batch(batch=DUMMY_BATCH)

# This is a non-blocking update. The results are returned in a future
# call to `update_from_batch(..., async_update=True)`
_ = learner_group.update_from_batch(batch=DUMMY_BATCH, async_update=True)

# Artificially wait for async request to be done to get the results
# in the next call to
# `LearnerGroup.update_from_batch(..., async_update=True)`.
time.sleep(5)
results = learner_group.update_from_batch(
    batch=DUMMY_BATCH, async_update=True
)
# `results` is a list of results dict. The items in the list represent the different
# remote results from the different calls to
# `update_from_batch(..., async_update=True)`.
assert len(results) > 0
# Each item is a results dict, already reduced over the n Learner workers.
assert isinstance(results[0], dict), results[0]

# This is an additional non-gradient based update.
learner_group.additional_update(**ADDITIONAL_UPDATE_KWARGS)

When updating a LearnerGroup you can perform blocking or async updates on batches of data. Async updates are necessary for implementing async algorithms such as APPO/IMPALA. You can perform non-gradient based updates using additional_update().

# This is a blocking update (given a training batch).
result = learner.update_from_batch(batch=DUMMY_BATCH)

# This is an additional non-gradient based update.
learner_group.additional_update(**ADDITIONAL_UPDATE_KWARGS)

When updating a Learner you can only perform blocking updates on batches of data. You can perform non-gradient based updates using additional_update().

Getting and setting state#

# Get the LearnerGroup's RLModule weights and optimizer states.
state = learner_group.get_state()
learner_group.set_state(state)

# Only get the RLModule weights.
weights = learner_group.get_weights()
learner_group.set_weights(weights)

Set/get the state dict of all learners through learner_group via set_state() or get_state(). This includes all states including both neural network weights, and optimizer states on each learner. You can set and get the weights of the RLModule of all learners through learner_group via set_weights() or get_weights(). This does not include optimizer states.

# Get the Learner's RLModule weights and optimizer states.
state = learner.get_state()
learner.set_state(state)

# Only get the RLModule weights (as numpy arrays).
module_state = learner.get_module_state()
learner.module.set_state(module_state)

You can set and get the weights of a Learner using set_state() and get_state() . For setting or getting only RLModule weights (without optimizer states), use set_module_state() or get_module_state() API.

Checkpointing#

learner_group.save_state(LEARNER_GROUP_CKPT_DIR)
learner_group.load_state(LEARNER_GROUP_CKPT_DIR)

Checkpoint the state of all learners in the LearnerGroup via save_state() and load_state(). This includes all states including neural network weights and any optimizer states. Note that since the state of all of the Learner instances is identical, only the states from the first Learner need to be saved.

learner.save_state(LEARNER_CKPT_DIR)
learner.load_state(LEARNER_CKPT_DIR)

Checkpoint the state of a Learner via save_state() and load_state(). This includes all states including neural network weights and any optimizer states.

Implementation#

Learner has many APIs for flexible implementation, however the core ones that you need to implement are:

Method

Description

configure_optimizers_for_module()

set up any optimizers for a RLModule.

compute_loss_for_module()

calculate the loss for gradient based update to a module.

additional_update_for_module()

do any non gradient based updates to a RLModule, e.g. target network updates.

compile_results()

compute training statistics and format them for downstream use.

Starter Example#

A Learner that implements behavior cloning could look like the following:

class BCTorchLearner(TorchLearner):

    @override(Learner)
    def compute_loss_for_module(
        self,
        *,
        module_id: ModuleID,
        config: AlgorithmConfig = None,
        batch: NestedDict,
        fwd_out: Dict[str, TensorType],
    ) -> TensorType:

        # standard behavior cloning loss
        action_dist_inputs = fwd_out[SampleBatch.ACTION_DIST_INPUTS]
        action_dist_class = self._module[module_id].get_train_action_dist_cls()
        action_dist = action_dist_class.from_logits(action_dist_inputs)
        loss = -torch.mean(action_dist.logp(batch[SampleBatch.ACTIONS]))

        return loss

    @override(Learner)
    def compile_results(
        self,
        *,
        batch: MultiAgentBatch,
        fwd_out: Dict[str, Any],
        loss_per_module: Dict[str, TensorType],
        metrics_per_module: DefaultDict[ModuleID, Dict[str, Any]],
    ) -> Dict[str, Any]:

        results = super().compile_results(
            batch=batch,
            fwd_out=fwd_out,
            loss_per_module=loss_per_module,
            metrics_per_module=metrics_per_module,
        )
        # report the mean weight of each
        mean_ws = {}
        for module_id in self.module.keys():
            m = self.module[module_id]
            parameters = convert_to_numpy(self.get_parameters(m))
            mean_ws[module_id] = np.mean([w.mean() for w in parameters])
            results[module_id]["mean_weight"] = mean_ws[module_id]

        return results