20.03.2021 Views

Deep-Learning-with-PyTorch

Create successful ePaper yourself

Turn your PDF publications into a flip-book with our unique Google optimized e-Paper software.

464 CHAPTER 15 Deploying to production

When we said that the JITed modules work like they did in Python, this includes

the fact that we can use them for training, too. On the flip side, this means we need to

set them up for inference (for example, using the torch.no_grad() context) just like

our traditional models, to make them do the right thing.

With algorithmically relatively simple models—like the CycleGAN, classification

models and U-Net-based segmentation—we can just trace the model as we did earlier.

For more complex models, a nifty property is that we can use scripted or traced functions

from other scripted or traced code, and that we can use scripted or traced submodules

when constructing and tracing or scripting a module. We can also trace

functions by calling nn.Models, but then we need to set all parameters to not require

gradients, as the parameters will be constants for the traced model.

As we have seen tracing already, let’s look at a practical example of scripting in

more detail.

15.3.4 Scripting the gaps of traceability

In more complex models, such as those from the Fast R-CNN family for detection or

recurrent networks used in natural language processing, the bits with control flow like

for loops need to be scripted. Similarly, if we needed the flexibility, we would find the

code bit the tracer warned about.

Listing 15.8

From utils/unet.py

class UNetUpBlock(nn.Module):

...

def center_crop(self, layer, target_size):

_, _, layer_height, layer_width = layer.size()

diff_y = (layer_height - target_size[0]) // 2

diff_x = (layer_width - target_size[1]) // 2

return layer[:, :, diff_y:(diff_y + target_size[0]),

➥ diff_x:(diff_x + target_size[1])]

The tracer warns here.

def forward(self, x, bridge):

...

crop1 = self.center_crop(bridge, up.shape[2:])

...

What happens is that the JIT magically replaces the shape tuple up.shape with a 1D

integer tensor with the same information. Now the slicing [2:] and the calculation of

diff_x and diff_y are all traceable tensor operations. However, that does not save us,

because the slicing then wants Python ints; and there, the reach of the JIT ends, giving

us the warning.

But we can solve this issue in a straightforward way: we script center_crop. We

slightly change the cut between caller and callee by passing up to the scripted center

_crop and extracting the sizes there. Other than that, all we need is to add the

@torch.jit.script decorator. The result is the following code, which makes the

U-Net model traceable without warnings.

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!