top of page
Writer's pictureMLV Prasad

MULTI-HEADED ATTENTION IN TRANSFORMERS FROM SCRATCH

Updated: Nov 14, 2023

The Multi-headed Attention Mechanism Used in the Latest LLM Models





Coding from Scratch in PyTorch.

 
"ATTENTION" IS ALL YOU NEED - LLM

import math
from typing import Optional, List
import torch
from torch 
import nn
from labml import tracker

Prepare for multi-head attention


This module does a linear transformation and splits the vector into a given number of heads for multi- head attention. This is used to transform key, query, and value vectors.

class PrepareForMultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):                             
        super().__init__()
        

Linear layer for linear transform

        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)

Linear layer for linear transformZ

        self.linear = nn.LinZear(d_model, heads * d_k, bias=bias)

Number of heads

        self.heads = heads

Number of dimensions in vectors in each head

        self.d_k = d_k
   def forward(self, x: torch.Tensor):

Input has shape [seq_len, batch_size, d_model] or [batch_size, d_model] .

We apply the linear transformation to the last dimension and split that into the heads.

        head_shape = x.shape[:-1]

Linear transform

        x = self.linear(x)

Split the last dimension into heads

        x = x.view(*head_shape, self.heads, self.d_k)

Output has shape [seq_len, batch_size, heads, d_k] or [batch_size, heads, d_model]

        return x

Multi-Head Attention Module

This computes scaled multi-headed attention for given query , key and value vectors.

Attention(Q,K,V) = softmax ​(Q K/ ​ sqrt ( Dk ) ) * V

In simple terms, it finds keys that matches the query and gets the values of those keys. It uses dot-product of query and key as the indicator of how matching they are.

Before taking the softmax the dot-products are scaled by sqrt ( Dk ).

This is done to avoid large dot-product values causing softmax to give very small gradients

when ( Dk ). is large. Softmax is calculated along the axis of of the sequence (or time).


class MultiHeadAttention(nn.Module):
  • heads is the number of heads.

  • d_model is the number of features in the query, key and value vectors.

    def __init__(self, heads: int, d_model: int, 
                             dropout_prob: float = 0.1, 
                             bias: bool = True):
        super().__init__()

Number of features per head

        self.d_k = d_model // heads

Number of heads

        self.heads = heads

These transform the query, key and value vectors for multi-headed attention.

        self.query = PrepareForMultiHeadAttention(d_model, heads, 
                                                  self.d_k, bias=bias)
        self.key   = PrepareForMultiHeadAttention(d_model, heads, 
                                                  self.d_k, bias=bias)
        self.value = PrepareForMultiHeadAttention(d_model, heads, 
                                                   self.d_k, bias=True)

Softmax for attention along the time dimension of key

        self.softmax = nn.Softmax(dim=1)

Output layer

        self.output = nn.Linear(d_model, d_model)

Dropout

        self.dropout = nn.Dropout(dropout_prob)

Scaling factor before the softmax

        self.scale = 1 / math.sqrt(self.d_k)

We store attentions so that it can be used for logging, or other computations if needed

        self.attn = None

Calculate scores between queries and keys This method can be overridden for other variations like relative attention.

   def get_scores(self, query: torch.Tensor, key: torch.Tensor):

Calculate QK⊤ or Sijbh ​= dQibhdKjbhd

        return torch.einsum('ibhd,jbhd->ijbh', query, key)

mask has shape [seq_len_q, seq_len_k, batch_size] , where first dimension is the

query dimension. If the query dimension is equal to 1 it will be broadcasted.

   def prepare_mask(self, mask: torch.Tensor,
                          query_shape: List[int],
                          key_shape: List[int]):
        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
        assert mask.shape[1] == key_shape[0]
        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

Same mask applied to all heads.

        mask = mask.unsqueeze(-1)

resulting mask has shape [seq_len_q, seq_len_k, batch_size, heads]

        return mask

query, key and value are the tensors that store collection of query, key and value vectors.

They have shape [seq_len, batch_size, d_model]. mask has shape [seq_len, seq_len, batch_size] and mask[i, j, b] indicates whether

for batch b, query at position i has access to key-value at position j .

   def forward(self, *,query: torch.Tensor, key: torch.Tensor,            
                       value: torch.Tensor, 
                       mask: Optional[torch.Tensor] = None):

query , key and value have shape [seq_len, batch_size, d_model]

        seq_len, batch_size, _ = query.shape164165      
        if mask is not None:
              mask = self.prepare_mask(mask, query.shape, key.shape)

Prepare query, key and value for attention computation. These will then have shape

[seq_len, b atch_size, heads, d_k] .

        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

Compute attention scores Q K.

This gives a tensor of shape [seq_len, seq_len, batch_size, heads] .

        scores = self.get_scores(query, key)

Scale scores Q K/​ sqrt ( Dk )

        scores *= self.scale

Apply mask

        if mask is not None:
               scores = scores.masked_fill(mask == 0, float('-inf'))

softmax attention along the key sequence dimension

softmax (Q K/ ​ sqrt ( Dk ) * V )

        attn = self.softmax(scores)

Save attentions if debugging

        tracker.debug('attn', attn)

Apply dropout

        attn = self.dropout(attn)

Multiply by values softmax (Q K/ ​ sqrt ( Dk ) * V )

        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

Save attentions for any other calculations

        self.attn = attn.detach()

Concatenate multiple heads

        x = x.reshape(seq_len, batch_size, -1)

Output layer

        return self.output(x)

92 views0 comments

Comments


bottom of page