TensorFlow-Specific Sub-Classes

TFPolicy

class ray.rllib.policy.tf_policy.TFPolicy(observation_space: <Mock name='mock.spaces.Space' id='139804182087632'>, action_space: <Mock name='mock.spaces.Space' id='139804182087632'>, config: dict, sess: <Mock name='mock.compat.v1.Session' id='139804175137936'>, obs_input: Any, sampled_action: Any, loss: Union[Any, List[Any]], loss_inputs: List[Tuple[str, Any]], model: Optional[ray.rllib.models.modelv2.ModelV2] = None, sampled_action_logp: Optional[Any] = None, action_input: Optional[Any] = None, log_likelihood: Optional[Any] = None, dist_inputs: Optional[Any] = None, dist_class: Optional[type] = None, state_inputs: Optional[List[Any]] = None, state_outputs: Optional[List[Any]] = None, prev_action_input: Optional[Any] = None, prev_reward_input: Optional[Any] = None, seq_lens: Optional[Any] = None, max_seq_len: int = 20, batch_divisibility_req: int = 1, update_ops: List[Any] = None, explore: Optional[Any] = None, timestep: Optional[Any] = None)[source]

An agent policy and loss implemented in TensorFlow.

Do not sub-class this class directly (neither should you sub-class DynamicTFPolicy), but rather use rllib.policy.tf_policy_template.build_tf_policy to generate your custom tf (graph-mode or eager) Policy classes.

Extending this class enables RLlib to perform TensorFlow specific optimizations on the policy, e.g., parallelization across gpus or fusing multiple graphs together in the multi-agent setting.

Input tensors are typically shaped like [BATCH_SIZE, …].

Examples

>>> policy = TFPolicySubclass(
    sess, obs_input, sampled_action, loss, loss_inputs)
>>> print(policy.compute_actions([1, 0, 2]))
(array([0, 1, 1]), [], {})
>>> print(policy.postprocess_trajectory(SampleBatch({...})))
SampleBatch({"action": ..., "advantages": ..., ...})
__init__(observation_space: <Mock name='mock.spaces.Space' id='139804182087632'>, action_space: <Mock name='mock.spaces.Space' id='139804182087632'>, config: dict, sess: <Mock name='mock.compat.v1.Session' id='139804175137936'>, obs_input: Any, sampled_action: Any, loss: Union[Any, List[Any]], loss_inputs: List[Tuple[str, Any]], model: Optional[ray.rllib.models.modelv2.ModelV2] = None, sampled_action_logp: Optional[Any] = None, action_input: Optional[Any] = None, log_likelihood: Optional[Any] = None, dist_inputs: Optional[Any] = None, dist_class: Optional[type] = None, state_inputs: Optional[List[Any]] = None, state_outputs: Optional[List[Any]] = None, prev_action_input: Optional[Any] = None, prev_reward_input: Optional[Any] = None, seq_lens: Optional[Any] = None, max_seq_len: int = 20, batch_divisibility_req: int = 1, update_ops: List[Any] = None, explore: Optional[Any] = None, timestep: Optional[Any] = None)[source]

Initializes a Policy object.

