20.03.2021 Views

Deep-Learning-with-PyTorch

Create successful ePaper yourself

Turn your PDF publications into a flip-book with our unique Google optimized e-Paper software.

398 CHAPTER 13 Using segmentation to find suspected nodules

state = {

'sys_argv': sys.argv,

'time': str(datetime.datetime.now()),

'model_state': model.state_dict(), The important part

'model_name': type(model).__name__,

'optimizer_state' : self.optimizer.state_dict(),

Preserves momentum,

'optimizer_name': type(self.optimizer).__name__, and so on

'epoch': epoch_ndx,

'totalTrainingSamples_count': self.totalTrainingSamples_count,

}

torch.save(state, file_path)

We set file_path to something like data-unversioned/part2/models/p2ch13/

seg_2019-07-10_02.17.22_ch12.50000.state. The .50000. part is the number of

training samples we’ve presented to the model so far, while the other parts of the path

are obvious.

TIP By saving the optimizer state as well, we could resume training seamlessly.

While we don’t provide an implementation of this, it could be useful if your access

to computing resources is likely to be interrupted. Details on loading a model and

optimizer to restart training can be found in the official documentation

(https://pytorch.org/tutorials/beginner/saving_loading_models.html).

If the current model has the best score we’ve seen so far, we save a second copy of

state with a .best.state filename. This might get overwritten later by another, higherscore

version of the model. By focusing only on this best file, we can divorce customers

of our trained model from the details of how each epoch of training went (assuming,

of course, that our score metric is of high quality).

Listing 13.35

training.py:514, .saveModel

if isBest:

best_path = os.path.join(

'data-unversioned', 'part2', 'models',

self.cli_args.tb_prefix,

f'{type_str}_{self.time_str}_{self.cli_args.comment}.best.state')

shutil.copyfile(file_path, best_path)

log.info("Saved model params to {}".format(best_path))

with open(file_path, 'rb') as f:

log.info("SHA1: " + hashlib.sha1(f.read()).hexdigest())

We also output the SHA1 of the model we just saved. Similar to sys.argv and the

timestamp we put into the state dictionary, this can help us debug exactly what model

we’re working with if things become confused later (for example, if a file gets

renamed incorrectly).

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!