Dataset provides a simple abstraction for training with
Get in touch with us if you’re using or considering using RaySGD!
Setting up a dataset¶
A dataset can be constructed via any python iterable, or a
ParallelIterator. Optionally, a batch size, download function, concurrency, and a transformation can also be specified.
When constructing a dataset, a download function can be specified. For example, if a dataset is initialized with a set of paths, a download function can be specified which converts those paths to
(input, label) tuples. The download function can be executed in parallel via
max_concurrency. This may be useful if the backing datastore has rate limits, there is high overhead associated with a download, or downloading is computationally expensive. Downloaded data is stored as objects in the plasma store.
An additional, final transformation can be specified via
Dataset::transform. This function is guaranteed to take place on the same worker that training will take place on. It is good practice to do operations which produce large outputs, such as converting images to tensors as transformations.
Finally, the batch size can be specified. The batch size is the number of data points used per training step per worker.
Batch size should be specified via the dataset’s constructor, __not__ the
config["batch_size"] passed into the Trainer constructor. In general, datasets are configured via their own constructor, not the Trainer config, wherever possible.
Using a dataset¶
To use a dataset, pass it in as an argument to
trainer.train(). A dataset passed in to
trainer.train will take precedence over the trainer’s data creator during that training run.
trainer.train(dataset=dataset, num_steps=10) # Trains using a dataset trainer.train() # Trains with the original data creator trainer.train(dataset=dataset2, num_steps=20) # Trains using a different dataset
Complete dataset example¶
Below is an example of training a network with a single hidden layer to learn the identity function.
import ray from ray.util.sgd.torch.torch_trainer import TorchTrainer from ray.util.sgd.data.dataset import Dataset import torch from torch import nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(1, 128) self.fc2 = nn.Linear(128, 1) def forward(self, x): x = self.fc1(x) x = F.relu(x) x = self.fc2(x) return x def model_creator(config): return Net() def optimizer_creator(model, config): return torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-4)) def to_mat(x): return torch.tensor([[x]]).float() def dataset_creator(): num_points = 32 * 100 * 2 data = [i * (1 / num_points) for i in range(num_points)] dataset = Dataset( data, batch_size=32, max_concurrency=2, download_func=lambda x: (to_mat(x), to_mat(x))) return dataset def main(): dataset = dataset_creator() trainer = TorchTrainer( model_creator=model_creator, data_creator=None, optimizer_creator=optimizer_creator, loss_creator=torch.nn.MSELoss, num_workers=2, ) for i in range(10): # Train a full epoch using the data_creator # trainer.train() # Train for another epoch using the dataset trainer.train(dataset=dataset, num_steps=100) model = trainer.get_model() print("f(0.5)=", float(model(to_mat(0.5)))) if __name__ == "__main__": ray.init() main()