Back to Blog

Document Classification with LayoutLMv3

Fine-tune a LayoutLMv3 model using PyTorch Lightning to perform classification on document images with imbalanced classes. You will learn how to use Hugging Face Transformers library, evaluate the model using confusion matrix, and upload the trained model to the Hugging Face Hub. The tutorial also covers installing and using various libraries like EasyOCR, and Torchmetrics. A complete Jupyter (Google Colab) notebook included.

March 24, 2023
13 min read
Document Classification with LayoutLMv3

In this tutorial, we will explore the task of document classification using layout information and image content. We will use the LayoutLMv3 model, a state-of-the-art model for this task, and PyTorch Lightning, a lightweight PyTorch wrapper for high-performance training.

We will start by preparing the dataset and data loaders, followed by building and training the model. We will then evaluate the performance of our model and analyze the results using a confusion matrix. Finally, we will explore ways to improve the performance of the model on specific classes. By the end of this tutorial, you will have a good understanding of how to use LayoutLMv3 for document classification and how to leverage PyTorch Lightning to train and evaluate deep learning models.

Notebook Setup

We will begin by installing wkhtmltopdf1, a utility that can convert HTML files into images:

py code
%%bash
wget -q https://github.com/wkhtmltopdf/packaging/releases/download/0.12.6-1/wkhtmltox_0.12.6-1.bionic_amd64.deb
cp wkhtmltox_0.12.6-1.bionic_amd64.deb /usr/bin
apt -qq install /usr/bin/wkhtmltox_0.12.6-1.bionic_amd64.deb

Next, we will proceed to install all the necessary libraries:

py code
!pip install -qqq transformers==4.27.2 --progress-bar off
!pip install -qqq pytorch-lightning==1.9.4 --progress-bar off
!pip install -qqq torchmetrics==0.11.4 --progress-bar off
!pip install -qqq imgkit==1.2.3 --progress-bar off
!pip install -qqq easyocr==1.6.2 --progress-bar off
!pip install -qqq Pillow==9.4.0 --progress-bar off
!pip install -qqq tensorboardX==2.5.1 --progress-bar off
!pip install -qqq huggingface_hub==0.11.1 --progress-bar off
!pip install -qqq --upgrade --no-cache-dir gdown

The essential libraries for this tutorial are:

  • transformers: We'll be using the implementation of LayoutLMv3 from this library for our model.
  • pytorch-lightning: It will help us in fine-tuning our model.
  • torchmetrics: This library provides us with various metrics for classification and other tasks.
  • easyocr: We'll be using this library to run OCR on the document images.

Let's add all imports that we'll use:

python code
from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from sklearn.model_selection import train_test_split
import imgkit
import easyocr
import torchvision.transforms as T
from pathlib import Path
import matplotlib.pyplot as plt
import os
import cv2
from typing import List
import json
from torchmetrics import Accuracy
from huggingface_hub import notebook_login
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
 
%matplotlib inline
pl.seed_everything(42)

The last line sets the seed for PyTorch Lightning to 42. Setting a seed ensures that the random number generator used by PyTorch Lightning (and the underlying PyTorch framework) produces the same sequence of random numbers each time the code is run.

Data

The data is from Kaggle - Financial Documents Clustering2. It contains HTML documents (tables) from the publically available Hexaware Technologies financial annual reports3. It has 5 categories:

  • Income Statements (317 files)
  • Balance Sheets (282 files)
  • Cash Flows (36 files)
  • Notes (702 files)
  • Others (1236 files)

Download and extract an exact copy of the Kaggle files from my Google Drive:

!gdown 1tMZXonmajLPK9zhZ2dt-CdzRTs5YfHy0
!unzip -q financial-documents.zip
!mv "TableClassifierQuaterlyWithNotes" "documents"

Convert HTML to Images

The documents are in HTML format, which is not usable for our model. We'll convert them to images.

First, let's change the folder names to "snake_case":

python code
for dir in Path("documents").glob("*"):
  dir.rename(str(dir).lower().replace(" ", "_"))
 
list(Path("documents").glob("*"))
[PosixPath('documents/notes'),
 PosixPath('documents/cash_flow'),
 PosixPath('documents/balance_sheets'),
 PosixPath('documents/income_statement'),
 PosixPath('documents/others')]

We need a directory for the converted images and each class of documents:

py code
for dir in Path("documents").glob("*"):
    image_dir = Path(f"images/{dir.name}")
    image_dir.mkdir(exist_ok=True, parents=True)

To convert the HTML files to images, we'll be utilizing the imgkit package:

