Generate Text With A Trained Model

Once you’ve completed training, you can use your model to generate text.

In this tutorial we’ll walk through getting 🤗 Transformers et up and generating text with a trained GPT-2 Small model.

Set Up Hugging Face

Hugging Face’s transformers repo provides a helpful script for generating text with a GPT-2 model.

To access these scripts, clone the repo

git clone https://github.com/huggingface/transformers.git

Run run_generation.py With Your Model

As your model training runs, it should save checkpoints with all of the model resources in the directory you specified with articfacts.run_dir in the conf/tutorial-gpt2-micro.yaml config file.

For this example, lets assume you have saved the checkpoints in /home/tutorial-gpt2-micro/runs/run-1. If you trained for 400000 steps, you should have a corresponding checkpoint at /home/tutorial-gpt2-micro/runs/run-1/checkpoint-400000. This directory contains all the resources for your model, with files such as pytorch_model.bin containing the actual model and vocab.json which maps word pieces to their indices among others.

To run text generation, issue the following command:

conda activate mistral
cd transformers/examples/text-generation
python run_generation.py --model_type=gpt2 --model_name_or_path=/home/tutorial-gpt2-micro/runs/run-1/checkpoint-400000

This will create the following output requesting a text prompt.

06/28/2021 03:16:16 - WARNING - __main__ - device: cuda, n_gpu: 1, 16-bits training: False
06/28/2021 03:16:26 - INFO - __main__ - Namespace(device=device(type='cuda'), fp16=False, k=0, length=20, model_name_or_path='hello-world/runs/run-1/checkpoint-400000', model_type='gpt2', n_gpu=1, no_cuda=False, num_return_sequences=1, p=0.9, padding_text='', prefix='', prompt='', repetition_penalty=1.0, seed=42, stop_token=None, temperature=1.0, xlm_language='')
Model prompt >>>

Enter an example prompt, and the script will generate a text completion for you using your model!

Model prompt >>> Hello world. This is a prompt.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
=== GENERATED SEQUENCE 1 ===
Hello world. This is a prompt. This is no ‘say what, say it’ stuff, it’s all on