RLModule API#

RL Module specifications and configuraitons#

Single Agent#

SingleAgentRLModuleSpec(module_class, ...)

A utility spec class to make it constructing RLModules (in single-agent case) easier.

SingleAgentRLModuleSpec.build()

Builds the RLModule from this spec.

SingleAgentRLModuleSpec.get_rl_module_config()

Returns the RLModule config for this spec.

RLModule Configuration#

RLModuleConfig(observation_space, ...)

A utility config class to make it constructing RLModules easier.

RLModuleConfig.to_dict()

Returns a serialized representation of the config.

RLModuleConfig.from_dict(d)

Creates a config from a serialized representation.

RLModuleConfig.get_catalog()

Returns the catalog for this config.

Multi Agent#

MultiAgentRLModuleSpec(marl_module_class, ...)

A utility spec class to make it constructing MARL modules easier.

MultiAgentRLModuleSpec.build([module_id])

Builds either the multi-agent module or the single-agent module.

MultiAgentRLModuleSpec.get_marl_config()

Returns the MultiAgentRLModuleConfig for this spec.

RL Module API#

Constructor#

RLModule(config)

Base class for RLlib modules.

RLModule.as_multi_agent()

Returns a multi-agent wrapper around this module.

Forward methods#

forward_train(batch, **kwargs)

Forward-pass during training called from the learner.

forward_exploration(batch, **kwargs)

Forward-pass during exploration, called from the sampler.

forward_inference(batch, **kwargs)

Forward-pass during evaluation, called from the sampler.

IO specifications#

input_specs_inference()

Returns the input specs of the forward_inference method.

input_specs_exploration()

Returns the input specs of the forward_exploration method.

input_specs_train()

Returns the input specs of the forward_train method.

output_specs_inference()

Returns the output specs of the forward_inference method.

output_specs_exploration()

Returns the output specs of the forward_exploration method.

output_specs_train()

Returns the output specs of the forward_train method.

Saving and Loading#

get_state()

Returns the state dict of the module.

set_state(state_dict)

Sets the state dict of the module.

save_state(path)

Saves the weights of this RLModule to path.

load_state(path)

Loads the weights of an RLModule from path.

save_to_checkpoint(checkpoint_dir_path)

Saves the module to a checkpoint directory.

from_checkpoint(checkpoint_dir_path)

Loads the module from a checkpoint directory.

Multi Agent RL Module API#

Constructor#

MultiAgentRLModule(*args, **kwargs)

Base class for multi-agent RLModules.

MultiAgentRLModule.setup()

Sets up the underlying RLModules.

MultiAgentRLModule.as_multi_agent()

Returns a multi-agent wrapper around this module.

Modifying the underlying RL modules#

add_module(module_id, module, *[, override])

Adds a module at run time to the multi-agent module.

remove_module(module_id, *[, ...])

Removes a module at run time from the multi-agent module.

Saving and Loading#

save_state(path)

Saves the weights of this MultiAgentRLModule to dir.

load_state(path[, modules_to_load])

Loads the weights of an MultiAgentRLModule from dir.