TheoryofDeepLearning.2022
You also want an ePaper? Increase the reach of your titles
YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.
ultra-wide neural networks and neural tangent kernel 71
equal to H(0). Moreover, under a random initialization of parameters,
the random matrix H(0) converges in probability to a certain
deterministic kernel matrix H ∗ as the width goes to infinity, which
is the Neural Tangent Kernel k(·, ·) evaluated on the training data. If
H(t) = H ∗ for all t, then Equation (8.1) becomes
du(t)
dt
= −H ∗ · (u(t) − y). (8.5)
Note that the above dynamics is identical to the dynamics of kernel
regression under gradient flow, for which at time t → ∞ the final
prediction function is (assuming u(0) = 0)
f ∗ (x) = (k(x, x 1 ), . . . , k(x, x n )) · (H ∗ ) −1 y. (8.6)
8.2 Coupling Ultra-wide Neural Networks and NTK
In this section, we consider a simple two-layer neural network of the
following form:
f (a, W, x) = √ 1 m ( )
m
∑ a r σ wr ⊤ x
r=1
(8.7)
where σ (·) is the activation function. Here we assume | ˙σ (z)| and
|¨σ (z)| are bounded by 1 for all z ∈ R and For example, soft-plus activation
function, σ (z) = log (1 + exp(z)), satisfies this assumption. 1
We also assume all any input x has Euclidean norm 1, ‖x‖ 2
= 1. The
scaling 1/ √ m will play an important role in proving H(t) is close to
the fixed H ∗ kernel. Throughout the section, to measure the closeness
of two matrices A and B, we use the operator norm ‖·‖ 2
.
We use random initialization w r (0) ∼ N(0, I) and a r ∼ Unif [{−1, 1}].
For simplicity, we will only optimize the first layer, i.e., W = [w 1 , . . . , w m ].
Note this is still a non-convex optimization problem.
We can first calculate H(0) and show as m → ∞, H(0) converges
to a fixed matrix H ∗ . Note ∂ f (a,W,x i)
∂w r
= √ 1
m
a r x i ˙σ ( wr ⊤ )
x i . Therefore,
each entry of H(0) admits the formula
[H(0)] ij
=
=
m
∑
r=1
m
∑
r=1
〈 ∂ f (a, W(0), xi )
, ∂ f (a, W(0), x 〉
j)
∂w r (0) ∂w r (0)
〈 1
)
√ a r x i ˙σ
(w r (0) ⊤ 1
)
x i , √ a r x j ˙σ
(w 〉
r (0) ⊤ x i m m
=x ⊤ i x j · ∑m r=1 ˙σ ( w r (0) ⊤ x i
) ˙σ
(
wr (0) ⊤ x j
)
Here the last step we used a 2 r = 1 for all r = 1, . . . , m because we
initialize a r ∼ Unif [{−1, 1}]. Recall every w r (0) is i.i.d. sampled from
a standard Gaussian distribution. Therefore, one can view [H(0)] ij
m
1 Note rectified linear unit (ReLU)
activation function does not satisfy this
assumption. However, one can use a
specialized analysis of ReLU to show
H(t) ≈ H ∗ [? ].