20.03.2021 Views

Deep-Learning-with-PyTorch

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

214 CHAPTER 8 Using convolutions to generalize

8.4.1 Measuring accuracy

In order to have a measure that is more interpretable than the loss, we can take a look

at our accuracies on the training and validation datasets. We use the same code as in

chapter 7:

# In[32]:

train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,

shuffle=False)

val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,

shuffle=False)

def validate(model, train_loader, val_loader):

for name, loader in [("train", train_loader), ("val", val_loader)]:

Gives us the index

of the highest

value as output

correct = 0

total = 0

with torch.no_grad():

for imgs, labels in loader:

We do not want gradients

here, as we will not want to

update the parameters.

outputs = model(imgs)

_, predicted = torch.max(outputs, dim=1)

total += labels.shape[0]

correct += int((predicted == labels).sum())

print("Accuracy {}: {:.2f}".format(name , correct / total))

Counts the number of

examples, so total is

increased by the batch

size

validate(model, train_loader, val_loader)

# Out[32]:

Accuracy train: 0.93

Accuracy val: 0.89

Comparing the predicted class that had the

maximum probability and the ground-truth

labels, we first get a Boolean array. Taking the

sum gives the number of items in the batch

where the prediction and ground truth agree.

We cast to a Python int—for integer tensors, this is equivalent to using .item(), similar

to what we did in the training loop.

This is quite a lot better than the fully connected model, which achieved only 79%

accuracy. We about halved the number of errors on the validation set. Also, we used

far fewer parameters. This is telling us that the model does a better job of generalizing

its task of recognizing the subject of images from a new sample, through locality and

translation invariance. We could now let it run for more epochs and see what performance

we could squeeze out.

8.4.2 Saving and loading our model

Since we’re satisfied with our model so far, it would be nice to actually save it, right?

It’s easy to do. Let’s save the model to a file:

# In[33]:

torch.save(model.state_dict(), data_path + 'birds_vs_airplanes.pt')

The birds_vs_airplanes.pt file now contains all the parameters of model: that is,

weights and biases for the two convolution modules and the two linear modules. So,

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!