22.02.2024 Views

Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

Create successful ePaper yourself

Turn your PDF publications into a flip-book with our unique Google optimized e-Paper software.

Helper Function #7

1 def preprocessed_dataset(model, loader, device=None):

2 if device is None:

3 device = next(model.parameters()).device

4 features = None

5 labels = None

6

7 for i, (x, y) in enumerate(loader):

8 model.eval()

9 output = model(x.to(device))

10 if i == 0:

11 features = output.detach().cpu()

12 labels = y.cpu()

13 else:

14 features = torch.cat(

15 [features, output.detach().cpu()])

16 labels = torch.cat([labels, y.cpu()])

17

18 dataset = TensorDataset(features, labels)

19 return dataset

We can use it to pre-process our datasets:

Data Preparation (1)

1 train_preproc = preprocessed_dataset(alex, train_loader)

2 val_preproc = preprocessed_dataset(alex, val_loader)

There we go—we have TensorDatasets containing tensors for features generated

by AlexNet for each and every image, as well as for the corresponding labels.

IMPORTANT: This pre-processing step assumes no data

augmentation. If you want to perform data augmentation, you

will need to train the "top" of the model while it is still attached to

the rest of the model since the features produced by the frozen

layers will be slightly different every time due to the

augmentation itself.

We can also save these tensors to disk:

518 | Chapter 7: Transfer Learning

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!