Initializing New Word Embeddings for Pretrained Language Models

Expanding the vocabulary of a pretrained language model can make it more useful, but new words' embeddings need to be initialized. The current default is bad, but it's easy to do better.


tl;dr

  • When we add words to the vocabulary of pretrained language models, the default behavior of huggingface is to initialize the new words’ embeddings with the same distribution used before pretraining – that is, small-norm random noise.
  • This can cause the pretrained language model to place probability \(\approx 1\) on the new word(s) for every (or most) prefix(es). Generating from this network only generates the new words.
  • This is because the pretraining process sometimes leads all logits (dot products between softmax matrix embeddings and the vectors in the last layer of the network) to be negative and large.
  • So the new words’ logits – roughly 0, since the zero vector dotted with anything is zero – dominate the partition function of the softmax, since \(\exp(0) \gg \exp(-\beta)\) for large \(\beta\).
  • Empirically, this can lead to worse domain adaptation performance – the first gradient steps just remove probability from the new words!
  • Instead, just average all existing embeddings to initialize new embeddings; this bounds the KL-divergence between the pre-expansion and post-expansion language models’ token-level distributions, and is nice for finetuning.

This problem has been observed in huggingface issues (1) (2), and averaging has been proposed as a workaround by Eric Mitchell (and anecdotally by others); this blog post argues averaging should be the new “default” and provides some math to explain the problem and show why averaging is a good solution.

Generation before vocabulary augmentation (a well-formed sentence), and after vocabulary augmentation (consisting of only the newly added word.)

A hidden problem in vocabulary expansion

Pretrained language models increasingly form the foundation of modern natural language processing. Commonly, language models are trained with a fixed vocabulary of, e.g., 50,000 word (pieces). When adapting language models to a downstream task or domain, it’s frequently useful to consider expanding the vocabulary. For example, if one is adapting GPT2 to J.R.R. Tolkien’s The Lord of the Rings (LOTR), one may be concerned to find that Frodo is tokenzed as follows:

