20.03.2021 Views

Deep-Learning-with-PyTorch

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

Updating the training script for segmentation

397

When we add similar code to our classification training loop in the next chapter, we’ll

use the F1 score.

Back in the main training loop, we’ll keep track of the best_score we’ve seen so

far in this training run. When we save our model, we’ll include a flag that indicates

whether this is the best score we’ve seen so far. Recall from section 13.6.4 that we’re

only calling the doValidation function for the first and then every fifth epochs. That

means we’re only going to check for a best score on those epochs. That shouldn’t be a

problem, but it’s something to keep in mind if you need to debug something happening

on epoch 7. We do this checking just before we save the images.

Listing 13.33

training.py:210, SegmentationTrainingApp.main

def main(self):

The epoch-loop

best_score = 0.0

we already saw

for epoch_ndx in range(1, self.cli_args.epochs + 1):

# if validation is wanted

# ... line 233

valMetrics_t = self.doValidation(epoch_ndx, val_dl)

score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)

best_score = max(score, best_score)

Computes the

score. As we saw

earlier, we take

the recall.

self.saveModel('seg', epoch_ndx, score == best_score)

Now we only need to write saveModel. The third parameter

is whether we want to save it as best model, too.

Let’s take a look at how we persist our model to disk.

13.6.6 Saving our model

PyTorch makes it pretty easy to save our model to disk. Under the hood, torch.save

uses the standard Python pickle library, which means we could pass our model

instance in directly, and it would save properly. That’s not considered the ideal way to

persist our model, however, since we lose some flexibility.

Instead, we will save only the parameters of our model. Doing this allows us to load

those parameters into any model that expects parameters of the same shape, even if

the class doesn’t match the model those parameters were saved under. The saveparameters-only

approach allows us to reuse and remix our models in more ways than

saving the entire model.

We can get at our model’s parameters using the model.state_dict() function.

Listing 13.34

training.py:480, .saveModel

def saveModel(self, type_str, epoch_ndx, isBest=False):

# ... line 496

model = self.segmentation_model

if isinstance(model, torch.nn.DataParallel):

model = model.module

Gets rid of the DataParallel

wrapper, if it exists

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!