Back to Blog

Faster LLM Inference

Can you optimize the inference time of your LLM? How? Let's explore strategies to enhance the inference speed of your LLM.

June 11, 2023
6 min read
Faster LLM Inference

Can you optimize the inference time of your LLM? How?

Using a general LLM or fine-tuned model (with LoRA or other techniques) for inference is typically the last step in your AI project. However, if the token generation speed is slow, users may not even give your product a chance. Let's explore strategies to enhance the inference speed of your LLM.

Setup

The tests are conducted on a Google Colab instance equipped with a T4 GPU (16 GB VRAM) and 26 GB RAM. We'll use this prompt for most examples:

py code
prompt = f"""
<human>: What is your return policy?
<assistant>:
""".strip()

We'll use the fine-tuned Falcon 7b model with QLoRA from our previous tutorial1.

Tokenizer

Let's load the tokenizer from the original model:

py code
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

And encode the prompt:

py code
DEVICE = "cuda:0"
 
encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE)

Generation Config

We'll use the same generation config for all runs:

py code
generation_config = model.generation_config
generation_config.max_new_tokens = 20
generation_config.temperature = 0
generation_config.num_return_sequences = 1
generation_config.pad_token_id = tokenizer.eos_token_id
generation_config.eos_token_id = tokenizer.eos_token_id

We'll keep max_new_tokens=20, as it significantly affects the inference time based on the number of tokens generated.

Generating Output

We'll measure generate the text generation speed using the following code:

py code
%%timeit -r 5
 
with torch.inference_mode():
    outputs = model.generate(
        input_ids=encoding.input_ids,
        attention_mask=encoding.attention_mask,
        generation_config=generation_config,
        do_sample=False,
        use_cache=True,
    )

The timeit2 magic command will run the code 5 times and report the average time.

With Training Config

For our initial attempt, let's utilize the training configuration as a baseline:

py code
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
 
MODEL_ID = "curiousily/falcon-7b-qlora-chat-support-bot-faq"
config = PeftConfig.from_pretrained(MODEL_ID)
 
model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    return_dict=True,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)
 
model = PeftModel.from_pretrained(model, MODEL_ID)

Here's the result (consistent with the training tutorial):

13 s ± 222 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)

py code
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
markdown code
<human>: What is your return policy? <assistant>: Our return policy allows you
to return eligible items within 30 days of purchase for a full refund.

Pretty slow, right? Let's see how we can improve it.

Loading in 4 Bit

Since we performed fine-tuning using the QLoRA technique, we can now attempt to load the model in 4-bit format directly using the transformers library:

py code
MODEL_ID = "curiousily/falcon-7b-qlora-chat-support-bot-faq"
config = PeftConfig.from_pretrained(MODEL_ID)
 
model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    return_dict=True,
    device_map="auto",
    trust_remote_code=True,
    load_in_4bit=True,
)
 
model = PeftModel.from_pretrained(model, MODEL_ID)

Note that we removed the bitsandbytes config. Here's the result:

9.18 s ± 47.8 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)

py code
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
markdown code
<human>: What is your return policy? <assistant>: Our return policy allows you
to return most items within 30 days of purchase. Please see our Returns

Almost a 4 sec improvement. Not bad! Note that the output has changed somewhat, though.

Loading in 8 bit

A natural next step is to load the model in 8bit, even though 4bit should be faster:

py code
MODEL_ID = "curiousily/falcon-7b-qlora-chat-support-bot-faq"
config = PeftConfig.from_pretrained(MODEL_ID)
 
model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    return_dict=True,
    device_map="auto",
    trust_remote_code=True,
    load_in_8bit=True,
)
 
model = PeftModel.from_pretrained(model, MODEL_ID)

3.12 s ± 61.4 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)

py code
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
markdown code
<human>: What is your return policy? <assistant>: Our return policy allows you
to return most items within 30 days of purchase. Please refer to our

Now that's a significant improvement! We're almost 3 times faster than the 4bit version. But why?

My best guess is that the current optimizations in the HuggingFace transformers library work really well with 8bit. I'm sure we'll see improvements in the future.

