Transformers FLOPS and memory usage

Understanding memory usage and the number of floating point operations required to run inference using Transformers

Introduction

Everyday we hear of new AI models that are larger and better than ever before. As of writing this post, the state-of-the-art models have hundreds of billions of parametersDeepSeek R1 has about 671 billion parameters. How much GPU memory is needed to run this model? How long will inference take on a given GPU? In this post, I will walk through an example to understand how to calculate memory usage and estimate the time taken to run inference. This post assumes familiarity with the transformerhttps://arxiv.org/pdf/1706.03762 architecture.

In the first part of this post, I count all the floating point operations that happen during the forward pass. Why is this important? Every GPU has a maximum number of floating point operations it can perform in a second (FLOPS) and is specified in the GPU specs. If we can calculate the number of floating point operations that our model needs to perform for a single token, we can calculate the time taken for the forward pass.Assuming that we have a kv cache, I’m calculating the FLOP for a single new token. More details about this later in the post.

In the second part of this post, I calculate the GPU memory usage given a model. This includes the memory needed to store the model parameters and kv cache.

Floating point operations

Given a transformer model and a GPU, can we approximately calculate the inference time without running any code? We can calculate it using this simple formula

24.nlayers.dmodel2+2.dvocab.dmodel224.n_{layers}.d_{model}^2 + 2.d_{vocab}.d_{model}^2

where

nlayers:Number of transformer blocksdmodel:Model Embedding dimensiondvocab:Total number of tokens in the vocabulary\begin{array}{c|c} n_{layers}: & \text{Number of transformer blocks} \\ d_{model}: & \text{Model Embedding dimension} \\ d_{vocab}: &\text{Total number of tokens in the vocabulary} \end{array}

In the next few sub-sections, I go through in detail as to how I arrived at the above equation. Feel free to skip to the memory section if you aren’t particularly interested in the details.

A GPU datasheet specifies the theoretical maximum number of floating point operations it can perform in a second. Using the transformer model architecture, let’s calculate the number of floating point operations (FLOP) per token

FLOP for matrix multiplication

Let’s consider an example for matrix multiplication Lecture Notes on FLOPS for basic operations

A=[a00a01a02a10a11a12]B=[b00b01b02b03b10b11b12b13b20b21b22b23]A = \begin{bmatrix} a_{00} & a_{01} & a_{02} \\ a_{10} & a_{11} & a_{12} \end{bmatrix} \quad B = \begin{bmatrix} b_{00} & b_{01} & b_{02} & b_{03} \\ b_{10} & b_{11} & b_{12} & b_{13} \\ b_{20} & b_{21} & b_{22} & b_{23} \end{bmatrix}

A is a 2x32x3 matrix and B is a 3x43x4 matrix. The product of A and B will have 2x42x4 elements. In the generalized case, when a matrix of dimensions mxnmxn is multiplied with a matrix of dimensions nxpnxp, the product is a mxpmxp dimensional matrix.

Calculating the first element of the AxBAxB,

a00.b00+a01.b10+a02.b20a_{00}.b_{00} + a_{01}.b_{10} + a_{02}.b_{20}

Notice that there are 3 multiplication operations and 2 addition operations. In the general case, there are nn additions and n1n-1 multiplications to calculate one element in the product matrix.

n+(n1)2nn + (n-1) \approx 2n

Since there are 8 elements in AxBAxB and mxpmxp elements in the general case, the total number of floating point operations to calculate the product of two matrices is

(1)(2n)mp=2mnp \approx (2n)*m*p = 2mnp \tag{1}

FLOP in a transformer block

A transformer block consists of two sub-blocks, a multi-headed attention block and a feed forward block which are shown in the picture below.

Let’s calculate the number of floating point operations in each of the steps in the transformer block.

te:Token embedding(R1xdmodel)dmodel:Model embedding dimensiondk:Key,Query and Value matrix dimension per head in the multi-headed attention blockWQ,WK,WV:Query, Key and Value matrices(Rdmodelxdk)nheads:Number of heads in the multi-headed attention blocknlayers:Number of transformer blocks in the modeldvocab:Vocabulary size for tokenization\begin{array} {l l} t_{e}: &\text{Token embedding} \quad (\mathbb{R}^{1xd_{model}})\\ d_{model}: &\text{Model embedding dimension} \\ d_{k}: &\text{Key,Query and Value matrix dimension per head in the multi-headed attention block} \\ W_{Q},W_{K},W_{V}: &\text{Query, Key and Value matrices} \quad (\mathbb{R}^{d_{model}xd_{k}})\\ n_{heads} : &\text{Number of heads in the multi-headed attention block} \\ n_{layers}: &\text{Number of transformer blocks in the model} \\ d_{vocab}: &\text{Vocabulary size for tokenization} \end{array}

