20.03.2021 Views

Deep-Learning-with-PyTorch

Create successful ePaper yourself

Turn your PDF publications into a flip-book with our unique Google optimized e-Paper software.

LibTorch: PyTorch in C++

465

Listing 15.9

Rewritten excerpt from utils/unet.py

@torch.jit.script

def center_crop(layer, target):

_, _, layer_height, layer_width = layer.size()

_, _, target_height, target_width = target.size()

diff_y = (layer_height - target_height) // 2

diff_x = (layer_width - target_width]) // 2

return layer[:, :, diff_y:(diff_y + target_height),

➥ diff_x:(diff_x + target_width)]

class UNetUpBlock(nn.Module):

...

Changes the signature, taking

target instead of target_size

The indexing uses the

size values we got.

Gets the sizes within

the scripted part

def forward(self, x, bridge):

...

crop1 = center_crop(bridge, up)

...

We adapt our call to pass

up rather than the size.

Another option we could choose—but that we will not use here—would be to move

unscriptable things into custom operators implemented in C++. The TorchVision

library does that for some specialty operations in Mask R-CNN models.

15.4 LibTorch: PyTorch in C++

We have seen various way to export our models, but so far, we have used Python. We’ll

now look at how we can forgo Python and work with C++ directly.

Let’s go back to the horse-to-zebra CycleGAN example. We will now take the JITed

model from section 15.2.3 and run it from a C++ program.

15.4.1 Running JITed models from C++

The hardest part about deploying PyTorch vision models in C++ is choosing an image

library to choose the data. 8 Here, we go with the very lightweight library CImg

(http://cimg.eu). If you are very familiar with OpenCV, you can adapt the code to use

that instead; we just felt that CImg is easiest for our exposition.

Running a JITed model is very simple. We’ll first show the image handling; it is not

really what we are after, so we will do this very quickly. 9

Listing 15.10

cyclegan_jit.cpp

#include "torch/script.h"

#define cimg_use_jpeg

#include "CImg.h"

using namespace cimg_library;

int main(int argc, char **argv) {

CImg<float> image(argv[2]);

Includes the PyTorch script header

and CImg with native JPEG support

Loads and decodes the

image into a float array

8

But TorchVision may develop a convenience function for loading images.

9

The code works with PyTorch 1.4 and, hopefully, above. In PyTorch versions before 1.3 you needed data in

place of data_ptr.

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!