Computer Vision
Contents
Computer Vision#
This guide explains how to perform common computer vision tasks like:
Reading image data#
Datasets like ImageNet store files like this:
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
To load images stored in this layout, read the raw images and include the class names.
import ray
from ray.data.datasource.partitioning import Partitioning
root = "s3://[email protected]/cifar-10/images"
partitioning = Partitioning("dir", field_names=["class"], base_dir=root)
dataset = ray.data.read_images(root, partitioning=partitioning)
Then, apply a user-defined function to encode the class names as integer targets.
from typing import Dict
import numpy as np
CLASS_TO_LABEL = {
"airplane": 0,
"automobile": 1,
"bird": 2,
"cat": 3,
"deer": 4,
"dog": 5,
"frog": 6,
"horse": 7,
"ship": 8,
"truck": 9,
}
def add_label_column(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
labels = []
for name in batch["class"]:
label = CLASS_TO_LABEL[name]
labels.append(label)
batch["label"] = np.array(labels)
return batch
def remove_class_column(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
del batch["class"]
return batch
dataset = dataset.map_batches(add_label_column).map_batches(remove_class_column)
Tip
You can also use LabelEncoder
to encode labels.
To load NumPy arrays into a Dataset
, separately read the image and label arrays.
import ray
images = ray.data.read_numpy("s3://[email protected]/cifar-10/images.npy")
labels = ray.data.read_numpy("s3://[email protected]/cifar-10/labels.npy")
Then, combine the datasets and rename the columns.
dataset = images.zip(labels)
dataset = dataset.map_batches(
lambda batch: batch.rename(
columns={"__value__": "image", "__value___1": "label"}
)
)
Image datasets often contain tf.train.Example
messages that look like this:
features {
feature {
key: "image"
value {
bytes_list {
value: ... # Raw image bytes
}
}
}
feature {
key: "label"
value {
int64_list {
value: 3
}
}
}
}
To load examples stored in this format, read the TFRecords into a Dataset
.
import ray
dataset = ray.data.read_tfrecords(
"s3://[email protected]/cifar-10/tfrecords"
)
Then, apply a user-defined function to decode the raw image bytes.
import io
from typing import Dict
import numpy as np
from PIL import Image
def decode_bytes(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
images = []
for data in batch["image"]:
image = Image.open(io.BytesIO(data))
images.append(np.array(image))
batch["image"] = np.array(images)
return batch
dataset = dataset.map_batches(decode_bytes, batch_format="numpy")
To load image data stored in Parquet files, call ray.data.read_parquet()
.
import ray
dataset = ray.data.read_parquet("s3://[email protected]/cifar-10/parquet")
For more information on creating datasets, see Creating Datasets.
Transforming images#
To transform images, create a Preprocessor
. They’re the
standard way to preprocess data with Ray.
To apply TorchVision transforms, create a TorchVisionPreprocessor
.
Create two TorchVisionPreprocessors
– one to normalize images, and another to augment images. Later, you’ll pass the preprocessors to Trainers
,
Predictors
, and
PredictorDeployments
.
from torchvision import transforms
from ray.data.preprocessors import TorchVisionPreprocessor
transform = transforms.Compose([transforms.ToTensor(), transforms.CenterCrop(224)])
preprocessor = TorchVisionPreprocessor(columns=["image"], transform=transform)
per_epoch_transform = transforms.RandomHorizontalFlip(p=0.5)
per_epoch_preprocessor = TorchVisionPreprocessor(
columns=["image"], transform=per_epoch_transform
)
To apply TorchVision transforms, create a BatchMapper
.
Create two BatchMapper
– one to normalize images, and another to
augment images. Later, you’ll pass the preprocessors to Trainers
,
Predictors
, and
PredictorDeployments
.
from typing import Dict
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import imagenet_utils
from ray.data.preprocessors import BatchMapper
def preprocess(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
batch["image"] = imagenet_utils.preprocess_input(batch["image"])
batch["image"] = tf.image.resize(batch["image"], (224, 224)).numpy()
return batch
preprocessor = BatchMapper(preprocess, batch_format="numpy")
def augment(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
batch["image"] = tf.image.random_flip_left_right(batch["image"]).numpy()
return batch
per_epoch_preprocessor = BatchMapper(augment, batch_format="numpy")
For more information on transforming data, see Using Preprocessors and Transforming Datasets.
Training vision models#
Trainers
let you train models in parallel.
To train a vision model, define the training loop per worker.
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from ray import train
from ray.air import session
from ray.air.config import DatasetConfig, ScalingConfig
from ray.train.torch import TorchCheckpoint, TorchTrainer
def train_one_epoch(model, *, criterion, optimizer, batch_size, epoch):
dataset_shard = session.get_dataset_shard("train")
running_loss = 0
for i, batch in enumerate(
dataset_shard.iter_torch_batches(
batch_size=batch_size, local_shuffle_buffer_size=256
)
):
inputs, labels = batch["image"], batch["label"]
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
session.report(
metrics={
"epoch": epoch,
"batch": i,
"running_loss": running_loss / 2000,
},
checkpoint=TorchCheckpoint.from_model(model),
)
running_loss = 0
def train_loop_per_worker(config):
model = train.torch.prepare_model(models.resnet50())
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=config["lr"])
for epoch in range(config["epochs"]):
train_one_epoch(
model,
criterion=criterion,
optimizer=optimizer,
batch_size=config["batch_size"],
epoch=epoch,
)
Then, create a TorchTrainer
and call
fit()
.
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config={"batch_size": 32, "lr": 0.02, "epochs": 1},
datasets={"train": dataset},
dataset_config={
"train": DatasetConfig(per_epoch_preprocessor=per_epoch_preprocessor)
},
scaling_config=ScalingConfig(num_workers=2),
preprocessor=preprocessor,
)
results = trainer.fit()
For more in-depth examples, read Training a Torch Image Classifier and Using Trainers.
To train a vision model, define the training loop per worker.
import tensorflow as tf
from ray.air import session
from ray.air.integrations.keras import ReportCheckpointCallback
def train_loop_per_worker(config):
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
train_shard = session.get_dataset_shard("train")
train_dataset = train_shard.to_tf(
"image",
"label",
batch_size=config["batch_size"],
local_shuffle_buffer_size=256,
)
with strategy.scope():
model = tf.keras.applications.resnet50.ResNet50(weights=None)
optimizer = tf.keras.optimizers.Adam(config["lr"])
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
model.fit(
train_dataset,
epochs=config["epochs"],
callbacks=[ReportCheckpointCallback()],
)
Then, create a TensorflowTrainer
and call
fit()
.
from ray.air import DatasetConfig, ScalingConfig
from ray.train.tensorflow import TensorflowTrainer
trainer = TensorflowTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config={"batch_size": 32, "lr": 0.02, "epochs": 1},
datasets={"train": dataset},
dataset_config={
"train": DatasetConfig(per_epoch_preprocessor=per_epoch_preprocessor)
},
scaling_config=ScalingConfig(num_workers=2),
preprocessor=preprocessor,
)
results = trainer.fit()
For more information, read Using Trainers.
Creating checkpoints#
Checkpoints
are required for batch inference and model
serving. They contain model state and optionally a preprocessor.
If you’re going from training to prediction, don’t create a new checkpoint.
Trainer.fit()
returns a
Result
object. Use
Result.checkpoint
instead.
To create a TorchCheckpoint
, pass a Torch model and
the Preprocessor
you created in Transforming images
to TorchCheckpoint.from_model()
.
from torchvision import models
from ray.train.torch import TorchCheckpoint
model = models.resnet50(pretrained=True)
checkpoint = TorchCheckpoint.from_model(model, preprocessor=preprocessor)
To create a TensorflowCheckpoint
, pass a TensorFlow model and
the Preprocessor
you created in Transforming images
to TensorflowCheckpoint.from_model()
.
import tensorflow as tf
from ray.train.tensorflow import TensorflowCheckpoint
model = tf.keras.applications.resnet50.ResNet50()
checkpoint = TensorflowCheckpoint.from_model(model, preprocessor=preprocessor)
Batch predicting images#
BatchPredictor
lets you perform inference on large
image datasets.
To create a BatchPredictor
, call
BatchPredictor.from_checkpoint
and pass the checkpoint
you created in Creating checkpoints.
from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchPredictor
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchPredictor)
predictor.predict(dataset, feature_columns=["image"], keep_columns=["label"])
For more in-depth examples, read Performing GPU Batch Prediction on Images with a PyTorch Model and Using Predictors for Inference.
To create a BatchPredictor
, call
BatchPredictor.from_checkpoint
and pass the checkpoint
you created in Creating checkpoints.
import tensorflow as tf
from ray.train.batch_predictor import BatchPredictor
from ray.train.tensorflow import TensorflowPredictor
predictor = BatchPredictor.from_checkpoint(
checkpoint,
TensorflowPredictor,
model_definition=tf.keras.applications.resnet50.ResNet50,
)
predictor.predict(dataset, feature_columns=["image"], keep_columns=["label"])
For more information, read Using Predictors for Inference.
Serving vision models#
PredictorDeployment
lets you
deploy a model to an endpoint and make predictions over the Internet.
Deployments use HTTP adapters to define how HTTP messages are converted to model
inputs. For example, json_to_ndarray()
converts HTTP messages like this:
{"array": [[1, 2], [3, 4]]}
To NumPy ndarrays like this:
array([[1., 2.],
[3., 4.]])
To deploy a Torch model to an endpoint, pass the checkpoint you created in Creating checkpoints
to PredictorDeployment.bind
and specify
json_to_ndarray()
as the HTTP adapter.
from ray import serve
from ray.serve import PredictorDeployment
from ray.serve.http_adapters import json_to_ndarray
from ray.train.torch import TorchPredictor
serve.run(
PredictorDeployment.bind(
TorchPredictor,
checkpoint,
http_adapter=json_to_ndarray,
)
)
Then, make a request to classify an image.
from io import BytesIO
import numpy as np
import requests
from PIL import Image
response = requests.get("http://placekitten.com/200/300")
image = Image.open(BytesIO(response.content))
payload = {"array": np.array(image).tolist(), "dtype": "float32"}
response = requests.post("http://localhost:8000/", json=payload)
predictions = response.json()
For more in-depth examples, read Training a Torch Image Classifier and Deploying Predictors with Serve.
To deploy a TensorFlow model to an endpoint, pass the checkpoint you created in Creating checkpoints
to PredictorDeployment.bind
and specify
json_to_multi_ndarray()
as the HTTP adapter.
import tensorflow as tf
from ray import serve
from ray.serve import PredictorDeployment
from ray.serve.http_adapters import json_to_multi_ndarray
from ray.train.tensorflow import TensorflowPredictor
serve.run(
PredictorDeployment.bind(
TensorflowPredictor,
checkpoint,
http_adapter=json_to_multi_ndarray,
model_definition=tf.keras.applications.resnet50.ResNet50,
)
)
Then, make a request to classify an image.
from io import BytesIO
import numpy as np
import requests
from PIL import Image
response = requests.get("http://placekitten.com/200/300")
image = Image.open(BytesIO(response.content))
payload = {"image": {"array": np.array(image).tolist(), "dtype": "float32"}}
response = requests.post("http://localhost:8000/", json=payload)
predictions = response.json()
For more information, read Deploying Predictors with Serve.