In this blog post, I'll share my method for deploying fast inference services for Large Language Models (LLMs). We'll use NVIDIA's Triton Inference Server to deploy the Llama 3.1 8B Instruct model. The goal of this post is to help you understand the intricacies and attention to detail required for developing an effective inference solution. While using API services or one-click deployment solutions is common, such options often limit learning opportunities. This blog aims to bridge that gap by walking you through a hands-on approach.
Deploying LLMs involves numerous considerations. Different models have different use cases, and you may need to deploy fine-tuned models or replicate a model to ensure high throughput and reliable availability of your API service. These considerations generally fall into four main deployment categories:
In this post, we'll focus on option 1 for simplicity.
Another important aspect of LLM deployment is choosing the right backend. Options include PyTorch, Python, and TensorRT. For fast inference, I've selected TensorRT—specifically TensorRT-LLM—because it provides significant performance advantages that other backends lack.
Before starting, be sure to consult the Support Matrix to determine which models and GPUs are compatible with the TensorRT build. For instance, the NVIDIA A100 GPU, with 80 GB of memory, supports TensorRT-LLM models.
The Llama 3.1 8B parameter model, which requires around 16 GB of memory, can be loaded onto the A100 GPU without quantization. However, with its maximum context length of 128k tokens, it will need an additional 15.62 GB of GPU memory, totaling approximately 31 GB. You can refer to the inference chart for further memory requirements based on other configurations.
Get a VM with A100 GPU, with Ubuntu LTS Linux distro.
SSH into the system and update your packages.
ssh [server-config-name]
sudo apt update
sudo apt upgrade
Make sure you have Docker installed and the NVIDIA docker registry added. This is done so that inside a docker image, the NVIDIA GPU drivers are accessible.
Also, install miniconda as it will will streamline dependency management for this project.
Finally, create a dev
folder in your home directory to keep your scripts, configuration files, and other resources organized.
mkdir ~/dev
cd ~/dev
In this step, we will set up the environment to build a TensorRT-LLM model by installing the model weights, building the Docker image, converting weights to TensorRT format, and verifying the model with some basic tests.
First, download the model weights from Hugging Face. Since Meta requires you to sign an agreement, you must generate a Hugging Face token to authenticate the download. You can do this by logging into Hugging Face and going to Settings > Access Tokens.
pip install --upgrade huggingface_hub
huggingface-cli login --token *****
MODEL_DIR="meta-llama-3.1-8b-instruct"
REPO_ID="meta-llama/Meta-Llama-3.1-8B-Instruct"
huggingface-cli download $REPO_ID --local-dir $MODEL_DIR
Alternatively, you can use git-lfs
to download the model weights if you're familiar with Git Large File Storage (LFS), which is often used for managing large files in git repositories.
Clone the TensorRT-LLM repository from GitHub. The earliest release that supports the Llama3.1 model is v0.11.0
.
TENSOR_RT_VERSION="v0.11.0"
git clone -b $TENSOR_RT_VERSION https://github.com/NVIDIA/TensorRT-LLM.git
At this point, your ~/dev
folder should look like this:
tree -L 1
# ~/dev
# ├── TensorRT-LLM
# └── meta-llama-3.1-8b-instruct
Next, create a Dockerfile for building the model. This Dockerfile will include the necessary dependencies to support TensorRT-LLM.
touch Dockerfile.build_model
# ~/dev/Dockerfile.build_model
# NOTE: TensorRT-LLM v0.11.0
# NOTE: Image name=tensorrtllm_build_image
FROM nvidia/cuda:12.5.1-devel-ubuntu22.04
RUN apt update && apt upgrade -y && apt -y install python3.10 python3-pip python-is-python3 openmpi-bin libopenmpi-dev git git-lfs
RUN pip3 install tensorrt_llm==0.11.0 --extra-index-url https://pypi.nvidia.com
RUN pip install numpy==1.26.4 transformers==4.43.1 optimum==1.21.4
CMD ["echo","hi"]
Build the Docker image:
docker build -t tensorrtllm_build_image -f Dockerfile.build_model .
Then, run the image as a container with the following command:
docker run --rm --runtime=nvidia --gpus all --volume ${PWD}:/tensorrtllm --workdir /tensorrtllm --entrypoint /bin/bash -it tensorrtllm_build_image
--rm
: Automatically removes the container after it stops.--runtime=nvidia
and --gpus all
: Grants access to NVIDIA GPUs within the container.--volume ${PWD}:/tensorrtllm --workdir /tensorrtllm
: Mounts the current working directory (~/dev
) as a volume in the container to /tensorrtllm
, making files accessible within the container.--entrypoint /bin/bash -it
: Opens an interactive Bash session inside the container.To build the TensorRT-LLM model, you need to convert the Llama model into a TensorRT checkpoint. This checkpoint is used to build the engine that will do fast inferencing.
Inside the container, define the following directories for storing the TensorRT checkpoints and engine files:
# INSIDE DOCKER IMAGE= tesnorrtllm_build_image
# COPY $MODEL_DIR variable into docker image
CHECKPOINT_DIR="./llama3.1_tllm_checkpoint_1gpu_bf16"
ENGINE_DIR="./tmp/llama/3.1/8B/trt_engines/bf16/gpu/1"
Now we can conver the Llama model weights to a TensorRT checkpoint. Since Llama is trained on bfloat16
precision; thus, --dtype bfloat16
. We could use lower precision which would allow less memory to be consumed; however, reduce inferencing accuracy.
# Create Checkpoint Dir
python3 TensorRT-LLM/examples/llama/convert_checkpoint.py --model_dir $MODEL_DIR --output_dir $CHECKPOINT_DIR --dtype bfloat16
Build the mode with trtllm-build
command.
# Build From Checkpoint
trtllm-build --checkpoint_dir $CHECKPOINT_DIR --output_dir $ENGINE_DIR --max_seq_len 65000 --max_num_tokens 16384 --max_input_len 64000
Check out all configuration options for trtllm-build
here.
Now, verify the model accuracy by running a few tests. If the output is as expected (i.e. the Hugging Face model and the TensorRT model have similar outputs), the build was successful.
# RUN Examples
python3 TensorRT-LLM/examples/run.py --engine_dir=$ENGINE_DIR --max_output_len 100 --tokenizer_dir $MODEL_DIR --input_text "How do I count to nine in French?"
# Expect output to count to 9.
To test the Llama3.1-specific prompt template:
# LLAMA3.1 specific prompt
python3 TensorRT-LLM/examples/run.py --engine_dir=$ENGINE_DIR --max_output_len 1000 --tokenizer_dir $MODEL_DIR --input_text "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nTools: brave_search, wolfram_alpha\n\nCutting Knowledge Date: December 2023\nToday Date: 23 Jul 2024\n\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nWhat is the current weather in Toronto, Ontario?<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
# Expect assistant to use `brave_search` tool to look up the weather in toronto ontario.
Exercise: test the model's context length by creating a prompt in sample_long_context_file.txt
and run:
# RUN FROM FILE
python3 TensorRT-LLM/examples/run.py --engine_dir=$ENGINE_DIR --max_output_len 2000 --tokenizer_dir $MODEL_DIR --input_text "$(cat sample_long_context_file.txt)"
After successful tests, you can remove the checkpoint directory, exit the container, and move to the next step.
rm -r $CHECKPOINT_DIR
exit
Install the tensorrtllm_backend
codebase.
git clone -b $TENSOR_RT_VERSION https://github.com/triton-inference-server/tensorrtllm_backend.git
Note: The NVIDIA/TensorRT-LLM
GitHub branch and triton-inference-server/tensorrtllm_backend
GitHub branch version must match.
The current status of the ~/dev
folder should be the following:
tree -L 1
# ~/dev
# ├── Dockerfile.build_model
# ├── TensorRT-LLM
# ├── meta-llama-3.1-8b-instruct
# └── tensorrtllm_backend
In this sub step, we will download config files that are specific to tensorrt-llm
that allow it to work with the triton server. This will require us to download config files, python scripts and configure them to our needs.
In the ~/dev
folder, make the config.
mkdir -p all_models/inflight_batcher_llm/
Create the following variables:
SUFFIX="llama3.1"
CONFIG="https://raw.githubusercontent.com/triton-inference-server/tensorrtllm_backend/$TENSOR_RT_VERSION/all_models/inflight_batcher_llm"
TRITON_ENGINE_DIR="all_models/inflight_batcher_llm/tensorrt_llm_$SUFFIX/1/"
SUFFIX
: This variable is used to distinguish the different types of models in case we decide to deploy other models.CONFIG
: Used to download the specific files needed for the Triton server. It depends on the TensorRT version. Alternatively, you can copy the files from the tensorrtllm_backend
GitHub repo cloned in step 2.0.TRITON_ENGINE_DIR
: Location where the model will be moved to.# paste $ENGINE_DIR, $MODEL_DIR path outside docker dir
cp $ENGINE_DIR/* $TRITON_ENGINE_DIR/1
For the Llama3.1 model we will use an ensemble method to do inferencing. This mean that the ensemble backend will be called it will orchestrate the order of operations for model inferencing. This includes
# --------------
# preprocessing
# --------------
wget -P all_models/inflight_batcher_llm/preprocessing_$SUFFIX $CONFIG/preprocessing/config.pbtxt
wget -P all_models/inflight_batcher_llm/preprocessing_$SUFFIX/1 $CONFIG/preprocessing/1/model.py
python3 tensorrtllm_backend/tools/fill_template.py -i all_models/inflight_batcher_llm/preprocessing/config.pbtxt tokenizer_dir:${MODEL_DIR},tokenizer_type:auto,triton_max_batch_size:64,preprocessing_instance_count:1
sed -i 's/name: "preprocessing"/name: "preprocessing_${SUFFIX}"/1' all_models/inflight_batcher_llm/preprocessing_$SUFFIX/config.pbtxt
In the script above:
config.pbtxt
and model.py
files are downloaded.name
argument in the config.pbtxt
file.Post processing and ensemble are done with the script below:
# --------------
# postprocessing
# --------------
wget -P all_models/inflight_batcher_llm/postprocessing_$SUFFIX $CONFIG/postprocessing/config.pbtxt
sed -i 's/name: "postprocessing"/name: "postprocessing_${SUFFIX}"/1' all_models/inflight_batcher_llm/postprocessing_$SUFFIX/config.pbtxt
python3 tensorrtllm_backend/tools/fill_template.py -i all_models/inflight_batcher_llm/postprocessing_$SUFFIX/config.pbtxt tokenizer_dir:${MODEL_DIR},SUFFIX:${SUFFIX},tokenizer_type:auto,triton_max_batch_size:64,postprocessing_instance_count:1
wget -P all_models/inflight_batcher_llm/postprocessing_$SUFFIX/1 $CONFIG/postprocessing/1/model.py
# --------------
# ensemble
# --------------
wget -P all_models/inflight_batcher_llm/ensemble_$SUFFIX $CONFIG/ensemble/config.pbtxt
sed -i 's/name: "ensemble"/name: "ensemble_${SUFFIX}"/1' all_models/inflight_batcher_llm/ensemble_$SUFFIX/config.pbtxt
sed -i 's/model_name: "preprocessing"/model_name: "preprocessing_${SUFFIX}"/1' all_models/inflight_batcher_llm/ensemble_$SUFFIX/config.pbtxt
sed -i 's/model_name: "tensorrt_llm"/model_name: "tensorrt_llm_${SUFFIX}"/1' all_models/inflight_batcher_llm/ensemble_$SUFFIX/config.pbtxt
sed -i 's/model_name: "postprocessing"/model_name: "postprocessing_${SUFFIX}"/1' all_models/inflight_batcher_llm/ensemble_$SUFFIX/config.pbtxt
mkdir all_models/inflight_batcher_llm/ensemble_$SUFFIX/1
python3 tensorrtllm_backend/tools/fill_template.py -i all_models/inflight_batcher_llm/ensemble_$SUFFIX/config.pbtxt triton_max_batch_size:64,SUFFIX:$SUFFIX
Below is the tensorrt_llm
script. This is fairly involved and will change depending on the model you use as well as the deployment architecture.
# --------------
# tensorrt_llm
# --------------
wget -P all_models/inflight_batcher_llm/tensorrt_llm_$SUFFIX $CONFIG/tensorrt_llm/config.pbtxt
sed -i 's/name: "tensorrt_llm"/name: "tensorrt_llm_${SUFFIX}"/1' all_models/inflight_batcher_llm/tensorrt_llm_$SUFFIX/config.pbtxt
python3 tensorrtllm_backend/tools/fill_template.py -i all_models/inflight_batcher_llm/tensorrt_llm_$SUFFIX/config.pbtxt SUFFIX:$SUFFIX,triton_max_batch_size:64,decoupled_mode:True,max_beam_width:1,engine_dir:$TRITON_ENGINE_DIR,max_tokens_in_paged_kv_cache:unspecified,max_attention_window_size:32000,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,triton_backend:tensorrtllm,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0,gpu_device_ids:0,max_utilization:max_utilization
Note for streaming set decoupled
to true in the compiled model config.
sudo docker run -it --rm --gpus all --network host --shm-size=1g -v $(pwd):/triton_server --workdir /triton_server nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3
# In docker container
python3 tensorrtllm_backend/scripts/launch_triton_server.py --model_repo all_models/inflight_batcher_llm --world_size 1 --multi-model
-it
FlagThe launch_triton_server.py
script merely calls the tritonserver
cli. You can run the server without the launch script as such:
MODEL_REPOSITORY=all_models/inflight_batcher_llm
docker run --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 --name triton-server-trtllm --gpus all --shm-size=2g -v $(pwd):/triton_server --workdir /triton_server nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3 mpirun --allow-run-as-root -n 1 /opt/tritonserver/bin/tritonserver --model-repository=/triton_server/$MODEL_REPOSITORY --grpc-port=8001 --http-port=8000 --metrics-port=8002 --disable-auto-complete-config --backend-config=python,shm-region-prefix-name=prefix0_ :
MODEL_REPOSITORY=all_models/inflight_batcher_llm
TRITON_LOG_FILE_NAME='logs/triton-image-24-07'
docker run -d --restart on-failure:2 -p 8000:8000 -p 8001:8001 -p 8002:8002 --name triton-server-trtllm --gpus all --shm-size=2g -v $(pwd):/triton_server --workdir /triton_server nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3 mpirun --allow-run-as-root --output-filename mpirun_triton_server_logs -n 1 /opt/tritonserver/bin/tritonserver --model-repository=/triton_server/$MODEL_REPOSITORY --log-verbose=2 --log-file="/triton_server/$TRITON_LOG_FILE_NAME-datetime-$(date '+%Y-%m-%d-%H-%M-%S').log" --grpc-port=8001 --http-port=8000 --metrics-port=8002 --disable-auto-complete-config --backend-config=python,shm-region-prefix-name=prefix0_ :
--alow-run-as-root
is dangerous and should not be used for production.--restart on-failure:2
will restart the container at most 2 times if it fails.-d
flag will make sure it is running in the background-p
will expose ports 8000
, 8001
, 8002
necessary for Triton.mpi
process with 1 Triton server instance will run.logs
directory and run models in the $MODEL_REPOSITORY
folderdocker container rm triton-server-trtllm
Make sure the server is running:
curl http://localhost:8000/
# {"error":"Not Found"}
Make sure the model is ready:
curl http://localhost:8000/v2/models/ensemble_llama3.1/ready
# expect status code 200, no output
curl -X POST http://localhost:8000/v2/models/ensemble_llama3.1/generate -d '{"text_input": "How do I count to nine in French?","max_tokens":100}'
# Status Code: 200
# Json Output
# {"context_logits":0.0,"cum_log_probs":0.0,"generation_logits":0.0,"model_name":"ensemble_llama3.1","model_version":"1","output_log_probs":[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0],"sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":" Compter jusqu'à neuf en français.\nTo count to nine in French, you can use the following numbers:\nUn (one)\nDeux (two)\nTrois (three)\nQuatre (four)\nCinq (five)\nSix (six)\nSept (seven)\nHuit (eight)\nNeuf (nine)\nYou can count up to nine by saying:\nUn, deux, trois, quatre, cinq, six, sept, huit, neuf.\nIf you want to count"}
Now that we know this work as expected, lets create automated tests for this API service.
import httpx
# pip install httpx
# Config
base_url = "http://localhost:8000"
headers = {"Content-Type": "application/json"}
model_name = "ensemble_llama3.1"
# Test 1
def test1():
print("-" * 20, "test1", "-" * 20)
data = {
"text_input": "How do I count to nine in French?",
"parameters": {"max_tokens": 100, "bad_words": [""], "stop_words": [""]},
}
url = f"{base_url}/v2/models/{model_name}/generate"
resp = httpx.post(url, headers=headers, json=data)
if resp.status_code == 200:
jd = resp.json()
print(jd["text_output"])
else:
print(f"status_code={resp.status_code}: {resp.text}")
# Test 2
def test2():
print("-" * 20, "test2", "-" * 20)
query = "What is the current weather in Toronto, Ontario?"
print(f"User> {query}")
initial_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Tools: brave_search, wolfram_alpha
Cutting Knowledge Date: December 2023
Today Date: 23 Jul 2024
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
{query}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>"""
data = {
"text_input": initial_prompt,
"parameters": {
"max_tokens": 1_000,
"bad_words": [""],
"stop_words": ["<|eom_id|>", "<|eot_id|>"],
},
}
url = f"{base_url}/v2/models/{model_name}/generate"
resp = httpx.post(url, headers=headers, json=data)
if resp.status_code == 200:
jd = resp.json()
assistant_response = jd["text_output"].strip()
print(f"Assistant> {assistant_response}")
if "brave_search" in assistant_response:
tool_call_output = "27 deg celcius"
subsequent_prompt = f"""{initial_prompt}
<|python_tag|>{assistant_response}<|eom_id|>
<|start_header_id|>ipython<|end_header_id|>
{tool_call_output}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
data["text_input"] = subsequent_prompt
resp = httpx.post(url, headers=headers, json=data)
if resp.status_code == 200:
jd = resp.json()
print(f'Assistant> {jd["text_output"]}')
else:
print('No tool call')
else:
print(f"status_code={resp.status_code}: {resp.text}")
if __name__ == "__main__":
test1()
test2()
print("[DONE] - ran all tests")
Exercise: Write a python script that allows you to interact with the llama3.1 model. Use the token generation pipeline. You can use the /generate_stream
endpoint instead of /generate
.
In this blog post, we walked through the process of deploying a fast inference service for Large Language Models (LLMs) using NVIDIA's Triton Inference Server. We used the Llama 3.1 8B Instruct model as an example, but the steps can be applied to other models as well. We covered the prerequisites, including choosing the right backend and model, and then delved into the hands-on process of building the TensorRT model, setting up the Triton Server, and testing the service.
This blog post is useful for anyone who wants to deploy LLMs in a production environment, whether it's for research, development, or commercial use. By following the steps outlined in this post, you can create a fast and efficient inference service that can handle a large volume of requests.
If you're interested in trying out the concepts discussed in this post, here are some additional exercises: