Building a Seq2Seq Model with Attention for Language Translation


# Ref from https://github.com/wjdghks950/Seq2seq-with-attention/blob/master/model.py

import random

import os

import re

import unicodedata

import zipfile

 

import matplotlib.pyplot as plt

import numpy as np

import requests

import torch

import torch.nn as nn

import torch.nn.functional as F

import torch.optim as optim

import tokenizers

import tqdm

 

 

#

# Data preparation

#

 

 

# Download dataset provided by Anki: https://www.manythings.org/anki/ with requests

if not os.path.exists(“fra-eng.zip”):

    url = “http://storage.googleapis.com/download.tensorflow.org/data/fra-eng.zip”

    response = requests.get(url)

    with open(“fra-eng.zip”, “wb”) as f:

        f.write(response.content)

 

# Normalize text

# each line of the file is in the format “\t

# We convert text to lowercasee, normalize unicode (UFKC)

def normalize(line):

    “”“Normalize a line of text and split into two at the tab character”“”

    line = unicodedata.normalize(“NFKC”, line.strip().lower())

    eng, fra = line.split(“\t”)

    return eng.lower().strip(), fra.lower().strip()

 

text_pairs = []

with zipfile.ZipFile(“fra-eng.zip”, “r”) as zip_ref:

    for line in zip_ref.read(“fra.txt”).decode(“utf-8”).splitlines():

        eng, fra = normalize(line)

        text_pairs.append((eng, fra))

 

#

# Tokenization with BPE

#

 

if os.path.exists(“en_tokenizer.json”) and os.path.exists(“fr_tokenizer.json”):

    en_tokenizer = tokenizers.Tokenizer.from_file(“en_tokenizer.json”)

    fr_tokenizer = tokenizers.Tokenizer.from_file(“fr_tokenizer.json”)

else:

    en_tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE())

    fr_tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE())

 

    # Configure pre-tokenizer to split on whitespace and punctuation, add space at beginning of the sentence

    en_tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.ByteLevel(add_prefix_space=True)

    fr_tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.ByteLevel(add_prefix_space=True)

 

    # Configure decoder: So that word boundary symbol “Ġ” will be removed

    en_tokenizer.decoder = tokenizers.decoders.ByteLevel()

    fr_tokenizer.decoder = tokenizers.decoders.ByteLevel()

 

    # Train BPE for English and French using the same trainer

    VOCAB_SIZE = 8000

    trainer = tokenizers.trainers.BpeTrainer(

        vocab_size=VOCAB_SIZE,

        special_tokens=[“[start]”, “[end]”, “[pad]”],

        show_progress=True

    )

    en_tokenizer.train_from_iterator([x[0] for x in text_pairs], trainer=trainer)

    fr_tokenizer.train_from_iterator([x[1] for x in text_pairs], trainer=trainer)

 

    en_tokenizer.enable_padding(pad_id=en_tokenizer.token_to_id(“[pad]”), pad_token=“[pad]”)

    fr_tokenizer.enable_padding(pad_id=fr_tokenizer.token_to_id(“[pad]”), pad_token=“[pad]”)

 

    # Save the trained tokenizers

    en_tokenizer.save(“en_tokenizer.json”, pretty=True)

    fr_tokenizer.save(“fr_tokenizer.json”, pretty=True)

 

# Test the tokenizer

print(“Sample tokenization:”)

en_sample, fr_sample = random.choice(text_pairs)

encoded = en_tokenizer.encode(en_sample)

print(f“Original: {en_sample}”)

print(f“Tokens: {encoded.tokens}”)

print(f“IDs: {encoded.ids}”)

print(f“Decoded: {en_tokenizer.decode(encoded.ids)}”)

print()

 

encoded = fr_tokenizer.encode(“[start] “ + fr_sample + ” [end]”)

print(f“Original: {fr_sample}”)

print(f“Tokens: {encoded.tokens}”)

print(f“IDs: {encoded.ids}”)

print(f“Decoded: {fr_tokenizer.decode(encoded.ids)}”)

print()

 

#

# Create PyTorch dataset for the BPE-encoded translation pairs

#

 

class TranslationDataset(torch.utils.data.Dataset):

    def __init__(self, text_pairs):

        self.text_pairs = text_pairs

 

    def __len__(self):

        return len(self.text_pairs)

 

    def __getitem__(self, idx):

        eng, fra = self.text_pairs[idx]

        return eng, “[start] “ + fra + ” [end]”

 

 

