22.02.2024 Views

Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

Notebook Cell 2.6 - Loading checkpoint to resume training

checkpoint = torch.load('model_checkpoint.pth')

model.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

saved_epoch = checkpoint['epoch']

saved_losses = checkpoint['loss']

saved_val_losses = checkpoint['val_loss']

model.train() # always use TRAIN for resuming training 1

1 Never forget to set the mode!

print(model.state_dict())

Output

OrderedDict([('0.weight', tensor([[1.9448]], device='cuda:0')),

('0.bias', tensor([1.0295], device='cuda:0'))])

Cool, we recovered our model’s state, and we can resume training.

After loading a model to resume training, make sure you

ALWAYS set it to training mode:

model.train()

In our example, this is going to be redundant because our

train_step_fn() function already does it. But it is important to

pick up the habit of setting the mode of the model accordingly.

Next, we can run Model Training V5 to train it for another 200 epochs.

"Why 200 more epochs? Can’t I choose a different number?"

Well, you could, but you’d have to change the code in Model Training V5. This

clearly isn’t ideal, but we will make our model training code more flexible very

166 | Chapter 2: Rethinking the Training Loop

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!