A blog post on novel perspectives to understand random feature attention
This blog post introduces a new perspective to understanding the Random Feature Attention (RFA) mechanism. We show that 1) the conventional softmax attention can be equivalently rewritten as an expectation over RFAs, and that 2) RFA is in fact a self-normalized importance sampler to estimate conventional softmax attention. This new perspective grounds the heuristic RFA approximation and also sheds light on how to generalize further and improve RFAs. More details can be found in our ICML paper
The attention mechanism
For each query $\mbq_n$, conventional softmax attention computes the following quantity,
To reduce the computational complexity of softmax attention, researchers proposed to use random features (RF)
To estimate the expectation in \eqref{eqn:identity}, one can draw multiple Monte Carlo samples from $\mathcal{N}(\omega;0,\mathbf{I})$ such that $\exp(\mbx^{\top} \mby) \approx \frac{1}{S}\sum_{s=1}^S \xi(\mbx,\omega_s)^{\top}\xi(\mby, \omega_s)$. By substituting such approximation into the softmax attention, we obtain random feature attention (RFA)
Why it is called random feature attention?
This is due to the fact that the sample average can be written as $\mbphi(\mbx,\mbw)^\top \mbphi(\mby,\mbw)$, where $\mbphi(\mbx,\mbw) \coloneqq 1/\sqrt{S}[\xi(\mbx,\omega_1), \dots, \xi(\mbx, \omega_S)]^\top \in \R^{lS}$. The $\mbphi(\cdot,\cdot)$ can be considered as a feature map transforming the input vector to a new vector representation; as a result, $\mbphi(\cdot,\cdot)$ are conveniently referred to as random features
RFA suffers from significant performance degradation across various tasks, as observed in many previous studies (e.g., see here). Although computationally efficient, the reduced modeling capacity of RFA greatly limits its practical usage. Researchers try to improve the performance from different perspectives, such as
In this work, we explore an orthogonal axis by re-examining the estimation bias of RFA. Our key observation is that RFA is a heuristic approximation to the whole softmax attention. Although the estimation of individual exponential kernels is unbiased, such estimation of a ratio of exponential kernels is not unbiased anymore. This is due to the non-linearity of ratios,
Such bias not only incurs a potentially large approximation gap between RFA and softmax attention but also bottlenecks the effectiveness of unbiased kernel estimation.
Our work aims to address the following research question:
Given that we already know how to unbiasedly estimate exponential kernels, how do we construct an unbiased estimator for the whole softmax attention?
We answer this question in the affirmative and prove that softmax attention can be rewritten as an expectation of simple RFAs, $$ \begin{equation} \mathsf{SoftmaxAttn}(\mbq_n, \mbK,\mbV) = \sum_{m} \frac{\exp \left(\mathbf{q}_{n}^\top \mathbf{k}_{m} \right)}{\sum_{m'} \exp \left(\mathbf{q}_{n}^\top \mathbf{k}_{m'} \right)} \mathbf{v}_{m}^{\top} = \mathbb{E}_{p_n(\omega)}\left[f_n(\omega)\right],\label{eqn:softmax_as_expectation} \end{equation} $$ where $$ \begin{align} p_n(\omega) &= \sum_{m=1}^M \frac{\exp\left( \mbq_n^\top\mbk_m \right)}{\sum_{m'=1}^M\exp\left( \mbq_n^\top\mbk_{m'} \right)} \mathcal{N}(\omega; \mbq_n + \mbk_m, \mathbf{I}), \label{eqn:ra-density} \\ f_n(\omega) &= \frac{\xi(\mbq_n,\omega)^\top \sum_{m=1}^M \xi(\mbk_m, \omega) \mbv_{m}^{\top}}{\xi(\mbq_n,\omega)^\top \sum_{m'=1}^M \xi(\mbk_{m'}, \omega)} \label{eqn:ra-function}. \end{align} $$ Intuitively,
Notably, our result can be viewed as a neat generalization of random feature approximation, which exhibits a high degree of symmetry: $$ \begin{align} \exp(\mbq_n^\top \mbk_m)\mbv_m^\top &= \mathbb{E}_{q(\omega)}\left[\xi(\mbq_n,\omega)^\top\xi(\mbk_m, \omega)\mbv_m^\top\right] \notag \\ \sum_{m} \frac{\exp \left(\mathbf{q}_{n}^\top \mathbf{k}_{m} \right)}{\sum_{m'} \exp \left(\mathbf{q}_{n}^\top \mathbf{k}_{m'} \right)} \mathbf{v}_{m}^{\top} &= \mathbb{E}_{p_n(\omega)}\left[\frac{\xi(\mbq_n,\omega)^\top \sum_{m=1}^M \xi(\mbk_m, \omega) \mbv_{m}^{\top}}{\xi(\mbq_n,\omega)^\top \sum_{m'=1}^M \xi(\mbk_{m'}, \omega)}\right].\notag \end{align} $$ Another implication is that we can construct a Monte Carlo estimator to approximate softmax attention in an unbiased way. By drawing $\omega_n \sim p_n(\omega)$, we obtain $$ \begin{align} \mathsf{SoftmaxAttn}(\mbq_n, \mbK,\mbV) &\approx \frac{\xi(\mbq_n,\omega_n)^\top \sum_{m=1}^M\xi(\mbk_m, \omega_n) \mbv_{m}^{\top}}{ \xi(\mbq_n,\omega_n)^\top \sum_{m'=1}^M\xi(\mbk_{m'}, \omega_n)} \label{eqn:ra}\\ &\coloneqq \mathsf{RA}\left(\mbq_{n},\mbK,\mbV\right) \notag \end{align} $$ We name such estimator Randomized Attention (RA) since it computes similarity scores with individual randomized mappings (instead of concatenated features). To the best of our knowledge, this is the first result that generalizes the unbiased kernel estimation to unbiased attention estimation.
Remark: The proof of \eqref{eqn:softmax_as_expectation} is done by first reverse-engineering the formulation of RFA, equating it with self-normalized importance sampling (see below) and then completing the square of Gaussians to derive the density $p_n(\omega)$. The function $f_n(\omega)$ can be solved by substituting the density $p_n(\omega)$ into the equation. See the paper for a detailed proof.
The analysis above further reveals that RFA is a specific self-normalized importance sampling estimator to softmax attention.
Importance sampling (IS) is a general approach to approximating expectation $\mathbb{E}_{p_n(\omega)}\left[f_n(\omega)\right]$ when it is difficult to directly draw samples from $p_n(\omega)$. In importance sampling, we use a proposal distribution $q(\omega)$ to draw samples and estimate the quantity as $$ \mathbb{E}_{p_n(\omega)}\left[f_n(\omega)\right] = \mathbb{E}_{\omega \sim q(\omega)}\left[\frac{p_n(\omega)}{q(\omega)}f_n(\omega)\right] \approx \frac{1}{S} \sum_{s=1}^S \frac{p_n(\omega)}{q(\omega)} f_n(\omega_s), $$ where $\omega_1, \dots, \omega_S \sim q(\omega)$. The self-normalized importance sampling (SNIS) is defined as $$ \begin{equation} \mathbb{E}_{p_n(\omega)}\left[f_n(\omega)\right] \approx \frac{\sum_{s=1}^S\frac{p_n(\omega_s)}{q(\omega_s)}f(\omega_s)}{\sum_{s=1}^S\frac{p_n(\omega_s)}{q(\omega_s)}} = \sum_{s=1}^S\frac{\frac{p_n(\omega_s)}{q(\omega_s)}}{\sum_{s=1}^S\frac{p_n(\omega_s)}{q(\omega_s)}}f(\omega_s). \label{eqn:snis} \end{equation} $$ The name self-normalized comes from the fact that the importance weights $p_n(\omega)/q(\omega)$ are explicitly normalized and sum to 1.
Our key finding here is that the formulation of RFA can be exactly derived from SNIS. Supposing $p_n(\omega)$ and $f_n(\omega)$ are given in \eqref{eqn:ra-density} and \eqref{eqn:ra-function} respectively, and $q(\omega) = \mathcal{N}(\omega;0,\mathbf{I})$, we have $$ \begin{align} \mathsf{RFA}\left(\mbq_{n},\mbK,\mbV\right) &= \frac{\sum_{s=1}^S\textcolor{strings}{\xi(\mbq_n,\omega_s)^{\top}\sum_{m=1}^M\xi(\mbk_m, \omega_s)\mbv_{m}^{\top}}}{\sum_{s=1}^S \textcolor{keywords}{\xi(\mbq_n,\omega_s)^{\top}\sum_{m'=1}^M\xi(\mbk_{m'}, \omega_s)}} = \frac{ \sum_{s=1}^S\textcolor{strings}{\frac{p_n(\omega_s)}{q(\omega_s)} f(\omega_s)}}{ \sum_{s=1}^S\textcolor{keywords}{\frac{p_n(\omega_s)}{q(\omega_s)}}}. \label{eqn:rfa-as-snis} \end{align} $$ This formulation provides a new understanding of RFA: it is just a specific instantiation of SNIS estimators for softmax attention, whose proposal distribution $q(\omega)$ is chosen to be standard Gaussian. This reveals one of the possible reasons why RFA does not work well in practice: The plain standard Gaussian proposal in RFA is far away from the true Gaussian mixture (as in RA), which might lead to a large approximation gap. More importantly, this view implies that we can generalize and extend RFA by using other proposal distributions or adopting other estimating schemes!
So far, we have two types of estimators available for approximating softmax attention: unbiased RA and biased RFA. Besides the theoretical biasedness, how do they differ in terms of practical modeling behavior? We list a comprehensive comparison to better illustrate their main differences.
Motivated by the comparison, we propose LineAr-time Randomized Attention (LARA) that attempts to get the best of both worlds, combining both the efficiency of RFA and the expressiveness of RA.
LARA takes the following form $$ \begin{align} \mathsf{LARA}\left(\mbq_{n},\mbK,\mbV\right) &= \frac{\sum_{c=1}^C\alpha_{nc}(\omega_c)\frac{p_n(\omega_c)}{q_c(\omega_c)} f_n(\omega_c)}{\sum_{c=1}^C\alpha_{nc}(\omega_c)\frac{p_n(\omega_c)}{q_c(\omega_c)}}.\label{eqn:lara} \end{align} $$ Here,
Remark: We provide a detailed discussion about the parameterization of our proposal distributions in Appendix G.3 of our paper. To summarize, we find the key is to let different proposals depend on different sets of query information so that they could be as query-specific as possible. A good default is to divide the whole sequence into $C$ chunks, compute the mean of queries $\{\widetilde{\mbq}_c\}_{c=1}^C$ and keys $\{\widetilde{\mbk}_c\}_{c=1}^C$ within the same chunk, and set $q_c(\omega) = \mcN(\omega;\widetilde{\mbq}_c + \widetilde{\mbk}_c, \mathbf{I})$. We find this choice works well across various benchmarks.
LARA can be equivalently written in a similar way to RA and RFA. We spell it out here to see a systematic comparison among RA, LARA, and RFA: $$ \begin{align} \mathsf{RA}\left(\mbq_{n},\mbK,\mbV\right) &= \frac{\xi(\mbq_n,\omega_n)^\top \sum_{m=1}^M\xi(\mbk_m, \omega_n) \mbv_{m}^{\top}}{ \xi(\mbq_n,\omega_n)^\top \sum_{m'=1}^M\xi(\mbk_{m'}, \omega_n)}, &&\textcolor{keywords}{\omega_n \sim p_n(\omega)}\notag\\ \mathsf{LARA}\left(\mbq_{n},\mbK,\mbV\right) &= \frac{\sum_{c=1}^C \textcolor{strings}{\alpha'_{nc}(\omega_c)} \xi(\mbq_n,\omega_c)^\top \sum_{m=1}^M\xi(\mbk_m, \omega_c) \mbv_{m}^{\top}}{\sum_{c=1}^C \textcolor{strings}{\alpha'_{nc}(\omega_c)} \xi(\mbq_n,\omega_c)^\top \sum_{m=1}^M \xi(\mbk_{m}, \omega_c)}, &&\textcolor{keywords}{\omega_c \sim q_c(\omega)}\notag\\ \mathsf{RFA}\left(\mbq_{n},\mbK,\mbV\right) &= \frac{ \sum_{s=1}^S\xi(\mbq_n,\omega_s)^{\top}\sum_{m=1}^M\xi(\mbk_m, \omega_s)\mbv_{m}^{\top}}{\sum_{s=1}^S \xi(\mbq_n,\omega_s)^{\top}\sum_{m'=1}^M\xi(\mbk_{m'}, \omega_s)}, &&\textcolor{keywords}{\omega_1,\dots,\omega_S \sim \mcN(\omega;0, \mathbf{I})} \notag \end{align} $$ where we denote $\alpha'_{nc}(\omega_c) \coloneqq \alpha_{nc}(\omega_c)\mcN(\omega_c;0, \mathbf{I})/q_c(\omega_c)$ to simplify the notation. Note that their major difference lies in the choice of sampling distributions.
LARA is not designed to be a simple interpolation between RA and RFA; instead, it is a generalized estimation framework that includes both RA and RFA as its special cases. To see this,
To demonstrate the effectiveness of our approach, we first visualize the approximation error of LARA to the true softmax attention outputs. Our unbiased estimate, RA, achieves the lowest MSE among these three methods and gets very close to the true softmax attention; RFA (Performer) soon plateaus at large approximation error and does not improve even with more samples, possibly due to low sample efficiency. On the other hand, LARA exhibits much lower MSE than RFA and the approximation error continually decreases as the number of proposals increases.
We further verify the improved performance of LARA by conducting experiments across a wide range of data modalities, including images, videos, natural language texts, and a long-sequence benchmark. From the Table below, we observe that
Model | Complexity | Image Classification on ImageNet (🠅) | Video Recognition on SSv2 (🠅) | Machine Translation on WMT (🠅) | Long Range Arena suite (🠅) |
---|---|---|---|---|---|
Softmax | $\mathcal{O}(N^2)$ | 79.9 | 66.5 | 27.5 | 59.08 |
RA | $\mathcal{O}(N^2)$ | 80.0 | 64.9 | 27.8 | 59.30 |
RFA | $\mathcal{O}(N)$ | 74.3 | 53.1 | 23.7 | 57.63 |
LARA | $\mathcal{O}(N)$ | 79.5 | 63.7 | 27.0 | 59.12 |
LARA also enjoys much better scalability and could achieve SOTA results for image classification when applying it to advanced transformer architectures.
Finally, we evaluate the empirical efficiency of different attention methods. We note that RA runs almost twice slower than softmax attention, while its linear variant LARA runs much faster and brings marginal computational overheads compared to RFA.
In this work, we generalize the random feature attention (RFA) algorithm in two aspects:
LARA greatly improves the performance of RFA while also maintaining its efficiency of RFA. Our framework provides a novel perspective for understanding and improving RF-based attention approximation, which is also orthogonal to most previous work. At the same time, there are several limitations of our approach: