Motivation

In transformer attention, the logits are $q^\top k$, where $q, k \in \mathbb{R}^d$ are query and key vectors. So if μP says to scale attention by $1/d$ rather than the usual $1/\sqrt{d}$, the natural question is: why?

If $q$ and $k$ were independent, the usual $1/\sqrt{d}$ scaling would be exactly the right guess. Each term $q_i k_i$ is mean-zero with bounded variance, so the sum $q^\top k = \sum_i q_i k_i$ lives in the familiar Central Limit Theorem world and fluctuates at scale $\sqrt{d}$.

But during training, $q$ and $k$ pick up correlations. Then the summands $q_i k_i$ can reinforce one another, so the dot product starts looking less like a CLT object and more like a Law of Large Numbers object. In the large-width regime that μP cares about, this matters because we want the size of $q^\top k$ to stay $O(1)$ as $d$ grows.

So the real issue is whether $q^\top k$ behaves like a $\sqrt{d}$ object or a $d$ object. This note gives a simple Gaussian calculation that makes that distinction explicit. The object we will compute is the $L^2$ norm

$$\|q^\top k\|_2 := \left(\mathbb{E}[(q^\top k)^2]\right)^{1/2}$$

for jointly Gaussian $q$ and $k$. The answer splits neatly into a baseline term, which is there even when $q \perp k$, and two alignment terms controlled by the cross-covariance $C = \mathbb{E}[qk^\top]$. That decomposition makes the $\sqrt{d} \to d$ transition completely explicit.


Setup

Of course, real query and key vectors in a trained network are not literally Gaussian. But in the spirit of μP, it is useful to start with a simplified probabilistic model and see what it already explains.

Assumption 1. Let $q, k \in \mathbb{R}^d$ be jointly Gaussian, centered:

$$ \begin{pmatrix} q \\ k \end{pmatrix} \sim \mathcal{N}\left( 0, \begin{pmatrix} \Sigma_q & C \\ C^\top & \Sigma_k \end{pmatrix} \right) $$

where $\Sigma_q = \mathbb{E}[q q^\top]$, $\Sigma_k = \mathbb{E}[k k^\top]$, and $C = \mathbb{E}[q k^\top]$ is the cross-covariance. Define the scalar $S := q^\top k = \sum_{i=1}^d q_i k_i$.

Before specializing to $S = q^\top k$, let me fix the standard probability notation: if $X$ is a random variable on a probability space and $\mathbb{E}[|X|^p] < \infty$, then its $L^p$ norm is $$ \|X\|_p := \left(\mathbb{E}[|X|^p]\right)^{1/p}. $$ See Vershynin, High-Dimensional Probability, Chapter 1.

Definition 2. The $L^2$ norm of $S = q^\top k$ is:

$$\|S\|_2 := \left(\mathbb{E}[S^2]\right)^{1/2} = \left(\mathbb{E}[(q^\top k)^2]\right)^{1/2}$$

This is the root-mean-square magnitude of the dot product.


Main Result

Theorem 1. Under Assumption 1,

$$\boxed{ \|q^\top k\|_2 = \left( \operatorname{tr}(\Sigma_q \Sigma_k) + \|C\|_F^2 + (\operatorname{tr} C)^2 \right)^{1/2} }$$


Interpretation

The formula is easy to read once you look at the three terms separately:

  • Independent ($C = 0$): $\|q^\top k\|_2 = \sqrt{d}$ — CLT scaling.
  • Isotropic correlation ($C = \rho I$): $\|q^\top k\|_2 = \sqrt{d + \rho^2 d + \rho^2 d^2}$ — any nonzero $\rho$ makes the $d^2$ term dominate.
  • Perfect alignment ($k = q$): $\|q^\top k\|_2 \approx d$ — LLN scaling.

So the whole story is really carried by $C$. If $C$ is negligible, you stay in the $\sqrt{d}$ world. If $C$ has coherent structure, the $d$-scale behavior starts to take over.


Proof of Theorem 1

This clean version of the proof using Stein's formula was found with the help of ChatGPT. My own proof computed $\mathbb{E}[(q^\top k)^2]$ by expanding everything into fourth-order Gaussian moments and I realized that I needed to Isserlis's theorem to continue the proof. That works, but in coordinates it quickly turns into quite a mess.

After a few rounds of discussion with ChatGPT, it suggested using Stein's lemma , which was a suprisingly elegant alternative. We package $q^\top k$ as a quadratic form $w^\top A w$ in the joint vector $w = (q, k)$, and then use Stein's lemma to get the needed quadratic-form identity almost for free.

Lemma 1 (Stein's lemma). For $z \sim \mathcal{N}(0, I)$ and smooth $h : \mathbb{R}^n \to \mathbb{R}$ with $\mathbb{E}[|\nabla h(z)|] < \infty$,

$$\mathbb{E}[z_i \, h(z)] = \mathbb{E}[\partial_i h(z)]$$

Applying this twice gives the matrix form:

$$\mathbb{E}[zz^\top h(z)] = I\,\mathbb{E}[h(z)] + \mathbb{E}[\nabla^2 h(z)]$$

Lemma 2 (Quadratic form identity). For $w \sim \mathcal{N}(0, \Sigma)$ and symmetric $B$,

$$\mathbb{E}[(w^\top B w)^2] = (\operatorname{tr}(B\Sigma))^2 + 2\operatorname{tr}((B\Sigma)^2)$$

Proof of Lemma 2.

First reduce to the standard Gaussian case. Write $w = \Sigma^{1/2} z$ with $z \sim \mathcal{N}(0, I)$, and set $M = \Sigma^{1/2} B \Sigma^{1/2}$, which is still symmetric. Then $w^\top B w = z^\top M z$, and the trace identities $\operatorname{tr}(M) = \operatorname{tr}(B\Sigma)$ and $\operatorname{tr}(M^2) = \operatorname{tr}((B\Sigma)^2)$ follow from cyclicity. So all we really need is

$$\mathbb{E}[(z^\top M z)^2] = (\operatorname{tr} M)^2 + 2\operatorname{tr}(M^2)$$

Now apply Lemma 1 with $h(z) = z^\top M z$. Since $\mathbb{E}[h(z)] = \operatorname{tr} M$ and $\nabla^2 h(z) = 2M$, the matrix form of Stein's lemma gives

$$\mathbb{E}[zz^\top (z^\top M z)] = (\operatorname{tr} M)\, I + 2M$$

Finally, use the simple identity $(z^\top M z)^2 = \operatorname{tr}(Mzz^\top)\,(z^\top M z)$:

$$\mathbb{E}[(z^\top M z)^2] = \operatorname{tr}\!\big(M \,\mathbb{E}[zz^\top (z^\top M z)]\big) = \operatorname{tr}\!\big(M((\operatorname{tr} M)\,I + 2M)\big) = (\operatorname{tr} M)^2 + 2\operatorname{tr}(M^2)$$

$\square$

Proof of Theorem 1.

At this point the rest is bookkeeping. Write $$ w = \begin{pmatrix} q \\ k \end{pmatrix} \sim \mathcal{N}(0, \Sigma), \qquad \Sigma = \begin{pmatrix} \Sigma_q & C \\ C^\top & \Sigma_k \end{pmatrix}, $$ and observe that $q^\top k = w^\top A w$ with $$ A = \frac{1}{2}\begin{pmatrix} 0 & I_d \\ I_d & 0 \end{pmatrix}. $$ Lemma 2 therefore gives $$ \mathbb{E}[(q^\top k)^2] = (\operatorname{tr}(A\Sigma))^2 + 2\operatorname{tr}((A\Sigma)^2). $$ A quick block multiplication shows that $$ A\Sigma = \frac{1}{2}\begin{pmatrix} C^\top & \Sigma_k \\ \Sigma_q & C \end{pmatrix}, $$ so $\operatorname{tr}(A\Sigma) = \operatorname{tr} C$ and $\operatorname{tr}((A\Sigma)^2) = \frac{1}{2}(\operatorname{tr}(\Sigma_q \Sigma_k) + \|C\|_F^2)$. Plugging these in gives

$$\mathbb{E}[(q^\top k)^2] = (\operatorname{tr} C)^2 + \operatorname{tr}(\Sigma_q \Sigma_k) + \|C\|_F^2$$

$\square$