22.02.2024 Views

Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

train_loader = DataLoader(

dataset=train_dataset, batch_size=16, sampler=sampler

)

val_loader = DataLoader(dataset=val_dataset, batch_size=16)

Once again, if we’re using a sampler, we cannot use the shuffle argument.

There is a lot of boilerplate code here, right? Let’s build yet another function,

Helper Function #5, to wrap it all up:

Helper Function #5

1 def make_balanced_sampler(y):

2 # Computes weights for compensating imbalanced classes

3 classes, counts = y.unique(return_counts=True)

4 weights = 1.0 / counts.float()

5 sample_weights = weights[y.squeeze().long()]

6 # Builds sampler with compute weights

7 generator = torch.Generator()

8 sampler = WeightedRandomSampler(

9 weights=sample_weights,

10 num_samples=len(sample_weights),

11 generator=generator,

12 replacement=True

13 )

14 return sampler

sampler = make_balanced_sampler(y_train_tensor)

Much better! Its only argument is the tensor containing the labels: The function

will compute the weights and build the corresponding weighted sampler on its own.

Seeds and more (seeds)

Time to set the seed for the generator used in the sampler assigned to the data

loader. It is a long sequence of objects, but we can work our way through it to

retrieve the generator and call its manual_seed() method:

Data Preparation | 293

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!