Transformer-Based Models - ReMA (RU)¶

Tutorial 4¶

Last update: 2024/12/04¶

Aditya Parikh (aditya.parikh@ru.nl)¶


In this tutorial, we will enhance our bigram language model by incorporating key elements of the Transformer architecture. Our goal is to introduce advanced features, such as the attention mechanism, which are essential to modern language models. This approach aims to reduce loss and improve the model's overall performance.

I have outlined the bigram model (which we explored in our previous tutorial). If you run the code, you will observe the loss at each stage of training, and it will eventually generate some strings. You will also notice a significant improvement compared to our previous output, which was entirely random.

In [ ]:
# Let's start with importing our Tiny-Shakespeare dataset

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
--2024-12-05 09:03:37--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’

input.txt           100%[===================>]   1.06M  --.-KB/s    in 0.09s   

2024-12-05 09:03:37 (11.9 MB/s) - ‘input.txt’ saved [1115394/1115394]

In [ ]:
import torch
import torch.nn as nn
from torch.nn import functional as F

#hyperparameters
batch_size = 32 #independent sequences process in parallel
block_size = 8 #context length
max_iters = 10000
eval_interval = 1000
learning_rate = 1e-2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32 #embedding size

torch.manual_seed(1337)

#read the data
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

#Extract all the unique characters occur in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)

#creating mapping from characters to integers
string_to_int = {ch:i for i,ch in enumerate(chars)}
int_to_string = {i:ch for i,ch in enumerate(chars)}

encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ''.join([int_to_string[i] for i in l])

#train-test split
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


model = BigramLanguageModel(vocab_size)
m = model.to(device)

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=800)[0].tolist()))
step 0: train loss 4.7305, val loss 4.7241
step 1000: train loss 2.4879, val loss 2.5192
step 2000: train loss 2.4714, val loss 2.4913
step 3000: train loss 2.4610, val loss 2.4869
step 4000: train loss 2.4591, val loss 2.4832
step 5000: train loss 2.4576, val loss 2.4886
step 6000: train loss 2.4528, val loss 2.4922
step 7000: train loss 2.4612, val loss 2.4893
step 8000: train loss 2.4538, val loss 2.4863
step 9000: train loss 2.4601, val loss 2.4952

Aneandchatheth d 'd sotho st elesiritheed ag y oestt utould:
I:

OR:
An yotofoff d
Fodoone ho ENROUK:
OFis aprgr als oreave, yonedilacaia f tllmald rey l ouevert CURord beamaroullaiowonge, s.
My y jomes by, e w hed ay t oull,
He d mewhe sty.
MEREDouer f teay scke ws Lo d matar w beves. ongen
AUK: w a nd wesor cthenousclllac Oumes l hyory wemintha asofothen thanot!

Bat mm t ty, Is, y,
Cou F s h wrt,
Evere ty onen rppe'l totoche or cis aldst nwhoofolobl
NCKENERK:
rir che opr,
QUS tofape win wnt ap
Whisthake m;
preave, hano sprenday atarce,
FLo blon izetlas IFixce'd w.
We f irchisipe.

IOne,

INAndrel love!
Fou olll?
TEThime su was hanghe ld.
Mel t;
An wh e u!
Howithan mfaind gequ ind ot stouss, apours?
Atony yo ls imy se kigausayes t meedo, lit'mor heloutegutls n, meisof: viner Waviner, by 

transformer.png

self2.png

Now first we will start with the "Head" module. In the Head module you will perform a small task based on above equations. You can also take a look at our previous tutorial (Tutorial 2: where we implemented attention block).

In [ ]:
class Head(nn.Module):
  """ single head of self-attention """
  def __init__(self, head_size):
    super().__init__()
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)
    # These are the linear peojections we are going to apply to our nodes
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    # Creating a Tril variable, that is not a parameter of module, in pytorch naming conventions it is called a buffer, so assign it to the module using a register buffer (Our lower tringular matrix)

  def forward(self,x):
    B,T,C = x.shape
    #This is your attention block

    # WRITE YOUR CODE BELOW.
    return out

Now make the changes in our Bigram language model class and add positional embeddings with token embeddings. Please check the transformer architecture. Here you will have to look at which operation is happening between positional embedding and token embeddings. Accordingly, you will perform a small task in below forward function.

In [ ]:
n_embd = 32

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_head = Head(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        #We will first encode our information here with token embeddings and positional embeddings
        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.sa_head(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            #Crop idx to the last block_size tokens
            # We need to make sure that our idx we feed into the model should not exceed the block_size. if that happens then our positional embeddings will run out of scope as it has embeddings up to block size.

            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

Now assemble the complete code and try to run it again. It should reduce the loss to some extent.

Implementing multi-head attention module¶

Once we implemented scale dot product attention, we now implement multi-head attention. In simple words, multi-head attention is just applying multiple attentions in parallel and concatenating their results.

We will just create a new class for it.

In [ ]:
class MultiHeadAttention(nn.Module):
  """Multiple attention heads in parallel"""
  # num_heads = how many heads you want
  # head_size = what is the head size of each head

  def __init__(self, num_heads, head_size):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) #run them in parallel into a list

  def forward(self, x):
    out = torch.cat([h(x) for h in self.heads], dim=-1) #Simply concat all of the outputs
    return out

Now let's implement multi-headed attention in our previous work.

In [ ]:
n_embd = 32

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_head = MultiHeadAttention(4, n_embd//4) #4 heads of 8 dimention
        # Instead of 1 communication channels we have now 4 (4 heads)
        # With 4 communicaiton channels we need 8 dimentional self attention
        # when they will concat, it will give us 32 dimentional vector
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        #We will first encode our information here with token embeddings and positional embeddings
        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        #YOUR CODE HERE (Just one line - Hint - adding token embedding with positional embedding)

        x = self.sa_head(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            #Crop idx to the last block_size tokens
            # We need to make sure that our idx we feed into the model should not exceed the block_size. if that happens then our positional embeddings will run out of scope as it has embeddings up to block size.

            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

Now again assemble the code with multi-headed attention, and run the code again. It should reduce the loss further. Can you try it?

Now, if you look at the transformer architecture again, in the decoder side there is one element left. That is feed-forward network. If you see the paper, it is just a simple multi-layer perceptron (MLP). So we can simply add the MLP computation at a per node level.

feedforward_layer.png

In [ ]:
class FeedForward(nn.Module):
  """ A simple linear layer followed by a ReLU non-linearity """
  def __init__(self,n_embd):
    super().__init__()
    self.net = nn.Sequential(nn.Linear(n_embd,n_embd),nn.ReLU())

  def forward(self,x):
    return self.net(x)

Shall we again add this layer to our original transformer block? Let's do this.

In [ ]:
n_embd = 32

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_head = MultiHeadAttention(4, n_embd//4) #4 heads of 8 dimention
        # Instead of 1 communication channels we have now 4 (4 heads)
        # With 4 communicaiton channels we need 8 dimentional self attention
        # when they will concat, it will give us 32 dimentional vector
        self.ffwd = FeedForward(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        #We will first encode our information here with token embeddings and positional embeddings
        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.sa_head(x)
        x = self.ffwd(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            #Crop idx to the last block_size tokens
            # We need to make sure that our idx we feed into the model should not exceed the block_size. if that happens then our positional embeddings will run out of scope as it has embeddings up to block size.

            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

Finally, we need to create multiple replicas of the Transformer blocks and address one more important step: the residual (or skip) connection.

In neural networks, especially deep ones with many layers, information can become distorted or lost as it moves through each layer, similar to the misalignments in a tall tower. To mitigate this, we use residual connections (also known as skip connections). These connections allow the original input to bypass certain layers and be added directly to the output of those layers.

Mathematically, if the input to a layer is 𝑥 and the output after processing is 𝐹(𝑥), a residual connection adds the original input 𝑥 to the output 𝐹(𝑥), resulting in 𝐹(𝑥) + 𝑥. This ensures that even as data is transformed across layers, the original input 𝑥 is preserved and passed along.

Lastly, we need to consider layer normalization. Fortunately, PyTorch provides a built-in implementation of layer normalization, which we can use directly.

In [ ]:
class Block(nn.Module):
  def __init__(self, n_embd,n_head):
    super().__init__()
    head_size = n_embd // n_head
    self.sa = MultiHeadAttention(n_head, head_size)
    self.ffwd = FeedForward(n_embd)
    self.ln1 = nn.LayerNorm(n_embd)
    self.ln2 = nn.LayerNorm(n_embd)

  def forward(self,x):
    x = x + self.sa(self.ln1(x))  #without residual connection (x = self.sa(x))
    x = x + self.ffwd(self.ln2(x)) #without residual connection (x=self.ffwd(x))
    return x

Now let's add the multiple transformer blocks in our original model.

In [ ]:
n_embd = 32

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        #self.sa_head = MultiHeadAttention(4, n_embd//4) #4 heads of 8 dimention
        # Instead of 1 communication channels we have now 4 (4 heads)
        # With 4 communicaiton channels we need 8 dimentional self attention
        # when they will concat, it will give us 32 dimentional vector
        #self.ffwd = FeedForward(n_embd)
        self.blocks = nn.Sequential(
            Block(n_embd,n_head=4),
            Block(n_embd,n_head=4),
            Block(n_embd,n_head=4))
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        #We will first encode our information here with token embeddings and positional embeddings
        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.sa_head(x)
        x = self.ffwd(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            #Crop idx to the last block_size tokens
            # We need to make sure that our idx we feed into the model should not exceed the block_size. if that happens then our positional embeddings will run out of scope as it has embeddings up to block size.

            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

That’s it! Now that we’ve covered all the components of the Transformer architecture, assembling the modules and running the model again should be straightforward. You’ll likely notice a significant improvement in the model's performance, with the loss potentially decreasing from around 2.5 to 2.0. Additionally, when generating text, the output should show noticeable improvements—perhaps even forming coherent English words.

In the next tutorial, we’ll focus on running the fully assembled Transformer blocks and comparing their performance with existing Transformer architectures available on Hugging Face. This comparison will help us identify similarities and gain a better understanding of their functionality.

In [ ]: