Distributed training#
This tutorial executes a distributed training workload that connects the following heterogeneous workloads:
preprocess the dataset prior to training
distributed training with Ray Train and PyTorch with observability
evaluation (batch inference and eval logic)
save model artifacts to a model registry (MLOps)
Note: this tutorial doesn’t tune the model but see Ray Tune for experiment execution and hyperparameter tuning at any scale.

%%bash
pip install -q "matplotlib==3.10.0" "torch==2.7.0" "transformers==4.52.3" "scikit-learn==1.6.0" "mlflow==2.19.0" "ipywidgets==8.1.3"
Successfully registered `matplotlib, torch` and 4 other packages to be installed on all cluster nodes.
View and update dependencies here: https://console.anyscale.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_cz951f43jjdybtzkx1s5sjgz99/workspaces/expwrk_eys8cskj5aivghbf773dp2vmcd?workspace-tab=dependencies
%load_ext autoreload
%autoreload all
import os
import ray
import sys
sys.path.append(os.path.abspath(".."))
# Enable Ray Train v2. It's too good to wait for public release!
os.environ["RAY_TRAIN_V2_ENABLED"] = "1"
ray.init(
# connect to existing ray runtime (from previous notebook if still running)
address=os.environ.get("RAY_ADDRESS", "auto"),
runtime_env={
"env_vars": {"RAY_TRAIN_V2_ENABLED": "1"},
# working_dir to import doggos (default working_dir=".")
"working_dir": "../",
},
)
2025-06-23 14:26:58,662 INFO worker.py:1723 -- Connecting to existing Ray cluster at address: 10.0.52.172:6379...
2025-06-23 14:26:58,674 INFO worker.py:1908 -- Connected to Ray cluster. View the dashboard at https://session-gcwehd9xxjzkv5lxv8lgcdgx2n.i.anyscaleuserdata.com
2025-06-23 14:26:58,721 INFO packaging.py:588 -- Creating a file package for local module '../'.
2025-06-23 14:26:58,781 INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_df54fa2aa282ae62.zip' (13.77MiB) to Ray cluster...
2025-06-23 14:26:58,845 INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_df54fa2aa282ae62.zip'.
%%bash
# This will be removed once Ray Train v2 is enabled by default.
echo "RAY_TRAIN_V2_ENABLED=1" > /home/ray/default/.env
# Load env vars in notebooks.
from dotenv import load_dotenv
load_dotenv()
True
Preprocess#
You need to convert the classes to labels (unique integers) so that you can train a classifier that can correctly predict the class given an input image. But before you do this, apply the same data ingestion and preprocessing as the previous notebook.
def add_class(row):
row["class"] = row["path"].rsplit("/", 3)[-2]
return row
# Preprocess data splits.
train_ds = ray.data.read_images("s3://doggos-dataset/train", include_paths=True, shuffle="files")
train_ds = train_ds.map(add_class)
val_ds = ray.data.read_images("s3://doggos-dataset/val", include_paths=True)
val_ds = val_ds.map(add_class)
Define a Preprocessor
class that:
creates an embedding. A later step moves the embedding layer outside of the model since you freeze the embedding layer’s weights and so you don’t have to do it repeatedly as part of the model’s forward pass, saving on unnecessary compute.
convert the classes into labels for the classifier.
While you could’ve just done this step as a simple operation, you’re taking the time to organize it as a class so that you can save and load for inference later.
def convert_to_label(row, class_to_label):
if "class" in row:
row["label"] = class_to_label[row["class"]]
return row
import numpy as np
from PIL import Image
import torch
from transformers import CLIPModel, CLIPProcessor
from doggos.embed import EmbedImages
class Preprocessor:
"""Preprocessor class."""
def __init__(self, class_to_label=None):
self.class_to_label = class_to_label or {} # mutable defaults
self.label_to_class = {v: k for k, v in self.class_to_label.items()}
def fit(self, ds, column):
self.classes = ds.unique(column=column)
self.class_to_label = {tag: i for i, tag in enumerate(self.classes)}
self.label_to_class = {v: k for k, v in self.class_to_label.items()}
return self
def transform(self, ds, concurrency=4, batch_size=64, num_gpus=1):
ds = ds.map(
convert_to_label,
fn_kwargs={"class_to_label": self.class_to_label},
)
ds = ds.map_batches(
EmbedImages,
fn_constructor_kwargs={
"model_id": "openai/clip-vit-base-patch32",
"device": "cuda",
},
concurrency=4,
batch_size=64,
num_gpus=1,
accelerator_type="L4",
)
ds = ds.drop_columns(["image"])
return ds
def save(self, fp):
with open(fp, "w") as f:
json.dump(self.class_to_label, f)
# Preprocess.
preprocessor = Preprocessor()
preprocessor = preprocessor.fit(train_ds, column="class")
train_ds = preprocessor.transform(ds=train_ds)
val_ds = preprocessor.transform(ds=val_ds)
2025-06-23 14:27:10,597 INFO dataset.py:3048 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2025-06-23 14:27:10,599 INFO logging.py:295 -- Registered dataset logger for dataset dataset_65_0
2025-06-23 14:27:10,612 INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_65_0. Full logs are in /tmp/ray/session_2025-06-23_13-49-50_102769_2149/logs/ray-data
2025-06-23 14:27:10,613 INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_65_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)] -> AllToAllOperator[Aggregate] -> LimitOperator[limit=1]
2025-06-23 14:27:17,996 INFO streaming_executor.py:227 -- ✔️ Dataset dataset_65_0 execution finished in 7.38 seconds
See this extensive guide on data loading and preprocessing for the last-mile preprocessing you need to do prior to training your models. However, Ray Data does support performant joins, filters, aggregations, etc., for the more structure data processing your workloads may need.
import shutil
# Write processed data to cloud storage.
preprocessed_data_path = os.path.join("/mnt/cluster_storage", "doggos/preprocessed_data")
if os.path.exists(preprocessed_data_path): # Clean up.
shutil.rmtree(preprocessed_data_path)
preprocessed_train_path = os.path.join(preprocessed_data_path, "preprocessed_train")
preprocessed_val_path = os.path.join(preprocessed_data_path, "preprocessed_val")
train_ds.write_parquet(preprocessed_train_path)
val_ds.write_parquet(preprocessed_val_path)
2025-06-23 14:19:45,048 INFO logging.py:295 -- Registered dataset logger for dataset dataset_40_0
2025-06-23 14:19:45,067 INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_40_0. Full logs are in /tmp/ray/session_2025-06-23_13-49-50_102769_2149/logs/ray-data
2025-06-23 14:19:45,069 INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_40_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)->Map(convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbedImages)] -> TaskPoolMapOperator[MapBatches(drop_columns)->Write]
2025-06-23 14:19:45,088 INFO actor_pool_map_operator.py:633 -- Scaling up actor pool by 4 (reason=scaling to min size, running=0, restarting=0, pending=0)
(_MapWorker pid=18628, ip=10.0.102.235) Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
2025-06-23 14:19:57,926 INFO actor_pool_map_operator.py:661 -- Scaled down actor pool by 1 (reason=None; running=3, restarting=0, pending=0)
2025-06-23 14:19:58,259 INFO streaming_executor.py:227 -- ✔️ Dataset dataset_40_0 execution finished in 13.19 seconds
2025-06-23 14:19:58,573 INFO dataset.py:4603 -- Data sink Parquet finished. 2880 rows and 5.9MB data written.
2025-06-23 14:19:58,584 INFO logging.py:295 -- Registered dataset logger for dataset dataset_43_0
2025-06-23 14:19:58,602 INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_43_0. Full logs are in /tmp/ray/session_2025-06-23_13-49-50_102769_2149/logs/ray-data
2025-06-23 14:19:58,603 INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_43_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)->Map(convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbedImages)] -> TaskPoolMapOperator[MapBatches(drop_columns)->Write]
2025-06-23 14:19:58,620 INFO actor_pool_map_operator.py:633 -- Scaling up actor pool by 4 (reason=scaling to min size, running=0, restarting=0, pending=0)
(_MapWorker pid=33082, ip=10.0.102.235) Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`. [repeated 4x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
2025-06-23 14:20:07,331 INFO actor_pool_map_operator.py:661 -- Scaled down actor pool by 1 (reason=None; running=3, restarting=0, pending=0)
2025-06-23 14:20:07,854 INFO actor_pool_map_operator.py:661 -- Scaled down actor pool by 1 (reason=None; running=2, restarting=0, pending=0)
2025-06-23 14:20:08,323 INFO streaming_executor.py:227 -- ✔️ Dataset dataset_43_0 execution finished in 9.72 seconds
2025-06-23 14:20:08,372 INFO dataset.py:4603 -- Data sink Parquet finished. 720 rows and 1.5MB data written.
Store the preprocessed data into shared cloud storage to:
save a record of what this preprocessed data looks like
avoid triggering the entire preprocessing for each batch the model processes
avoid
materialize
of the preprocessed data because you shouldn’t force large data to fit in memory
Model#
Define the model – a simple two layer neural net with Softmax layer to predict class probabilities. Notice that it’s all just base PyTorch and nothing else.
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
class ClassificationModel(torch.nn.Module):
def __init__(self, embedding_dim, hidden_dim, dropout_p, num_classes):
super().__init__()
# Hyperparameters
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.dropout_p = dropout_p
self.num_classes = num_classes
# Define layers
self.fc1 = nn.Linear(embedding_dim, hidden_dim)
self.batch_norm = nn.BatchNorm1d(hidden_dim)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout_p)
self.fc2 = nn.Linear(hidden_dim, num_classes)
def forward(self, batch):
z = self.fc1(batch["embedding"])
z = self.batch_norm(z)
z = self.relu(z)
z = self.dropout(z)
z = self.fc2(z)
return z
@torch.inference_mode()
def predict(self, batch):
z = self(batch)
y_pred = torch.argmax(z, dim=1).cpu().numpy()
return y_pred
@torch.inference_mode()
def predict_probabilities(self, batch):
z = self(batch)
y_probs = F.softmax(z, dim=1).cpu().numpy()
return y_probs
def save(self, dp):
Path(dp).mkdir(parents=True, exist_ok=True)
with open(Path(dp, "args.json"), "w") as fp:
json.dump({
"embedding_dim": self.embedding_dim,
"hidden_dim": self.hidden_dim,
"dropout_p": self.dropout_p,
"num_classes": self.num_classes,
}, fp, indent=4)
torch.save(self.state_dict(), Path(dp, "model.pt"))
@classmethod
def load(cls, args_fp, state_dict_fp, device="cpu"):
with open(args_fp, "r") as fp:
model = cls(**json.load(fp))
model.load_state_dict(torch.load(state_dict_fp, map_location=device))
return model
# Initialize model.
num_classes = len(preprocessor.classes)
model = ClassificationModel(
embedding_dim=512,
hidden_dim=256,
dropout_p=0.3,
num_classes=num_classes,
)
print (model)
ClassificationModel(
(fc1): Linear(in_features=512, out_features=256, bias=True)
(batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
(dropout): Dropout(p=0.3, inplace=False)
(fc2): Linear(in_features=256, out_features=36, bias=True)
)
Batching#
Take a look at a sample batch of data and ensure that tensors have the proper data type.
from ray.train.torch import get_device
def collate_fn(batch):
dtypes = {"embedding": torch.float32, "label": torch.int64}
tensor_batch = {}
for key in dtypes.keys():
if key in batch:
tensor_batch[key] = torch.as_tensor(
batch[key],
dtype=dtypes[key],
device=get_device(),
)
return tensor_batch
# Sample batch
sample_batch = train_ds.take_batch(batch_size=3)
collate_fn(batch=sample_batch)
2025-06-23 14:27:26,458 INFO logging.py:295 -- Registered dataset logger for dataset dataset_72_0
2025-06-23 14:27:26,469 INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_72_0. Full logs are in /tmp/ray/session_2025-06-23_13-49-50_102769_2149/logs/ray-data
2025-06-23 14:27:26,470 INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_72_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)->Map(convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbedImages)] -> TaskPoolMapOperator[MapBatches(drop_columns)] -> LimitOperator[limit=3]
2025-06-23 14:27:26,489 INFO actor_pool_map_operator.py:633 -- Scaling up actor pool by 4 (reason=scaling to min size, running=0, restarting=0, pending=0)
(_MapWorker pid=18053, ip=10.0.90.122) Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
2025-06-23 14:27:33,774 INFO streaming_executor.py:227 -- ✔️ Dataset dataset_72_0 execution finished in 7.30 seconds
/tmp/ipykernel_18629/3214280880.py:6: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.)
tensor_batch[key] = torch.as_tensor(
{'embedding': tensor([[-0.1921, 0.1182, -0.1963, ..., 0.7892, -0.2841, -0.0829],
[-0.0389, -0.1284, -0.5749, ..., 0.4360, 0.0745, -0.1555],
[-0.1139, 0.1539, -0.1519, ..., 0.8438, 0.3064, -0.1918]]),
'label': tensor([22, 11, 33])}
(autoscaler +35s) Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.
Model registry#
Create a model registry in Anyscale user storage to save the model checkpoints to. Use OSS MLflow but you can easily set up other experiment trackers with Ray.
import shutil
(autoscaler +57m1s) [autoscaler] Downscaling node i-03a133888407b8cf8 (node IP: 10.0.103.152) due to node idle termination.
(autoscaler +57m1s) [autoscaler] Downscaling node i-06023e83fb012b7ae (node IP: 10.0.90.122) due to node idle termination.
(autoscaler +57m6s) [autoscaler] Cluster resized to {56 CPU, 6 GPU}.
model_registry = "/mnt/cluster_storage/mlflow/doggos"
if os.path.isdir(model_registry):
shutil.rmtree(model_registry) # clean up
os.makedirs(model_registry, exist_ok=True)
Training#
Define the training workload by specifying the:
experiment and model parameters
compute scaling configuration
forward pass for batches of training and validation data
train loop for each epoch of data and checkpointing

# Train loop config.
experiment_name = "doggos"
train_loop_config = {
"model_registry": model_registry,
"experiment_name": experiment_name,
"embedding_dim": 512,
"hidden_dim": 256,
"dropout_p": 0.3,
"lr": 1e-3,
"lr_factor": 0.8,
"lr_patience": 3,
"num_epochs": 20,
"batch_size": 256,
}
# Scaling config
num_workers = 4
scaling_config = ray.train.ScalingConfig(
num_workers=num_workers,
use_gpu=True,
resources_per_worker={"CPU": 8, "GPU": 2},
accelerator_type="L4",
)
import tempfile
import mlflow
import numpy as np
from ray.train.torch import TorchTrainer
def train_epoch(ds, batch_size, model, num_classes, loss_fn, optimizer):
model.train()
loss = 0.0
ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=collate_fn)
for i, batch in enumerate(ds_generator):
optimizer.zero_grad() # Reset gradients.
z = model(batch) # Forward pass.
targets = F.one_hot(batch["label"], num_classes=num_classes).float()
J = loss_fn(z, targets) # Define loss.
J.backward() # Backward pass.
optimizer.step() # Update weights.
loss += (J.detach().item() - loss) / (i + 1) # Cumulative loss
return loss
def eval_epoch(ds, batch_size, model, num_classes, loss_fn):
model.eval()
loss = 0.0
y_trues, y_preds = [], []
ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=collate_fn)
with torch.inference_mode():
for i, batch in enumerate(ds_generator):
z = model(batch)
targets = F.one_hot(batch["label"], num_classes=num_classes).float() # one-hot (for loss_fn)
J = loss_fn(z, targets).item()
loss += (J - loss) / (i + 1)
y_trues.extend(batch["label"].cpu().numpy())
y_preds.extend(torch.argmax(z, dim=1).cpu().numpy())
return loss, np.vstack(y_trues), np.vstack(y_preds)
def train_loop_per_worker(config):
# Hyperparameters.
model_registry = config["model_registry"]
experiment_name = config["experiment_name"]
embedding_dim = config["embedding_dim"]
hidden_dim = config["hidden_dim"]
dropout_p = config["dropout_p"]
lr = config["lr"]
lr_factor = config["lr_factor"]
lr_patience = config["lr_patience"]
num_epochs = config["num_epochs"]
batch_size = config["batch_size"]
num_classes = config["num_classes"]
# Experiment tracking.
if ray.train.get_context().get_world_rank() == 0:
mlflow.set_tracking_uri(f"file:{model_registry}")
mlflow.set_experiment(experiment_name)
mlflow.start_run()
mlflow.log_params(config)
# Datasets.
train_ds = ray.train.get_dataset_shard("train")
val_ds = ray.train.get_dataset_shard("val")
# Model.
model = ClassificationModel(
embedding_dim=embedding_dim,
hidden_dim=hidden_dim,
dropout_p=dropout_p,
num_classes=num_classes,
)
model = ray.train.torch.prepare_model(model)
# Training components.
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=lr_factor,
patience=lr_patience,
)
# Training.
best_val_loss = float("inf")
for epoch in range(num_epochs):
# Steps
train_loss = train_epoch(train_ds, batch_size, model, num_classes, loss_fn, optimizer)
val_loss, _, _ = eval_epoch(val_ds, batch_size, model, num_classes, loss_fn)
scheduler.step(val_loss)
# Checkpoint (metrics, preprocessor and model artifacts).
with tempfile.TemporaryDirectory() as dp:
model.module.save(dp=dp)
metrics = dict(lr=optimizer.param_groups[0]["lr"], train_loss=train_loss, val_loss=val_loss)
with open(os.path.join(dp, "class_to_label.json"), "w") as fp:
json.dump(config["class_to_label"], fp, indent=4)
if ray.train.get_context().get_world_rank() == 0: # only on main worker 0
mlflow.log_metrics(metrics, step=epoch)
if val_loss < best_val_loss:
best_val_loss = val_loss
mlflow.log_artifacts(dp)
# End experiment tracking.
if ray.train.get_context().get_world_rank() == 0:
mlflow.end_run()
Notice that there isn’t much new Ray Train code on top of the base PyTorch code. You specified how you want to scale out the training workload, load the Ray datasets, and then checkpoint on the main worker node and that’s it. See these guides (PyTorch, PyTorch Lightning, Hugging Face Transformers) to see the minimal change in code needed to distribute your training workloads. See this extensive list of Ray Train user guides.
# Load preprocessed datasets.
preprocessed_train_ds = ray.data.read_parquet(preprocessed_train_path)
preprocessed_val_ds = ray.data.read_parquet(preprocessed_val_path)
# Trainer.
train_loop_config["class_to_label"] = preprocessor.class_to_label
train_loop_config["num_classes"] = len(preprocessor.class_to_label)
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_loop_config,
scaling_config=scaling_config,
datasets={"train": preprocessed_train_ds, "val": preprocessed_val_ds},
)
# Train.
results = trainer.fit()
Ray Train#
automatically handles multi-node, multi-GPU setup with no manual SSH setup or hostfile configs.
define per-worker fractional resource requirements, for example, 2 CPUs and 0.5 GPU per worker.
run on heterogeneous machines and scale flexibly, for example, CPU for preprocessing and GPU for training.
built-in fault tolerance with retry of failed workers and continue from last checkpoint.
supports Data Parallel, Model Parallel, Parameter Server, and even custom strategies.
Ray Compiled graphs allow you to even define different parallelism for jointly optimizing multiple models like Megatron, DeepSpeed, etc., or only allow for one global setting.
You can also use Torch DDP, FSPD, DeepSpeed, etc., under the hood.
🔥 RayTurbo Train offers even more improvement to the price-performance ratio, performance monitoring and more:
elastic training to scale to a dynamic number of workers, continue training on fewer resources, even on spot instances.
purpose-built dashboard designed to streamline the debugging of Ray Train workloads:
Monitoring: View the status of training runs and train workers.
Metrics: See insights on training throughput and training system operation time.
Profiling: Investigate bottlenecks, hangs, or errors from individual training worker processes.

You can view experiment metrics and model artifacts in the model registry. You’re using OSS MLflow so you can run the server by pointing to the model registry location:
mlflow server -h 0.0.0.0 -p 8080 --backend-store-uri /mnt/cluster_storage/mlflow/doggos
You can view the dashboard by going to the Overview tab > Open Ports.

You also have the preceding Ray Dashboard and Train workload specific dashboards.

# Sorted runs
mlflow.set_tracking_uri(f"file:{model_registry}")
sorted_runs = mlflow.search_runs(
experiment_names=[experiment_name],
order_by=["metrics.val_loss ASC"])
best_run = sorted_runs.iloc[0]
best_run
run_id c65d5aba186c4ee58bf8188493cd047c
experiment_id 477478897635232497
status FINISHED
artifact_uri file:///mnt/cluster_storage/mlflow/doggos/4774...
start_time 2025-06-23 14:23:03.775000+00:00
end_time 2025-06-23 14:23:21.440000+00:00
metrics.train_loss 0.388298
metrics.lr 0.001
metrics.val_loss 0.664968
params.batch_size 256
params.num_epochs 20
params.lr 0.001
params.hidden_dim 256
params.experiment_name doggos
params.dropout_p 0.3
params.embedding_dim 512
params.lr_patience 3
params.class_to_label {'doberman': 0, 'collie': 1, 'dingo': 2, 'pome...
params.lr_factor 0.8
params.model_registry /mnt/cluster_storage/mlflow/doggos
params.num_classes 36
tags.mlflow.source.name /home/ray/anaconda3/lib/python3.12/site-packag...
tags.mlflow.user ray
tags.mlflow.source.type LOCAL
tags.mlflow.runName abrasive-newt-588
Name: 0, dtype: object
Production Job#
You can easily wrap the training workload as a production grade Anyscale Job (API ref).
Note:
This Job uses a
containerfile
to define dependencies, but you could easily use a pre-built image as well.You can specify the compute as a compute config or inline in a job config file.
When you don’t specify compute while launching from a workspace, this configuration defaults to the compute configuration of the workspace.
# Production batch job.
anyscale job submit --name=train-doggos-model \
--containerfile="/home/ray/default/containerfile" \
--compute-config="/home/ray/default/configs/aws.yaml" \
--working-dir="/home/ray/default" \
--exclude="" \
--max-retries=0 \
-- python doggos/train.py

Evaluation#
This tutorial concludes by evaluating the trained model on the test dataset. Evaluation is essentially the same as the batch inference workload where you apply the model on batches of data and then calculate metrics using the predictions versus true labels. Ray Data is hyper optimized for throughput so preserving order isn’t a priority. But for evaluation, this approach is crucial. Achieve this approach by preserving the entire row and adding the predicted label as another column to each row.
from urllib.parse import urlparse
from sklearn.metrics import multilabel_confusion_matrix
class TorchPredictor:
def __init__(self, preprocessor, model):
self.preprocessor = preprocessor
self.model = model
self.model.eval()
def __call__(self, batch, device="cuda"):
self.model.to(device)
batch["prediction"] = self.model.predict(collate_fn(batch))
return batch
def predict_probabilities(self, batch, device="cuda"):
self.model.to(device)
predicted_probabilities = self.model.predict_probabilities(collate_fn(batch))
batch["probabilities"] = [
{
self.preprocessor.label_to_class[i]: float(prob)
for i, prob in enumerate(probabilities)
}
for probabilities in predicted_probabilities
]
return batch
@classmethod
def from_artifacts_dir(cls, artifacts_dir):
with open(os.path.join(artifacts_dir, "class_to_label.json"), "r") as fp:
class_to_label = json.load(fp)
preprocessor = Preprocessor(class_to_label=class_to_label)
model = ClassificationModel.load(
args_fp=os.path.join(artifacts_dir, "args.json"),
state_dict_fp=os.path.join(artifacts_dir, "model.pt"),
)
return cls(preprocessor=preprocessor, model=model)
# Load and preproces eval dataset.
artifacts_dir = urlparse(best_run.artifact_uri).path
predictor = TorchPredictor.from_artifacts_dir(artifacts_dir=artifacts_dir)
test_ds = ray.data.read_images("s3://doggos-dataset/test", include_paths=True)
test_ds = test_ds.map(add_class)
test_ds = predictor.preprocessor.transform(ds=test_ds)
# y_pred (batch inference).
pred_ds = test_ds.map_batches(
predictor,
concurrency=4,
batch_size=64,
num_gpus=1,
accelerator_type="L4",
)
pred_ds.take(1)
2025-06-23 14:25:17,471 INFO logging.py:295 -- Registered dataset logger for dataset dataset_56_0
2025-06-23 14:25:17,483 INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_56_0. Full logs are in /tmp/ray/session_2025-06-23_13-49-50_102769_2149/logs/ray-data
2025-06-23 14:25:17,484 INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_56_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)->Map(convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbedImages)] -> TaskPoolMapOperator[MapBatches(drop_columns)] -> TaskPoolMapOperator[MapBatches(TorchPredictor)] -> LimitOperator[limit=1]
2025-06-23 14:25:17,504 INFO actor_pool_map_operator.py:633 -- Scaling up actor pool by 4 (reason=scaling to min size, running=0, restarting=0, pending=0)
(_MapWorker pid=41895, ip=10.0.102.235) Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
(MapBatches(TorchPredictor) pid=7131, ip=10.0.90.122) /tmp/ipykernel_14938/3214280880.py:6: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.)
(_MapWorker pid=6304, ip=10.0.90.122) Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`. [repeated 3x across cluster]
2025-06-23 14:25:31,572 INFO streaming_executor.py:227 -- ✔️ Dataset dataset_56_0 execution finished in 14.08 seconds
[{'path': 'doggos-dataset/test/basset/basset_10288.jpg',
'class': 'basset',
'label': 26,
'embedding': array([-1.04914151e-01, -2.44789988e-01, -9.95982289e-02, 1.35369569e-01,
-5.52587211e-02, -5.80722839e-02, 1.91796571e-01, 1.56359702e-01,
-6.07913733e-01, 2.08769619e-01, -3.80898006e-02, -1.11314066e-01,
-1.96144834e-01, -6.14988208e-02, 5.18053114e-01, 2.08482340e-01,
1.18680000e+00, 2.00228021e-01, -2.38505289e-01, 7.44116083e-02,
-1.17921010e-01, 1.65986642e-02, 4.06986564e-01, 1.73043087e-02,
-7.19358325e-02, -2.49894068e-01, 5.69958836e-02, -2.07780451e-02,
-2.98084527e-01, -1.49073690e-01, 2.44870782e-02, 4.86774921e-01,
3.78374428e-01, -2.37518042e-01, 1.26714706e-01, 1.10405624e-01,
1.23483673e-01, -2.53296018e-01, -1.41814440e-01, 1.88360083e+00,
-4.67942834e-01, -1.71202213e-01, 2.93785512e-01, 9.53243077e-02,
-1.08036891e-01, -1.05388820e+00, 2.12952226e-01, 3.43122423e-01,
-9.08568352e-02, -6.02110699e-02, 1.57682300e-02, 1.13998428e-01,
-9.61582065e-02, 1.91040933e-01, 3.62998173e-02, -1.67396963e-02,
4.08946127e-01, 4.58516389e-01, -4.09091681e-01, -3.85877311e-01,
9.77702141e-01, -1.69139802e-02, 1.93179488e-01, 1.36374593e-01,
-2.66537070e-01, -6.00859582e-01, -5.44146113e-02, 1.52056739e-01,
-2.88875699e-01, 2.30367318e-01, 6.66391551e-02, -3.48750651e-01,
1.32896990e-01, 2.43517846e-01, -3.36779654e-03, 2.86127269e-01,
-3.56745601e-01, -1.14945844e-01, 1.51565939e-01, 4.90366817e-02,
7.63746500e-02, -2.27382034e-02, 2.54388422e-01, -5.34341276e-01,
3.07917655e-01, 4.43625525e-02, 3.23391706e-02, -3.16016555e-01,
3.49402249e-01, 1.40896916e-01, -3.93401146e-01, -6.98464215e-01,
-7.05318165e+00, -9.64104384e-02, -1.29345521e-01, 1.01153195e-01,
1.66721642e-03, 2.46858150e-01, -6.62657797e-01, 8.84700537e-01,
-2.41105676e-01, -1.67729586e-01, -2.76175410e-01, -1.06329188e-01,
4.68529433e-01, -2.96109051e-01, 5.00090122e-01, -1.51693597e-02,
1.84735969e-01, -4.76171166e-01, 2.78874516e-01, -7.43267417e-01,
3.29548061e-01, 9.67882574e-03, -2.46126920e-01, -2.13637024e-01,
-5.42725086e-01, 3.51180196e-01, -2.11806729e-01, 3.27730656e-01,
1.95189789e-01, 1.26086920e-01, 6.48027122e-01, 2.56954640e-01,
4.22701418e-01, -2.30529577e-01, -1.10486835e-01, -1.01444468e-01,
7.89555907e-03, -2.47240350e-01, 1.73558876e-01, 3.03944647e-01,
-5.77825531e-02, 9.45507646e-01, -4.95145559e-01, 2.86680222e-01,
-7.24357292e-02, -8.29979897e-01, 4.94338155e-01, 2.54262447e-01,
2.29299828e-01, -2.25470066e-02, 5.62191963e-01, 3.00550222e-01,
-2.83117369e-02, 3.84202749e-01, 2.89719075e-01, 3.54923964e-01,
2.66314894e-01, -3.58392656e-01, -3.72334182e-01, 5.86691260e-01,
-1.24578431e-01, -4.04101044e-01, -5.07451952e-01, 5.48313916e-01,
-3.14691275e-01, -1.80745274e-01, 2.89481759e-01, 5.75179756e-02,
-1.80967286e-01, 9.15101022e-02, 4.65520680e-01, 7.72555918e-02,
2.23801851e-01, -1.68022275e-01, 1.34750500e-01, 2.97952116e-01,
2.26987794e-01, 3.05612266e-01, 8.25502351e-02, 1.27266854e-01,
4.45461750e-01, 4.75219965e-01, 2.56610662e-02, -4.94095474e-01,
6.80846751e-01, 6.35496229e-02, 2.54889160e-01, -1.44209296e-01,
-5.48627734e-01, 3.29704136e-02, 4.15674299e-02, -2.43748799e-02,
-2.19443023e-01, -1.42820716e-01, -2.50694096e-01, -2.07656205e-01,
-1.79199561e-01, 3.50940913e-01, 6.33473039e-01, 3.80550534e-01,
-2.89176375e-01, 2.02112049e-01, -4.48559523e-01, 2.72922575e-01,
2.24376589e-01, -2.83806473e-01, -4.37651068e-01, -9.45880890e-01,
1.22266248e-01, 4.01376486e-02, 3.55452418e-01, 2.14725018e-01,
-3.82868618e-01, -3.58605623e-01, 1.33403972e-01, 3.17366868e-02,
8.55787545e-02, 8.59863982e-02, 9.54705626e-02, -3.47019404e-01,
-7.17684031e-02, 2.91243881e-01, 2.65088528e-01, -9.42258835e-02,
-1.77515849e-01, 2.28757620e-01, 9.07460928e-01, -1.03129521e-01,
7.33332276e-01, 2.64944017e-01, -1.47793442e-01, 3.05287898e-01,
-2.62915194e-01, 1.97677180e-01, 6.06525466e-02, -1.16444737e-01,
7.31713697e-03, 1.67819709e-01, 9.79746133e-02, 1.47581011e-01,
-4.00336832e-01, 4.21648145e-01, -8.30136314e-02, -6.39808178e-01,
-1.41640380e-01, 4.65202779e-02, 7.18399584e-02, -4.38913584e-01,
2.07775518e-01, 4.70566414e-02, -8.90242606e-02, -4.53150421e-01,
-2.14878619e-01, 2.44945884e-01, 3.16962540e-01, -3.41699839e-01,
-1.91379115e-01, -2.09521651e-02, 2.30608553e-01, 3.33673239e-01,
2.77272910e-01, -2.96298712e-01, 1.22105137e-01, -2.16433048e-01,
5.48319101e-01, 2.72968113e-01, 1.73093528e-01, 1.80758208e-01,
-3.40644240e-01, 2.62541264e-01, 1.24807566e-01, -7.05128908e-01,
-1.10303462e-02, -1.81341395e-01, -1.78187087e-01, 1.32017612e-01,
-4.31975611e-02, 3.50797176e-03, 1.59508839e-01, 9.21480432e-02,
4.54917192e-01, 2.72805333e-01, -5.77595115e-01, -2.87324011e-01,
1.66138291e-01, 8.66501480e-02, 9.02174413e-03, -3.78495932e-01,
-3.07204783e-01, 1.98499486e-02, -2.17410654e-01, -3.29564735e-02,
-9.36664641e-03, 1.02078244e-01, -5.64144492e-01, 2.59325683e-01,
-1.29754335e-01, 1.67371452e-01, 3.65311772e-01, 1.91542730e-02,
-1.80281848e-01, -1.50442168e-01, 3.04976612e-01, 3.71464863e-02,
1.42819434e-02, 1.84083462e-01, 2.46860430e-01, 1.05640769e-01,
4.84380722e-02, -3.53347808e-02, -4.98287007e-02, 2.02643886e-01,
-1.73173457e-01, -3.63763243e-01, -2.20462531e-01, 3.16181600e-01,
6.26130402e-02, 7.24823922e-02, -1.47105128e-01, 3.08875024e-01,
9.42751825e-01, 1.98151171e-02, -1.21707544e-02, -2.04986826e-01,
2.55928785e-01, -9.34749842e-02, -1.57368124e-01, -9.39193606e-01,
7.99043655e-01, 7.17637539e-01, -3.75674933e-01, 5.69818616e-01,
-1.33306235e-02, 5.30459285e-01, -5.34143746e-01, 2.46586412e-01,
-1.07142270e-01, 3.60272974e-02, -2.97878295e-01, -4.83343840e-01,
6.04178667e-01, -5.00948548e-01, 3.49492311e-01, 2.63357386e-02,
9.19313729e-02, 4.02335197e-01, 1.58837855e-01, -6.79962993e-01,
-2.58434951e-01, -4.40313041e-01, 3.03083509e-01, 3.24987084e-01,
5.39690614e-01, 5.20520747e-01, 4.50525880e-01, 4.25642878e-01,
-3.66918445e-01, 3.89405370e-01, -1.27459884e+00, 1.07019678e-01,
-2.60990173e-01, -1.43924609e-01, 7.54836053e-02, 9.26972032e-01,
3.27434987e-01, -1.17758155e+00, 1.98659331e-01, -2.22037435e-02,
7.09707081e-01, 2.66087234e-01, 1.21972881e-01, 3.83028030e-01,
-7.28927612e-01, 2.53533423e-01, -4.85364050e-01, -2.49552578e-01,
-6.45122454e-02, -7.29703009e-01, 4.32397306e-01, 2.20177278e-01,
2.00846434e-01, -9.86097157e-02, -1.90976754e-01, 2.79123753e-01,
1.66312551e+00, 4.78211313e-01, -2.51018330e-02, 2.72021592e-01,
7.38141775e-01, -1.70819223e-01, 8.71482790e-02, 5.43940544e-01,
1.69077605e-01, -3.87216598e-01, -2.42075190e-01, 2.69218534e-01,
3.44690025e-01, -8.90391588e-01, -7.69253790e-01, -3.58836114e-01,
5.44936597e-01, -5.26414633e-01, -7.02109337e-02, -9.80197862e-02,
1.44381337e-02, 2.74508834e-01, -2.26176381e-01, -4.58218932e-01,
-1.67408079e-01, 9.71819162e-02, -4.52373654e-01, 2.12075204e-01,
3.00378114e-01, -4.85782117e-01, -8.94452184e-02, -3.76136094e-01,
6.35548115e-01, -5.96615791e-01, 4.56892580e-01, 8.58041495e-02,
-4.65728045e-01, 2.77835429e-02, 3.81691009e-02, -2.30244100e-01,
2.88146824e-01, 4.18678313e-01, 2.95979947e-01, -3.73036146e-01,
2.28022650e-01, 3.33540946e-01, -1.05593085e-01, -3.15681905e-01,
-1.58446252e-01, -1.87164396e-01, -2.52391577e-01, -2.95362055e-01,
8.43314469e-01, 1.14071526e-01, -2.23938376e-02, 1.09957650e-01,
-3.88728201e-01, 1.39827147e-01, 2.20899284e-03, -1.90839812e-01,
-9.09137726e-01, 1.57145649e-01, -1.39061660e-02, -2.81439349e-02,
1.31379187e-01, 1.93342119e-02, -3.97078514e-01, 4.37840447e-02,
5.70612431e-01, -3.71424943e-01, 1.27987966e-01, -1.53837383e-01,
-1.62056446e-01, -2.61603892e-02, -9.74950790e-01, -2.85338938e-01,
1.48266554e-06, -5.19999146e-01, -1.39436916e-01, -1.61675125e-01,
2.82035142e-01, 5.65708935e-01, 1.78672537e-01, 2.84627140e-01,
-1.29202381e-02, -5.35536408e-01, 6.67068288e-02, 1.26034901e-01,
4.77381468e-01, 4.13616210e-01, -8.82375419e-01, 2.16037527e-01,
-7.70060718e-03, -1.17288813e-01, 3.86771172e-01, 3.40055674e-01,
-3.02813143e-01, -2.90828168e-01, -4.41879481e-01, -3.02490562e-01,
1.14623025e-01, 5.78140691e-02, -5.26804924e-01, -1.41756445e-01,
2.43902951e-03, 6.49944693e-02, -2.29362592e-01, -5.48198938e-01,
-7.99068272e-01, -3.52486148e-02, 4.28467467e-02, -5.25768399e-01,
1.63442969e-01, -2.11263120e-01, -6.78404570e-02, -2.00107336e-01,
4.71601546e-01, -4.66121018e-01, 2.91595191e-01, -5.46462014e-02,
-5.07597744e-01, 6.30303860e-01, -7.32594371e-01, 1.00498527e-01,
-7.07668364e-01, -8.52217302e-02, -5.60935438e-02, -1.76870823e-03,
3.38252485e-01, -1.68113291e-01, -1.64995581e-01, 1.30709872e-01,
-9.02270138e-01, 1.71258092e-01, -5.64923435e-02, -2.03939527e-01],
dtype=float32),
'prediction': 26}]
def batch_metric(batch):
labels = batch["label"]
preds = batch["prediction"]
mcm = multilabel_confusion_matrix(labels, preds)
tn, fp, fn, tp = [], [], [], []
for i in range(mcm.shape[0]):
tn.append(mcm[i, 0, 0]) # True negatives
fp.append(mcm[i, 0, 1]) # False positives
fn.append(mcm[i, 1, 0]) # False negatives
tp.append(mcm[i, 1, 1]) # True positives
return {"TN": tn, "FP": fp, "FN": fn, "TP": tp}
# Aggregated metrics after processing all batches.
metrics_ds = pred_ds.map_batches(batch_metric)
aggregate_metrics = metrics_ds.sum(["TN", "FP", "FN", "TP"])
# Aggregate the confusion matrix components across all batches.
tn = aggregate_metrics["sum(TN)"]
fp = aggregate_metrics["sum(FP)"]
fn = aggregate_metrics["sum(FN)"]
tp = aggregate_metrics["sum(TP)"]
# Calculate metrics.
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
accuracy = (tp + tn) / (tp + tn + fp + fn)
2025-06-23 14:25:31,814 INFO logging.py:295 -- Registered dataset logger for dataset dataset_59_0
2025-06-23 14:25:31,828 INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_59_0. Full logs are in /tmp/ray/session_2025-06-23_13-49-50_102769_2149/logs/ray-data
2025-06-23 14:25:31,829 INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_59_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)->Map(convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbedImages)] -> TaskPoolMapOperator[MapBatches(drop_columns)] -> TaskPoolMapOperator[MapBatches(TorchPredictor)] -> TaskPoolMapOperator[MapBatches(batch_metric)] -> AllToAllOperator[Aggregate] -> LimitOperator[limit=1]
2025-06-23 14:25:31,856 INFO actor_pool_map_operator.py:633 -- Scaling up actor pool by 4 (reason=scaling to min size, running=0, restarting=0, pending=0)
(_MapWorker pid=7186, ip=10.0.90.122) Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
2025-06-23 14:25:43,855 INFO actor_pool_map_operator.py:661 -- Scaled down actor pool by 1 (reason=None; running=3, restarting=0, pending=0)
(MapBatches(TorchPredictor) pid=7259, ip=10.0.90.122) /tmp/ipykernel_14938/3214280880.py:6: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.)
(_MapWorker pid=14469, ip=10.0.103.152) Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`. [repeated 3x across cluster]
2025-06-23 14:25:44,370 INFO actor_pool_map_operator.py:661 -- Scaled down actor pool by 1 (reason=None; running=2, restarting=0, pending=0)
2025-06-23 14:25:44,899 INFO actor_pool_map_operator.py:661 -- Scaled down actor pool by 1 (reason=None; running=1, restarting=0, pending=0)
2025-06-23 14:25:45,419 INFO actor_pool_map_operator.py:661 -- Scaled down actor pool by 1 (reason=None; running=0, restarting=0, pending=0)
(MapBatches(TorchPredictor) pid=7393, ip=10.0.90.122) /tmp/ipykernel_14938/3214280880.py:6: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.) [repeated 4x across cluster]
(MapBatches(TorchPredictor) pid=46643, ip=10.0.102.235) /tmp/ipykernel_14938/3214280880.py:6: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.) [repeated 4x across cluster]
(MapBatches(TorchPredictor) pid=15409, ip=10.0.69.70) /tmp/ipykernel_14938/3214280880.py:6: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.) [repeated 4x across cluster]
(MapBatches(TorchPredictor) pid=16788, ip=10.0.90.122) /tmp/ipykernel_14938/3214280880.py:6: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.) [repeated 4x across cluster]
(MapBatches(TorchPredictor) pid=15462, ip=10.0.67.42) /tmp/ipykernel_14938/3214280880.py:6: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.) [repeated 4x across cluster]
(MapBatches(TorchPredictor) pid=47017, ip=10.0.102.235) /tmp/ipykernel_14938/3214280880.py:6: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.) [repeated 4x across cluster]
(MapBatches(TorchPredictor) pid=15584, ip=10.0.69.70) /tmp/ipykernel_14938/3214280880.py:6: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.) [repeated 4x across cluster]
(MapBatches(TorchPredictor) pid=17097, ip=10.0.103.152) /tmp/ipykernel_14938/3214280880.py:6: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.) [repeated 4x across cluster]
(MapBatches(TorchPredictor) pid=17183, ip=10.0.90.122) /tmp/ipykernel_14938/3214280880.py:6: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.) [repeated 4x across cluster]
2025-06-23 14:26:35,251 INFO streaming_executor.py:227 -- ✔️ Dataset dataset_59_0 execution finished in 63.42 seconds
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1: {f1:.2f}")
print(f"Accuracy: {accuracy:.2f}")
Precision: 0.84
Recall: 0.84
F1: 0.84
Accuracy: 0.98
🚨 Note: Reset this notebook using the “🔄 Restart” button location at the notebook’s menu bar. This way we can free up all the variables, utils, etc. used in this notebook.