20.03.2021 Views

Deep-Learning-with-PyTorch

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

What does an ideal dataset look like?

339

consists of 10 equally weighted classes, and we want to instead have 1 class (say, “airplane”)

now make up 50% of all of the training images. We could decide to use

WeightedRandomSampler (http://mng.bz/8plK) and weight each of the “airplane”

sample indexes higher, but constructing the weights argument requires that we know

in advance which indexes are airplanes.

As we discussed, the Dataset API only specifies that subclasses provide __len__

and __getitem__, but there is nothing direct we can use to ask “Which samples are

airplanes?” We’d either have to load up every sample beforehand to inquire about the

class of that sample, or we’d have to break encapsulation and hope the information

we need is easily obtained from looking at the internal implementation of the Dataset

subclass.

Since neither of those options is particularly ideal in cases where we have control

over the dataset directly, the code for part 2 implements any needed data shaping

inside the Dataset subclasses instead of relying on an external sampler.

IMPLEMENTING CLASS BALANCING IN THE DATASET

We are going to directly change our LunaDataset to present a balanced, one-to-one

ratio of positive and negative samples for training. We will keep separate lists of negative

training samples and positive training samples, and alternate returning samples

from each of those two lists. This will prevent the degenerate behavior of the model

scoring well by simply answering “false” to every sample presented. In addition, the

positive and negative classes will be intermixed so that the weight updates are forced

to discriminate between the classes.

Let’s add a ratio_int to LunaDataset that will control the label for the Nth sample

as well as keep track of our samples separated by label.

Listing 12.6

dsets.py:217, class LunaDataset

class LunaDataset(Dataset):

def __init__(self,

val_stride=0,

isValSet_bool=None,

ratio_int=0,

):

self.ratio_int = ratio_int

# ... line 228

self.negative_list = [

nt for nt in self.candidateInfo_list if not nt.isNodule_bool

]

self.pos_list = [

nt for nt in self.candidateInfo_list if nt.isNodule_bool

]

# ... line 265

def shuffleSamples(self):

if self.ratio_int:

random.shuffle(self.negative_list)

random.shuffle(self.pos_list)

We will call this at the top of each

epoch to randomize the order of

samples being presented.

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!