EL-Attention: Memory Efficient Lossless Attention for Generation
Transformer models are great for NLP tasks, but also very resource intensive!
EL-attention proposes a way to reduce the memory requirements during inference, and this post will try to provide a simple summary of the approach.
Attention mechanism
The workhorse of modern NLP and computer vision. There are many variations to this operation, but we’ll start with a simple implementation in PyTorch
import torch
# We'll use typical values here
hidden_dim = 512
key_value_dim = 128
seq_size = 10
W_Q = torch.rand(key_value_dim, hidden_dim) # Query matrix
W_K = torch.rand(key_value_dim, hidden_dim) # Key matrix
W_V = torch.rand(key_value_dim, hidden_dim) # Value matrix
W_O = torch.rand(key_value_dim, hidden_dim) # Output matrix
hidden_state = torch.rand(bz, seq_size, hidden_dim)
def vanilla_attention(hidden_state):
# Project hidden state into queries, keys and values
q_proj = hidden_state @ W_Q
k_proj = hidden_state @ W_K
v_proj = hidden_state @ W_V
# Compute scores between keys and values
scores = q_proj @ k_proj.transpose(-2, -1)
probs = scores.softmax(-1)
# Weighted sum of values
weighted_sum = probs @ v_proj
# Project the sum with output matrix
hidden_out = weighted_sum @ W_O
return hidden_out
vanilla_output = vanilla_attention(hidden_state)
print(vanilla_output.shape) # (bz, seq_size, hidden_dim)
Auto-regressive generation
At inference time transformers generate one token a time (incremental decoding), it would be inefficient to re-compute all the projections necessary for self-attention each time. We instead cache the keys and values at every layer of the previous tokens and only using the hidden states of the last generated token we can perform generation in more computationally efficient way.
def vanilla_attention_with_cache(hidden_state, kv_cache):
# hidden_state shape will be (bz, 1, hidden_dim)
q_proj = hidden_state @ W_Q
k_proj = hidden_state @ W_K
v_proj = hidden_state @ W_V
if kv_cache is not None:
k_cache, v_cache = kv_cache
k_proj = torch.cat([k_cache, k_proj], dim=1)
v_proj = torch.cat([v_cache, v_proj], dim=1)
scores = q_proj @ k_proj.transpose(-2, -1)
probs = scores.softmax(-1)
weighted_sum = probs @ v_proj
hidden_out = weighted_sum @ W_O
kv_cache = (k_proj, v_proj)
return hidden_out, kv_cache
We’ve traded computational efficiency for memory with kv_cache
. This scales with the sequence length of the past tokens. When you consider
multiple layers and
hidden_dim
of larger modelslonger sequences
large batch size
This quickly becomes expensive!
EL-Attention
This paper proposes an alternative method of computing the attention operation. By re-ordering some of the matrix multiplication steps, we cach get away with caching just the past hidden states.
Matrix multiplication properties
There are two properties we are specifically interested in
Looking at the scores
computation,
In the incremental decoding case with key-value caching, the hidden states used for computing Q and K will be different. The one used in computing Q
will be of sequence size 1 and the one in K
will be of past tokens length.
By re-ordering the operations we can see that key caching can be substituted for hidden_state
caching for computing scores.
We still have to address value caching, but this too can be re-ordered and computed with just past hidden states
Here’s the attention function rewritten with EL-attention
def el_attention(hidden_state, h_cache=None):
q_proj = hidden_state @ W_Q
if h_cache is not None:
h_cache = torch.cat([h_cache, hidden_state], dim=1)
else:
h_cache = hidden_state
scores = q_proj @ W_K.T @ h_cache.transpose(-2, -1)
probs = scores.softmax(-1)
weighted_sum = probs @ h_cache
hidden_out = weighted_sum @ W_V @ W_O
return hidden_out, h_cache
output, kv_cache = vanilla_attention_with_cache(hidden_states)
output_el, h_cache = el_attention(hidden_states)
torch.allclose(output, output_el) # Should print True
Now we’ve verified that both normal attention and EL-attention produces the same output.
I’ve omitted the implementation for multi-head attention, but the memory savings still apply even in that case.
For decoder only models, we’ve effectively halved1 the amount of caching memory required. For EL-attention gets even better as we can completely get rid of key-value caches for cross-attention and share the same encoder states across all the decoder cross-attention layers.
Caveats
EL-attention is almost a free optimization that you can apply for any transformer model.
(Mostly) Incompatible with RoPE embeddings - they apply a transformation to query and key values to incorporate position information. This transformation is hard to move around to achieve EL-attention
Flash Attention - You can still use EL-attention in conjunction with this, but since you’re working in the dimension of hidden_states for computing scores/softmax it’s usually a much larger than the key/value dimension. This increase in dimension size prevents the utilization of fast kernels for flash attention.
Most transformer blocks follow `num_heads * kv_dim = hidden_dim`