22.02.2024 Views

Daniel Voigt Godoy - Deep Learning with PyTorch Step-by-Step A Beginner’s Guide-leanpub

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

16

17 @property

18 def alphas(self):

19 # Shape: n_heads, N, 1, L (source)

20 return torch.stack(

21 [attn.alphas for attn in self.attn_heads], dim=0

22 )

23

24 def output_function(self, contexts):

25 # N, 1, n_heads * D

26 concatenated = torch.cat(contexts, axis=-1)

27 out = self.linear_out(concatenated) # N, 1, D

28 return out

29

30 def forward(self, query, mask=None):

31 contexts = [attn(query, mask=mask)

32 for attn in self.attn_heads]

33 out = self.output_function(contexts)

34 return out

35

36 class Attention(nn.Module):

37 def __init__(self, hidden_dim,

38 input_dim=None, proj_values=False):

39 super().__init__()

40 self.d_k = hidden_dim

41 self.input_dim = hidden_dim if input_dim is None \

42 else input_dim

43 self.proj_values = proj_values

44 self.linear_query = nn.Linear(self.input_dim, hidden_dim)

45 self.linear_key = nn.Linear(self.input_dim, hidden_dim)

46 self.linear_value = nn.Linear(self.input_dim, hidden_dim)

47 self.alphas = None

48

49 def init_keys(self, keys):

50 self.keys = keys

51 self.proj_keys = self.linear_key(self.keys)

52 self.values = self.linear_value(self.keys) \

53 if self.proj_values else self.keys

54

55 def score_function(self, query):

56 proj_query = self.linear_query(query)

57 # scaled dot product

Putting It All Together | 791

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!