Collate Utilities#
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union
import torch
from ray import cloudpickle as pickle
import pyarrow as pa
# (dtype, shape, offset)
FEATURE_TYPE = Tuple[torch.dtype, torch.Size, int]
TORCH_BYTE_ELEMENT_TYPE = torch.uint8
def _create_binary_array_from_buffer(buffer: bytes) -> pa.BinaryArray:
"""Zero-copy create a binary array from a buffer."""
data_buffer = pa.py_buffer(buffer)
return pa.Array.from_buffers(
pa.binary(),
1,
[
None,
pa.array([0, data_buffer.size], type=pa.int32()).buffers()[1],
data_buffer,
],
)
@dataclass
class _Metadata:
features: Dict[str, List[FEATURE_TYPE]]
total_buffer_size: int
@dataclass
class _TensorBatch:
"""Internal class for serializing/deserializing tensor batches."""
buffer: torch.Tensor
metadata: _Metadata
@classmethod
def from_batch(cls, batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]]) -> '_TensorBatch':
"""Serialize a batch of tensors into a single buffer."""
features: Dict[str, List[FEATURE_TYPE]] = {}
flattened_binary_tensors = []
total_buffer_size = 0
for name, tensors in batch.items():
features[name] = []
if not isinstance(tensors, list):
tensors = [tensors]
for tensor in tensors:
flattened_tensor = tensor.flatten().contiguous().view(TORCH_BYTE_ELEMENT_TYPE)
flattened_binary_tensors.append(flattened_tensor)
features[name].append((tensor.dtype, tensor.shape, total_buffer_size))
total_buffer_size += flattened_tensor.shape[0]
buffer = torch.empty(total_buffer_size, dtype=TORCH_BYTE_ELEMENT_TYPE)
cur_offset = 0
for flattened_tensor in flattened_binary_tensors:
buffer[cur_offset:cur_offset + flattened_tensor.shape[0]] = flattened_tensor
cur_offset += flattened_tensor.shape[0]
return _TensorBatch(
buffer=buffer,
metadata=_Metadata(
features=features,
total_buffer_size=total_buffer_size,
),
)
def to_table(self) -> pa.Table:
"""Convert to a single-row PyArrow table."""
buffer_array = _create_binary_array_from_buffer(self.buffer.numpy().data)
metadata_array = _create_binary_array_from_buffer(pickle.dumps(self.metadata))
return pa.Table.from_arrays(
arrays=[buffer_array, metadata_array],
names=["_buffer", "_metadata"],
)
@classmethod
def from_table(cls, table: pa.Table) -> '_TensorBatch':
"""Deserialize from a single-row PyArrow table."""
return _TensorBatch(
buffer=torch.frombuffer(
table["_buffer"].chunks[0].buffers()[2],
dtype=TORCH_BYTE_ELEMENT_TYPE
),
metadata=pickle.loads(table["_metadata"].chunks[0].buffers()[2]),
)
def to_batch(self, pin_memory: bool = False) -> Dict[str, List[torch.Tensor]]:
"""Deserialize back to a batch of tensors."""
batch = {}
storage_buffer = self.buffer.untyped_storage()
offsets = []
for name, features in self.metadata.features.items():
for _, _, offset in features:
offsets.append(offset)
offsets.append(self.metadata.total_buffer_size)
offset_id = 0
for name, features in self.metadata.features.items():
batch[name] = []
for dtype, shape, _ in features:
# Create a zero-copy view of the byte slice.
byte_slice = self.buffer[offsets[offset_id]:offsets[offset_id + 1]]
tensor = torch.frombuffer(
byte_slice.numpy().data, dtype=dtype
).view(shape)
if pin_memory:
tensor = tensor.pin_memory()
batch[name].append(tensor)
offset_id += 1
return batch
# Helper functions for use in your code
def serialize_tensors_to_table(batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]]) -> pa.Table:
"""Serialize a batch of tensors to a PyArrow table."""
return _TensorBatch.from_batch(batch).to_table()
def deserialize_table_to_tensors(table: pa.Table, pin_memory: bool = False) -> Dict[str, List[torch.Tensor]]:
"""Deserialize a PyArrow table back to tensors."""
return _TensorBatch.from_table(table).to_batch(pin_memory=pin_memory)
Random Text Generator#
The following helper functions generate random text samples with labels:
import random
import string
import ray
def random_text(length: int) -> str:
"""Generate random text of specified length."""
if length <= 0:
return ""
if length <= 3:
return "".join(random.choices(string.ascii_lowercase, k=length))
words = []
current_length = 0
while current_length < length:
remaining = length - current_length
if remaining <= 4:
word_length = remaining
word = "".join(random.choices(string.ascii_lowercase, k=word_length))
words.append(word)
break
else:
max_word_length = min(10, remaining - 1)
if max_word_length >= 3:
word_length = random.randint(3, max_word_length)
else:
word_length = remaining
word = "".join(random.choices(string.ascii_lowercase, k=word_length))
words.append(word)
current_length += len(word) + 1
text = " ".join(words)
return text[:length]
def random_label() -> int:
"""Pick a random label."""
labels = [0, 1, 2, 3, 4, 5, 6, 7]
return random.choice(labels)
def create_mock_ray_text_dataset(dataset_size: int = 96, min_len: int = 5, max_len: int = 100):
"""Create a mock Ray dataset with random text and labels."""
numbers = random.choices(range(min_len, max_len + 1), k=dataset_size)
ray_dataset = ray.data.from_items(numbers)
def map_to_text_and_label(item):
length = item['item']
text = random_text(length)
label = random_label()
return {
"length": length,
"text": text,
"label": label
}
text_dataset = ray_dataset.map(map_to_text_and_label)
return text_dataset