手撕selfAttention和MHA


from torch import nn
import torch.nn.functional as F
import torch
import math


class SelfAttention(nn.Module):
    def __init__(self, hidden_dim, dropout_rate=0.0):
        super().__init__()
        self.hidden_dim = hidden_dim

        self.q = nn.Linear(hidden_dim, hidden_dim)
        self.k = nn.Linear(hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout_rate)


    def forward(self, x, mask=None):
        # the shape of x, [bs, seq_len, hidden_dim]
        bs, seq_len, hidden_dim = x.shape
        q = self.q(x)   # bs, seq_len, hidden_dim
        k = self.k(x)
        v = self.v(x)

        score = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(hidden_dim)
        if mask is not None:
            score = torch.masked_fill(score, mask, value=-1e9)
        score = F.softmax(score, dim=-1)   # bs, seq_len, seq_len
        score = self.dropout(score)
        output = torch.matmul(score,v)
        return output


class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, head_num):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.d_k = hidden_dim // head_num
        self.q = nn.Linear(hidden_dim, hidden_dim)
        self.k = nn.Linear(hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, hidden_dim)

        self.o = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x, mask=None):
        # the shape of x, [bs, seq_len, hidden_dim]
        bs, seq_len, hidden_dim = x.shape

        h, d_k = self.head_num, self.d_k
        q = self.q(x).view(bs, seq_len, h, d_k)  # bs, seq_len, hidden_dim -> bs, seq_len, head_num, d_k
        k = self.k(x).view(bs, seq_len, h, d_k)
        v = self.v(x).view(bs, seq_len, h, d_k)

        q = q.transpose(1,2)   # bs, head_num,seq_len , d_k
        k = k.transpose(1,2)
        v = v.transpose(1,2)

        score = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(d_k) # bs, head_num,seq_len , seq_len

        if mask is not None:
            score = score.masked_fill(mask, -1e20)

        score = F.softmax(score, dim=-1)
        # score = self.dropout(score)
        out = torch.matmul(score, v)  # bs, head_num,seq_len , d_k

        out = out.transpose(1,2).contiguous()
        out = out.view(bs, seq_len, -1)
        out = self.o(out)

        return out
if __name__ == '__main__':
    bs, seq_len, head_num, hidden_dim = 3, 4, 8, 128
    x = torch.rand(bs, seq_len, hidden_dim)

    mask = torch.tril(torch.ones(seq_len, seq_len))
    mha = MultiHeadAttention(hidden_dim, head_num)
    output = mha(x, mask=mask)
    print(output)