import transformers
tok = transformers.GPT2Tokenizer.from_pretrained('gpt2')
tok.convert_ids_to_tokens(tok('Frodo')[input_ids'])
['F', 'ro', 'do']

The word Frodo is common in the domain of interest, but not in general web text. So the tokenizer splits Frodo, and many other words common in LOTR, into multiple tokens, extending average sequence lengths and forcing knowledge about those entities to be stored in a manner (1) distributed across those word tokens and (2) composed from the tokens throughout the contextualization process.

If I just generate from GPT2 conditioned on a sentence like Aragorn told Frodo to mind Lothlorien, the tokenizer breaks the 6-word sentence into 12 subwords.

lotr_sent = 'Aragorn told Frodo to mind Lothlorien'
tok.convert_ids_to_tokens(tok(lotr_sent)['input_ids'])
['Ar', 'ag', 'orn', 'Ġtold', 'ĠFro', 'do', 'Ġto', 'Ġmind', 'ĠL', 'oth', 'lor', 'ien']

Here’s a sample generated from gpt2 small (125M parameters) conditioned on that string:

lotr_sent = 'Aragorn told Frodo to mind Lothlorien'
model = transformers.AutoModelForCausalLM.from_pretrained('gpt2')
tok.decode(model.generate(**tok(lotr_sent, return_tensors='pt'), do_sample=True)[0])
'Aragorn told Frodo to mind Lothlorien for a while.\n\nThe Ring'

Nice, the model knows that The Ring is a reasonable thing to mention in this context. Still, if we want our model to perform well on LOTR text, we might want to add words to our vocabulary so that common words are less likely to be split into many pieces. This is easy in huggingface transformers:

tok.add_tokens(['Aragorn', 'Frodo', 'Lothlorien'])
tok.convert_ids_to_tokens(tok(lotr_sent)['input_ids'])
['Aragorn', 'told', 'Frodo', 'to', 'Ġmind', 'Lothlorien']

But remember, each word in the vocabulary needs its own word embedding; these new words have no corresponding embedding in the model right now. So, we have to resize the model’s embeddings to make new embeddings for the new words:

model.resize_token_embeddings(len(tok))

The new words’ embeddings haven’t been trained, so we don’t expect them to be useful yet. We need to train them as we adapt on LOTR text. Let’s do a quick sanity check on the model before we start finetuning, conditioning on a string not containing any of the new words and generating from the model:

sent2 = 'Dogs are great because they are '
tok.decode(model.generate(**tok(sent2, return_tensors='pt'), do_sample=True)[0])
'Dogs are great because they are  Aragorn Aragorn Aragorn Aragorn Frodo Aragorn'

What happened? The model only generated new tokens. Just by adding new, untrained word embeddings to our model, we’ve lost the distribution we pretrained.

Just by adding new, untrained word embeddings to our model, we've lost the distribution we pretrained.


In fact, if we generated from any prefix, we’d still only generate words that we’ve just now added. That seems wrong!

I’ll characterize exactly when this happens, and how we can provably avoid it, for any LM (for any string up until we see a new word for the first time; after that, it’s hard to tell what happens since the new word in the input will undoubtedly lead to unpredictable behavior by the LM.).

Vocabulary expansion

From a “user of lanugage models” perspective, I was expecting that adding new words to my vocabulary, before training their embeddings, shouldn’t really affect the pretrained language model’s distribution.1 So why did it?

The answer has to do with the distribution of the logits of the language model: the dot products between the embeddings and the vectors at the last layer of the network.

To discuss this precisely, let’s get a bit more formal.

Language Model Setup

Let \(w_{1:T}\) be a sequence of word (identifiers) in \(\mathcal{V}\), where \(\mathcal{V} = \{1,\dots,n\}\) is a vocabulary. Let \(p_\theta(w_i\mid w_{1:i-1})\) be a neural language model parameterized by \(\theta\) and defined by

$$p_\theta(w_i \mid w_{1:i-1}) = \frac{\exp(h_{i-1}^{\top} e_{w_i})}{\sum_{j=1}^n \exp(h_{i-1}^\top e_j)},$$

where \(h_{i-1} = \phi_\theta(w_{1:i-1}) \in \mathbb{R}^d\) is the neural representation of the prefix, and \(e_i\in\mathbb{R}^d\) is the word embedding for word \(i\in\mathcal{V}\). The \(e_i\) are contained in \(\theta\).

Vocabulary Expansion Setup

In vocabulary expansion, we add a new word \(n+1 \not \in V\). This implies we need a new word embedding \(e_{n+1}\), which we sample from an initialization distribution. The new language model, which we denote \(p_{\theta'}(w_i \mid w_{1:i-1})\) has parameters \(\theta' = \theta \cup \{e_{n+1}\}\) and is defined by

$$p_{\theta'}(w_i \mid w_{1:i-1}) = \frac{\exp(h_{i-1}^{\top} e_{w_i})}{Z + \exp(h_{i-1}^\top e_{n+1})},$$

where \(Z = \sum_{j=1}^{n} \exp (h_{i-1}^\top e_{j})\) is the partition function of the pre-expansion LM.

What happens?

The probability of one of the pre-expansion words \(1,\dots, n\) can only go down post-expansion, since the new partition function can only be larger than the old one; in particular, the new partition function is larger by the logit of the new word: \(\exp(h_{i-1}^\top e_{n+1})\). The relationship between the pre-expansion and post-expansion probabilities is as follows:

$$p_{\theta'}(w_i \mid w_{1:i-1}) =p_{\theta}(w_i \mid w_{1:i-1}) * \frac{1}{1 + \frac{\exp(h_{i-1}^\top e_{n+1})}{Z}},$$

This equation can be interpreted as follows: the new probability of any word is the old probability of a word, multiplied by some factor less than 1. The factor is \(1/\big(1 + \frac{\exp(h_{i-1}^\top e_{n+1})}{Z}\big)\). Interpreting this factor, if \(Z\) is small relative to the new logit, then the probabilities of all pre-expansion words decrease a lot.

To summarize how different the post-expansion language model distribution is, we can use the KL-divergence (for a specific prefix):

$$\text{KL}(p_\theta(w\mid w_{1:i-1}) \| p_{\theta'}(w\mid w_{1:i-1})) = \log (1 + \frac{\exp(h_{i-1}^\top e_{n+1})}{Z})$$

Zero-init can cause problems

So why does this happen in practice? Huggingface defaults initialize to small random noise; here we’ll consider the simpler case of initializing new word embeddings to the zero vector.

So, \(e_{n+1} = \mathbf{0}\). The logit of the new word \(\exp(h_{i-1}^\top e_{n+1})\) is thus always exactly \(1\), since the dot product of \(\mathbf{0}\) with any vector is \(0\). Plugging this into the general form of the KL divergence, we get:

$$\text{KL}(p_\theta(w\mid w_{1:i-1}) \| p_{\theta'}(w\mid w_{1:i-1})) = \log (1 + \frac{1}{Z})$$

So, if \(Z\) is small relative to \(1\), we get large KL-divergence between the pre- and post-expansion LMs.

Why would this happen? If all \(h_{i-1}^\top e_i\) are large and negative. Because the softmax function is invariant to adding any scalar to all inputs, my best guess is that this happens sort of by accident of optimization. In general, this doesn’t happen for all (or even most) LMs, but when it does, zero-initialization is bad.

Intuition: Small-norm initialization behaves similarly to zero initialization when the norm gets very small.

What about averaging?

The general problem of determining an initialization that causes small KL-divergence is that we can’t make assumptions about the distribution of the neural activity \(h_{i-1}\).

Thankfully, there’s a simple way to ensure low KL-divergence: just average all of the existing embeddings to initialize the new embeddings:

$$ \begin{align} &\mu = \frac{1}{n} \sum_{i=1}^{n} e_i\\ &e_{n+1} =\mu \end{align} $$

If you want to add some randomness, e.g., because you’re adding a bunch of words at once (I don’t think adding random noise is necessary in this case, but regardless), here’s a simple way to do that.

$$ \begin{align} &\mu = \frac{1}{n} \sum_{i=1}^{n} e_i\\ &\Sigma = (E-\mu)^\top(E-\mu)/n\\ &e_{n+1} \sim \mathcal{N}(\mu, \Sigma)\\ \end{align} $$

where \(E\in\mathbb{R}^{n\times d}\) is the matrix of embeddings \([e_1; \dots; e_n]\). (And \(\mu\) is broadcast to \(\mathbb{R}^{n\times d}\) in the second line following the usual convention.)

I don’t have strong intuitions about why this particular noise distribution might be useful, beyond the fact that it focuses the direction of the noise where other embeddings point, and thus (unlike isotropic noise) maybe is less likely to accidentally fall in a particularly “bad” direction w.r.t. the neural activity.

Note also that the KL-divergence bounds we give below don’t hold when we add noise. In practice, this seems to be fine.

Averaging bounds the KL-divergence

Let’s take a look at how much probability is assigned to our new word \(n+1\) given that its embedding \(e_{n+1}\) is initialized to exactly the average of the other embeddings (that is, no noise is added.)

We start by looking at the logit:

$$ \begin{align} \exp(h_{i-1}^\top e_{n+1}) &= \exp(h_{i-1}^\top \frac{1}{n} \sum_{j=1}^n e_j)\\ &= \exp(\frac{1}{n} \sum_{j=1}^{n} h_{i-1}^\top e_j)\\ &\leq \frac{1}{n} \sum_{j=1}^{n} \exp(h_{i-1}^\top e_j)\\ &= Z/n \end{align} $$

where the inequality holds due to Jensen’s inequality. Intuitively, the new word’s contribution to the new partition function cannot be higher than \(\frac{1}{n}\). As the original vocabulary size grows, the new word’s probability decreases.

This immediately leads to a bound on the KL-divergence between the pre- and post-expansion LMs:

$$\text{KL}(p_\theta(w\mid w_{1:i-1}) \| p_{\theta'}(w\mid w_{1:i-1})) = \log (1 + \frac{1}{n})$$

Note that this holds for any LM, for any prefix you give it, as taken over the original vocabulary \(V\). Nice!

An aside: nice gradients at low probabilities

You might worry that, by assigning very low probabilities to the new words in context, we might inadvertantly harm learning by causing the gradients to be very large (since the cross-entropy loss explodes at low probabilities) or very small (since the gradient of the softmax vanishes to 0 at very low probabilities.) In fact, these two concerns perfectly cancel out; the gradient w.r.t. the embedding, for example, is:

$$ \frac{\partial \mathcal{L}(p_\theta, w_{1:t-1}, w_t)}{\partial e_{w_t}} = h_{i-1}^\top (1 - p_\theta(w_t \mid w_{1:t-1}))$$

That is, the norm of the graident as a function of the probability is very well-behaved: as the probability approaches 0, the norm approaches 1. So we needn’t worry about ensuring there’s some lower-bound on the probability assigned to the new word in any context.

Experiments!

A quick sanity check

Let’s go back to our running example. First, we instantiate a model and tokenizer, add our new tokens, and resize the embeddings.

tok = transformers.GPT2Tokenizer.from_pretrained('gpt2')
model = transformers.AutoModelForCausalLM.from_pretrained('gpt2')
tok.add_tokens(['Aragorn', 'Frodo', 'Lothlorien'])
model.resize_token_embeddings(len(tok))

Next, we compute the distribution from which we’ll sample our new embeddings:

import torch
params = model.state_dict()
embeddings = params['transformer.wte.weight']
pre_expansion_embeddings = embeddings[:-3,:]
mu = torch.mean(pre_expansion_embeddings, dim=0)
n = pre_expansion_embeddings.size()[0]
sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
dist = torch.distributions.multivariate_normal.MultivariateNormal(
        mu, covariance_matrix=1e-5*sigma)

We’ll load in our new embeddings into the model:

new_embeddings = torch.stack(tuple((dist.sample() for _ in range(3))), dim=0)
embeddings[-3:,:] = new_embeddings
params['transformer.wte.weight'][-3:,:] = new_embeddings
model.load_state_dict(params)

Finally, we sample from the model and observe that it does not just generate the new words we just added to the vocabulary.

sent2 = 'Dogs are great because they are '
tok.decode(model.generate(**tok(sent2, return_tensors='pt'), do_sample=True)[0])
"Dogs are great because they are icky and don't tend to get in the way much,"

That’s not my favorite sample, but it passes our check!

Taking a peek at KL-divergences

First, take a look at the distribution (across some text from the wikitext dataset) of the KL-divergence between pre- and post-expansion LMs for GPT2-small if one uses zero-initialization . That is, for every prefix in some wikitext text, I stored \(\log( 1+ 1/\sum_{j=1}^{n} \exp(h_{i-1}^\top e_j))\). I added a constant of \(1*10^{-30}\) to each partition function because some were so close to zero they caused the \(\log\) function to NaN. Here’s the histogram:

Generation before vocabulary augmentation (a well-formed sentence), and after vocabulary augmentation (consisting of only the newly added word.)

So, some prefixes’ distributions have close to \(0\) KL-divergence (the bar at \(0\)), but many have very large KL-divergence. (The bar between 60 and 70 corresponds to the small constant I added, so the true KL-divergence of those is even larger.)

One model that doesn’t exhibit the “only generate the new tokens” behavior is GPT2-large. Let’s take a look at the histogram of its token-level KL-divergences:

Generation before vocabulary augmentation (a well-formed sentence), and after vocabulary augmentation (consisting of only the newly added word.)

All the KL-divergences are around zero, which means we’d predict no problems with initializing using small random noise, as observed.

Finetuning for domain adaptation

In these experiments, I took 100 words that appear frequently in wikitext, but are split into multiple subword tokens by the GPT2 tokenizer. I added the words to the vocabulary of GPT2, and initialized their embeddings in three ways:

  • default: Use the huggingface default, which initializes to small random noise.
  • zeros: Initialize all new words’ embeddings to the zero vector.
  • avg_emb: Initialzie using the average of all existing embeddings + small noise.

I then finetuned each of four models: gpt2-small, gpt2-medium, gpt2-large, and EleutherAI’s 125M-parameter GPT model, for 500 gradient steps, reporting the validation perplexity on Wikitext every 20 gradient steps for each initialization method.

Here are the results:

Generation before vocabulary augmentation (a well-formed sentence), and after vocabulary augmentation (consisting of only the newly added word.) Generation before vocabulary augmentation (a well-formed sentence), and after vocabulary augmentation (consisting of only the newly added word.) Generation before vocabulary augmentation (a well-formed sentence), and after vocabulary augmentation (consisting of only the newly added word.) Generation before vocabulary augmentation (a well-formed sentence), and after vocabulary augmentation (consisting of only the newly added word.)


So, it looks like for the gpt2-small and gpt2-medium, and eleuther-125M models, averaging the embeddings leads to faster adaptation to the target domain, at least as measured by perplexity on text from that domain. For gpt2-large, which we predict shouldn’t have a problem, we see no finetuning difference between the methods.

Conclusion

A good initialization strategy for new word embeddings for a pretrained LM should enable strong adaptation to the downstream domain (or task). The current huggingface default has a subtle gotcha that can really break a pretrained LM, and lead to worse adaptation. I characterized why this is the case (and only for some LMs). I then show that averaging embeddings is a general-purpose solution.

I conclude that we should average the existing embeddings as the default initialization for new word embeddings for pretrained language models.

Code for reproducing my experiments can be found at https://github.com/john-hewitt/embed-init.

Citation

For attribution in academic contexts, please cite this work as:

@misc{hewitt2021initializing,
  author = {Hewitt, John},
  title = {Initializing New Word Embeddings for Pretrained Language Models},
  year = {2021},
  howpublished = \url{https:/nlp.stanford.edu/~johnhew//vocab-expansion.html},
}

Thanks to Sebastian Ruder (or his blog posts) for the idea of adding citation info to blog posts.

Footnotes

  1. Okay this is a little naive / hopeful – but completely destroying the pretrained distribution seems extreme. 

Join My Newsletter

Sign up to receive weekly updates.

x