py code
def convert_html_to_image(file_path: Path, images_dir: Path, scale: float = 1.0) -> Path:
    file_name = file_path.with_suffix(".jpg").name
    save_path = images_dir / file_path.parent.name / f"{file_name}"
    imgkit.from_file(str(file_path), save_path, options={'quiet': '', 'format': 'jpeg'})
 
    image = Image.open(save_path)
    width, height = image.size
    image = image.resize((int(width * scale), int(height * scale)))
    image.save(str(save_path))
 
    return save_path
py code
document_paths = list(Path("documents").glob("*/*"))
 
for doc_path in tqdm(document_paths):
    convert_html_to_image(doc_path, Path("images"), scale=0.8)

Let's look at a sample document image:

py code
image_paths = sorted(list(Path("images").glob("*/*.jpg")))
 
image = Image.open(image_paths[0]).convert("RGB")
width, height = image.size
image
Sample document
Sample document

EasyOCR

EasyOCR is a Python library for optical character recognition (OCR), which is the process of extracting text from images. EasyOCR uses deep learning models to recognize text and can handle a wide range of font styles, sizes, and orientations.

py code
reader = easyocr.Reader(['en'])

We'll feed our sample document into the EasyOCR reader and see what it detects:

py code
image_path = image_paths[0]
ocr_result = reader.readtext(str(image_path))

The ocr_result has the following format:

text box coordinates [x,y], text, confidence

Here's the first row from the result:

([[279, 13], [327, 13], [327, 27], [279, 27]], 'In lacs)', 0.46634036192148154)

We'll examine the OCR output overlaid on top of the document image:

py code
def create_bounding_box(bbox_data):
    xs = []
    ys = []
    for x, y in bbox_data:
        xs.append(x)
        ys.append(y)
 
    left = int(min(xs))
    top = int(min(ys))
    right = int(max(xs))
    bottom = int(max(ys))
 
    return [left, top, right, bottom]
 
font_path = Path(cv2.__path__[0]) / "qt/fonts/DejaVuSansCondensed.ttf"
font = ImageFont.truetype(str(font_path), size=12)
 
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(28, 28))
 
left_image = Image.open(image_path).convert("RGB")
right_image = Image.new("RGB", left_image.size, (255, 255, 255))
 
left_draw = ImageDraw.Draw(left_image)
right_draw = ImageDraw.Draw(right_image)
 
for i, (bbox, word, confidence) in enumerate(ocr_result):
    box = create_bounding_box(bbox)
 
    left_draw.rectangle(box, outline="blue", width=2)
    left, top, right, bottom = box
 
    left_draw.text((right + 5, top), text=str(i + 1), fill="red", font=font)
    right_draw.text((left, top), text=word, fill="black", font=font)
 
ax1.imshow(left_image)
ax2.imshow(right_image)
ax1.axis("off");
ax2.axis("off");
Document OCR
Document OCR

We define a helper function create_bounding_box() that takes text box coordinates. The function finds the minimum and maximum values of xs and ys, and returns the coordinates of the resulting bounding box as a list in the format left, top, right, bottom.

We can extract the OCR (Optical Character Recognition) result from each image and then save the results in JSON files:

py code
for image_path in tqdm(image_paths):
    ocr_result = reader.readtext(str(image_path), batch_size=16)
 
    ocr_page = []
    for bbox, word, confidence in ocr_result:
        ocr_page.append({
            "word": word, "bounding_box": create_bounding_box(bbox)
        })
 
    with image_path.with_suffix(".json").open("w") as f:
        json.dump(ocr_page, f)

LayoutLMv3

LayoutLMv34 is a state-of-the-art pre-trained language model developed by Microsoft Research Asia. It is designed to handle document analysis tasks that require understanding of both text and layout information, such as document classification, information extraction, and question answering.

The model is built on top of the transformer architecture and trained on massive amounts of annotated document images and text. LayoutLMv3 is capable of recognizing and encoding both the textual content and the visual layout of a document, allowing it to provide superior performance on document analysis tasks.

You can use LayoutLMv3 for various tasks, such as document classification, named entity recognition, and question answering. To use LayoutLMv3, you can fine-tune the pre-trained model on your specific task with a small amount of task-specific data. Hugging Face Transformers and PyTorch provide easy-to-use APIs that will allow us to fine-tune LayoutLMv3 for document classification.

Preprocessing

LayoutLMv3 uses text, bounding boxes and images as input. To prepare all, we can use the LayoutLMv3Processor. The processor combines OCR, tokenization and image preprocessing:

