20.03.2021 Views

Deep-Learning-with-PyTorch

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

LibTorch: PyTorch in C++

471

...

model->push_back(torch::nn::Conv2d(

Spares us from torch::nn::Conv2dOptions(ngf * mult, ngf * mult * 2, 3)

reproducing some .stride(2)

tedious things

.padding(1))); An example of Options in action

...

register_module("model", model);

}

Tensor forward(const Tensor &inp) { return model->forward(inp); }

};

Creates a wrapper ResNetGenerator around our

TORCH_MODULE(ResNetGenerator); ResNetGeneratorImpl class. As archaic as it seems,

the matching names are important here.

That’s it—we’ve defined the perfect C++ analogue of the Python ResNetGenerator

model. Now we only need a main function to load parameters and run our model.

Loading the image with CImg and converting from image to tensor and tensor back

to image are the same as in the previous section. To include some variation, we’ll display

the image instead of writing it to disk.

Listing 15.16

cyclegan_cpp_api.cpp main

ResNetGenerator model; Instantiates our model

...

torch::load(model, argv[1]);

Loads the

parameters

...

cimg_library::CImg<float> image(argv[2]);

image.resize(400, 400);

auto input_ =

torch::tensor(torch::ArrayRef<float>(image.data(), image.size()));

auto input = input_.reshape({1, 3, image.height(), image.width()});

torch::NoGradGuard no_grad;

model->eval();

As in Python, eval mode is turned on (for our

model, it would not be strictly relevant).

auto output = model->forward(input);

Again, we call

...

forward rather

cimg_library::CImg<float> out_img(output.data_ptr<float>(),

than the model.

output.size(3), output.size(2),

1, output.size(1));

cimg_library::CImgDisplay disp(out_img, "See a C++ API zebra!");

while (!disp.is_closed()) {

disp.wait();

}

Declaring a guard variable

is the equivalent of the

torch.no_grad() context.

You can put it in a { … }

block if you need to limit

how long you turn off

gradients.

Displaying the image, we need to wait for a key

rather than immediately exiting our program.

The interesting changes are in how we create and run the model. Just as expected, we

instantiate the model by declaring a variable of the model type. We load the model

using torch::load (here it is important that we wrapped the model). While this looks

very familiar to PyTorch practitioners, note that it will work on JIT-saved files rather

than Python-serialized state dictionaries.

When running the model, we need the equivalent of with torch.no_grad():. This

is provided by instantiating a variable of type NoGradGuard and keeping it in scope for

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!