def collate_fn(batch):

    en_str, fr_str = zip(*batch)

    en_enc = en_tokenizer.encode_batch(en_str, add_special_tokens=True)

    fr_enc = fr_tokenizer.encode_batch(fr_str, add_special_tokens=True)

    en_ids = [enc.ids for enc in en_enc]

    fr_ids = [enc.ids for enc in fr_enc]

    return torch.tensor(en_ids), torch.tensor(fr_ids)

 

 

BATCH_SIZE = 32

dataset = TranslationDataset(text_pairs)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

 

 

#

# Create seq2seq model with attention for translation

#

 

class EncoderRNN(nn.Module):

    “”“A RNN encoder with an embedding layer”“”

    def __init__(self, vocab_size, embedding_dim, hidden_dim, dropout=0.1):

        “”

        Args:

            vocab_size: The size of the input vocabulary

            embedding_dim: The dimension of the embedding vector

            hidden_dim: The dimension of the hidden state

            dropout: The dropout rate

        ““”

        super().__init__()

        self.vocab_size = vocab_size

        self.embedding_dim = embedding_dim

        self.hidden_dim = hidden_dim

 

        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.rnn = nn.GRU(embedding_dim, hidden_dim, batch_first=True)

        self.dropout = nn.Dropout(dropout)

 

    def forward(self, input_seq):

        # input seq = [batch_size, seq_len] -> embedded = [batch_size, seq_len, embedding_dim]

        embedded = self.dropout(self.embedding(input_seq))

        # outputs = [batch_size, seq_len, embedding_dim]

        # hidden = [1, batch_size, hidden_dim]

        outputs, hidden = self.rnn(embedded)

        return outputs, hidden

 

 

class BahdanauAttention(nn.Module):

    “”“Bahdanau Attention https://arxiv.org/pdf/1409.0473.pdf

    The forward function takes query and keys only, and they should be the same shape (B,S,H)

    ““”

    def __init__(self, hidden_size):

        super(BahdanauAttention, self).__init__()

        self.Wa = nn.Linear(hidden_size, hidden_size)

        self.Ua = nn.Linear(hidden_size, hidden_size)

        self.Va = nn.Linear(hidden_size, 1)

 

    def forward(self, query, keys):

        “”“Bahdanau Attention

 

        Args:

            query: [B, 1, H]

            keys: [B, S, H]

 

        Returns:

            context: [B, 1, H]

            weights: [B, 1, S]

        ““”

        B, S, H = keys.shape

        assert query.shape == (B, 1, H)

        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))

        scores = scores.transpose(1,2)  # scores = [B, 1, S]

 

        weights = F.softmax(scores, dim=1)

        context = torch.bmm(weights, keys)

        return context, weights

 

 

class DecoderRNN(nn.Module):

    def __init__(self, vocab_size, embedding_dim, hidden_dim, dropout=0.1):

        super().__init__()

        self.vocab_size = vocab_size

        self.embedding_dim = embedding_dim

        self.hidden_dim = hidden_dim

 

        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.dropout = nn.Dropout(dropout)

        self.attention = BahdanauAttention(hidden_dim)

        self.gru = nn.GRU(embedding_dim + hidden_dim, hidden_dim, batch_first=True)

        self.out_proj = nn.Linear(hidden_dim, vocab_size)

 

    def forward(self, input_seq, hidden, enc_out):

        “”“Single token input, single token output”“”

        # input seq = [batch_size, 1] -> embedded = [batch_size, 1, embedding_dim]

        embedded = self.dropout(self.embedding(input_seq))

        # hidden = [1, batch_size, hidden_dim]

        # context = [batch_size, 1, hidden_dim]

        context, attn_weights = self.attention(hidden.transpose(0, 1), enc_out)

        # rnn_input = [batch_size, 1, embedding_dim + hidden_dim]

        rnn_input = torch.cat([embedded, context], dim=1)

        # rnn_output = [batch_size, 1, hidden_dim]

        rnn_output, hidden = self.gru(rnn_input, hidden)

        output = self.out_proj(rnn_output)

        return output, hidden

 

 

class Seq2SeqRNN(nn.Module):

    def __init__(self, encoder, decoder):

        super().__init__()

        self.encoder = encoder

        self.decoder = decoder

 

    def forward(self, input_seq, target_seq):

        “”“Given the partial target sequence, predict the next token”“”

        # input seq = [batch_size, seq_len]

        # target seq = [batch_size, seq_len]

        batch_size, target_len = target_seq.shape

        device = target_seq.device

        # list for storing the output logits

        outputs = []

        # encoder forward pass

        enc_out, hidden = self.encoder(input_seq)

        dec_hidden = hidden

        # decoder forward pass

        for t in range(target_len1):

            # during training, use the ground truth token as the input (teacher forcing)

            dec_in = target_seq[:, t].unsqueeze(1)

            # last target token and hidden states -> next token

            dec_out, dec_hidden = self.decoder(dec_in, dec_hidden, enc_out)

            # store the prediction

            outputs.append(dec_out)

        outputs = torch.cat(outputs, dim=1)

        return outputs

 

 

# Initialize model parameters

device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)

enc_vocab = len(en_tokenizer.get_vocab())

dec_vocab = len(fr_tokenizer.get_vocab())

emb_dim = 256

hidden_dim = 256

dropout = 0.1

 

# Create model

encoder = EncoderRNN(enc_vocab, emb_dim, hidden_dim, dropout).to(device)

decoder = DecoderRNN(dec_vocab, emb_dim, hidden_dim, dropout).to(device)

model = Seq2SeqRNN(encoder, decoder).to(device)

print(model)

 

print(“Model created with:”)

print(f”  Input vocabulary size: {enc_vocab}”)

print(f”  Output vocabulary size: {dec_vocab}”)

print(f”  Embedding dimension: {emb_dim}”)

print(f”  Hidden dimension: {hidden_dim}”)

print(f”  Dropout: {dropout}”)

print(f”  Total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}”)

 

# Initialize model parameters with uniform distribution [-0.08, 0.08]

#for name, param in model.named_parameters():

#    if param.dim() > 1:

#        nn.init.normal_(param.data, mean=0, std=0.01)

 

# Train unless model.pth exists

if os.path.exists(“seq2seq_attn.pth”):

    model.load_state_dict(torch.load(“seq2seq_attn.pth”))

else:

    optimizer = optim.Adam(model.parameters(), lr=0.0005)

    loss_fn = nn.CrossEntropyLoss() #ignore_index=fr_tokenizer.token_to_id(“[pad]”))

    N_EPOCHS = 100

 

    for epoch in range(N_EPOCHS):

        model.train()

        epoch_loss = 0

        for en_ids, fr_ids in tqdm.tqdm(dataloader, desc=“Training”):

            # Move the “sentences” to device

            en_ids = en_ids.to(device)

            fr_ids = fr_ids.to(device)

            # zero the grad, then forward pass

            optimizer.zero_grad()

            outputs = model(en_ids, fr_ids)

            # compute the loss: compare 3D logits to 2D targets

            loss = loss_fn(outputs.reshape(1, dec_vocab), fr_ids[:, 1:].reshape(1))

            loss.backward()

            optimizer.step()

            epoch_loss += loss.item()

        print(f“Epoch {epoch+1}/{N_EPOCHS}; Avg loss {epoch_loss/len(dataloader)}; Latest loss {loss.item()}”)

        torch.save(model.state_dict(), f“seq2seq_attn-epoch-{epoch+1}.pth”)

        # Test

        if (epoch+1) % 5 != 0:

            continue

        model.eval()

        epoch_loss = 0

        with torch.no_grad():

            for en_ids, fr_ids in tqdm.tqdm(dataloader, desc=“Evaluating”):

                en_ids = en_ids.to(device)

                fr_ids = fr_ids.to(device)

                outputs = model(en_ids, fr_ids)

                loss = loss_fn(outputs.reshape(1, dec_vocab), fr_ids[:, 1:].reshape(1))

                epoch_loss += loss.item()

        print(f“Eval loss: {epoch_loss/len(dataloader)}”)

    torch.save(model.state_dict(), “seq2seq_attn.pth”)

 

# Test for a few samples

model.eval()

N_SAMPLES = 5

MAX_LEN = 60

with torch.no_grad():

    start_token = torch.tensor([fr_tokenizer.token_to_id(“[start]”)]).to(device)

    for en, true_fr in random.sample(text_pairs, N_SAMPLES):

        en_ids = torch.tensor(en_tokenizer.encode(en).ids).unsqueeze(0).to(device)

        enc_out, hidden = model.encoder(en_ids)

        pred_ids = []

        prev_token = start_token.unsqueeze(0)

        for _ in range(MAX_LEN):

            output, hidden = model.decoder(prev_token, hidden, enc_out)

            output = output.argmax(dim=2)

            pred_ids.append(output.item())

            prev_token = output

            # early stop if the predicted token is the end token

            if pred_ids[1] == fr_tokenizer.token_to_id(“[end]”):

                break

        # Decode the predicted IDs

        pred_fr = fr_tokenizer.decode(pred_ids)

        print(f“English: {en}”)

        print(f“French: {true_fr}”)

        print(f“Predicted: {pred_fr}”)

        print()


Leave a Comment