22.02.2024 Views

Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

Create successful ePaper yourself

Turn your PDF publications into a flip-book with our unique Google optimized e-Paper software.

import torch

import torch.optim as optim

import torch.nn as nn

from sklearn.datasets import make_regression

from torch.utils.data import DataLoader, TensorDataset

from stepbystep.v3 import StepByStep

from data_generation.ball import load_data

Vanishing and Exploding Gradients

In this extra chapter, we’re discussing gradients once again. The gradients, together

with the learning rate, are what makes the model tick, or better yet, learn. We

discussed both of these topics in quite some detail in Chapter 6, but we always

assumed that the gradients were well behaved, as long as our learning rate was

sensible. Unfortunately, this is not necessarily true, and sometimes the gradients

may go awry: They can either vanish or explode. Either way, we need to rein them

in, so let’s see how we can accomplish that.

Vanishing Gradients

Do you remember how we tell PyTorch to compute gradients for us? It starts with

the loss value, followed by a call to the backward() method, which works its way

back up to the first layer. That’s backpropagation in a nutshell. It works fine for

models with a few hidden layers, but as models grow deeper, the gradients

computed for the weights in the initial layers become smaller and smaller. That’s

the so-called vanishing gradients problem, and it has always been a major obstacle

for training deeper models.

"Why is it so bad?"

If gradients vanish—that is, if they are close to zero—updating the weights will

barely change them. In other words, the model is not learning anything; it gets

stuck.

"Why does it happen?"

We can blame it on the (in)famous "internal covariate shift." But, instead of

562 | Extra Chapter: Vanishing and Exploding Gradients

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!