Faster LLM Inference
Can you optimize the inference time of your LLM? How?
Join the AI BootCamp!
Ready to dive into the world of AI and Machine Learning? Join the AI BootCamp to transform your career with the latest skills and hands-on project experience. Learn about LLMs, ML best practices, and much more!
Using a general LLM or fine-tuned model (with LoRA or other techniques) for inference is typically the last step in your AI project. However, if the token generation speed is slow, users may not even give your product a chance. Let's explore strategies to enhance the inference speed of your LLM.
In this part, we will be using Jupyter Notebook to run the code. If you prefer to follow along, you can find the notebook on GitHub: GitHub Repository (opens in a new tab)
Setup
The tests are conducted on a Google Colab instance equipped with a T4 GPU (16 GB VRAM) and 26 GB RAM. We'll use this prompt for most examples:
prompt = f"""
<human>: What is your return policy?
<assistant>:
""".strip()
We'll use the fine-tuned Falcon 7b model with QLoRA from our previous tutorial1.
Tokenizer
Let's load the tokenizer from the original model:
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
And encode the prompt:
DEVICE = "cuda:0"
encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE)
Generation Config
We'll use the same generation config for all runs:
generation_config = model.generation_config
generation_config.max_new_tokens = 20
generation_config.temperature = 0
generation_config.num_return_sequences = 1
generation_config.pad_token_id = tokenizer.eos_token_id
generation_config.eos_token_id = tokenizer.eos_token_id
We'll keep max_new_tokens=20
, as it significantly affects the inference time
based on the number of tokens generated.
Generating Output
We'll measure generate the text generation speed using the following code:
%%timeit -r 5
with torch.inference_mode():
outputs = model.generate(
input_ids=encoding.input_ids,
attention_mask=encoding.attention_mask,
generation_config=generation_config,
do_sample=False,
use_cache=True,
)
The timeit
2 magic command will run the code 5 times and report the average
time.
With Training Config
For our initial attempt, let's utilize the training configuration as a baseline:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
MODEL_ID = "curiousily/falcon-7b-qlora-chat-support-bot-faq"
config = PeftConfig.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(model, MODEL_ID)
Here's the result (consistent with the training tutorial):
13 s ± 222 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
<human>: What is your return policy? <assistant>: Our return policy allows you
to return eligible items within 30 days of purchase for a full refund.
Pretty slow, right? Let's see how we can improve it.
Loading in 4 Bit
Since we performed fine-tuning using the QLoRA technique, we can now attempt to load the model in 4-bit format directly using the transformers library:
MODEL_ID = "curiousily/falcon-7b-qlora-chat-support-bot-faq"
config = PeftConfig.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
device_map="auto",
trust_remote_code=True,
load_in_4bit=True,
)
model = PeftModel.from_pretrained(model, MODEL_ID)
Note that we removed the bitsandbytes config. Here's the result:
9.18 s ± 47.8 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
<human>: What is your return policy? <assistant>: Our return policy allows you
to return most items within 30 days of purchase. Please see our Returns
Almost a 4 sec improvement. Not bad! Note that the output has changed somewhat, though.
Loading in 8 bit
A natural next step is to load the model in 8bit, even though 4bit should be faster:
MODEL_ID = "curiousily/falcon-7b-qlora-chat-support-bot-faq"
config = PeftConfig.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
device_map="auto",
trust_remote_code=True,
load_in_8bit=True,
)
model = PeftModel.from_pretrained(model, MODEL_ID)
3.12 s ± 61.4 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
<human>: What is your return policy? <assistant>: Our return policy allows you
to return most items within 30 days of purchase. Please refer to our
Now that's a significant improvement! We're almost 3 times faster than the 4bit version. But why?
My best guess is that the current optimizations in the HuggingFace transformers library work really well with 8bit. I'm sure we'll see improvements in the future.
torch.compile()
Another popular way to improve inference speed is to use the torch.compile()
method. Let's see how it works:
model = torch.compile(model)
3.23 s ± 162 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
<human>: What is your return policy? <assistant>: Our return policy allows you
to return most items within 30 days of purchase. Please refer to our
We're slightly slower than the 8bit version. Maybe you can improve it by
compiling the actual LLM instead of the PeftModel
wrapper.
Batch Inference
The last approach to improve inference speed with our fine-tuned Falcon 7b model is to utilize batch inference3, where we process multiple prompts simultaneously. Let's begin by tokenizing these prompts:
prompts = [
f"""
<human>: What is your return policy?
<assistant>:
""".strip(),
f"""
<human>: How can I create an account?
<assistant>:
""".strip(),
f"""
<human>: What payment methods do you accept?
<assistant>:
""".strip(),
f"""
<human>: What happens when I return a clearance item?
<assistant>:
""".strip(),
]
encoding = tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(
DEVICE
)
Let's look at the results:
3.15 s ± 73.5 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
for output in outputs:
print(tokenizer.decode(output, skip_special_tokens=True))
<human>: What is your return policy? <assistant>: Our return policy allows you
to return most items within 30 days of purchase for a full refund.
<human>: How can I create an account? <assistant>: To create an account, click
on the 'Sign Up' button at the top right corner of
<human>: What payment methods do you accept? <assistant>: We accept credit card
payments through our secure payment gateway. We also accept payments through
PayPal and Google Pay
<human>: What happens when I return a clearance item? <assistant>: If you return
a clearance item, you will receive a refund for the item's original price minus
Wow, the inference time is almost the same as the fastest version. But we're getting 4 answers instead of 1. That's a huge improvement!
Batch inference is a great way to improve inference speed, but it's not always possible. For example, if you're building a chatbot, you can't wait for the user to type 4 questions before you respond. In this case, you can use a combination of techniques such as caching and pre-computing the most common answers.
Lit-Parrot
If you're looking for an alternative way to enhance the inference time of your LLM, you can consider using a pre-made library such as Lit-Parrot4. This library offers ready-to-use implementations of LLMs based on nanoGPT. Let's see how it works:
!git clone https://github.com/Lightning-AI/lit-parrot
%cd lit-parrot/
!git checkout a9dbfdf
Before we begin, make sure to install the necessary requirements, including a nightly build of PyTorch:
!pip install --index-url https://download.pytorch.org/whl/nightly/cu118 --pre 'torch>=2.1.0dev' -qqq --progress-bar off
!pip install huggingface_hub -r requirements.txt -qqq --progress-bar off
Next, we'll download the Falcon 7b model and convert it to the Lit-Parrot format (this is where the magic happens):
!python scripts/download.py --repo_id tiiuae/falcon-7b
!python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/tiiuae/falcon-7b
Let's try to recreate the generation config from the previous examples and generate an output:
!python generate/base.py \
--prompt "How can I create an account?" \
--checkpoint_dir checkpoints/tiiuae/falcon-7b \
--quantize llm.int8 \
--max_new_tokens 20
Time to instantiate model: 3.26 seconds. Time to load the model weights: 150.54
seconds. Global seed set to 1234 How can I create an account? What is the My
Invoices area? What are the subscription plans? How does Time for inference 1:
3.60 sec total, 5.56 tokens/sec Memory used: 9.72 GB
Not that fast, right? This might be due to a lot of reasons, but you'll have to dive deeper into the library to understand what optimizations work for your use case.
Anyway, the library is an interesting choice if you want to fine-tune an LLM on your own dataset and don't want to deal with the low-level details.
Join the The State of AI Newsletter
Every week, receive a curated collection of cutting-edge AI developments, practical tutorials, and analysis, empowering you to stay ahead in the rapidly evolving field of AI.
I won't send you any spam, ever!