ray.train.lightning.LightningConfigBuilder
ray.train.lightning.LightningConfigBuilder#
- class ray.train.lightning.LightningConfigBuilder[source]#
Bases:
object
Configuration Class to pass into LightningTrainer.
Example
import torch import torch.nn as nn from ray.train.lightning import LightningConfigBuilder class LinearModule(pl.LightningModule): def __init__(self, input_dim, output_dim) -> None: super().__init__() self.linear = nn.Linear(input_dim, output_dim) def forward(self, input): return self.linear(input) def training_step(self, batch): output = self.forward(batch) loss = torch.sum(output) self.log("loss", loss) return loss def predict_step(self, batch, batch_idx): return self.forward(batch) def configure_optimizers(self): return torch.optim.SGD(self.parameters(), lr=0.1) lightning_config = ( LightningConfigBuilder() .module( cls=LinearModule, input_dim=32, output_dim=4, ) .trainer(max_epochs=5, accelerator="gpu") .fit_params(datamodule=datamodule) .checkpointing(monitor="loss", save_top_k=2, mode="min") .build() )
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- module(cls: Type[pytorch_lightning.core.lightning.LightningModule], **kwargs) ray.train.lightning.lightning_trainer.LightningConfigBuilder [source]#
Set up the Pytorch Lightning module class.
- Parameters
cls – A subclass of
pytorch_lightning.LightningModule
that defines your model and training logic. Note that this is a class definition instead of a class instance.**kwargs – The initialization argument list of your lightning module.
- trainer(**kwargs) ray.train.lightning.lightning_trainer.LightningConfigBuilder [source]#
Set up the configurations of
pytorch_lightning.Trainer
.- Parameters
kwargs – The initialization arguments for
pytorch_lightning.Trainer
For valid arguments to pass, please refer to: https://lightning.ai/docs/pytorch/stable/common/trainer.html#init.
- fit_params(**kwargs) ray.train.lightning.lightning_trainer.LightningConfigBuilder [source]#
The parameter lists for
pytorch_lightning.Trainer.fit()
- Parameters
kwargs – The parameter lists for
pytorch_lightning.Trainer.fit()
For valid arguments to pass, please refer to: https://lightning.ai/docs/pytorch/stable/common/trainer.html#fit.
- ddp_strategy(**kwargs) ray.train.lightning.lightning_trainer.LightningConfigBuilder [source]#
Set up the configurations of
pytorch_lightning.Trainer
.- Parameters
kwargs – For valid arguments to pass, please refer to: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.DDPStrategy.html
- checkpointing(**kwargs) ray.train.lightning.lightning_trainer.LightningConfigBuilder [source]#
Set up the configurations of
pytorch_lightning.callbacks.ModelCheckpoint
.LightningTrainer creates a
ModelCheckpoint
callback based on this config. The AIR checkpointing and logging methods are triggered in that callback.- Parameters
kwargs – For valid arguments to pass, please refer to: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html