"""
callbacks.py
Custom Hugging Face Weights and Biases Callback that allows for writing custom metrics, better resuming functionality.
"""
import logging
import os
import time
from bisect import bisect_left
from typing import Dict, List, Optional
import jsonlines
import torch
from transformers import (
PreTrainedModel,
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
is_torch_tpu_available,
)
from transformers.integrations import WandbCallback
# Nest Overwatch under root `mistral` logger, inheriting formatting!
overwatch = logging.getLogger("mistral.core.callbacks")
# Helper Function
[docs]def rewrite_logs(d: Dict[str, float]) -> Dict[str, float]:
new_d = {}
eval_prefix = "eval_"
eval_prefix_len = len(eval_prefix)
for k, v in d.items():
if k.startswith(eval_prefix):
new_d["eval/" + k[eval_prefix_len:]] = v
else:
new_d["train/" + k] = v
return new_d
[docs]class CustomWandbCallback(WandbCallback):
"""Custom Weights and Biases Callback used by Mistral for logging information from the Huggingface Trainer."""
def __init__(
self,
project: str,
json_file: str,
group: str = None,
resume: bool = False,
resume_run_id: str = None,
wandb_dir: str = None,
api_key_path: str = None,
):
super(CustomWandbCallback, self).__init__()
# Authenticate via API key if available. If WANDB_API_KEY is set, the user will not be prompted login.
if api_key_path:
with open(os.path.expanduser(api_key_path), "r") as f:
os.environ["WANDB_API_KEY"] = f.read().strip()
# Set the Project Name
if isinstance(project, str):
overwatch.info(f"Setting W&B Project: {project}")
os.environ["WANDB_PROJECT"] = project
# Set wandb.watch(model) to False, throws an error otherwise
# Note: we manually watch the model in self.on_train_begin(..)
os.environ["WANDB_WATCH"] = "false"
# Set up JSON Log File
self.json_file = json_file
# Wandb arguments
self.group, self.resume, self.resume_run_id, self.wandb_dir = group, resume, resume_run_id, wandb_dir
# Timers
self.within_time: Optional[float] = None
self.between_time: Optional[float] = None
def _append_jsonl(self, data) -> None:
with jsonlines.open(self.json_file, mode="a") as writer:
writer.write(data)
def _log_memory(self, state, prefix="train_info"):
"""Simple method to log memory usage at the end of every training batch."""
if state.is_world_process_zero and torch.cuda.is_available():
memory_usage = {
f"{prefix}/memory_allocated": torch.cuda.memory_allocated() / 2**20,
f"{prefix}/memory_max_allocated": torch.cuda.max_memory_allocated() / 2**20,
f"{prefix}/memory_reserved": torch.cuda.memory_reserved() / 2**20,
f"{prefix}/memory_max_reserved": torch.cuda.max_memory_reserved() / 2**20,
}
# Log to _all_ loggers
self._wandb.log(memory_usage, step=state.global_step)
if state.global_step > self._last_log_step:
self._append_jsonl({"train_info": memory_usage, "step": state.global_step})
[docs] def setup(self, args, state, model, **kwargs):
"""
Note: have to override this method in order to inject additional arguments into the wandb.init call. Currently,
HF provides no way to pass kwargs to that.
Setup the optional Weights & Biases (`wandb`) integration.
One can subclass and override this method to customize the setup if needed. Find more information `here
<https://docs.wandb.ai/integrations/huggingface>`__. You can also override the following environment variables:
Environment:
WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to log model as artifact at the end of training.
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
logging or :obj:`"all"` to log gradients and parameters.
WANDB_PROJECT (:obj:`str`, `optional`, defaults to :obj:`"huggingface"`):
Set this to a custom string to store results in a different project.
WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to disable wandb entirely. Set `WANDB_DISABLED=true` to disable.
"""
if self._wandb is None:
return
self._initialized = True
# Process Zero Barrier --> only Log on First Process!
if state.is_world_process_zero:
overwatch.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**args.to_sanitized_dict()}
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
init_args, trial_name = {}, state.trial_name
if trial_name is not None:
run_name = trial_name
init_args["group"] = args.run_name
else:
run_name = args.run_name
init_args["group"] = self.group
# Add Additional kwargs into init_args Dictionary
init_args = {**init_args, **kwargs}
if self._wandb.run is None:
self._wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"),
name=run_name,
**init_args,
)
# Add Configuration Parameters
self._wandb.config.update(combined_dict, allow_val_change=True)
# Keep track of Model Topology and Gradients, Unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
self._wandb.watch(
model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
)
# Custom JSON Resume Behavior
if self.resume and os.path.exists(self.json_file):
resume_reader = jsonlines.open(self.json_file, mode="r")
for last_log in resume_reader:
pass
self._last_log_step = last_log["step"]
resume_reader.close()
else:
self._last_log_step = -1
self.jsonl_writer = jsonlines.open(self.json_file, mode="w" if not self.resume else "a")
[docs] def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: PreTrainedModel = None,
tokenizer=None,
optimizer=None,
lr_scheduler=None,
train_dataloader=None,
eval_dataloader=None,
**kwargs,
):
super().on_step_begin(args, state, control, **kwargs)
if state.is_world_process_zero:
if torch.cuda.is_available():
torch.cuda.synchronize()
# Compute and Log "Between Time Taken"
between_time_taken = time.time() - self.between_time
self._wandb.log({"train_info/time_between_train_steps": between_time_taken}, step=state.global_step)
if state.global_step > self._last_log_step:
self._append_jsonl(
{
"train_info/time_between_train_steps": between_time_taken,
"step": state.global_step,
}
)
# Start the timer within a step
self.within_time = time.time()
[docs] def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: PreTrainedModel = None,
tokenizer=None,
optimizer=None,
lr_scheduler=None,
train_dataloader=None,
eval_dataloader=None,
**kwargs,
):
if state.is_world_process_zero:
if torch.cuda.is_available():
torch.cuda.synchronize()
# Get time taken in step
within_time_taken = time.time() - self.within_time
# Aggregate Step Information
log_info = {
"info/global_step": state.global_step,
"train_info/time_within_train_step": within_time_taken,
}
if hasattr(optimizer, "loss_scale"):
log_info["train_info/loss_scale"] = optimizer.loss_scale
# Log
self._wandb.log(log_info, step=state.global_step)
if state.global_step > self._last_log_step:
self._append_jsonl(
{
"info/global_step": state.global_step,
"train_info/time_within_train_step": within_time_taken,
"step": state.global_step,
}
)
# Start timer for measuring between-step time
self.between_time = time.time()
[docs] def on_train_begin(
self,
args,
state,
control,
model: PreTrainedModel = None,
tokenizer=None,
optimizer=None,
lr_scheduler=None,
train_dataloader=None,
eval_dataloader=None,
**kwargs,
):
"""Calls wandb.init, we add additional arguments to that call using this method."""
# Pass in additional keyword arguments to the wandb.init call as kwargs
super().on_train_begin(
args, state, control, model, resume=self.resume, dir=self.wandb_dir, id=self.resume_run_id, **kwargs
)
# Process Zero Barrier
if state.is_world_process_zero:
# Watch the model
os.environ["WANDB_WATCH"] = "gradients"
self._wandb.watch(model, log="gradients", log_freq=args.eval_steps)
# Log model information
self._wandb.log(
{
"model-info/num_parameters": model.num_parameters(),
"model-info/trainable_parameters": model.num_parameters(only_trainable=True),
},
step=state.global_step,
)
if state.global_step > self._last_log_step:
self._append_jsonl(
{
"num_parameters": model.num_parameters(),
"trainable_parameters": model.num_parameters(only_trainable=True),
"step": state.global_step,
}
)
# Initialize the timers
self.within_time, self.between_time = time.time(), time.time()
[docs] def on_log(
self,
args,
state,
control,
model: PreTrainedModel = None,
tokenizer=None,
optimizer=None,
lr_scheduler=None,
train_dataloader=None,
eval_dataloader=None,
logs=None,
**kwargs,
):
# Null Check
if self._wandb is None:
return
# Initialization Check
if not self._initialized:
self.setup(args, state, model, reinit=False)
# Process Zero Barrier
if state.is_world_process_zero:
# Rewrite Logs
logs = rewrite_logs(logs)
self._wandb.log(logs, step=state.global_step)
# Log Memory Usage
self._log_memory(state)
# Append to the JSON Log
if state.global_step > self._last_log_step:
self._append_jsonl({"logs": logs, "step": state.global_step})
[docs]class CustomCheckpointCallback(TrainerCallback):
"""Custom Checkpoint Callback used by Mistral for Saving Checkpoints at different frequencies."""
def __init__(self, frequencies: List[List[int]]):
super(CustomCheckpointCallback, self).__init__()
# `frequencies` specifies when to checkpoint (based on the current training step). Specifically:
# Input: [(freq, until), (new_freq, until) ...]
# > We assert that `until` monotonically increases (lightweight validation)
self.freq, self.until = zip(*frequencies)
assert all(i < j for i, j in zip(self.until, self.until[1:])), "Frequency `until_step` not increasing!"
[docs] def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""Borrow Checkpoint Logic from `DefaultFlowCallback` to decide when to checkpoint."""
# Save (note we explicitly save checkpoint-0 in `train.py`, so no need to do it here)
c = state.global_step
if args.save_steps > 0 and c % (self.freq[bisect_left(self.until, c, hi=len(self.until) - 1)]) == 0:
control.should_save = True
return control