Parameters
  • observation_space – Observation space of the policy.

  • action_space – Action space of the policy.

  • config – Policy-specific configuration data.

  • sess – The TensorFlow session to use.

  • obs_input – Input placeholder for observations, of shape [BATCH_SIZE, obs…].

  • sampled_action – Tensor for sampling an action, of shape [BATCH_SIZE, action…]

  • loss – Scalar policy loss output tensor or a list thereof (in case there is more than one loss).

  • loss_inputs – A (name, placeholder) tuple for each loss input argument. Each placeholder name must correspond to a SampleBatch column key returned by postprocess_trajectory(), and has shape [BATCH_SIZE, data…]. These keys will be read from postprocessed sample batches and fed into the specified placeholders during loss computation.

  • model – The optional ModelV2 to use for calculating actions and losses. If not None, TFPolicy will provide functionality for getting variables, calling the model’s custom loss (if provided), and importing weights into the model.

  • sampled_action_logp – log probability of the sampled action.

  • action_input – Input placeholder for actions for logp/log-likelihood calculations.

  • log_likelihood – Tensor to calculate the log_likelihood (given action_input and obs_input).

  • dist_class – An optional ActionDistribution class to use for generating a dist object from distribution inputs.

  • dist_inputs – Tensor to calculate the distribution inputs/parameters.

  • state_inputs – List of RNN state input Tensors.

  • state_outputs – List of RNN state output Tensors.

  • prev_action_input – placeholder for previous actions.

  • prev_reward_input – placeholder for previous rewards.

  • seq_lens – Placeholder for RNN sequence lengths, of shape [NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See policy/rnn_sequencing.py for more information.

  • max_seq_len – Max sequence length for LSTM training.

  • batch_divisibility_req – pad all agent experiences batches to multiples of this value. This only has an effect if not using a LSTM model.

  • update_ops – override the batchnorm update ops to run when applying gradients. Otherwise we run all update ops found in the current variable scope.

  • explore – Placeholder for explore parameter into call to Exploration.get_exploration_action. Explicitly set this to False for not creating any Exploration component.

  • timestep – Placeholder for the global sampling timestep.

compute_actions_from_input_dict(input_dict: Union[ray.rllib.policy.sample_batch.SampleBatch, Dict[str, Any]], explore: bool = None, timestep: Optional[int] = None, episodes: Optional[List[Episode]] = None, **kwargs) → Tuple[Any, List[Any], Dict[str, Any]][source]

Computes actions from collected samples (across multiple-agents).

Takes an input dict (usually a SampleBatch) as its main data input. This allows for using this method in case a more complex input pattern (view requirements) is needed, for example when the Model requires the last n observations, the last m actions/rewards, or a combination of any of these.

Parameters
  • input_dict – A SampleBatch or input dict containing the Tensors to compute actions. input_dict already abides to the Policy’s as well as the Model’s view requirements and can thus be passed to the Model as-is.

  • explore – Whether to pick an exploitation or exploration action (default: None -> use self.config[“explore”]).

  • timestep – The current (sampling) time step.

  • episodes – This provides access to all of the internal episodes’ state, which may be useful for model-based or multi-agent algorithms.

Keyword Arguments

kwargs – Forward compatibility placeholder.

Returns

Batch of output actions, with shape like

[BATCH_SIZE, ACTION_SHAPE].

state_outs: List of RNN state output

batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].

info: Dictionary of extra feature batches, if any, with shape like

{“f1”: [BATCH_SIZE, …], “f2”: [BATCH_SIZE, …]}.

Return type

actions

compute_actions(obs_batch: Union[List[Any], Any], state_batches: Optional[List[Any]] = None, prev_action_batch: Union[List[Any], Any] = None, prev_reward_batch: Union[List[Any], Any] = None, info_batch: Optional[Dict[str, list]] = None, episodes: Optional[List[Episode]] = None, explore: Optional[bool] = None, timestep: Optional[int] = None, **kwargs)[source]

Computes actions for the current policy.

Parameters
  • obs_batch – Batch of observations.

  • state_batches – List of RNN state input batches, if any.

  • prev_action_batch – Batch of previous action values.

  • prev_reward_batch – Batch of previous rewards.

  • info_batch – Batch of info objects.

  • episodes – List of Episode objects, one for each obs in obs_batch. This provides access to all of the internal episode state, which may be useful for model-based or multi-agent algorithms.

  • explore – Whether to pick an exploitation or exploration action. Set to None (default) for using the value of self.config[“explore”].

  • timestep – The current (sampling) time step.

Keyword Arguments

kwargs – Forward compatibility placeholder

Returns

Batch of output actions, with shape like

[BATCH_SIZE, ACTION_SHAPE].

state_outs (List[TensorType]): List of RNN state output

batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].

info (List[dict]): Dictionary of extra feature batches, if any,

with shape like {“f1”: [BATCH_SIZE, …], “f2”: [BATCH_SIZE, …]}.

Return type

actions (TensorType)

compute_log_likelihoods(actions: Union[List[Any], Any], obs_batch: Union[List[Any], Any], state_batches: Optional[List[Any]] = None, prev_action_batch: Union[List[Any], Any, None] = None, prev_reward_batch: Union[List[Any], Any, None] = None, actions_normalized: bool = True) → Any[source]

Computes the log-prob/likelihood for a given action and observation.

The log-likelihood is calculated using this Policy’s action distribution class (self.dist_class).

Parameters
  • actions – Batch of actions, for which to retrieve the log-probs/likelihoods (given all other inputs: obs, states, ..).

  • obs_batch – Batch of observations.

  • state_batches – List of RNN state input batches, if any.

  • prev_action_batch – Batch of previous action values.

  • prev_reward_batch – Batch of previous rewards.

  • actions_normalized – Is the given actions already normalized (between -1.0 and 1.0) or not? If not and normalize_actions=True, we need to normalize the given actions first, before calculating log likelihoods.

Returns

[BATCH_SIZE].

Return type

Batch of log probs/likelihoods, with shape

learn_on_batch(postprocessed_batch: ray.rllib.policy.sample_batch.SampleBatch) → Dict[str, Any][source]

Perform one learning update, given samples.

Either this method or the combination of compute_gradients and apply_gradients must be implemented by subclasses.

Parameters

samples – The SampleBatch object to learn from.

Returns

Dictionary of extra metadata from compute_gradients().

Examples

>>> sample_batch = ev.sample()
>>> ev.learn_on_batch(sample_batch)
compute_gradients(postprocessed_batch: ray.rllib.policy.sample_batch.SampleBatch) → Tuple[Union[List[Tuple[Any, Any]], List[Any]], Dict[str, Any]][source]

Computes gradients given a batch of experiences.

Either this in combination with apply_gradients() or learn_on_batch() must be implemented by subclasses.

Parameters

postprocessed_batch – The SampleBatch object to use for calculating gradients.

Returns

List of gradient output values. grad_info: Extra policy-specific info values.

Return type

grads

apply_gradients(gradients: Union[List[Tuple[Any, Any]], List[Any]]) → None[source]

Applies the (previously) computed gradients.

Either this in combination with compute_gradients() or learn_on_batch() must be implemented by subclasses.

Parameters

gradients – The already calculated gradients to apply to this Policy.

get_weights() → Union[Dict[str, Any], List[Any]][source]

Returns model weights.

Note: The return value of this method will reside under the “weights” key in the return value of Policy.get_state(). Model weights are only one part of a Policy’s state. Other state information contains: optimizer variables, exploration state, and global state vars such as the sampling timestep.

Returns

Serializable copy or view of model weights.

set_weights(weights) → None[source]

Sets this Policy’s model’s weights.

Note: Model weights are only one part of a Policy’s state. Other state information contains: optimizer variables, exploration state, and global state vars such as the sampling timestep.

Parameters

weights – Serializable copy or view of model weights.

get_exploration_state() → Dict[str, Any][source]

Returns the state of this Policy’s exploration component.

Returns

Serializable information on the self.exploration object.

is_recurrent() → bool[source]

Whether this Policy holds a recurrent Model.

Returns

True if this Policy has-a RNN-based Model.

num_state_tensors() → int[source]

The number of internal states needed by the RNN-Model of the Policy.

Returns

The number of RNN internal states kept by this Policy’s Model.

Return type

int

get_state() → Union[Dict[str, Any], List[Any]][source]

Returns the entire current state of this Policy.

Note: Not to be confused with an RNN model’s internal state. State includes the Model(s)’ weights, optimizer weights, the exploration component’s state, as well as global variables, such as sampling timesteps.

Returns

Serialized local state.

set_state(state: dict) → None[source]

Restores the entire current state of this Policy from state.

Parameters

state – The new state to set this policy to. Can be obtained by calling self.get_state().

export_checkpoint(export_dir: str, filename_prefix: str = 'model') → None[source]

Export tensorflow checkpoint to export_dir.

export_model(export_dir: str, onnx: Optional[int] = None) → None[source]

Export tensorflow graph to export_dir for serving.

import_model_from_h5(import_file: str) → None[source]

Imports weights into tf model.

get_session() → Optional[<Mock name=’mock.compat.v1.Session’ id=’139804175137936’>][source]

Returns a reference to the TF session for this policy.

variables()[source]

Return the list of all savable variables for this policy.

get_placeholder(name) → <Mock name=’mock.compat.v1.placeholder’ id=’139804118316240’>[source]

Returns the given action or loss input placeholder by name.

If the loss has not been initialized and a loss input placeholder is requested, an error is raised.

Parameters

name (str) – The name of the placeholder to return. One of SampleBatch.CUR_OBS|PREV_ACTION/REWARD or a valid key from self._loss_input_dict.

Returns

The placeholder under the given str key.

Return type

tf1.placeholder

loss_initialized() → bool[source]

Returns whether the loss term(s) have been initialized.

copy(existing_inputs: List[Tuple[str, tf1.placeholder]])ray.rllib.policy.tf_policy.TFPolicy[source]

