20.03.2021 Views

Deep-Learning-with-PyTorch

Create successful ePaper yourself

Turn your PDF publications into a flip-book with our unique Google optimized e-Paper software.

PyTorch’s autograd: Backpropagating all things

137

5.5.4 Autograd nits and switching it off

From the previous training loop, we can appreciate that we only ever call backward on

train_loss. Therefore, errors will only ever backpropagate based on the training

set—the validation set is used to provide an independent evaluation of the accuracy of

the model’s output on data that wasn’t used for training.

The curious reader will have an embryo of a question at this point. The model is

evaluated twice—once on train_t_u and once on val_t_u—and then backward is

called. Won’t this confuse autograd? Won’t backward be influenced by the values generated

during the pass on the validation set?

Luckily for us, this isn’t the case. The first line in the training loop evaluates model

on train_t_u to produce train_t_p. Then train_loss is evaluated from train_t_p.

This creates a computation graph that links train_t_u to train_t_p to train_loss.

When model is evaluated again on val_t_u, it produces val_t_p and val_loss. In this

case, a separate computation graph will be created that links val_t_u to val_t_p to

val_loss. Separate tensors have been run through the same functions, model and

loss_fn, generating separate computation graphs, as shown in figure 5.15.

A

B

C

Figure 5.15 Diagram showing how gradients propagate through a graph with two

losses when .backward is called on one of them

The only tensors these two graphs have in common are the parameters. When we call

backward on train_loss, we run backward on the first graph. In other words, we

accumulate the derivatives of train_loss with respect to the parameters based on the

computation generated from train_t_u.

If we (incorrectly) called backward on val_loss as well, we would accumulate the

derivatives of val_loss with respect to the parameters on the same leaf nodes. Remember

the zero_grad thing, whereby gradients are accumulated on top of each other every

time we call backward unless we zero out the gradients explicitly? Well, here something

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!