20.03.2021 Views

Deep-Learning-with-PyTorch

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

298 CHAPTER 11 Training a classification model to detect suspected tumors

Listing 11.11

training.py:225, .computeBatchLoss

def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):

input_t, label_t, _series_list, _center_list = batch_tup

input_g = input_t.to(self.device, non_blocking=True)

label_g = label_t.to(self.device, non_blocking=True)

logits_g, probability_g = self.model(input_g)

loss_func = nn.CrossEntropyLoss(reduction='none')

loss_g = loss_func(

logits_g,

Index of the onehot-encoded

class

label_g[:,1],

)

# ... line 238

Recombines the loss per

return loss_g.mean() sample into a single value

reduction=‘none’ gives

the loss per sample.

Here we are not using the default behavior to get a loss value averaged over the batch.

Instead, we get a tensor of loss values, one per sample. This lets us track the individual

losses, which means we can aggregate them as we wish (per class, for example). We’ll

see that in action in just a moment. For now, we’ll return the mean of those per-sample

losses, which is equivalent to the batch loss. In situations where you don’t want to keep

statistics per sample, using the loss averaged over the batch is perfectly fine. Whether

that’s the case is highly dependent on your project and goals.

Once that’s done, we’ve fulfilled our obligations to the calling function in terms of

what’s required to do backpropagation and weight updates. Before we do that, however,

we also want to record our per-sample stats for posterity (and later analysis).

We’ll use the metrics_g parameter passed in to accomplish this.

Listing 11.12

training.py:26

METRICS_LABEL_NDX=0

METRICS_PRED_NDX=1

METRICS_LOSS_NDX=2

METRICS_SIZE = 3

These named array indexes are

declared at module-level scope.

# ... line 225

def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):

# ... line 238

start_ndx = batch_ndx * batch_size

end_ndx = start_ndx + label_t.size(0)

metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = \

label_g[:,1].detach()

metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = \

probability_g[:,1].detach()

metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = \

loss_g.detach()

Again, this is the loss

return loss_g.mean() over the entire batch.

We use detach since

none of our metrics

need to hold on to

gradients.

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!