torch.compile()

Another popular way to improve inference speed is to use the torch.compile() method. Let's see how it works:

py code
model = torch.compile(model)

3.23 s ± 162 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)

py code
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
markdown code
<human>: What is your return policy? <assistant>: Our return policy allows you
to return most items within 30 days of purchase. Please refer to our

We're slightly slower than the 8bit version. Maybe you can improve it by compiling the actual LLM instead of the PeftModel wrapper.

Batch Inference

The last approach to improve inference speed with our fine-tuned Falcon 7b model is to utilize batch inference3, where we process multiple prompts simultaneously. Let's begin by tokenizing these prompts:

py code
prompts = [
    f"""
<human>: What is your return policy?
<assistant>:
""".strip(),
    f"""
<human>: How can I create an account?
<assistant>:
""".strip(),
    f"""
<human>: What payment methods do you accept?
<assistant>:
""".strip(),
    f"""
<human>: What happens when I return a clearance item?
<assistant>:
""".strip(),
]
 
encoding = tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(
    DEVICE
)

Let's look at the results:

3.15 s ± 73.5 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)

py code
for output in outputs:
    print(tokenizer.decode(output, skip_special_tokens=True))
markdown code
<human>: What is your return policy? <assistant>: Our return policy allows you
to return most items within 30 days of purchase for a full refund.
 
<human>: How can I create an account? <assistant>: To create an account, click
on the 'Sign Up' button at the top right corner of
 
<human>: What payment methods do you accept? <assistant>: We accept credit card
payments through our secure payment gateway. We also accept payments through
PayPal and Google Pay
 
<human>: What happens when I return a clearance item? <assistant>: If you return
a clearance item, you will receive a refund for the item's original price minus

Wow, the inference time is almost the same as the fastest version. But we're getting 4 answers instead of 1. That's a huge improvement!

Batch inference is a great way to improve inference speed, but it's not always possible. For example, if you're building a chatbot, you can't wait for the user to type 4 questions before you respond. In this case, you can use a combination of techniques such as caching and pre-computing the most common answers.

Lit-Parrot

If you're looking for an alternative way to enhance the inference time of your LLM, you can consider using a pre-made library such as Lit-Parrot4. This library offers ready-to-use implementations of LLMs based on nanoGPT. Let's see how it works:

py code
!git clone https://github.com/Lightning-AI/lit-parrot
%cd lit-parrot/
!git checkout a9dbfdf

Before we begin, make sure to install the necessary requirements, including a nightly build of PyTorch:

py code
!pip install --index-url https://download.pytorch.org/whl/nightly/cu118 --pre 'torch>=2.1.0dev' -qqq --progress-bar off
!pip install huggingface_hub -r requirements.txt -qqq --progress-bar off

Next, we'll download the Falcon 7b model and convert it to the Lit-Parrot format (this is where the magic happens):

py code
!python scripts/download.py --repo_id tiiuae/falcon-7b
!python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/tiiuae/falcon-7b

Let's try to recreate the generation config from the previous examples and generate an output:

py code
!python generate/base.py \
        --prompt "How can I create an account?" \
        --checkpoint_dir checkpoints/tiiuae/falcon-7b \
        --quantize llm.int8 \
        --max_new_tokens 20
markdown code
Time to instantiate model: 3.26 seconds. Time to load the model weights: 150.54
seconds. Global seed set to 1234 How can I create an account? What is the My
Invoices area? What are the subscription plans? How does Time for inference 1:
3.60 sec total, 5.56 tokens/sec Memory used: 9.72 GB

Not that fast, right? This might be due to a lot of reasons, but you'll have to dive deeper into the library to understand what optimizations work for your use case.

Anyway, the library is an interesting choice if you want to fine-tune an LLM on your own dataset and don't want to deal with the low-level details.

References

Footnotes

  1. Falcon 7b fine-tuning on custom dataset tutorial

  2. timeit — Measure execution time of small code snippets

  3. Batch inference

  4. Lit-Parrot