Note that typically

dk=dmodelnheadsd_{k} = \frac{d_{model}}{n_{heads}}

Transformer Architecture Image Courtesy: Attention is all you need

For a single token, the first step is calculating the keys, queries and values for all the attention head.

q=WQteTq = W_{Q}t_{e}^{T}

The number of floating point operations to calculate q using (1) is

(2)2dmodeldk=2dmodel2nhead2d_{model}d_{k} = \frac{2d_{model}^2}{n_{head}} \tag{2}

The same number of FLOP are required to calculate k and v. Using (2), the total FLOP count per head is

3.2dmodel2nhead=6dmodel2nhead3.\frac{2d_{model}^2}{n_{head}} = \frac{6d_{model}^2}{n_{head}}

With all the heads combined, the total FLOP count is

(3)6dmodel26d_{model}^2 \tag{3}

Scaled dot product attention is calculated using

(4)attention=softmax(qkTdk0.5).vqkT=2.dk=2.dmodelnheadsFLOPsoftmax(qkTdk0.5)=2.dk=2.dmodelnheadsFLOPsoftmax(qkTdk0.5).v=2.dmodelFLOPTotal FLOP in a single attention head=4.dmodelnheads+2.dmodelTotal FLOP with all the attention heads=4.dmodel+2.dmodel.nheads\begin{aligned} attention = softmax\left(\frac{qk^T}{d_k^{0.5}}\right).v \\ qk^T = 2.d_{k} = \frac{2.d_{model}}{n_{heads}} \quad \text{FLOP} \\ softmax\left( \frac{qk^T}{d_k^{0.5}}\right) = 2.d_{k} = \frac{2.{d_{model}}}{n_{heads}} \quad \text{FLOP} \\ softmax\left( \frac{qk^T}{d_k^{0.5}}\right).v = 2.d_{model} \quad \text{FLOP} \\ \text{Total FLOP in a single attention head} = \frac{4.d_{model}}{n_{heads}} + 2.d_{model} \tag{4} \\ \text{Total FLOP with all the attention heads} = 4.d_{model} + 2.d_{model}.n_{heads} \end{aligned}

Concatenation does not involve any matrix multiplications. The output of the concatenation is a R1xdmodel\mathbb{R}^{1xd_{model}} vector. The linear layer is a Rdmodelxdmodel\mathbb{R}^{d_{model}xd_{model}} matrix. Based on (1), the number of FLOP in the linear layer is Assuming the linear layers have zero bias. Even if they don’t, they will not contribute significantly to the FLOP count

(5)2.dmodel22.d_{model}^2 \tag{5}

The feed-forward layer consists of two linear layers. The first layer is a Rdmodelx4dmodel\mathbb{R}^{d_{model}x4d_{model}} matrix and the second layer is a R4dmodeldmodel\mathbb{R}^{4d_{model}*d_{model}} matrix. Therefore the number of FLOP in these layers are

(6)2.4.dmodel2+2.4.dmodel2=16.dmodel22.4.d_{model}^2 + 2.4.d_{model}^2 = 16.d_{model}^2 \tag{6}

Adding (3),(4),(5) and (6), the total FLOP in a transformer block is

(7)6.dmodel2+4.dmodel+2.dmodel.nheads+2.dmodel2+16.dmodel224.dmodel26.d_{model}^2 + 4.d_{model} + 2.d_{model}.n_{heads} + 2.d_{model}^2 + 16.d_{model}^2 \approx 24.d_{model}^2 \tag{7}

When dmodeld_{model} is sufficiently large, dmodel2d_{model}^2 term dominates, so we can ignore all the terms with the dmodeld_{model} coefficient. Note that I have ignored the layernorm calculation and the residual calculation. These calculations would also be a constant factor times dmodeld_{model} and will be sufficiently small.

FLOP for other layers

The transformer model consists of token embedding, position embedding, nlayersn_{layers} of transformer blocks and a final linear layer which has Rdmodelxdvocab\mathbb{R}^{d_{model}xd_{vocab}} parameters. Token and position embeddings are lookups, so there is no matrix multiplication. The number of FLOP in the final linear layer is

