Collate Utilities#

from dataclasses import dataclass
from typing import Dict, List, Tuple, Union
import torch
from ray import cloudpickle as pickle
import pyarrow as pa

# (dtype, shape, offset)
FEATURE_TYPE = Tuple[torch.dtype, torch.Size, int]
TORCH_BYTE_ELEMENT_TYPE = torch.uint8

def _create_binary_array_from_buffer(buffer: bytes) -> pa.BinaryArray:
    """Zero-copy create a binary array from a buffer."""
    data_buffer = pa.py_buffer(buffer)
    return pa.Array.from_buffers(
        pa.binary(),
        1,
        [
            None,
            pa.array([0, data_buffer.size], type=pa.int32()).buffers()[1],
            data_buffer,
        ],
    )

@dataclass
class _Metadata:
    features: Dict[str, List[FEATURE_TYPE]]
    total_buffer_size: int

@dataclass
class _TensorBatch:
    """Internal class for serializing/deserializing tensor batches."""
    buffer: torch.Tensor
    metadata: _Metadata

    @classmethod
    def from_batch(cls, batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]]) -> '_TensorBatch':
        """Serialize a batch of tensors into a single buffer."""
        features: Dict[str, List[FEATURE_TYPE]] = {}
        flattened_binary_tensors = []
        total_buffer_size = 0
        
        for name, tensors in batch.items():
            features[name] = []
            if not isinstance(tensors, list):
                tensors = [tensors]
            for tensor in tensors:
                flattened_tensor = tensor.flatten().contiguous().view(TORCH_BYTE_ELEMENT_TYPE)
                flattened_binary_tensors.append(flattened_tensor)
                features[name].append((tensor.dtype, tensor.shape, total_buffer_size))
                total_buffer_size += flattened_tensor.shape[0]
        
        buffer = torch.empty(total_buffer_size, dtype=TORCH_BYTE_ELEMENT_TYPE)
        cur_offset = 0
        for flattened_tensor in flattened_binary_tensors:
            buffer[cur_offset:cur_offset + flattened_tensor.shape[0]] = flattened_tensor
            cur_offset += flattened_tensor.shape[0]
        
        return _TensorBatch(
            buffer=buffer,
            metadata=_Metadata(
                features=features,
                total_buffer_size=total_buffer_size,
            ),
        )

    def to_table(self) -> pa.Table:
        """Convert to a single-row PyArrow table."""
        buffer_array = _create_binary_array_from_buffer(self.buffer.numpy().data)
        metadata_array = _create_binary_array_from_buffer(pickle.dumps(self.metadata))
        return pa.Table.from_arrays(
            arrays=[buffer_array, metadata_array],
            names=["_buffer", "_metadata"],
        )

    @classmethod
    def from_table(cls, table: pa.Table) -> '_TensorBatch':
        """Deserialize from a single-row PyArrow table."""
        return _TensorBatch(
            buffer=torch.frombuffer(
                table["_buffer"].chunks[0].buffers()[2],
                dtype=TORCH_BYTE_ELEMENT_TYPE
            ),
            metadata=pickle.loads(table["_metadata"].chunks[0].buffers()[2]),
        )

    def to_batch(self, pin_memory: bool = False) -> Dict[str, List[torch.Tensor]]:
        """Deserialize back to a batch of tensors."""
        batch = {}
        storage_buffer = self.buffer.untyped_storage()
        offsets = []
        for name, features in self.metadata.features.items():
            for _, _, offset in features:
                offsets.append(offset)
        offsets.append(self.metadata.total_buffer_size)
        
        offset_id = 0
        for name, features in self.metadata.features.items():
            batch[name] = []
            for dtype, shape, _ in features:
                # Create a zero-copy view of the byte slice.
                byte_slice = self.buffer[offsets[offset_id]:offsets[offset_id + 1]]
                tensor = torch.frombuffer(
                    byte_slice.numpy().data, dtype=dtype
                ).view(shape)
                if pin_memory:
                    tensor = tensor.pin_memory()
                batch[name].append(tensor)
                offset_id += 1
        return batch

# Helper functions for use in your code
def serialize_tensors_to_table(batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]]) -> pa.Table:
    """Serialize a batch of tensors to a PyArrow table."""
    return _TensorBatch.from_batch(batch).to_table()

def deserialize_table_to_tensors(table: pa.Table, pin_memory: bool = False) -> Dict[str, List[torch.Tensor]]:
    """Deserialize a PyArrow table back to tensors."""
    return _TensorBatch.from_table(table).to_batch(pin_memory=pin_memory)

Random Text Generator#

The following helper functions generate random text samples with labels:

import random
import string
import ray

def random_text(length: int) -> str:
    """Generate random text of specified length."""
    if length <= 0:
        return ""

    if length <= 3:
        return "".join(random.choices(string.ascii_lowercase, k=length))

    words = []
    current_length = 0

    while current_length < length:
        remaining = length - current_length
        
        if remaining <= 4:
            word_length = remaining
            word = "".join(random.choices(string.ascii_lowercase, k=word_length))
            words.append(word)
            break
        else:
            max_word_length = min(10, remaining - 1)
            if max_word_length >= 3:
                word_length = random.randint(3, max_word_length)
            else:
                word_length = remaining
            word = "".join(random.choices(string.ascii_lowercase, k=word_length))
            words.append(word)
            current_length += len(word) + 1

    text = " ".join(words)
    return text[:length]

def random_label() -> int:
    """Pick a random label."""
    labels = [0, 1, 2, 3, 4, 5, 6, 7]
    return random.choice(labels)

def create_mock_ray_text_dataset(dataset_size: int = 96, min_len: int = 5, max_len: int = 100):
    """Create a mock Ray dataset with random text and labels."""
    numbers = random.choices(range(min_len, max_len + 1), k=dataset_size)
    ray_dataset = ray.data.from_items(numbers)

    def map_to_text_and_label(item):
        length = item['item']
        text = random_text(length)
        label = random_label()
        return {
            "length": length,
            "text": text,
            "label": label
        }

    text_dataset = ray_dataset.map(map_to_text_and_label)
    return text_dataset