ray.train.lightning.LightningConfigBuilder#

class ray.train.lightning.LightningConfigBuilder[source]#

Bases: object

Configuration Class to pass into LightningTrainer.

Example

import torch
import torch.nn as nn
from ray.train.lightning import LightningConfigBuilder

class LinearModule(pl.LightningModule):
    def __init__(self, input_dim, output_dim) -> None:
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, input):
        return self.linear(input)

    def training_step(self, batch):
        output = self.forward(batch)
        loss = torch.sum(output)
        self.log("loss", loss)
        return loss

    def predict_step(self, batch, batch_idx):
        return self.forward(batch)

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.1)

lightning_config = (
    LightningConfigBuilder()
    .module(
        cls=LinearModule,
        input_dim=32,
        output_dim=4,
    )
    .trainer(max_epochs=5, accelerator="gpu")
    .fit_params(datamodule=datamodule)
    .checkpointing(monitor="loss", save_top_k=2, mode="min")
    .build()
)

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

module(cls: Type[pytorch_lightning.core.lightning.LightningModule], **kwargs) ray.train.lightning.lightning_trainer.LightningConfigBuilder[source]#

Set up the Pytorch Lightning module class.

Parameters
  • cls – A subclass of pytorch_lightning.LightningModule that defines your model and training logic. Note that this is a class definition instead of a class instance.

  • **kwargs – The initialization argument list of your lightning module.

trainer(**kwargs) ray.train.lightning.lightning_trainer.LightningConfigBuilder[source]#

Set up the configurations of pytorch_lightning.Trainer.

Parameters

kwargs – The initialization arguments for pytorch_lightning.Trainer For valid arguments to pass, please refer to: https://lightning.ai/docs/pytorch/stable/common/trainer.html#init.

fit_params(**kwargs) ray.train.lightning.lightning_trainer.LightningConfigBuilder[source]#

The parameter lists for pytorch_lightning.Trainer.fit()

Parameters

kwargs – The parameter lists for pytorch_lightning.Trainer.fit() For valid arguments to pass, please refer to: https://lightning.ai/docs/pytorch/stable/common/trainer.html#fit.

ddp_strategy(**kwargs) ray.train.lightning.lightning_trainer.LightningConfigBuilder[source]#

Set up the configurations of pytorch_lightning.Trainer.

Parameters

kwargs – For valid arguments to pass, please refer to: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.DDPStrategy.html

checkpointing(**kwargs) ray.train.lightning.lightning_trainer.LightningConfigBuilder[source]#

Set up the configurations of pytorch_lightning.callbacks.ModelCheckpoint.

LightningTrainer creates a ModelCheckpoint callback based on this config. The AIR checkpointing and logging methods are triggered in that callback.

Parameters

kwargs – For valid arguments to pass, please refer to: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html

build() Dict[str, Any][source]#

Build and return a config dictionary to pass into LightningTrainer