20.03.2021 Views

Deep-Learning-with-PyTorch

Create successful ePaper yourself

Turn your PDF publications into a flip-book with our unique Google optimized e-Paper software.

Updating the dataset for segmentation

385

We’ll accomplish this by using a second model, similar to all the other subclasses of

nn.Module we’ve seen so far in this book. The main difference is that we’re not interested

in backpropagating gradients through the model, and the forward method will

be doing decidedly different things. There will be some slight modifications to the

actual augmentation routines since we’re working with 2D data for this chapter, but

otherwise, the augmentation will be very similar to what we saw in chapter 12. The

model will consume tensors and produce different tensors, just like the other models

we’ve implemented.

Our model’s __init__ takes the same data augmentation arguments—flip,

offset, and so on—that we used in the last chapter, and assigns them to self.

Listing 13.19

model.py:56, class SegmentationAugmentation

class SegmentationAugmentation(nn.Module):

def __init__(

self, flip=None, offset=None, scale=None, rotate=None, noise=None

):

super().__init__()

self.flip = flip

self.offset = offset

# ... line 64

Our augmentation forward method takes the input and the label, and calls out to

build the transform_t tensor that will then drive our affine_grid and grid_sample

calls. Those calls should feel very familiar from chapter 12.

Listing 13.20

model.py:68, SegmentationAugmentation.forward

def forward(self, input_g, label_g):

Note that we’re augmenting

transform_t = self._build2dTransformMatrix()

2D data.

transform_t = transform_t.expand(input_g.shape[0], -1, -1)

transform_t = transform_t.to(input_g.device, torch.float32)

affine_t = F.affine_grid(transform_t[:,:2],

input_g.size(), align_corners=False)

augmented_input_g = F.grid_sample(input_g,

affine_t, padding_mode='border',

align_corners=False)

augmented_label_g = F.grid_sample(label_g.to(torch.float32),

affine_t, padding_mode='border',

align_corners=False)

if self.noise:

noise_t = torch.randn_like(augmented_input_g)

noise_t *= self.noise

augmented_input_g += noise_t

return augmented_input_g, augmented_label_g > 0.5

The first dimension of the

transformation is the batch,

but we only want the first two

rows of the 3 × 3 matrices per

batch item.

We need the same transformation applied to CT and

mask, so we use the same grid. Because grid_sample

only works with floats, we convert here.

Just before returning, we convert the mask back to

Booleans by comparing to 0.5. The interpolation

that grid_sample results in fractional values.

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!