Transformer-Based Models - ReMA (RU)¶
In this tutorial, we will enhance our bigram language model by incorporating key elements of the Transformer architecture. Our goal is to introduce advanced features, such as the attention mechanism, which are essential to modern language models. This approach aims to reduce loss and improve the model's overall performance.
I have outlined the bigram model (which we explored in our previous tutorial). If you run the code, you will observe the loss at each stage of training, and it will eventually generate some strings. You will also notice a significant improvement compared to our previous output, which was entirely random.
# Let's start with importing our Tiny-Shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
--2024-12-05 09:03:37-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 1115394 (1.1M) [text/plain] Saving to: ‘input.txt’ input.txt 100%[===================>] 1.06M --.-KB/s in 0.09s 2024-12-05 09:03:37 (11.9 MB/s) - ‘input.txt’ saved [1115394/1115394]
import torch
import torch.nn as nn
from torch.nn import functional as F
#hyperparameters
batch_size = 32 #independent sequences process in parallel
block_size = 8 #context length
max_iters = 10000
eval_interval = 1000
learning_rate = 1e-2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32 #embedding size
torch.manual_seed(1337)
#read the data
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
#Extract all the unique characters occur in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
#creating mapping from characters to integers
string_to_int = {ch:i for i,ch in enumerate(chars)}
int_to_string = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ''.join([int_to_string[i] for i in l])
#train-test split
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]
# data loading
def get_batch(split):
# generate a small batch of data of inputs x and targets y
data = train_data if split == 'train' else val_data
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
x, y = x.to(device), y.to(device)
return x, y
@torch.no_grad()
def estimate_loss():
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
class BigramLanguageModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
# each token directly reads off the logits for the next token from a lookup table
self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
def forward(self, idx, targets=None):
# idx and targets are both (B,T) tensor of integers
logits = self.token_embedding_table(idx) # (B,T,C)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# get the predictions
logits, loss = self(idx)
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
model = BigramLanguageModel(vocab_size)
m = model.to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
for iter in range(max_iters):
# every once in a while evaluate the loss on train and val sets
if iter % eval_interval == 0:
losses = estimate_loss()
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
# sample a batch of data
xb, yb = get_batch('train')
# evaluate the loss
logits, loss = model(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=800)[0].tolist()))
step 0: train loss 4.7305, val loss 4.7241 step 1000: train loss 2.4879, val loss 2.5192 step 2000: train loss 2.4714, val loss 2.4913 step 3000: train loss 2.4610, val loss 2.4869 step 4000: train loss 2.4591, val loss 2.4832 step 5000: train loss 2.4576, val loss 2.4886 step 6000: train loss 2.4528, val loss 2.4922 step 7000: train loss 2.4612, val loss 2.4893 step 8000: train loss 2.4538, val loss 2.4863 step 9000: train loss 2.4601, val loss 2.4952 Aneandchatheth d 'd sotho st elesiritheed ag y oestt utould: I: OR: An yotofoff d Fodoone ho ENROUK: OFis aprgr als oreave, yonedilacaia f tllmald rey l ouevert CURord beamaroullaiowonge, s. My y jomes by, e w hed ay t oull, He d mewhe sty. MEREDouer f teay scke ws Lo d matar w beves. ongen AUK: w a nd wesor cthenousclllac Oumes l hyory wemintha asofothen thanot! Bat mm t ty, Is, y, Cou F s h wrt, Evere ty onen rppe'l totoche or cis aldst nwhoofolobl NCKENERK: rir che opr, QUS tofape win wnt ap Whisthake m; preave, hano sprenday atarce, FLo blon izetlas IFixce'd w. We f irchisipe. IOne, INAndrel love! Fou olll? TEThime su was hanghe ld. Mel t; An wh e u! Howithan mfaind gequ ind ot stouss, apours? Atony yo ls imy se kigausayes t meedo, lit'mor heloutegutls n, meisof: viner Waviner, by
Now first we will start with the "Head" module. In the Head module you will perform a small task based on above equations. You can also take a look at our previous tutorial (Tutorial 2: where we implemented attention block).
class Head(nn.Module):
""" single head of self-attention """
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
# These are the linear peojections we are going to apply to our nodes
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
# Creating a Tril variable, that is not a parameter of module, in pytorch naming conventions it is called a buffer, so assign it to the module using a register buffer (Our lower tringular matrix)
def forward(self,x):
B,T,C = x.shape
#This is your attention block
# WRITE YOUR CODE BELOW.
return out
Now make the changes in our Bigram language model class and add positional embeddings with token embeddings. Please check the transformer architecture. Here you will have to look at which operation is happening between positional embedding and token embeddings. Accordingly, you will perform a small task in below forward function.
n_embd = 32
class BigramLanguageModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
# each token directly reads off the logits for the next token from a lookup table
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.sa_head = Head(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
#We will first encode our information here with token embeddings and positional embeddings
# idx and targets are both (B,T) tensor of integers
tok_emb = self.token_embedding_table(idx)
pos_emb = self.position_embedding_table(torch.arange(T, device=device))
x = tok_emb + pos_emb
x = self.sa_head(x)
logits = self.lm_head(x)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
#Crop idx to the last block_size tokens
# We need to make sure that our idx we feed into the model should not exceed the block_size. if that happens then our positional embeddings will run out of scope as it has embeddings up to block size.
idx_cond = idx[:, -block_size:]
# get the predictions
logits, loss = self(idx_cond)
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
Now assemble the complete code and try to run it again. It should reduce the loss to some extent.
Implementing multi-head attention module¶
Once we implemented scale dot product attention, we now implement multi-head attention. In simple words, multi-head attention is just applying multiple attentions in parallel and concatenating their results.
We will just create a new class for it.
class MultiHeadAttention(nn.Module):
"""Multiple attention heads in parallel"""
# num_heads = how many heads you want
# head_size = what is the head size of each head
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) #run them in parallel into a list
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1) #Simply concat all of the outputs
return out
Now let's implement multi-headed attention in our previous work.
n_embd = 32
class BigramLanguageModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
# each token directly reads off the logits for the next token from a lookup table
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.sa_head = MultiHeadAttention(4, n_embd//4) #4 heads of 8 dimention
# Instead of 1 communication channels we have now 4 (4 heads)
# With 4 communicaiton channels we need 8 dimentional self attention
# when they will concat, it will give us 32 dimentional vector
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
#We will first encode our information here with token embeddings and positional embeddings
# idx and targets are both (B,T) tensor of integers
tok_emb = self.token_embedding_table(idx)
pos_emb = self.position_embedding_table(torch.arange(T, device=device))
#YOUR CODE HERE (Just one line - Hint - adding token embedding with positional embedding)
x = self.sa_head(x)
logits = self.lm_head(x)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
#Crop idx to the last block_size tokens
# We need to make sure that our idx we feed into the model should not exceed the block_size. if that happens then our positional embeddings will run out of scope as it has embeddings up to block size.
idx_cond = idx[:, -block_size:]
# get the predictions
logits, loss = self(idx_cond)
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
Now again assemble the code with multi-headed attention, and run the code again. It should reduce the loss further. Can you try it?
Now, if you look at the transformer architecture again, in the decoder side there is one element left. That is feed-forward network. If you see the paper, it is just a simple multi-layer perceptron (MLP). So we can simply add the MLP computation at a per node level.
class FeedForward(nn.Module):
""" A simple linear layer followed by a ReLU non-linearity """
def __init__(self,n_embd):
super().__init__()
self.net = nn.Sequential(nn.Linear(n_embd,n_embd),nn.ReLU())
def forward(self,x):
return self.net(x)
Shall we again add this layer to our original transformer block? Let's do this.
n_embd = 32
class BigramLanguageModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
# each token directly reads off the logits for the next token from a lookup table
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.sa_head = MultiHeadAttention(4, n_embd//4) #4 heads of 8 dimention
# Instead of 1 communication channels we have now 4 (4 heads)
# With 4 communicaiton channels we need 8 dimentional self attention
# when they will concat, it will give us 32 dimentional vector
self.ffwd = FeedForward(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
#We will first encode our information here with token embeddings and positional embeddings
# idx and targets are both (B,T) tensor of integers
tok_emb = self.token_embedding_table(idx)
pos_emb = self.position_embedding_table(torch.arange(T, device=device))
x = tok_emb + pos_emb
x = self.sa_head(x)
x = self.ffwd(x)
logits = self.lm_head(x)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
#Crop idx to the last block_size tokens
# We need to make sure that our idx we feed into the model should not exceed the block_size. if that happens then our positional embeddings will run out of scope as it has embeddings up to block size.
idx_cond = idx[:, -block_size:]
# get the predictions
logits, loss = self(idx_cond)
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
Finally, we need to create multiple replicas of the Transformer blocks and address one more important step: the residual (or skip) connection.
In neural networks, especially deep ones with many layers, information can become distorted or lost as it moves through each layer, similar to the misalignments in a tall tower. To mitigate this, we use residual connections (also known as skip connections). These connections allow the original input to bypass certain layers and be added directly to the output of those layers.
Mathematically, if the input to a layer is 𝑥 and the output after processing is 𝐹(𝑥), a residual connection adds the original input 𝑥 to the output 𝐹(𝑥), resulting in 𝐹(𝑥) + 𝑥. This ensures that even as data is transformed across layers, the original input 𝑥 is preserved and passed along.
Lastly, we need to consider layer normalization. Fortunately, PyTorch provides a built-in implementation of layer normalization, which we can use directly.
class Block(nn.Module):
def __init__(self, n_embd,n_head):
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size)
self.ffwd = FeedForward(n_embd)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self,x):
x = x + self.sa(self.ln1(x)) #without residual connection (x = self.sa(x))
x = x + self.ffwd(self.ln2(x)) #without residual connection (x=self.ffwd(x))
return x
Now let's add the multiple transformer blocks in our original model.
n_embd = 32
class BigramLanguageModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
# each token directly reads off the logits for the next token from a lookup table
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
#self.sa_head = MultiHeadAttention(4, n_embd//4) #4 heads of 8 dimention
# Instead of 1 communication channels we have now 4 (4 heads)
# With 4 communicaiton channels we need 8 dimentional self attention
# when they will concat, it will give us 32 dimentional vector
#self.ffwd = FeedForward(n_embd)
self.blocks = nn.Sequential(
Block(n_embd,n_head=4),
Block(n_embd,n_head=4),
Block(n_embd,n_head=4))
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
#We will first encode our information here with token embeddings and positional embeddings
# idx and targets are both (B,T) tensor of integers
tok_emb = self.token_embedding_table(idx)
pos_emb = self.position_embedding_table(torch.arange(T, device=device))
x = tok_emb + pos_emb
x = self.sa_head(x)
x = self.ffwd(x)
logits = self.lm_head(x)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
#Crop idx to the last block_size tokens
# We need to make sure that our idx we feed into the model should not exceed the block_size. if that happens then our positional embeddings will run out of scope as it has embeddings up to block size.
idx_cond = idx[:, -block_size:]
# get the predictions
logits, loss = self(idx_cond)
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
That’s it! Now that we’ve covered all the components of the Transformer architecture, assembling the modules and running the model again should be straightforward. You’ll likely notice a significant improvement in the model's performance, with the loss potentially decreasing from around 2.5 to 2.0. Additionally, when generating text, the output should show noticeable improvements—perhaps even forming coherent English words.
In the next tutorial, we’ll focus on running the fully assembled Transformer blocks and comparing their performance with existing Transformer architectures available on Hugging Face. This comparison will help us identify similarities and gain a better understanding of their functionality.