Creates a copy of self using existing input placeholders.

Optional: Only required to work with the multi-GPU optimizer.

Parameters

existing_inputs (List[Tuple[str, tf1.placeholder]]) – Dict mapping names (str) to tf1.placeholders to re-use (share) with the returned copy of self.

Returns

A copy of self.

Return type

TFPolicy

extra_compute_action_feed_dict() → Dict[Any, Any][source]

Extra dict to pass to the compute actions session run.

Returns

A feed dict to be added to the

feed_dict passed to the compute_actions session.run() call.

Return type

Dict[TensorType, TensorType]

extra_compute_action_fetches() → Dict[str, Any][source]

Extra values to fetch and return from compute_actions().

By default we return action probability/log-likelihood info and action distribution inputs (if present).

Returns

An extra fetch-dict to be passed to and

returned from the compute_actions() call.

Return type

Dict[str, TensorType]

extra_compute_grad_feed_dict() → Dict[Any, Any][source]

Extra dict to pass to the compute gradients session run.

Returns

Extra feed_dict to be passed to the

compute_gradients Session.run() call.

Return type

Dict[TensorType, TensorType]

extra_compute_grad_fetches() → Dict[str, any][source]

Extra values to fetch and return from compute_gradients().

Returns

Extra fetch dict to be added to the fetch dict

of the compute_gradients Session.run() call.

Return type

Dict[str, any]

optimizer() → <Mock name=’mock.keras.optimizers.Optimizer’ id=’139804118482384’>[source]

TF optimizer to use for policy optimization.

Returns

The local optimizer to use for this

Policy’s Model.

Return type

tf.keras.optimizers.Optimizer

gradients(optimizer: Union[tf.keras.optimizers.Optimizer, torch.optim.Optimizer, List[Union[tf.keras.optimizers.Optimizer, torch.optim.Optimizer]]], loss: Union[Any, List[Any]]) → Union[List[Union[List[Tuple[Any, Any]], List[Any]]], List[List[Union[List[Tuple[Any, Any]], List[Any]]]]][source]

Override this for a custom gradient computation behavior.

Parameters
  • optimizer (Union[LocalOptimizer, List[LocalOptimizer]]) – A single LocalOptimizer of a list thereof to use for gradient calculations. If more than one optimizer given, the number of optimizers must match the number of losses provided.

  • loss (Union[TensorType, List[TensorType]]) – A single loss term or a list thereof to use for gradient calculations. If more than one loss given, the number of loss terms must match the number of optimizers provided.

Returns

List of

ModelGradients (grads and vars OR just grads) OR List of List of ModelGradients in case we have more than one optimizer/loss.

Return type

Union[List[ModelGradients], List[List[ModelGradients]]]

build_apply_op(optimizer: Union[tf.keras.optimizers.Optimizer, torch.optim.Optimizer, List[Union[tf.keras.optimizers.Optimizer, torch.optim.Optimizer]]], grads_and_vars: Union[List[Tuple[Any, Any]], List[Any], List[Union[List[Tuple[Any, Any]], List[Any]]]]) → tf.Operation[source]

Override this for a custom gradient apply computation behavior.

Parameters
  • optimizer (Union[LocalOptimizer, List[LocalOptimizer]]) – The local tf optimizer to use for applying the grads and vars.

  • grads_and_vars (Union[ModelGradients, List[ModelGradients]]) – List of tuples with grad values and the grad-value’s corresponding tf.variable in it.

Returns

The tf op that applies all computed gradients

(grads_and_vars) to the model(s) via the given optimizer(s).

Return type

tf.Operation

DynamicTFPolicy

