20.03.2021 Views

Deep-Learning-with-PyTorch

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

336 CHAPTER 12 Improving training with metrics and augmentation

Now, let’s be perfectly clear: when we’re done, our model will be able to handle this

kind of data imbalance just fine. We could probably even train the model all the way

there without changing the balancing, assuming we were willing to wait for a gajillion

epochs first. 4 But we’re busy people with things to do, so rather than cook our GPU

until the heat death of the universe, let’s try to make our training data look more ideal

by changing the class balance we are training with.

12.4.1 Making the data look less like the actual and more like the “ideal”

The best thing to do would be to have relatively more positive samples. During the initial

epoch of training, when we’re going from randomized chaos to something more

organized, having so few training samples be positive means they get drowned out.

The method by which this happens is somewhat subtle, however. Recall that since

our network weights are initially randomized, the per-sample output of the network is

also randomized (but clamped to the range [0-1]).

NOTE Our loss function is nn.CrossEntropyLoss, which technically operates

on the raw logits rather than the class probabilities. For our discussion, we’ll

ignore that distinction and assume the loss and the label-prediction deltas are

the same thing.

The predictions numerically close to the correct label do not result in much change

to the weights of the network, while predictions that are significantly different from

the correct answer are responsible for a much greater change to the weights. Since

the output is random when the model is initialized with random weights, we can

assume that of our ~500k training samples (495,958, to be exact), we’ll have the following

approximate groups:

1 250,000 negative samples will be predicted to be negative (0.0 to 0.5) and result

in at most a small change to the network weights toward predicting negative.

2 250,000 negative samples will be predicted to be positive (0.5 to 1.0) and result

in a large swing toward the network weights predicting negative.

3 500 positive samples will be predicted to be negative and result in a swing

toward the network weights predicting positive.

4 500 positive samples will be predicted to be positive and result in almost no

change to the network weights.

NOTE Keep in mind that the actual predictions are real numbers between 0.0

and 1.0 inclusive, so these groups won’t have strict delineations.

Here’s the kicker, though: groups 1 and 4 can be any size, and they will continue to

have close to zero impact on training. The only thing that matters is that groups 2 and

3 can counteract each other’s pull enough to prevent the network from collapsing to a

degenerate “only output one thing” state. Since group 2 is 500 times larger than

4

It’s not clear if this is actually true, but it’s plausible, and the loss was getting better . . .

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!