Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                

Gradient of A Matrix Matrix Multiplication

Download as pdf or txt
Download as pdf or txt
You are on page 1of 1

Edward Hu Blog About

Gradient of a Matrix Matrix multiplication


This is just matrix multiplication.

It’s good to understand how to derive gradients for your neural network. It gets a little hairy when you have
matrix matrix multiplication, such as W X + b . When I was reviewing Backpropagation in CS231n, they
handwaved over this derivation of the loss function L in respect to the weights matrix W :

∂L ∂L ∂D
∂W ∂D ∂W

∂L ∂L
= X
∂W ∂D


X = (m, n) input matrix with m features and n samples

W = (H , m) weight matrix with H neurons
D = WX

(H , n) matrix

L = f (D)

scalar value, f is arbitrary loss function

Note that others may use D = XW where X ’s rows are samples and columns are feature dimensions. That’s
ok, you can follow this math and switch the indices and nd the result to be identical.

The canonical neuron is Relu(D + b), but to make things simpler we’ll ignore the nonlinearity and bias and
say the L takes in D instead of Relu(D + b). We want to nd the gradient of L with respect to W to do
gradient descent.

We want to nd , so let’s start by looking at a speci c weight Wdc . This way we can think more easily


about the gradient of L for a single weight and extrapolate for all weights W .

∂L ∂L ∂ D ij
= ∑
∂ W dc ∂ D ij ∂ W dc

Let’s look more closely at the partial of D ij with respect to Wdc . We know that = 0 if i ≠ d because

D ij is the dot product of row i of W and column j of X . This means the summation can be simpli ed by only
looking at cases where ≠ 0 , which is when i = d .

∂L ∂ D ij ∂L ∂ D dj
∑ = ∑
∂ D ij ∂ W dc ∂ D dj ∂ W dc
i,j j

Finally, what is ?

D dj = ∑ W dk X kj


q q
∂ D dj ∂ ∂
= ∑ W dk X kj = ∑ W dk X kj
∂ W dc ∂ W dc ∂ W dc
k=1 k=1

∂ D dj
∴ = X cj
∂ W dc

So to put it all together, we have:

∂L ∂L
= ∑ X cj
∂ W dc ∂ D dj

Now how can we simplify this? Well, one quick way is see that the sum over j is doing a dot product on with
row d and column c if we transpose X cj to X jc

∂L ∂L
= ∑ X
∂ W dc ∂ D dj

Now we want this for all weights in W , which means we can generalize this to:

∂L ∂L
= X
∂W ∂D

1 Comment edwardshu 🔒 Disqus' Privacy Policy  Jason Stanley

 Recommend 1 t Tweet f Share Sort by Best

Join the discussion…

Michael Heinzer • 4 months ago

There is a slightly imprecise notation whenever you sum up to q, as q is never defined. The q term should probably be replaced by m. I would recommend adding the limits of your sum everywhere to make your post more
△ ▽ • Reply • Share ›

✉ Subscribe d Add Disqus to your siteAdd DisqusAdd ⚠ Do Not Sell My Data

You might also like