Since my last blog post Training LeNet on Armenian script, I have made some significant improvement to the training process.

Model simplification

The model takes as input a mean and standard deviation for normalizing pixel intensities. These values are calibrated on the training set before initiating the gradient descent loop for adjusting the weights and biases.

To make things simpler, I hardcoded those parameters. This way the dependency between model and training set only happens in the gradient descent loop. Importing the model, for instance from a Gradio app, becomes:

N, num_classes, mean, std = 56, 38, 2 / 3, np.sqrt(2) / 3
model = LeNet(N, num_classes, mean, std)
model.load_state_dict(torch.load("model_state_dict.pt"))

To find mean and standard deviation values that make sense, I arbitrarily decided that a square occupying one-third of the total pixel space would represent the average character. The mean and standard deviation of the pixel intensities are respectively $1/3$ (a third is black) and $\sqrt{2}/3$. Those values are close to those observed on the training set.

I also chose to focus only on the lowercase letters for this project to reduce the number of classes to 38.

Accuracy metrics

Upon inspecting my code, I realized that certain PyTorch functions perform more tasks than initially anticipated. For instance, torch.nn.CrossEntropyLoss accepts integer labels (indices), eliminating the need for one-hot encoding. While this might be convenient in some cases, I believe one-hot encoding the labels for a classification task is significantly more readable.

Here is my current training loop, with one-hot encoded labels and accuracy score calculated by torchmetrics:

for epoch in epochs:

    train_loss, train_acc = 0, 0
    
    for inputs, labels in train_dataloader:

        # Evaluate logits and loss
        logits = model(inputs)
        loss = criterion(logits, labels)

        # Compute metrics
        train_loss += loss.item()
        max_indices = torch.argmax(logits, dim=1)
        preds = F.one_hot(max_indices, num_classes=logits.size(1)).float()
        train_acc += multiclass_exact_match(
            preds=preds, target=labels, num_classes=num_classes
        )

        # Gradient descent
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss /= len(train_dataloader)
    train_acc /= len(train_dataloader)

Extensive logging

I also started logging weight and biases in TensorBoard to check their stability. It has become a standard and is very easy to use:

writer = SummaryWriter(log_dir=log_dir)

for epoch in epochs:
    
    # train...
    
    with torch.no_grad():
        for name, param in model.named_parameters():
            writer.add_histogram(name, param, epoch)

Additionally, I added an ASCII logger for my character images which helped me realize that Arial did not support Armenian and was generating the unknown character symbol □. Here is an example of a rotated and noised ն character:

7   7  7 7      7 77 7  7  7  7 
  7 7 77  7   77      77777 77  
7777   777  8166077777   77  77 
    7 7 77 510000000 7  77 7 7 7
 77 7  77 750000000000 7   777 7
  7   77 710000000000006 7      
7  7      10000134400006  77 7  
7  7 7    000005  74045777   7 7
7 7777  700000997   7 7    77   
77  77 7700000677 2 7  777777   
7 77 77 00000097  60477  7  77  
 7  77  00000977 7000408 777   7
 7    77000009 7700000067 7    7
7    7 0000009 7730000677 77 77 
 77  7 000009 7  300006  77 777 
  7777 00000 7  00000037 7      
777 770000006   300006   7 777 7
777 77000009 7733000037777 7 77 
   77 100006 77000002    77    7
   7 2000059   0000037      7777
7777 4000057  3000002777  7  777
 77  000000  5000003 77 7 7 7 7 
 77 7300001101000002      77 7 7
 7 77640000000000002 77   7  77 
 77 77 0000000000027    777 7 7 
77777 7334155400006 7 7  7  7 7 
7   7 7 77  70000267 7  7   77  
7 777  77 7 77 003       7      
  7 7 77 77777 7    777  77  77 
 7777 777  77  777  7    7  7777
   7      7777 77 7   77    7  7
7 7 77777   777 7      7 777 7    

Bonus

Some additional findings:

  • Notebooks are excellent for visualizing data and analyzing results. However, I found that a robust CLI for preparing, training, and evaluating the model is even more beneficial. I have been using fire, which automatically generates a CLI from Python classes
  • tbparse helps extract data from TensorBoard as pandas DataFrame for plotting
  • Hugging Face provides Gradio app hosting to run AI models. Once the model weights are saved, the model can be easily instantiated from a Gradio app and used for inference. You can find the space here