20.03.2021 Views

Deep-Learning-with-PyTorch

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

Interacting with the PyTorch JIT

459

Let’s use a quick example to give you a taste of why looking at several operations at

once can be beneficial. When PyTorch runs a sequence of operations on the GPU, it

calls a subprogram (kernel, in CUDA parlance) for each of them. Every kernel reads

the input from GPU memory, computes the result, and then stores the result. Thus

most of the time is typically spent not computing things, but reading from and writing

to memory. This can be improved on by reading only once, computing several operations,

and then writing at the very end. This is precisely what the PyTorch JIT fuser

does. To give you an idea of how this works, figure 15.3 shows the pointwise computation

taking place in long short-term memory (LSTM; https://en.wikipedia.org/wiki/

Long_short-term_memory) cell, a popular building block for recurrent networks.

The details of figure 15.3 are not important to us here, but there are 5 inputs at

the top, 2 outputs at the bottom, and 7 intermediate results represented as rounded

indices. By computing all of this in one go in a single CUDA function and keeping the

intermediates in registers, the JIT reduces the number of memory reads from 12 to 5

and the number of writes from 9 to 2. These are the large gains the JIT gets us; it can

reduce the time to train an LSTM network by a factor of four. This seemingly simple

ingate ceLlgate forgetgate

cx

outgate

sigmoid

tanh

sigmoid

sigmoid

ingate

ceLlgate

forgetgate

outgate

tanh

cx

hx

Figure 15.3 LSTM cell pointwise operations. From five inputs at the top, this block computes

two outputs at the bottom. The boxes in between are intermediate results that vanilla PyTorch

will store in memory but the JIT fuser will just keep in registers.

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!