20.03.2021 Views

Deep-Learning-with-PyTorch

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

368 CHAPTER 13 Using segmentation to find suspected nodules

strongly. The fact that samples in batches are picked randomly at every epoch will

minimize the chances of a dull sample ending up in an all-dull batch, and hence those

dull samples getting overemphasized.

Second, since the output values are unconstrained, we are going to pass the output

through an nn.Sigmoid layer to restrict the output to the range [0, 1]. Third, we will

reduce the total depth and number of filters we allow our model to use. While this is

jumping ahead of ourselves a bit, the capacity of the model using the standard parameters

far outstrips our dataset size. This means we’re unlikely to find a pretrained model

that matches our exact needs. Finally, although this is not a modification, it’s important

to note that our output is a single channel, with each pixel of output representing the

model’s estimate of the probability that the pixel in question is part of a nodule.

This wrapping of U-Net can be done rather simply by implementing a model with

three attributes: one each for the two features we want to add, and one for the U-Net

itself—which we can treat just like any prebuilt module here. We will also pass any keyword

arguments we receive into the U-Net constructor.

Listing 13.1

model.py:17, class UNetWrapper

kwarg is a dictionary containing all keyword

arguments passed to the constructor.

The U-Net:

a small thing

to include

here, but it’s

really doing

all the work.

class UNetWrapper(nn.Module):

def __init__(self, **kwargs):

super().__init__()

BatchNorm2d wants us to

specify the number of input

channels, which we take from

the keyword argument.

self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels'])

self.unet = UNet(**kwargs)

self.final = nn.Sigmoid() Just as for the classifier in chapter 11, we use

our custom weight initialization. The function is

self._init_weights()

copied over, so we will not show the code again.

The forward method is a similarly straightforward sequence. We could use an

instance of nn.Sequential as we saw in chapter 8, but we’ll be explicit here for both

clarity of code and clarity of stack traces. 5

Listing 13.2

model.py:50, UNetWrapper.forward

def forward(self, input_batch):

bn_output = self.input_batchnorm(input_batch)

un_output = self.unet(bn_output)

fn_output = self.final(un_output)

return fn_output

Note that we’re using nn.BatchNorm2d here. This is because U-Net is fundamentally a

two-dimensional segmentation model. We could adapt the implementation to use 3D

5

In the unlikely event our code throws any exceptions—which it clearly won’t, will it?

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!