Optimizing Relative Positional Encodings in Enformer and Borzoi
What does the relative shift do?
Models like Enformer and Borzoi rely on relative positional encodings to capture dependencies between sequence elements based on their relative distances during the attention computation. Having spent quite some time with Borzoi and Enformer, I have wondered how that works, and what the mysterious relative_shift
function is used for in current Borzoi and Enformer implementations:
#from lucidrains/enformer_pytorch
def relative_shift(x):
to_pad = torch.zeros_like(x[..., :1])
x = torch.cat((to_pad, x), dim = -1)
_, h, t1, t2 = x.shape
x = x.reshape(-1, h, t2, t1)
x = x[:, :, 1:, :]
x = x.reshape(-1, h, t1, t2 - 1)
return x[..., :((t2 + 1) // 2)]
In brief, for attention with relative positional encodings Transformer-XL style, the matrix product between queries and learned relative positional keys \(R = QK^\top_\text{rel}\) is added to the standard \(QK^\top\) before the softmax.
Here, \(Q \in \mathbb{R}^{N \times d}\) is the query matrix and \(K_\text{rel} \in \mathbb{R}^{(2N-1) \times d}\) represents the relative position matrix, where relative positions for \(Q\) range from \(-N\) to \(N-1\).
This approach, however, calculates unnecessary entries in \(R\), such as invalid relative positions (e.g., negative indices for the first query):

To extract the relevant entries, Enformer and Borzoi use the relative shift operation. This operation isolates a submatrix corresponding to valid relative positions for each query \(q_i\), effectively extracting an anti-diagonal from \(R\). As one can see above, the current reference implementation appends a padding tensor to \(R\) to enable reshaping and to facilitate this extraction.
In images, without batch and head dimension:

Appending this padding tensor introduces a costly concatenation step, which increases computation time, particularly for long sequences. Profiling on an Nvidia A40 GPU shows that this step accounts for:
- 15% of forward pass time in Borzoi, where the transformer operates on 4096 embeddings with 128bp resolution.
- 5% of forward pass time in Enformer, which processes sequences of length 1536.
Most of this overhead stems from the high memory bandwidth demands of concatenation, which is more limiting on GPUs than compute.
A More Efficient Implementation
We can develop an optimized approach that eliminates concatenation entirely by directly indexing into \(R\) with offsets and strides. This method first starts the matrix at an offset equal to half the sequence length and then adjusts the stride to create the staircase-like submatrix.
This approach is computationally equivalent to the original method but avoids unnecessary memory operations. Using PyTorch’s as_strided
and vmap
functions, we can implement the solution efficiently by vmapping over batch and head dimension:
def fast_relative_shift(a, b):
return (
einsum("i d, j d -> i j", a, b)
.flatten()
.as_strided(
size=(a.shape[0], a.shape[0]), # we want a result of shape NxN = Q
stride=((a.shape[0] - 1) * 2, 1), # we want to start a new row every 2(N-1) elements
storage_offset=a.shape[0] - 1, # we start in the middle of the first matrix row
)
)
fast_relative_shift = torch.vmap(torch.vmap(fast_relative_shift), in_dims=(0, None)) # The relative keys have no batch dim, but the queries do
Benchmarks
Benchmarking the optimized implementation on an Nvidia A40 GPU shows significant speedups:
- Borzoi:
- 19.2% faster forward passes.
- 28.5% faster forward and backward passes.
- Enformer:
- 4.3% faster forward passes.
- 7.3% faster forward and backward passes.

Importantly, these gains are achieved without any change in numerical accuracy; the outputs remain identical to the original implementation.
Conclusion
This optimized approach significantly reduces the computational overhead of relative positional encodings, especially for models like Borzoi that process long sequences. As sequence lengths increase in next-generation models, such optimizations could play a critical role in improving efficiency without sacrificing performance. FlexAttention doesn’t yet work for supervised genomics models…
Acknowledgments
I’d like to thank Alexander Karollus and Laura Martens for fun evenings investigating the pecularities of Borzoi and Enformer.