py code
feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
tokenizer = LayoutLMv3TokenizerFast.from_pretrained(
    "microsoft/layoutlmv3-base"
)
processor = LayoutLMv3Processor(feature_extractor, tokenizer)

The LayoutLMv3FeatureExtractor uses Tesseract OCR as the default option. However, Tesseract OCR was very slow during my experiments. Instead, we'll use a custom OCR engine (EasyOCR). Consider Google Cloud Vision or Amazon Textract, if you require a faster and more accurate OCR solution.

We'll apply the processor to the sample document. LayoutLMv3 requires that each bounding box be normalized to be on a 0-1000 scale. We'll need the image width and height scale for that:

py code
image_path = image_paths[0]
image = Image.open(image_path).convert("RGB")
width, height = image.size
 
width_scale = 1000 / width
height_scale = 1000 / height

Next, we'll take the OCR and extract words and bounding boxes:

py code
def scale_bounding_box(box: List[int], width_scale : float = 1.0, height_scale : float = 1.0) -> List[int]:
    return [
        int(box[0] * width_scale),
        int(box[1] * height_scale),
        int(box[2] * width_scale),
        int(box[3] * height_scale)
    ]
 
json_path = image_path.with_suffix(".json")
with json_path.open("r") as f:
    ocr_result = json.load(f)
 
words = []
boxes = []
for row in ocr_result:
    boxes.append(scale_bounding_box(row["bounding_box"], width_scale, height_scale))
    words.append(row["word"])
 
len(words), len(boxes)
(174, 174)

We define the function scale_bounding_box() to apply the image scale to each bounding box. Next, we iterate over each row of the OCR results stored in ocr_result, extract the bounding box coordinates and word text for each recognized text region, and scale the bounding box coordinates using the scale_bounding_box().

py code
encoding = processor(
    image,
    words,
    boxes=boxes,
    max_length=512,
    padding="max_length",
    truncation=True,
    return_tensors="pt"
)
 
print(f"""
input_ids:  {list(encoding["input_ids"].squeeze().shape)}
word boxes: {list(encoding["bbox"].squeeze().shape)}
image data: {list(encoding["pixel_values"].squeeze().shape)}
image size: {image.size}
""")
input_ids:  [512]
word boxes: [512, 4]
image data: [3, 224, 224]
image size: (819, 1195)

We have three pieces of information: input_ids from the tokenizer, bbox for the bounding boxes, and pixel_values for the image. Let's have a look at the encoded image:

py code
image_data = encoding["pixel_values"][0]
transform = T.ToPILImage()
transform(image_data)
Encoded Document Image
Encoded Document Image

The image encoding is a 3-dimensional array of shape (channels, height, width). Next, we convert the tensor to a PIL image object using a transformation from torchvision.

Model

Let's create an instance of LayoutLMv3:

py code
model = LayoutLMv3ForSequenceClassification.from_pretrained(
    "microsoft/layoutlmv3-base", num_labels=2
)

The sequence classification model is loaded from the microsoft/layoutlmv3-base checkpoint. We set num_labels to 2, which indicates we'll use it for binary classification.

We can run the encoded document through the model and look at the predictions:

py code
outputs = model(**encoding)
outputs.logits
tensor([[0.0644, 0.2629]], grad_fn=<AddmmBackward0>)

Naturally, our model is untrained and lacks the ability to comprehend the documents in our dataset. Let's train it!

Training

To fine-tune LayoutLMv3, we will utilize PyTorch Lightning. This is what we'll do:

  • Split the data into training and testing subsets
  • Create a PyTorch Dataset
  • Generating dataloaders
  • Define a LightningModule
  • Use the Trainer from PyTorch Lightning to train our model

Let's start by preparing the data:

py code
train_images, test_images = train_test_split(image_paths, test_size=.2)
DOCUMENT_CLASSES = sorted(list(map(
    lambda p: p.name,
    Path("images").glob("*")
)))
DOCUMENT_CLASSES
[
    'balance_sheets',
    'cash_flow',
    'income_statement',
    'notes',
    'others'
]

First, we split the document images into train and test subsets. Next, we extract the document classes from the document image directory names. This allows us to create a mapping from document image to its class.

We have everything needed to create a PyTorch Dataset:

py code
class DocumentClassificationDataset(Dataset):
 
    def __init__(self, image_paths, processor):
        self.image_paths = image_paths
        self.processor = processor
 
    def __len__(self):
        return len(self.image_paths)
 
    def __getitem__(self, item):
 
        image_path = self.image_paths[item]
        json_path = image_path.with_suffix(".json")
        with json_path.open("r") as f:
            ocr_result = json.load(f)
 
            with Image.open(image_path).convert("RGB") as image:
 
                width, height = image.size
                width_scale = 1000 / width
                height_scale = 1000 / height
 
                words = []
                boxes = []
                for row in ocr_result:
                    boxes.append(scale_bounding_box(
                        row["bounding_box"],
                        width_scale,
                        height_scale
                    ))
                    words.append(row["word"])
 
                encoding = self.processor(
                    image,
                    words,
                    boxes=boxes,
                    max_length=512,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt"
                )
 
        label = DOCUMENT_CLASSES.index(image_path.parent.name)
 
        return dict(
            input_ids=encoding["input_ids"].flatten(),
            attention_mask=encoding["attention_mask"].flatten(),
            bbox=encoding["bbox"].flatten(end_dim=1),
            pixel_values=encoding["pixel_values"].flatten(end_dim=1),
            labels=torch.tensor(label, dtype=torch.long)
        )

The class takes two arguments:

  • image_paths: a list of paths to document images
  • processor: an instance of the LayoutLMv3Processor class

The __len__ method returns the number of images in the dataset, and the __getitem__ method loads and preprocesses the image and OCR results at a given index.

We can now create datasets and data loaders for the train and test documents:

py code
train_dataset = DocumentClassificationDataset(train_images, processor)
test_dataset = DocumentClassificationDataset(test_images, processor)
 
train_data_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=2
)
 
test_data_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=2
)

Let's implement a LightningModule using PyTorch Lightning. This will wrap all the components and allow us to train our model:

py code
class ModelModule(pl.LightningModule):
    def __init__(self, n_classes:int):
        super().__init__()
        self.model = LayoutLMv3ForSequenceClassification.from_pretrained(
            "microsoft/layoutlmv3-base",
            num_labels=n_classes
        )
        self.model.config.id2label = {k: v for k, v in enumerate(DOCUMENT_CLASSES)}
        self.model.config.label2id = {v: k for k, v in enumerate(DOCUMENT_CLASSES)}
        self.train_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
 
    def forward(self, input_ids, attention_mask, bbox, pixel_values, labels=None):
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            bbox=bbox,
            pixel_values=pixel_values,
            labels=labels
        )
 
    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        bbox = batch["bbox"]
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]
        output = self(input_ids, attention_mask, bbox, pixel_values, labels)
        self.log("train_loss", output.loss)
        self.log(
            "train_acc",
            self.train_accuracy(output.logits, labels),
            on_step=True,
            on_epoch=True
        )
        return output.loss
 
    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        bbox = batch["bbox"]
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]
        output = self(input_ids, attention_mask, bbox, pixel_values, labels)
        self.log("val_loss", output.loss)
        self.log(
            "val_acc",
            self.val_accuracy(output.logits, labels),
            on_step=False,
            on_epoch=True
        )
        return output.loss
 
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.00001) #1e-5
        return optimizer

The __init__ method initializes the LayoutLMv3 model for sequence classification with a specified number of classes, and sets up the accuracy metric for both training and validation.

The forward method takes input tensors (input_ids, attention_mask, bbox, and pixel_values) and an optional labels tensor (only used during training), and returns the model output.

The training_step and validation_step methods define the training and validation steps respectively. In each method, the input tensors are passed through the model, and the loss and accuracy are logged. The configure_optimizers method defines an Adam optimizer used for training.

Let's create an instance of our ModelModule:

py code
model_module = ModelModule(len(DOCUMENT_CLASSES))

We'll use Tensorboard to track the training progress:

py code
%load_ext tensorboard
%tensorboard --logdir lightning_logs

Finally, we need to set up the PyTorch Lightning Trainer:

py code
model_checkpoint = ModelCheckpoint(
    filename="{epoch}-{step}-{val_loss:.4f}", save_last=True, save_top_k=3, monitor="val_loss", mode="min"
)
 
trainer = pl.Trainer(
    accelerator="gpu",
    precision=16,
    devices=1,
    max_epochs=5,
    callbacks=[
        model_checkpoint
    ],
)

The ModelCheckpoint callback is defined to save the model's weights after each epoch, with a specific naming format that includes the epoch number, training step, and validation loss.

The Trainer will use a single GPU, mixed precision (16 bit) training, and train for 5 epochs.

Let's train:

py code
trainer.fit(model_module, train_data_loader, test_data_loader)
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name           | Type                                | Params
-----------------------------------------------------------------------
0 | model          | LayoutLMv3ForSequenceClassification | 125 M
1 | train_accuracy | MulticlassAccuracy                  | 0
2 | val_accuracy   | MulticlassAccuracy                  | 0
-----------------------------------------------------------------------
125 M     Trainable params
0         Non-trainable params
125 M     Total params
251.843   Total estimated model params size (MB)

We can have a look at the training metrics:

Training Metrics Accuracy
Training Metrics Accuracy
Training Metrics Loss
Training Metrics Loss

Our model was able to attain a validation accuracy of 94%, and the overall training process appears to be stable. Let's dive a bit deeper into our model performance.

Evaluation

To evaluate the model's performance, we'll begin by loading the best trained model and uploading it to the HuggingFace Hub:

py code
trained_model = ModelModule.load_from_checkpoint(
    model_checkpoint.best_model_path,
    n_classes=len(DOCUMENT_CLASSES),
    local_files_only=True
)
 
notebook_login()
 
trained_model.model.push_to_hub(
    "layoutlmv3-financial-document-classification"
)

Once the model is uploaded, we can easily download it using its name or ID. We will load the model from the Hub and put it on the GPU for inference:

py code
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
 
model = LayoutLMv3ForSequenceClassification.from_pretrained(
    "curiousily/layoutlmv3-financial-document-classification"
)
model = model.eval().to(DEVICE)

We'll write a function to do inference for a single document image:

py code
def predict_document_image(
    image_path: Path,
    model: LayoutLMv3ForSequenceClassification,
    processor: LayoutLMv3Processor):
 
    json_path = image_path.with_suffix(".json")
    with json_path.open("r") as f:
        ocr_result = json.load(f)
 
        with Image.open(image_path).convert("RGB") as image:
 
            width, height = image.size
            width_scale = 1000 / width
            height_scale = 1000 / height
 
            words = []
            boxes = []
            for row in ocr_result:
                boxes.append(
                    scale_bounding_box(
                        row["bounding_box"],
                        width_scale,
                        height_scale
                    )
                )
                words.append(row["word"])
 
            encoding = processor(
                image,
                words,
                boxes=boxes,
                max_length=512,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            )
 
    with torch.inference_mode():
        output = model(
            input_ids=encoding["input_ids"].to(DEVICE),
            attention_mask=encoding["attention_mask"].to(DEVICE),
            bbox=encoding["bbox"].to(DEVICE),
            pixel_values=encoding["pixel_values"].to(DEVICE)
        )
 
    predicted_class = output.logits.argmax()
    return model.config.id2label[predicted_class.item()]

This function takes an image path as input, opens the image, extracts the OCR, scales the bounding boxes based on the image size, and preprocesses the image and text data using the previously defined processor. The preprocessed data is then sent to the model for inference on the GPU. Finally, the function returns the predicted class label for the input image.

We can now execute the function on all test documents:

py code
labels = []
predictions = []
for image_path in tqdm(test_images):
    labels.append(image_path.parent.name)
    predictions.append(
        predict_document_image(image_path, model, processor)
    )

Given that the dataset is imbalanced, relying solely on accuracy as the evaluation metric may not provide a complete picture of the model's performance. Therefore, we will use a confusion matrix to gain deeper insights:

py code
cm = confusion_matrix(labels, predictions, labels=DOCUMENT_CLASSES)
cm_display = ConfusionMatrixDisplay(
    confusion_matrix=cm,
    display_labels=DOCUMENT_CLASSES
)
 
cm_display.plot()
cm_display.ax_.set_xticklabels(DOCUMENT_CLASSES, rotation=45)
cm_display.figure_.set_size_inches(16, 8)
 
plt.show();

There is some confusion between the two most represented classes - others and notes. Could you create an improved model that makes more accurate predictions for those?

Confusion Matrix
Confusion Matrix

Conclusion

You have successfully built a LayoutLMv3-based model using PyTorch Lightning to classify document images into multiple categories! We demonstrated how to handle an imbalanced dataset, train and evaluate the model, and perform inference on new images. PyTorch Lightning provided an efficient way to train the model and evaluate its performance, while LayoutLMv3 allowed us to fine-tune a state-of-the-art pre-trained model on our specific task.

Overall, this tutorial shows how PyTorch Lightning and LayoutLMv3 can be used together to build a powerful document classification model, even in the face of an imbalanced dataset.

References

Footnotes

  1. wkhtmltopdf - tool to render HTML into PDF/images ↩

  2. Financial Documents Clustering ↩

  3. Hexaware Technologies financial annual reports ↩

  4. LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking ↩