(8)2.dmodel2.dvocab2.d_{model}^2.d_{vocab} \tag{8}

FLOP for the full model

The transformer model has nlayersn_{layers} number of transformer blocks. Putting together (7) and (8), the total FLOP is

(9)24.dmodel2.nlayers+2.dmodel2.dvocab24.d_{model}^2.n_{layers} + 2.d_{model}^2.d_{vocab} \tag{9}

To get an idea of the scale of FLOP required to do the forward pass in a transformer using (9), I calculated the FLOP count for a few models.

Model NameLayersEmbedding dimensionVocabulary SizeFLOPGPT2XL(1.55B)481600502570.260TLlama3.1(8B)3240961280004.307TLlama3.1(405B)1261638412800069.53T\begin{array}{c|c|c|c|c} \text{Model Name} & \text{Layers} & \text{Embedding dimension} & \text{Vocabulary Size} &\text{FLOP}\\ \hline GPT2 XL (1.55B) & 48 & 1600 & 50257 & 0.260T\\ Llama 3.1(8B) & 32 & 4096 & 128000 & 4.307T&\\ Llama 3.1(405B) & 126 & 16384 & 128000 & 69.53T& \end{array}

My Nvidia 3060 GPU can theoretically do 101TTom’s Hardware page FLOPS using FP16 (half precision). If I run inference on the Llama 8B model, I can process about 1014.30723\frac{101}{4.307} \approx 23 tokens per second.Note that I have not considered memory to store these parameters

Memory requirements

In a transformer model, two major components that take up the GPU memory are the model parameters and KV cache. The formula to calculate the memory usage is

2.nparams+4.nlayers.dmodelbytes2.n_{params} + 4.n_{layers}.d_{model} \quad bytes

where

nparams:Number of parameters in the model\begin{array}{c|c} n_{params}: & \text{Number of parameters in the model} \end{array}

In the following sub-sections, I will show how I arrived at this formula.

Model parameters

Typically we store all the parameters in half precision(FP16), each parameter requires 2 bytes. If a model has nn parameters, the memory required to store them is simply 2.nparams2.n_{params} bytes

KV Cache

Let’s first understand why KV cache is necessary. In an autoregressive model like GPT-2, the next token depends on the previous tokens. Attention is calculated using the formula

Attention(Q,K,V)=softmax(QKTdk0.5).VAttention(Q,K,V) = softmax \left( \frac{QK^T}{d_k^{0.5}}\right).V

QQ,KK and VV are matrices of dimensions dmodelxnsd_{model}xn_s, nsn_s is the length of the sequence. Notice that the QKTQK^T computation increases quadratically with the length of the sequence. KV cache explained The key and value calculations are repeated for all the previous tokens which is wasteful. Caching keys and values will make the QKTQK^T computation linear with the length of the sequence nsn_s.

Let’s calculate how much memory is required to cache a single token.

(10)To store keys for a single attention head2.dkStoring keys and values4.dkFor all heads4.dk.nheads=4.dmodelFor all layers in the transformer4.nlayers.dmodel\begin{array}{c|c} \text{To store keys for a single attention head} & 2.d_k \\ \text{Storing keys and values} & 4.d_k \\ \text{For all heads} & 4.d_k.n_{heads} = 4.d_{model} \\ \text{For all layers in the transformer} & 4.n_{layers}.d_{model} \tag{10} \end{array}

Using (10), the memory needed to store 1000 tokens for a few models are shown below

Model NameLayersEmbedding dimensionMemory per 1000 tokensGPT2XL(1.55B)4816000.307GBLlama3.1(8B)3240960.524GBLlama3.1(405B)126163848.25GB\begin{array}{c|c|c|c} \text{Model Name} & \text{Layers} & \text{Embedding dimension} & \text{Memory per 1000 tokens} \\ \hline GPT2 XL (1.55B) & 48 & 1600 & 0.307GB\\ Llama 3.1(8B) & 32 & 4096 & 0.524GB\\ Llama 3.1(405B) & 126 & 16384 & 8.25GB \end{array}

On my Nvidia 3060 GPU, using GPT2, I can cache about 9GB0.307GB30k\frac{9GB}{0.307GB} \approx 30k tokens. The memory used for kv cache increases linearly with both nlayersn_{layers} and embedding size.

References

  1. Attention is all you need
  2. Kipply’s blog on transformer inference arithmetic
  3. Andrej Karpathy’s nanoGPT implementation
  4. KV cache explained with images/GIFs