class ray.rllib.policy.dynamic_tf_policy.DynamicTFPolicy(obs_space: <Mock name='mock.spaces.Space' id='139804182087632'>, action_space: <Mock name='mock.spaces.Space' id='139804182087632'>, config: dict, loss_fn: Callable[[ray.rllib.policy.policy.Policy, ray.rllib.models.modelv2.ModelV2, Type[ray.rllib.models.tf.tf_action_dist.TFActionDistribution], ray.rllib.policy.sample_batch.SampleBatch], Any], *, stats_fn: Optional[Callable[[ray.rllib.policy.policy.Policy, ray.rllib.policy.sample_batch.SampleBatch], Dict[str, Any]]] = None, grad_stats_fn: Optional[Callable[[ray.rllib.policy.policy.Policy, ray.rllib.policy.sample_batch.SampleBatch, Union[List[Tuple[Any, Any]], List[Any]]], Dict[str, Any]]] = None, before_loss_init: Optional[Callable[[ray.rllib.policy.policy.Policy, <Mock name='mock.spaces.Space' id='139804182087632'>, <Mock name='mock.spaces.Space' id='139804182087632'>, dict], None]] = None, make_model: Optional[Callable[[ray.rllib.policy.policy.Policy, <Mock name='mock.spaces.Space' id='139804182087632'>, <Mock name='mock.spaces.Space' id='139804182087632'>, dict], ray.rllib.models.modelv2.ModelV2]] = None, action_sampler_fn: Optional[Callable[[Any, List[Any]], Tuple[Any, Any]]] = None, action_distribution_fn: Optional[Callable[[ray.rllib.policy.policy.Policy, ray.rllib.models.modelv2.ModelV2, Any, Any, Any], Tuple[Any, type, List[Any]]]] = None, existing_inputs: Optional[Dict[str, tf1.placeholder]] = None, existing_model: Optional[ray.rllib.models.modelv2.ModelV2] = None, get_batch_divisibility_req: Optional[Callable[[ray.rllib.policy.policy.Policy], int]] = None, obs_include_prev_action_reward=-1)[source]

A TFPolicy that auto-defines placeholders dynamically at runtime.

Do not sub-class this class directly (neither should you sub-class TFPolicy), but rather use rllib.policy.tf_policy_template.build_tf_policy to generate your custom tf (graph-mode or eager) Policy classes.

__init__(obs_space: <Mock name='mock.spaces.Space' id='139804182087632'>, action_space: <Mock name='mock.spaces.Space' id='139804182087632'>, config: dict, loss_fn: Callable[[ray.rllib.policy.policy.Policy, ray.rllib.models.modelv2.ModelV2, Type[ray.rllib.models.tf.tf_action_dist.TFActionDistribution], ray.rllib.policy.sample_batch.SampleBatch], Any], *, stats_fn: Optional[Callable[[ray.rllib.policy.policy.Policy, ray.rllib.policy.sample_batch.SampleBatch], Dict[str, Any]]] = None, grad_stats_fn: Optional[Callable[[ray.rllib.policy.policy.Policy, ray.rllib.policy.sample_batch.SampleBatch, Union[List[Tuple[Any, Any]], List[Any]]], Dict[str, Any]]] = None, before_loss_init: Optional[Callable[[ray.rllib.policy.policy.Policy, <Mock name='mock.spaces.Space' id='139804182087632'>, <Mock name='mock.spaces.Space' id='139804182087632'>, dict], None]] = None, make_model: Optional[Callable[[ray.rllib.policy.policy.Policy, <Mock name='mock.spaces.Space' id='139804182087632'>, <Mock name='mock.spaces.Space' id='139804182087632'>, dict], ray.rllib.models.modelv2.ModelV2]] = None, action_sampler_fn: Optional[Callable[[Any, List[Any]], Tuple[Any, Any]]] = None, action_distribution_fn: Optional[Callable[[ray.rllib.policy.policy.Policy, ray.rllib.models.modelv2.ModelV2, Any, Any, Any], Tuple[Any, type, List[Any]]]] = None, existing_inputs: Optional[Dict[str, tf1.placeholder]] = None, existing_model: Optional[ray.rllib.models.modelv2.ModelV2] = None, get_batch_divisibility_req: Optional[Callable[[ray.rllib.policy.policy.Policy], int]] = None, obs_include_prev_action_reward=-1)[source]

Initializes a DynamicTFPolicy instance.

Initialization of this class occurs in two phases and defines the static graph.

Phase 1: The model is created and model variables are initialized.

Phase 2: A fake batch of data is created, sent to the trajectory postprocessor, and then used to create placeholders for the loss function. The loss and stats functions are initialized with these placeholders.

