"""
trainer.py
Custom Hugging Face Trainer that allows for online eval of multiple datasets.
"""
import collections
import logging
import time
from dataclasses import dataclass # type: ignore
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
try:
from torchdata.datapipes.iter import IterDataPipe
except ImportError:
from torch.utils.data import IterDataPipe
from transformers import (
AutoModelForCausalLM,
BatchEncoding,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainingArguments,
)
from transformers.data.data_collator import DataCollator
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction, speed_metrics
# Nest Overwatch under root `mistral` logger, inheriting formatting!
overwatch = logging.getLogger("mistral.core.trainer")
[docs]class OnlineBenchmarkTrainer(Trainer):
"""
Trainer that handles online evaluation datasets -- e.g., LAMBADA and Wikitext 103 Perplexity Scores.
Overrides `evaluate` to trigger eval on each online dataset.
"""
control: Any
_globalstep_last_logged: int
def __init__(
self,
model: AutoModelForCausalLM,
args: TrainingArguments,
data_collator: Optional[DataCollator] = None,
dataset_name: str = "unknown",
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
custom_eval_datasets: Optional[Dict[str, Dataset]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Callable[[], PreTrainedModel] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
):
super(OnlineBenchmarkTrainer, self).__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
)
self.dataset_name = dataset_name
custom_eval_datasets = custom_eval_datasets if custom_eval_datasets is not None else {}
# No idea why, but you can't use a dict to store the datasets. They must be stored separately as class objects.
# It might be related to how module need custom ModuleDicts for dictionaries to work with distributed models.
self.wikitext_dataset = custom_eval_datasets.get("wikitext", None)
self.lambada_dataset = custom_eval_datasets.get("lambada", None)
[docs] def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
eval_ppl_datasets: bool = True,
) -> Dict[str, float]:
# Normal Evaluate -- this calls the on_evaluate callback
metrics = super(OnlineBenchmarkTrainer, self).evaluate(eval_dataset, ignore_keys, metric_key_prefix)
# Create New Metrics Dictionary --> TODO trainer.A :: Fix so doesn't explicitly assume OpenWebText
metrics = {
f"eval_{self.dataset_name}_loss": metrics["eval_loss"],
f"eval_{self.dataset_name}_ppl": np.exp(metrics["eval_loss"]),
f"eval_{self.dataset_name}_runtime": metrics["eval_runtime"],
f"eval_{self.dataset_name}_samples_per_second": metrics["eval_samples_per_second"],
"epoch": metrics.get("epoch"),
}
self.log(metrics)
if not eval_ppl_datasets:
return metrics
# Start Memory Tracker
self._memory_tracker.start()
# Iterate over each Online Evaluation Dataset - Store New Metrics for Control Call
new_dataset_metrics = {}
if self.wikitext_dataset is not None:
output_metrics = self.single_dataset_eval("wikitext", self.wikitext_dataset, metric_key_prefix)
new_dataset_metrics.update(output_metrics)
self.log(output_metrics)
if self.lambada_dataset is not None:
output_metrics = self.single_dataset_eval("lambada", self.lambada_dataset, metric_key_prefix)
new_dataset_metrics.update(output_metrics)
self.log(output_metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, new_dataset_metrics)
self._memory_tracker.stop_and_update_metrics(new_dataset_metrics)
metrics.update(new_dataset_metrics)
return metrics
[docs] def single_dataset_eval(self, dataset_name: str, dataset: Dataset, metric_key_prefix: str) -> Dict[str, float]:
"""Run Perplexity Evaluation on a Single Dataset."""
custom_metric_key_prefix = f"{metric_key_prefix}_{dataset_name}"
if dataset is not None and not isinstance(dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
eval_dataloader = self.get_eval_dataloader(dataset)
start_time = time.time()
output = self.prediction_loop(
eval_dataloader,
description=f"Evaluation {dataset_name}",
prediction_loss_only=True,
metric_key_prefix=custom_metric_key_prefix,
)
n_samples = len(dataset) if dataset is not None else 1
output.metrics.update(speed_metrics(custom_metric_key_prefix, start_time, n_samples))
# Compute perplexity --- Note :: this is unadjusted
ppl = np.exp(output.metrics[f"{custom_metric_key_prefix}_loss"])
output.metrics.update({f"{custom_metric_key_prefix}_ppl": ppl})
return output.metrics
[docs] def get_train_dataloader(self) -> DataLoader:
"""ensures we're shuffling if we're using a new-style (iterable) dataset"""
if isinstance(self.train_dataset, IterDataPipe):
return DataLoader(
self.train_dataset,
shuffle=True,
batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
else:
return super().get_train_dataloader()
[docs]@dataclass
class LMDataCollator:
tokenizer: PreTrainedTokenizerBase
[docs] def __call__(self, examples: List[BatchEncoding]):
batch = BatchEncoding(data={k: torch.tensor([v[k] for v in examples]) for k in examples[0].keys()})
if "labels" in batch:
labels = batch["labels"]
else:
labels = batch["input_ids"]
if self.tokenizer.pad_token_id is not None:
labels = labels.clone()
labels[labels == self.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch