from torch import nn
import torch.nn.functional as F
import torch
import math
class SelfAttention(nn.Module):
def __init__(self, hidden_dim, dropout_rate=0.0):
super().__init__()
self.hidden_dim = hidden_dim
self.q = nn.Linear(hidden_dim, hidden_dim)
self.k = nn.Linear(hidden_dim, hidden_dim)
self.v = nn.Linear(hidden_dim, hidden_dim)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x, mask=None):
# the shape of x, [bs, seq_len, hidden_dim]
bs, seq_len, hidden_dim = x.shape
q = self.q(x) # bs, seq_len, hidden_dim
k = self.k(x)
v = self.v(x)
score = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(hidden_dim)
if mask is not None:
score = torch.masked_fill(score, mask, value=-1e9)
score = F.softmax(score, dim=-1) # bs, seq_len, seq_len
score = self.dropout(score)
output = torch.matmul(score,v)
return output
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_dim, head_num):
super().__init__()
self.hidden_dim = hidden_dim
self.head_num = head_num
self.d_k = hidden_dim // head_num
self.q = nn.Linear(hidden_dim, hidden_dim)
self.k = nn.Linear(hidden_dim, hidden_dim)
self.v = nn.Linear(hidden_dim, hidden_dim)
self.o = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x, mask=None):
# the shape of x, [bs, seq_len, hidden_dim]
bs, seq_len, hidden_dim = x.shape
h, d_k = self.head_num, self.d_k
q = self.q(x).view(bs, seq_len, h, d_k) # bs, seq_len, hidden_dim -> bs, seq_len, head_num, d_k
k = self.k(x).view(bs, seq_len, h, d_k)
v = self.v(x).view(bs, seq_len, h, d_k)
q = q.transpose(1,2) # bs, head_num,seq_len , d_k
k = k.transpose(1,2)
v = v.transpose(1,2)
score = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(d_k) # bs, head_num,seq_len , seq_len
if mask is not None:
score = score.masked_fill(mask, -1e20)
score = F.softmax(score, dim=-1)
# score = self.dropout(score)
out = torch.matmul(score, v) # bs, head_num,seq_len , d_k
out = out.transpose(1,2).contiguous()
out = out.view(bs, seq_len, -1)
out = self.o(out)
return out
if __name__ == '__main__':
bs, seq_len, head_num, hidden_dim = 3, 4, 8, 128
x = torch.rand(bs, seq_len, hidden_dim)
mask = torch.tril(torch.ones(seq_len, seq_len))
mha = MultiHeadAttention(hidden_dim, head_num)
output = mha(x, mask=mask)
print(output)
手撕selfAttention和MHA
32 views