{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "TsniIjjg2Pym" }, "source": [ "*This example is adapted from Continual AI Avalanche quick start https://avalanche.continualai.org/*" ] }, { "cell_type": "markdown", "metadata": { "id": "1VsUrzVm1W-h" }, "source": [ "# Incremental Learning with Ray AIR\n", "\n", "In this example, we show how to use Ray AIR to incrementally train a simple image classification PyTorch model\n", "on a stream of incoming tasks.\n", "\n", "Each task is a random permutation of the MNIST Dataset, which is a common benchmark\n", "used for continual training. After training on all the\n", "tasks, the model is expected to be able to make predictions on data from any task.\n", "\n", "In this example, we use just a naive finetuning strategy, where the model is trained\n", "on each task, without any special methods to prevent [catastrophic forgetting](\n", "https://en.wikipedia.org/wiki/Catastrophic_interference). Model performance is\n", "expected to be poor.\n", "\n", "More precisely, this example showcases domain incremental training, in which during\n", "prediction/testing\n", "time, the model is asked to predict on data from tasks trained on so far with the\n", "task ID not provided. This is opposed to task incremental training, where the task ID is\n", "provided during prediction/testing time.\n", "\n", "For more information on the 3 different categories for incremental/continual\n", "learning, please see [\"Three scenarios for continual learning\" by van de Ven and Tolias](https://arxiv.org/pdf/1904.07734.pdf)" ] }, { "cell_type": "markdown", "metadata": { "id": "Q3oGiuqYfj9_" }, "source": [ "This example will cover the following:\n", "1. Loading a PyTorch Dataset to Ray Datasets\n", "2. Create an `Iterator[ray.data.Datasets]` abstraction to represent a stream of data to train on for incremental training.\n", "3. Implement a custom Ray AIR preprocessor to preprocess the Dataset.\n", "4. Incrementally train a model using data parallel training.\n", "5. Use our trained model to perform batch prediction on test data.\n", "6. Incrementally deploying our trained model with Ray Serve and performing online prediction queries." ] }, { "cell_type": "markdown", "metadata": { "id": "z52Y8O4q1bIk" }, "source": [ "# Step 1: Installations and Initializing Ray\n", "\n", "To get started, let's first install the necessary packages: Ray AIR, torch, and torchvision. Uncomment the below lines and run the cell to install the necessary packages." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "kWr6BRMk1Y1j", "outputId": "dad49a31-a602-4e44-b5fe-932de603925e" }, "outputs": [], "source": [ "# !pip install -q \"ray[air]\"\n", "# !pip install -q torch\n", "# !pip install -q torchvision" ] }, { "cell_type": "markdown", "metadata": { "id": "RpD4STX3g1dq" }, "source": [ "Then, let's initialize Ray! We can just import and call `ray.init()`. If you are running on a Ray cluster, then you can do `ray.init(\"auto\")` to connect to the cluster instead of initiailzing a new local Ray instance." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "72fEFqL4T7iA", "outputId": "9cae25f2-c712-4baa-f66b-337049e1b565" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-09-23 16:31:18,554\tINFO worker.py:1509 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32m127.0.0.1:8265 \u001b[39m\u001b[22m\n" ] }, { "data": { "text/html": [ "
\n", "
\n", "

Ray

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", "
Python version:3.10.6
Ray version: 3.0.0.dev0
Dashboard:http://127.0.0.1:8265
\n", "
\n", "
\n" ], "text/plain": [ "RayContext(dashboard_url='127.0.0.1:8265', python_version='3.10.6', ray_version='3.0.0.dev0', ray_commit='{{RAY_COMMIT_SHA}}', address_info={'node_ip_address': '10.109.175.190', 'raylet_ip_address': '10.109.175.190', 'redis_address': None, 'object_store_address': '/tmp/ray/session_2022-09-23_16-31-16_736743_855752/sockets/plasma_store', 'raylet_socket_name': '/tmp/ray/session_2022-09-23_16-31-16_736743_855752/sockets/raylet', 'webui_url': '127.0.0.1:8265', 'session_dir': '/tmp/ray/session_2022-09-23_16-31-16_736743_855752', 'metrics_export_port': 59668, 'gcs_address': '10.109.175.190:64318', 'address': '10.109.175.190:64318', 'dashboard_agent_listen_port': 52365, 'node_id': '610d4158d56aeda61abd25d5751611d23ba1aa97eddb34d2ee4e6020'})" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import ray\n", "ray.init()\n", "# If runnning on a cluster, use the below line instead.\n", "# ray.init(\"auto\")" ] }, { "cell_type": "markdown", "metadata": { "id": "AedcxD_FClQL" }, "source": [ "# Step 2: Define our PyTorch Model\n", "\n", "Now that we have the necessary installations, let's define our PyTorch model. For this example to classify MNIST images, we will use a simple multi-layer perceptron." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "3TVkSmFFCHhI" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/pdmurray/.pyenv/versions/mambaforge/envs/ray/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import torch.nn as nn\n", "\n", "class SimpleMLP(nn.Module):\n", " def __init__(self, num_classes=10, input_size=28 * 28):\n", " super(SimpleMLP, self).__init__()\n", "\n", " self.features = nn.Sequential(\n", " nn.Linear(input_size, 512),\n", " nn.ReLU(inplace=True),\n", " nn.Dropout(),\n", " )\n", " self.classifier = nn.Linear(512, num_classes)\n", " self._input_size = input_size\n", "\n", " def forward(self, x):\n", " x = x.contiguous()\n", " x = x.view(-1, self._input_size)\n", " x = self.features(x)\n", " x = self.classifier(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": { "id": "L2N1U22VC_N9" }, "source": [ "# Step 3: Create the Stream of tasks\n", "\n", "We can now create a stream of tasks (where each task contains a dataset to train on). For this example, we will create an artificial stream of tasks consisting of\n", "permuted variations of MNIST, which is a classic benchmark in continual learning\n", "research.\n", "\n", "For real-world scenarios, this step is not necessary as fresh data will already be\n", "arriving as a stream of tasks. It does not need to be artificially created." ] }, { "cell_type": "markdown", "metadata": { "id": "3SVSrkqrDJuc" }, "source": [ "## 3a: Load MNIST Dataset to a Ray Dataset\n", "\n", "Let's first define a simple function that will return the original MNIST Dataset as a distributed Ray Dataset. Ray Datasets are the standard way to load and exchange data in Ray libraries and applications, read more about them [here](https://docs.ray.io/en/latest/data/dataset.html)!\n", "\n", "The function in the below code snippet does the following:\n", "1. Downloads the MNIST Dataset from torchvision in-memory\n", "2. Loads the in-memory Torch Dataset into a Ray Dataset\n", "3. Converts the Ray Dataset into Numpy format. Instead of the Ray Dataset iterating over tuples, it will have 2 columns: \"image\" & \"label\". \n", "This will allow us to apply built-in preprocessors to the Ray Dataset and allow Ray Datasets to be used with Ray AIR Predictors.\n", "\n", "For this example, since we are just working with MNIST dataset, which is small, we use the {py:class}`~ray.data.datasource.from_torch` which just loads the full MNIST dataset into memory.\n", "\n", "For loading larger datasets in a parallel fashion, you should use [Ray Dataset's additional read APIs](https://docs.ray.io/en/master/data/dataset.html#supported-input-formats) to load data from parquet, csv, image files, and more!" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "0XKwJKrNCxg4" }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "import torchvision\n", "from torchvision.transforms import RandomCrop\n", "\n", "import ray\n", "\n", "\n", "def get_mnist_dataset(train: bool = True) -> ray.data.Dataset:\n", " \"\"\"Returns MNIST Dataset as a ray.data.Dataset.\n", " \n", " Args:\n", " train: Whether to return the train dataset or test dataset.\n", " \"\"\"\n", " if train:\n", " # Only perform random cropping on the Train dataset.\n", " transform = RandomCrop(28, padding=4)\n", " else:\n", " transform = None\n", " \n", " mnist_dataset = torchvision.datasets.MNIST(\"./data\", download=True, train=train, transform=transform)\n", " mnist_dataset = ray.data.from_torch(mnist_dataset)\n", " \n", " def convert_batch_to_numpy(batch):\n", " images = np.array([np.array(item[0]) for item in batch])\n", " labels = np.array([item[1] for item in batch])\n", "\n", " return {\"image\": images, \"label\": labels}\n", "\n", " mnist_dataset = mnist_dataset.map_batches(convert_batch_to_numpy)\n", " return mnist_dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "vqrfgfl9YnVe" }, "source": [ "## 3b: Create our Stream abstraction\n", "\n", "Now we can create our \"stream\" abstraction. This abstraction provides two\n", "methods (`generate_train_stream` and `generate_test_stream`) that each returns an Iterator\n", "over Ray Datasets. Each item in this iterator contains a unique permutation of\n", "MNIST, and is one task that we want to train on.\n", "\n", "In this example, \"the stream of tasks\" is contrived since all the data for all tasks exist already in an offline setting. For true online continual learning, you would want to implement a custom dataset iterator that reads from some stream datasource to produce new tasks. The only abstraction that's needed is `Iterator[ray.data.Dataset]`.\n", "\n", "Note that the test dataset stream has the same permutations that are used for the training dataset stream. In general for continual learning, it is expected that the data distribution of the test/prediction data follows what the model was trained on. If you notice that the distribution of new prediction queries is changing compared to the distribution of the training data, then you should probably trigger training of a new task." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "f2EagMWCN3he" }, "outputs": [], "source": [ "from typing import Dict, Iterator, List\n", "import random\n", "import numpy as np\n", "\n", "from ray.data import ActorPoolStrategy\n", "\n", "\n", "class PermutedMNISTStream:\n", " \"\"\"Generates streams of permuted MNIST Datasets.\n", " \n", " Example:\n", " \n", " permuted_mnist = PermutedMNISTStream(n_tasks=3)\n", " train_stream = permuted_mnist.generate_train_stream()\n", " \n", " # Iterate through the train_stream\n", " for train_dataset in train_stream:\n", " ...\n", " \n", " Args:\n", " n_tasks: The number of tasks to generate.\n", " \"\"\"\n", "\n", " def __init__(self, n_tasks: int = 3):\n", " self.n_tasks = n_tasks\n", " self.permutations = [\n", " np.random.permutation(28 * 28) for _ in range(self.n_tasks)\n", " ]\n", "\n", " self.train_mnist_dataset = get_mnist_dataset(train=True)\n", " self.test_mnist_dataset = get_mnist_dataset(train=False)\n", "\n", " def random_permute_dataset(\n", " self, dataset: ray.data.Dataset, permutation: np.ndarray\n", " ):\n", " \"\"\"Randomly permutes the pixels for each image in the dataset.\"\"\"\n", "\n", " class PixelsPermutation(object):\n", " def __call__(self, batch):\n", " batch[\"image\"] = batch[\"image\"].map(lambda image: image.reshape(-1)[permutation].reshape(28, 28))\n", " return batch\n", "\n", " return dataset.map_batches(PixelsPermutation, compute=ActorPoolStrategy(), batch_format=\"pandas\")\n", "\n", " def generate_train_stream(self) -> Iterator[ray.data.Dataset]:\n", " for permutation in self.permutations:\n", " permuted_mnist_dataset = self.random_permute_dataset(\n", " self.train_mnist_dataset, permutation\n", " )\n", " yield permuted_mnist_dataset\n", "\n", " def generate_test_stream(self) -> Iterator[ray.data.Dataset]:\n", " for permutation in self.permutations:\n", " mnist_dataset = get_mnist_dataset(train=False)\n", " permuted_mnist_dataset = self.random_permute_dataset(\n", " self.test_mnist_dataset, permutation\n", " )\n", " yield permuted_mnist_dataset\n", "\n", " def generate_test_samples(self, num_samples: int = 10) -> List[np.ndarray]:\n", " \"\"\"Generates num_samples permuted MNIST images.\"\"\"\n", " random_permutation = random.choice(self.permutations)\n", " return list(self.random_permute_dataset(\n", " self.test_mnist_dataset.random_shuffle().limit(num_samples),\n", " random_permutation,\n", " ).to_pandas()[\"image\"].to_numpy())\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "HDGHgtb699kd" }, "source": [ "# Step 4: Define the logic for Training and Inference/Prediction\n", "\n", "Now that we can get an Iterator over Ray Datasets, we can incrementally train our model in a data parallel fashion via Ray Train, while incrementally deploying our model via Ray Serve. Let's define some helper functions to allow us to do this!\n", "\n", "If you are not familiar with data parallel training, it is a form of distributed training strategies, where we have multiple model replicas, and each replica trains on a different batch of data. After each batch, the gradients are synchronized across the replicas. This effecitively allows us to train on more data in a shorter amount of time." ] }, { "cell_type": "markdown", "metadata": { "id": "SBWxP1sP-G-o" }, "source": [ "## 4a: Define our training logic for each Data Parallel worker\n", "\n", "The first thing we need to do is to define the training loop that will be run on each training worker. \n", "\n", "The training loop takes in a `config` Dict as an argument that we can use to pass in any configurations for training.\n", "\n", "This is just standard PyTorch training, with the difference being that we can leverage [Ray Train's utility functions](https://docs.ray.io/en/master/train/api.html#training-function-utilities) and [Ray AIR Sesssion](https://docs.ray.io/en/master/ray-air/package-ref.html#module-ray.air.session):\n", "- `ray.train.torch.prepare_model(...)`: This will prepare the model for distributed training by wrapping it in either PyTorch `DistributedDataParallel` or `FullyShardedDataParallel` and moving it to the correct accelerator device.\n", "- `ray.air.session.get_dataset_shard(...)`: This will get the Ray Dataset shard for this particular Data Parallel worker.\n", "- `ray.air.session.report({}, checkpoint=...)`: This will tell Ray Train to persist the provided `Checkpoint` object.\n", "- `ray.air.session.get_checkpoint()`: Returns a checkpoint to resume from. This is useful for either fault tolerance purposes, or for our purposes, to continue training the same model on a new incoming dataset." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "Y9IRDMec-GZ9" }, "outputs": [], "source": [ "from ray import train\n", "from ray.air import session, Checkpoint\n", "\n", "from torch.optim import SGD\n", "from torch.nn import CrossEntropyLoss\n", "\n", "def train_loop_per_worker(config: dict):\n", " num_epochs = config[\"num_epochs\"]\n", " learning_rate = config[\"learning_rate\"]\n", " momentum = config[\"momentum\"]\n", " batch_size = config[\"batch_size\"]\n", "\n", " model = SimpleMLP(num_classes=10)\n", "\n", " # Load model from checkpoint if there is a checkpoint to load from.\n", " checkpoint_to_load = session.get_checkpoint()\n", " if checkpoint_to_load:\n", " state_dict_to_resume_from = checkpoint_to_load.to_dict()[\"model\"]\n", " model.load_state_dict(state_dict=state_dict_to_resume_from)\n", "\n", " model = train.torch.prepare_model(model)\n", "\n", " optimizer = SGD(model.parameters(), lr=learning_rate, momentum=momentum)\n", " criterion = CrossEntropyLoss()\n", "\n", " # Get the Ray Dataset shard for this data parallel worker, and convert it to a PyTorch Dataset.\n", " dataset_shard = session.get_dataset_shard(\"train\").iter_torch_batches(\n", " batch_size=batch_size,\n", " )\n", "\n", " for epoch_idx in range(num_epochs):\n", " running_loss = 0\n", " for iteration, batch in enumerate(dataset_shard):\n", " optimizer.zero_grad()\n", " train_mb_x, train_mb_y = batch[\"image\"], batch[\"label\"]\n", " train_mb_x = train_mb_x.to(train.torch.get_device())\n", " train_mb_y = train_mb_y.to(train.torch.get_device())\n", "\n", " # Forward\n", " logits = model(train_mb_x)\n", " # Loss\n", " loss = criterion(logits, train_mb_y)\n", " # Backward\n", " loss.backward()\n", " # Update\n", " optimizer.step()\n", "\n", " running_loss += loss.item()\n", " if session.get_world_rank() == 0 and iteration % 500 == 0:\n", " print(f\"loss: {loss.item():>7f}, epoch: {epoch_idx}, iteration: {iteration}\")\n", "\n", " # Checkpoint model after every epoch.\n", " state_dict = model.state_dict()\n", " checkpoint = Checkpoint.from_dict(dict(model=state_dict))\n", " session.report({\"loss\": running_loss}, checkpoint=checkpoint)" ] }, { "cell_type": "markdown", "metadata": { "id": "9HUciluylZbX" }, "source": [ "## 4b: Define our Preprocessor\n", "\n", "Next, we define our `Preprocessor` to preprocess our data before training and prediction. Our preprocessor will normalize the MNIST Images by the mean and standard deviation of the MNIST training dataset. This is a common operation to do on MNIST to improve training: https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "yHzQZTlAlY-9" }, "outputs": [], "source": [ "from typing import Dict\n", "import numpy as np\n", "\n", "import torch\n", "from torchvision import transforms\n", "\n", "from ray.data.preprocessors import TorchVisionPreprocessor\n", "\n", "\n", "transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,))\n", "])\n", "mnist_normalize_preprocessor = TorchVisionPreprocessor(columns=[\"image\"], transform=transform)" ] }, { "cell_type": "markdown", "metadata": { "id": "Uto3v90Hagni" }, "source": [ "## 4c: Define logic for Batch/Offline Prediction.\n", "\n", "After training on each task, we want to use our trained model to do batch (i.e. offline) inference on a test dataset. \n", "\n", "To do this, we leverage the built-in `ray.air.BatchPredictor`. We define a `batch_predict` function that will take in a Checkpoint and a Test Dataset and outputs the accuracy our model achieves on the test dataset." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "DM2lFHzFa6uI" }, "outputs": [], "source": [ "from ray.train.batch_predictor import BatchPredictor\n", "from ray.train.torch import TorchPredictor\n", "\n", "def batch_predict(checkpoint: ray.air.Checkpoint, test_dataset: ray.data.Dataset) -> float:\n", " \"\"\"Perform batch prediction on the provided test dataset, and return accuracy results.\"\"\"\n", "\n", " batch_predictor = BatchPredictor.from_checkpoint(checkpoint, predictor_cls=TorchPredictor, model=SimpleMLP(num_classes=10))\n", " model_output = batch_predictor.predict(\n", " data=test_dataset, feature_columns=[\"image\"], keep_columns=[\"label\"]\n", " )\n", " \n", " # Postprocess model outputs.\n", " # Convert logits outputted from model into actual class predictions.\n", " def convert_logits_to_classes(df):\n", " best_class = df[\"predictions\"].map(lambda x: np.array(x).argmax())\n", " df[\"predictions\"] = best_class\n", " return df\n", " \n", " prediction_results = model_output.map_batches(convert_logits_to_classes, batch_format=\"pandas\")\n", " \n", " # Then, for each prediction output, see if it matches with the ground truth\n", " # label.\n", " def calculate_prediction_scores(df):\n", " return pd.DataFrame({\"correct\": df[\"predictions\"] == df[\"label\"]})\n", "\n", " correct_dataset = prediction_results.map_batches(\n", " calculate_prediction_scores, batch_format=\"pandas\"\n", " )\n", "\n", " return correct_dataset.sum(on=\"correct\") / correct_dataset.count()" ] }, { "cell_type": "markdown", "metadata": { "id": "GWiTtsmVbIZP" }, "source": [ "## 4d: Define logic for Deploying and Querying our model\n", "\n", "In addition to batch inference, we also want to deploy our model so that we can submit live queries to it for online inference. We use Ray Serve's `PredictorDeployment` utility to deploy our trained model. \n", "\n", "Once we deploy the model, we can send HTTP requests to our deployment." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "ZC3JCWz7bhR-" }, "outputs": [], "source": [ "from typing import List\n", "import requests\n", "from requests import Response\n", "import numpy as np\n", "\n", "from ray.serve.http_adapters import json_to_ndarray\n", "\n", "\n", "def deploy_model(checkpoint: ray.air.Checkpoint) -> str:\n", " \"\"\"Deploys the model from the provided Checkpoint and returns the URL for the endpoint of the model deployment.\"\"\"\n", " serve.run(\n", " PredictorDeployment.options(\n", " name=\"mnist_model\",\n", " route_prefix=\"/mnist_predict\",\n", " num_replicas=2,\n", " ).bind(\n", " http_adapter=json_to_ndarray,\n", " predictor_cls=TorchPredictor,\n", " checkpoint=latest_checkpoint,\n", " model=SimpleMLP(num_classes=10),\n", " )\n", " )\n", "\n", " return \"http://localhost:8000/mnist_predict\"\n", "\n", "# Function that queries our deployed model\n", "def query_deployment(test_samples: List[np.ndarray], endpoint_uri: str) -> List[Response]:\n", " \"\"\"Given a set of test samples, queries the model deployment at the provided endpoint and returns the results.\"\"\"\n", " results = []\n", " # Convert to Python List since Numpy arrays are not Json serializable.\n", " for sample in test_samples:\n", " results.append(requests.post(endpoint_uri, json={\"array\": sample.tolist(), \"dtype\": \"float32\"}))\n", " return results" ] }, { "cell_type": "markdown", "metadata": { "id": "-NQDj0rFVUX3" }, "source": [ "# Step 5: Putting it all together\n", "\n", "Once we have defined our training logic and our preprocessor, we can put everything together!\n", "\n", "For each dataset in our stream, we do the following:\n", "1. Train on the dataset in Data Parallel fashion. We create a `TorchTrainer`, specify the config for the training loop we defined above, the dataset to train on, and how much we want to scale. `TorchTrainer` also accepts a `checkpoint` arg to continue training from a previously saved checkpoint.\n", "2. Get the saved checkpoint from the training run.\n", "3. Test our trained model on a test set containing test data from all the tasks trained on so far.\n", "3. After training on each task, we deploy our model so we can query it for predictions.\n", "\n", "In this example, the training and test data for each task is well-defined beforehand by the benchmark. For real-world scenarios, this probably will not be the case. It is very likely that the prediction requests after training on one task will become the training data for the next task. \n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "I_OrfQTqNYRk", "outputId": "a89da8b8-1acf-4796-cc88-9ee889a32123" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Read->Map_Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00, 3.42s/it]\n", "Read->Map_Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 5.27it/s]\n", "Map Progress (1 actors 1 pending): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.40it/s]\n", "Read->Map_Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4.17it/s]\n", "Map Progress (1 actors 1 pending): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.78it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Starting training for task: 0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Tune Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "
Current time:2022-09-23 16:31:51
Running for: 00:00:20.79
Memory: 17.1/62.7 GiB
\n", "
\n", "
\n", "
\n", "

System Info

\n", " Using FIFO scheduling algorithm.
Resources requested: 0/24 CPUs, 0/0 GPUs, 0.0/32.53 GiB heap, 0.0/16.26 GiB objects\n", "
\n", " \n", "
\n", "
\n", "
\n", "

Trial Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc iter total time (s) loss _timestamp _time_this_iter_s
TorchTrainer_da157_00000TERMINATED10.109.175.190:856770 4 17.0121 0 1663975908 0.0839479
\n", "
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=856836)\u001b[0m 2022-09-23 16:31:37,847\tINFO config.py:71 -- Setting up process group for: env:// [rank=0, world_size=1]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=856836)\u001b[0m 2022-09-23 16:31:38,047\tINFO train_loop_utils.py:354 -- Moving model to device: cuda:0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=856836)\u001b[0m loss: 2.436360, epoch: 0, iteration: 0\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=856836)\u001b[0m loss: 1.608793, epoch: 0, iteration: 500\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=856836)\u001b[0m loss: 1.285775, epoch: 0, iteration: 1000\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=856836)\u001b[0m loss: 0.785092, epoch: 0, iteration: 1500\n" ] }, { "data": { "text/html": [ "
\n", "

Trial Progress

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name _time_this_iter_s _timestamp _training_iterationdate done episodes_total experiment_id experiment_taghostname iterations_since_restore lossnode_ip pidshould_checkpoint time_since_restore time_this_iter_s time_total_s timestamp timesteps_since_restoretimesteps_total training_iterationtrial_id warmup_time
TorchTrainer_da157_00000 0.0839479 1663975908 42022-09-23_16-31-49True 96c794a64d6f43d79b87130a76d21f1f 0corvus 4 010.109.175.190856770True 17.0121 0.11111 17.0121 1663975909 0 4da157_00000 0.00297165
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "2022-09-23 16:31:51,231\tINFO tune.py:762 -- Total run time: 20.91 seconds (20.79 seconds for the tuning loop).\n", "Map_Batches: 0%| | 0/1 [00:00Map_Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4.26it/s]\n", "Map Progress (1 actors 1 pending): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.72it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Starting training for task: 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Tune Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "
Current time:2022-09-23 16:33:08
Running for: 00:00:19.49
Memory: 18.2/62.7 GiB
\n", "
\n", "
\n", "
\n", "

System Info

\n", " Using FIFO scheduling algorithm.
Resources requested: 0/24 CPUs, 0/0 GPUs, 0.0/32.53 GiB heap, 0.0/16.26 GiB objects\n", "
\n", " \n", "
\n", "
\n", "
\n", "

Trial Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc iter total time (s) loss _timestamp _time_this_iter_s
TorchTrainer_09424_00000TERMINATED10.109.175.190:857781 4 15.3611 0 1663975986 0.0699804
\n", "
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=857818)\u001b[0m 2022-09-23 16:32:55,672\tINFO config.py:71 -- Setting up process group for: env:// [rank=0, world_size=1]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=857818)\u001b[0m 2022-09-23 16:32:55,954\tINFO train_loop_utils.py:354 -- Moving model to device: cuda:0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=857818)\u001b[0m loss: 2.457292, epoch: 0, iteration: 0\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=857818)\u001b[0m loss: 1.339169, epoch: 0, iteration: 500\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=857818)\u001b[0m loss: 1.032746, epoch: 0, iteration: 1000\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=857818)\u001b[0m loss: 0.707931, epoch: 0, iteration: 1500\n" ] }, { "data": { "text/html": [ "
\n", "

Trial Progress

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name _time_this_iter_s _timestamp _training_iterationdate done episodes_total experiment_id experiment_taghostname iterations_since_restore lossnode_ip pidshould_checkpoint time_since_restore time_this_iter_s time_total_s timestamp timesteps_since_restoretimesteps_total training_iteration trial_id warmup_time
TorchTrainer_09424_00000 0.0699804 1663975986 42022-09-23_16-33-06True 77c9c5f109fa4a47b459b0afadf3ba33 0corvus 4 010.109.175.190857781True 15.3611 0.0725608 15.3611 1663975986 0 409424_00000 0.00418878
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "2022-09-23 16:33:09,072\tINFO tune.py:762 -- Total run time: 19.62 seconds (19.49 seconds for the tuning loop).\n", "Map Progress (1 actors 1 pending): 0%| | 0/2 [00:01Map_Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 5.31it/s]\n", "Map Progress (1 actors 1 pending): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.76it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Starting training for task: 2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Tune Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "
Current time:2022-09-23 16:34:33
Running for: 00:00:19.45
Memory: 18.4/62.7 GiB
\n", "
\n", "
\n", "
\n", "

System Info

\n", " Using FIFO scheduling algorithm.
Resources requested: 0/24 CPUs, 0/0 GPUs, 0.0/32.53 GiB heap, 0.0/16.26 GiB objects\n", "
\n", " \n", "
\n", "
\n", "
\n", "

Trial Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc iter total time (s) loss _timestamp _time_this_iter_s
TorchTrainer_3b7e3_00000TERMINATED10.109.175.190:858536 4 15.3994 0 1663976070 0.0710998
\n", "
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=858579)\u001b[0m 2022-09-23 16:34:19,902\tINFO config.py:71 -- Setting up process group for: env:// [rank=0, world_size=1]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=858579)\u001b[0m 2022-09-23 16:34:20,191\tINFO train_loop_utils.py:354 -- Moving model to device: cuda:0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=858579)\u001b[0m loss: 2.515887, epoch: 0, iteration: 0\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=858579)\u001b[0m loss: 1.260738, epoch: 0, iteration: 500\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=858579)\u001b[0m loss: 0.892560, epoch: 0, iteration: 1000\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=858579)\u001b[0m loss: 0.497198, epoch: 0, iteration: 1500\n" ] }, { "data": { "text/html": [ "
\n", "

Trial Progress

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name _time_this_iter_s _timestamp _training_iterationdate done episodes_total experiment_id experiment_taghostname iterations_since_restore lossnode_ip pidshould_checkpoint time_since_restore time_this_iter_s time_total_s timestamp timesteps_since_restoretimesteps_total training_iterationtrial_id warmup_time
TorchTrainer_3b7e3_00000 0.0710998 1663976070 42022-09-23_16-34-30True c9312be01e964b958b931d1796623509 0corvus 4 010.109.175.190858536True 15.3994 0.0705044 15.3994 1663976070 0 43b7e3_00000 0.00414133
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "2022-09-23 16:34:33,315\tINFO tune.py:762 -- Total run time: 19.59 seconds (19.45 seconds for the tuning loop).\n", "Map Progress (1 actors 1 pending): 0%| | 0/3 [00:01 0\n", "\n", "permuted_mnist = PermutedMNISTStream(n_tasks=n_tasks)\n", "train_stream = permuted_mnist.generate_train_stream()\n", "test_stream = permuted_mnist.generate_test_stream()\n", "\n", "latest_checkpoint = None\n", "\n", "accuracy_for_all_tasks = []\n", "task_idx = 0\n", "all_test_datasets_seen_so_far = []\n", "for train_dataset, test_dataset in zip(train_stream, test_stream):\n", " print(f\"Starting training for task: {task_idx}\")\n", " task_idx += 1\n", "\n", " # *********Training*****************\n", "\n", " trainer = TorchTrainer(\n", " train_loop_per_worker=train_loop_per_worker,\n", " train_loop_config={\n", " \"num_epochs\": num_epochs,\n", " \"learning_rate\": learning_rate,\n", " \"momentum\": momentum,\n", " \"batch_size\": batch_size,\n", " },\n", " # Have to specify trainer_resources as 0 so that the example works on Colab. \n", " scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu, trainer_resources={\"CPU\": 0}),\n", " datasets={\"train\": train_dataset},\n", " preprocessor=mnist_normalize_preprocessor,\n", " resume_from_checkpoint=latest_checkpoint,\n", " )\n", " result = trainer.fit()\n", " latest_checkpoint = result.checkpoint\n", "\n", " # **************Batch Prediction**************************\n", "\n", " # We can do batch prediction on the test data for the tasks seen so far.\n", " # TODO: Fix type signature in Ray Datasets\n", " # TODO: Fix dataset.union when used with empty list.\n", " if len(all_test_datasets_seen_so_far) > 0:\n", " full_test_dataset = test_dataset.union(*all_test_datasets_seen_so_far)\n", " else:\n", " full_test_dataset = test_dataset\n", "\n", " all_test_datasets_seen_so_far.append(test_dataset)\n", "\n", " accuracy_for_this_task = batch_predict(latest_checkpoint, full_test_dataset)\n", " print(f\"Accuracy for task {task_idx}: {accuracy_for_this_task}\")\n", " accuracy_for_all_tasks.append(accuracy_for_this_task)\n", "\n", " # *************Model Deployment & Online Inference***************************\n", " \n", " # We can also deploy our model to do online inference with Ray Serve.\n", " # Start Ray Serve.\n", " test_samples = permuted_mnist.generate_test_samples()\n", " endpoint_uri = deploy_model(latest_checkpoint)\n", " online_inference_results = query_deployment(test_samples, endpoint_uri)\n", "\n", " if ray.available_resources().get(\"CPU\", 0) < num_workers+1:\n", " # If there are no more CPUs left, then shutdown the Serve replicas so we can continue training on the next task.\n", " serve.shutdown()\n", "\n", " \n", "serve.shutdown()" ] }, { "cell_type": "markdown", "metadata": { "id": "ORWpRkPjcPbD" }, "source": [ "Now that we have finished all of our training, let's see the accuracy of our model after training on each task. \n", "\n", "We should see the accuracy decrease over time. This is to be expected since we are using just a naive fine-tuning strategy so our model is prone to catastrophic forgetting.\n", "\n", "As we increase the number of tasks, the model performance on all the tasks trained on so far should decrease." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "thpeB0KGmr99", "outputId": "59fdbb6d-eaf4-4c2a-d350-5ff6b48e96a3" }, "outputs": [ { "data": { "text/plain": [ "[0.8678, 0.86465, 0.8439]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy_for_all_tasks" ] }, { "cell_type": "markdown", "metadata": { "id": "xLLAvsTk8LoV" }, "source": [ "# [Optional] Step 6: Compare against full training.\n", "\n", "We have now incrementally trained our simple multi-layer perceptron. Let's compare the incrementally trained model via fine tuning against a model that is trained on all the tasks up front.\n", "\n", "Since we are using a naive fine-tuning strategy, we should expect that our incrementally trained model will perform worse than the the one that is fully trained! However, there's various other strategies that have been developed and are actively being researched to improve accuracy for incremental training. And overall, incremental/continual learning allows you to train in many real world settings where the entire dataset is not available up front, but new data is arriving at a relatively high rate." ] }, { "cell_type": "markdown", "metadata": { "id": "RNHsEVBHc0p2" }, "source": [ "Let's first combine all of our datasets for each task into a single, unified Dataset" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "pU2fVH068lfF", "outputId": "fd6a3b56-dda1-4fa6-cebd-d0ee8784e698" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Map Progress (1 actors 1 pending): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.37it/s]\n", "Map Progress (1 actors 1 pending): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.37it/s]\n", "Map Progress (1 actors 1 pending): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.40it/s]\n", "Shuffle Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 40.34it/s]\n", "Shuffle Reduce: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 28.99it/s]\n" ] } ], "source": [ "train_stream = permuted_mnist.generate_train_stream()\n", "\n", "# Collect all datasets in the stream into a single dataset.\n", "all_training_datasets = []\n", "for train_dataset in train_stream:\n", " all_training_datasets.append(train_dataset)\n", "combined_training_dataset = all_training_datasets[0].union(*all_training_datasets[1:])\n", "\n", "\n", "combined_training_dataset = combined_training_dataset.random_shuffle()" ] }, { "cell_type": "markdown", "metadata": { "id": "tJ6Oqdgvc5dn" }, "source": [ "Then, we train a new model on the unified Dataset using the same configurations as before." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "PmH9c0-z9KME", "outputId": "653b4dfc-ed47-4307-fa84-e4c4ea3ec354" }, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Tune Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "
Current time:2022-09-23 16:37:13
Running for: 00:00:25.97
Memory: 19.4/62.7 GiB
\n", "
\n", "
\n", "
\n", "

System Info

\n", " Using FIFO scheduling algorithm.
Resources requested: 0/24 CPUs, 0/0 GPUs, 0.0/32.53 GiB heap, 0.0/16.26 GiB objects\n", "
\n", " \n", "
\n", "
\n", "
\n", "

Trial Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc iter total time (s) loss _timestamp _time_this_iter_s
TorchTrainer_971af_00000TERMINATED10.109.175.190:860035 4 22.1282 0 1663976231 0.0924587
\n", "
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=860154)\u001b[0m 2022-09-23 16:36:55,188\tINFO config.py:71 -- Setting up process group for: env:// [rank=0, world_size=1]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=860154)\u001b[0m 2022-09-23 16:36:55,399\tINFO train_loop_utils.py:354 -- Moving model to device: cuda:0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=860154)\u001b[0m loss: 2.301066, epoch: 0, iteration: 0\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=860154)\u001b[0m loss: 1.869080, epoch: 0, iteration: 500\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=860154)\u001b[0m loss: 1.489264, epoch: 0, iteration: 1000\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=860154)\u001b[0m loss: 1.646756, epoch: 0, iteration: 1500\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=860154)\u001b[0m loss: 1.582330, epoch: 0, iteration: 2000\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=860154)\u001b[0m loss: 1.246018, epoch: 0, iteration: 2500\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=860154)\u001b[0m loss: 1.035204, epoch: 0, iteration: 3000\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=860154)\u001b[0m loss: 0.872962, epoch: 0, iteration: 3500\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=860154)\u001b[0m loss: 1.138829, epoch: 0, iteration: 4000\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=860154)\u001b[0m loss: 0.753354, epoch: 0, iteration: 4500\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=860154)\u001b[0m loss: 0.991935, epoch: 0, iteration: 5000\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=860154)\u001b[0m loss: 0.928292, epoch: 0, iteration: 5500\n" ] }, { "data": { "text/html": [ "
\n", "

Trial Progress

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name _time_this_iter_s _timestamp _training_iterationdate done episodes_total experiment_id experiment_taghostname iterations_since_restore lossnode_ip pidshould_checkpoint time_since_restore time_this_iter_s time_total_s timestamp timesteps_since_restoretimesteps_total training_iterationtrial_id warmup_time
TorchTrainer_971af_00000 0.0924587 1663976231 42022-09-23_16-37-11True 26d685b2612a4752b7d062d1ebfb89f0 0corvus 4 010.109.175.190860035True 22.1282 0.0941384 22.1282 1663976231 0 4971af_00000 0.0034101
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "2022-09-23 16:37:13,525\tINFO tune.py:762 -- Total run time: 26.08 seconds (25.96 seconds for the tuning loop).\n" ] } ], "source": [ "# Now we do training with the same configurations as before\n", "trainer = TorchTrainer(\n", " train_loop_per_worker=train_loop_per_worker,\n", " train_loop_config={\n", " \"num_epochs\": num_epochs,\n", " \"learning_rate\": learning_rate,\n", " \"momentum\": momentum,\n", " \"batch_size\": batch_size,\n", " },\n", " # Have to specify trainer_resources as 0 so that the example works on Colab. \n", " scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu, trainer_resources={\"CPU\": 0}),\n", " datasets={\"train\": combined_training_dataset},\n", " preprocessor=mnist_normalize_preprocessor,\n", " )\n", "result = trainer.fit()\n", "full_training_checkpoint = result.checkpoint" ] }, { "cell_type": "markdown", "metadata": { "id": "jLaOcmBddRqB" }, "source": [ "Then, let's test model that was trained on all the tasks up front." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "WC7zV_Cw9TAi", "outputId": "12a86f2b-be90-47b6-e252-25e3199689f9" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Map Progress (1 actors 1 pending): 0%| | 0/3 [00:01