Parameters
  • observation_space – Observation space of the policy.

  • action_space – Action space of the policy.

  • config – Policy-specific configuration data.

  • loss_fn – Function that returns a loss tensor for the policy graph.

  • stats_fn – Optional callable that - given the policy and batch input tensors - returns a dict mapping str to TF ops. These ops are fetched from the graph after loss calculations and the resulting values can be found in the results dict returned by e.g. Trainer.train() or in tensorboard (if TB logging is enabled).

  • grad_stats_fn – Optional callable that - given the policy, batch input tensors, and calculated loss gradient tensors - returns a dict mapping str to TF ops. These ops are fetched from the graph after loss and gradient calculations and the resulting values can be found in the results dict returned by e.g. Trainer.train() or in tensorboard (if TB logging is enabled).

  • before_loss_init – Optional function to run prior to loss init that takes the same arguments as __init__.

  • make_model – Optional function that returns a ModelV2 object given policy, obs_space, action_space, and policy config. All policy variables should be created in this function. If not specified, a default model will be created.

  • action_sampler_fn – A callable returning a sampled action and its log-likelihood given Policy, ModelV2, observation inputs, explore, and is_training. Provide action_sampler_fn if you would like to have full control over the action computation step, including the model forward pass, possible sampling from a distribution, and exploration logic. Note: If action_sampler_fn is given, action_distribution_fn must be None. If both action_sampler_fn and action_distribution_fn are None, RLlib will simply pass inputs through self.model to get distribution inputs, create the distribution object, sample from it, and apply some exploration logic to the results. The callable takes as inputs: Policy, ModelV2, obs_batch, state_batches (optional), seq_lens (optional), prev_actions_batch (optional), prev_rewards_batch (optional), explore, and is_training.

  • action_distribution_fn – A callable returning distribution inputs (parameters), a dist-class to generate an action distribution object from, and internal-state outputs (or an empty list if not applicable). Provide action_distribution_fn if you would like to only customize the model forward pass call. The resulting distribution parameters are then used by RLlib to create a distribution object, sample from it, and execute any exploration logic. Note: If action_distribution_fn is given, action_sampler_fn must be None. If both action_sampler_fn and action_distribution_fn are None, RLlib will simply pass inputs through self.model to get distribution inputs, create the distribution object, sample from it, and apply some exploration logic to the results. The callable takes as inputs: Policy, ModelV2, input_dict, explore, timestep, is_training.

  • existing_inputs – When copying a policy, this specifies an existing dict of placeholders to use instead of defining new ones.

  • existing_model – When copying a policy, this specifies an existing model to clone and share weights with.

  • get_batch_divisibility_req – Optional callable that returns the divisibility requirement for sample batches. If None, will assume a value of 1.

copy(existing_inputs: List[Tuple[str, tf1.placeholder]])ray.rllib.policy.tf_policy.TFPolicy[source]

Creates a copy of self using existing input placeholders.

get_initial_state() → List[Any][source]

Returns initial RNN state for the current policy.

Returns

Initial RNN state for the current policy.

Return type

List[TensorType]

load_batch_into_buffer(batch: ray.rllib.policy.sample_batch.SampleBatch, buffer_index: int = 0) → int[source]

Bulk-loads the given SampleBatch into the devices’ memories.

The data is split equally across all the Policy’s devices. If the data is not evenly divisible by the batch size, excess data should be discarded.

Parameters
  • batch – The SampleBatch to load.

  • buffer_index – The index of the buffer (a MultiGPUTowerStack) to use on the devices. The number of buffers on each device depends on the value of the num_multi_gpu_tower_stacks config key.

Returns

The number of tuples loaded per device.

get_num_samples_loaded_into_buffer(buffer_index: int = 0) → int[source]

Returns the number of currently loaded samples in the given buffer.

Parameters

buffer_index – The index of the buffer (a MultiGPUTowerStack) to use on the devices. The number of buffers on each device depends on the value of the num_multi_gpu_tower_stacks config key.

Returns

The number of tuples loaded per device.

learn_on_loaded_batch(offset: int = 0, buffer_index: int = 0)[source]

Runs a single step of SGD on an already loaded data in a buffer.

Runs an SGD step over a slice of the pre-loaded batch, offset by the offset argument (useful for performing n minibatch SGD updates repeatedly on the same, already pre-loaded data).

Updates the model weights based on the averaged per-device gradients.

Parameters
  • offset – Offset into the preloaded data. Used for pre-loading a train-batch once to a device, then iterating over (subsampling through) this batch n times doing minibatch SGD.

  • buffer_index – The index of the buffer (a MultiGPUTowerStack) to take the already pre-loaded data from. The number of buffers on each device depends on the value of the num_multi_gpu_tower_stacks config key.

Returns

The outputs of extra_ops evaluated over the batch.