22.02.2024 Views

Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

Create successful ePaper yourself

Turn your PDF publications into a flip-book with our unique Google optimized e-Paper software.

model anyway. We’ll use it to compute statistics only. By the way, we need

statistics for each channel, as required by the Normalize() transform.

So, let’s build a function that takes a mini-batch (images and labels) and computes

the mean pixel value and standard deviation per channel of each image, adding up

the results for all images. Better yet, let’s make it a method of our StepByStep class

too.

StepByStep Method

@staticmethod

def statistics_per_channel(images, labels):

# NCHW

n_samples, n_channels, n_height, n_weight = images.size()

# Flatten HW into a single dimension

flatten_per_channel = images.reshape(n_samples, n_channels, -1)

# Computes statistics of each image per channel

# Average pixel value per channel

# (n_samples, n_channels)

means = flatten_per_channel.mean(axis=2)

# Standard deviation of pixel values per channel

# (n_samples, n_channels)

stds = flatten_per_channel.std(axis=2)

# Adds up statistics of all images in a mini-batch

# (1, n_channels)

sum_means = means.sum(axis=0)

sum_stds = stds.sum(axis=0)

# Makes a tensor of shape (1, n_channels)

# with the number of samples in the mini-batch

n_samples = torch.tensor([n_samples]*n_channels).float()

# Stack the three tensors on top of one another

# (3, n_channels)

return torch.stack([n_samples, sum_means, sum_stds], axis=0)

setattr(StepByStep, 'statistics_per_channel',

statistics_per_channel)

Data Preparation | 421

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!