![](https://crypto4nerd.com/wp-content/uploads/2023/11/1TPuxEStK-3at1fcsLn4edA-1024x269.png)
In this article, we will go through the RoFormer paper, which introduced rotary positional embedding for transformer architecture and positional encodings. Also, we will implement it using the JAX deep learning framework
Before jumping into RoPE (rotary positional encoding), let’s first discuss positional encoding for the transformer architecture introduced in the original transformer paper
If you want to understand transformers better, I would suggest these videos (I), (II), and (III). I will assume that you know the basics of transformers and how Multi-head attention works.
- The self-attention formulation is as follows
- Here the attention from the query vector for the token at the m position is given by
- We take the dot product of the projection from the query vector with the query vector of all the tokens (preceding tokens for the decoder and all tokens for the encoder), and then we take softmax followed by matrix multiplication with the projection from the value vector. Finally, the output projection is derived after passing through another feedforward layer.
- But we can see here that the self-attention computation is permutationally invariant, because if we shuffle the order the attention value would be the same. However, the order of the tokens is important for natural language, so the authors proposed positional encodings, which are added to each token embedding before passing through the transformer layers (highlighting that it is only done once at the start).
- Considering that we have the embeddings for each token as a dimension of 512, above we have drawn positional embeddings.
- So for each token, we iterate through the embedding vectors, and for every…