22.02.2024 Views

Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

22 self.decoder.init_keys(encoder_states)

23

24 def decode(self, shifted_target_seq,

25 source_mask=None, target_mask=None):

26 # Decodes / generates a sequence using the shifted

27 # (masked) target sequence - used in TRAIN mode

28 outputs = self.decoder(shifted_target_seq,

29 source_mask=source_mask,

30 target_mask=target_mask)

31 return outputs

32

33 def predict(self, source_seq, source_mask):

34 # Decodes / generates a sequence using one input

35 # at a time - used in EVAL mode

36 inputs = source_seq[:, -1:]

37 for i in range(self.target_len):

38 out = self.decode(inputs,

39 source_mask,

40 self.trg_masks[:, :i+1, :i+1])

41 out = torch.cat([inputs, out[:, -1:, :]], dim=-2)

42 inputs = out.detach()

43 outputs = inputs[:, 1:, :]

44 return outputs

45

46 def forward(self, X, source_mask=None):

47 # Sends the mask to the same device as the inputs

48 self.trg_masks = self.trg_masks.type_as(X).bool()

49 # Slices the input to get source sequence

50 source_seq = X[:, :self.input_len, :]

51 # Encodes source sequence AND initializes decoder

52 self.encode(source_seq, source_mask)

53 if self.training:

54 # Slices the input to get the shifted target seq

55 shifted_target_seq = X[:, self.input_len-1:-1, :]

56 # Decodes using the mask to prevent cheating

57 outputs = self.decode(shifted_target_seq,

58 source_mask,

59 self.trg_masks)

60 else:

61 # Decodes using its own predictions

62 outputs = self.predict(source_seq, source_mask)

63

760 | Chapter 9 — Part II: Sequence-to-Sequence

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!