20.03.2021 Views

Deep-Learning-with-PyTorch

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

Exporting models

457

But then all we need to do is call torch.jit.trace. 7

Listing 15.7

trace_example.py

import torch

from p2ch13.model_seg import UNetWrapper

seg_dict = torch.load('data-unversioned/part2/models/p2ch13/seg_2019-10-20_15

➥ .57.21_none.best.state', map_location='cpu')

seg_model = UNetWrapper(in_channels=8, n_classes=1, depth=4, wf=3,

➥ padding=True, batch_norm=True, up_mode='upconv')

seg_model.load_state_dict(seg_dict['model_state'])

seg_model.eval()

for p in seg_model.parameters():

p.requires_grad_(False)

Sets the parameters to

not require gradients

dummy_input = torch.randn(1, 8, 512, 512)

traced_seg_model = torch.jit.trace(seg_model, dummy_input)

The tracing

The tracing gives us a warning:

TracerWarning: Converting a tensor to a Python index might cause the trace

to be incorrect. We can't record the data flow of Python values, so this

value will be treated as a constant in the future. This means the trace

might not generalize to other inputs!

return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x +

➥ target_size[1])]

This stems from the cropping we do in U-Net, but as long as we only ever plan to feed

images of size 512 × 512 into the model, we will be OK. In the next section, we’ll take

a closer look at what causes the warning and how to get around the limitation it highlights

if we need to. It will also be important when we want to convert models that are

more complex than convolutional networks and U-Nets to TorchScript.

We can save the traced model

torch.jit.save(traced_seg_model, 'traced_seg_model.pt')

and load it back without needed anything but the saved file, and then we can call it:

loaded_model = torch.jit.load('traced_seg_model.pt')

prediction = loaded_model(batch)

The PyTorch JIT will keep the model’s state from when we saved it: that we had put it

into evaluation mode and that our parameters do not require gradients. If we had not

taken care of it beforehand, we would need to use with torch.no_grad(): in the

execution.

7

Strictly speaking, this traces the model as a function. Recently, PyTorch gained the ability to preserve more

of the module structure using torch.jit.trace_module, but for us, the plain tracing is sufficient.

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!