# IterDataPipe for preprocessing data, tokenizing, and caching to disk
# The general file format we're going with is an apache parquet file with columns for the output of the tokenizer,
# A row is a single doc. Parquet files are efficient column stores, which means that we can grab token slices from
# multiple docs as a single operation, which makes concatenation much faster (and means we don't need to cache slices).
# (We might add back in file metadata later? though huggingface deletes it)
# We don't want to have one giant file, so we'll split it up into chunks.
# In general, an IndexedDataset is a directory of parquet files plus a metadata file called the ledger.
# The ledger is a json file with the following structure:
# {
# "files": { "file_name": <name>, "num_tokens": <num_tokens>},
# }
# We don't actually use the num_tokens field, but it's useful for sanity checking.
# The ledger is written last, so we can always check to see if we were interrupted.
import json
import logging
import os
from pathlib import Path
from typing import Iterator, Optional, Union
import datasets
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
try:
from torchdata.datapipes.iter import IterDataPipe
except ImportError:
from torch.utils.data import IterDataPipe
from tqdm import tqdm
from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerFast
from src.corpora.tokenization_utils import batch_tokenize, concatenate_and_group_texts
# As a heuristic, we're aiming for files that are around ~250MB
# Typically we're training on sequences of length ~1024 and batch size up to 512, so better to make it divisible by that.
# 4bytes * 512 * 1024 = 2Mi, so we'll go with 128 * 512 * 1024 = 67108864 tokens, which is about 256MiB
NUM_TOKENS_PER_FILE = 67108864
overwatch = logging.getLogger("mistral.corpora.indexer")
# TASKS:
# TODO: figure out directory structure for caching multiple sources
# TODO: if we're super careful we can compute the number of samples (for a given batch size and stride) in advance
# if we do that, we can implement a Map-style dataset, which is somewhat preferable when not streaming
# TODO: bring in sprucfluo/simultaneous caching and streaming if we want.
LEDGER_FILE = "ledger.json"
[docs]class IndexedDataset(IterDataPipe[BatchEncoding]):
def __init__(self, cache_dir, seq_len: int, stride: Optional[int] = None):
self.cache_dir = cache_dir
self.ledger = self._load_ledger()
self.seq_len = seq_len
self.stride = stride
def _files(self):
for entry in self.ledger["files"]:
yield entry["file_name"]
def __iter__(self):
for file_name in self._files():
for entry in read_cache_file(file_name, flatten=True):
yield from concatenate_and_group_texts(entry, self.seq_len, self.stride)
@staticmethod
def build_or_load(
token_iter: Iterator[BatchEncoding],
cache_dir: Union[str, os.PathLike],
seq_len: int,
stride: Optional[int] = None,
num_tokens_per_file: int = NUM_TOKENS_PER_FILE,
file_template: str = "docs-{}.parquet",
) -> "IndexedDataset":
os.makedirs(cache_dir, exist_ok=True)
ledger_file = os.path.join(cache_dir, LEDGER_FILE)
if os.path.exists(ledger_file):
overwatch.info("Found existing indexed dataset at %s", cache_dir)
return IndexedDataset(cache_dir, seq_len, stride)
file_index = 0
current_writer: Optional[pq.ParquetWriter] = None
current_num_tokens = 0
tq: tqdm = tqdm(desc=f"file {file_index} progress", total=num_tokens_per_file, unit="token")
file_out: Optional[Path] = None
# list of (file_name, num_tokens), to be output at the end if we finish the whole iterator
ledger_files = []
def close_writer():
nonlocal current_writer, file_out, file_index, current_num_tokens
if current_writer is not None:
current_writer.close()
current_writer = None
if current_num_tokens > 0:
ledger_files.append({"file_name": str(file_out), "num_tokens": current_num_tokens})
try:
for tokens in token_iter:
batch = _as_record_batch(tokens)
batch_len = sum(len(t) for t in tokens["input_ids"])
if current_writer and current_num_tokens + batch_len > num_tokens_per_file:
close_writer()
if not current_writer:
file_out = Path(f"{cache_dir}/{file_template.format(file_index)}")
file_out.parent.mkdir(parents=True, exist_ok=True)
file_index += 1
current_writer = pq.ParquetWriter(file_out, batch.schema, version="2.6", compression="ZSTD")
current_num_tokens = 0
tq.reset()
tq.set_description(f"file {file_index} progress")
current_writer.write_batch(batch)
current_num_tokens += batch_len
tq.update(batch_len)
if current_writer:
tq.reset(current_num_tokens)
tq.update(current_num_tokens)
close_writer()
# if we successfully wrote the whole iterator, we can write the ledger
with open(ledger_file, "w") as f:
ledger = {"files": ledger_files}
json.dump(ledger, f)
return IndexedDataset(cache_dir, seq_len, stride)
except (KeyboardInterrupt, InterruptedError):
current_writer.close()
current_writer = None
file_out.unlink(missing_ok=True) # type: ignore
raise
def _load_ledger(self):
ledger_path = os.path.join(self.cache_dir, LEDGER_FILE)
if os.path.exists(ledger_path):
with open(ledger_path, "r") as f:
return json.load(f)
else:
raise FileNotFoundError(f"{self.cache_dir} is not a complete cache")
[docs]def read_cache_file(file, flatten: bool = False) -> Iterator[BatchEncoding]:
"""Reads the cache files produced by cache_and_group and yields tokenized sequences.
If flatten is false, this returns the docs as they were presented to the caching process. If flatten is True,
then the documents returned are actually concatenated documents, where the number is the number of documents
presented as a batch to the caching process."""
for b in pq.read_table(file).to_batches():
if flatten:
# insert a newaxis to the beginning so that it appears to be bs=1
yield BatchEncoding(
{
b.field(i).name: b.column(i).values.to_numpy(zero_copy_only=True)[np.newaxis, :]
for i in range(b.num_columns)
}
)
else:
yield BatchEncoding(
{b.field(i).name: b.column(i).to_numpy(zero_copy_only=False) for i in range(b.num_columns)}
)
def _as_record_batch(doc: BatchEncoding) -> pa.RecordBatch:
names, columns = zip(*[(k, pa.array(v)) for k, v in doc.items()])
return pa.RecordBatch.from_arrays(list(columns), names)
if __name__ == "__main__":
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained("gpt2")
dataset = datasets.load_dataset("dlwh/wikitext_2_detokenized", split="train")
token_iter = batch_tokenize(dataset, tokenizer, batch_size=1000)
indexer = IndexedDataset.build_or_load(
batch_tokenize(dataset, tokenizer, batch_size=1000), "cache/wikitext-2-indexed", seq_len=512, stride=None
)
for i, batch in enumerate(indexer):
print(i, batch)