Blog
Document Classification with Layoutlmv3

Document Classification with LayoutLMv3

In this tutorial, we will explore the task of document classification using layout information and image content. We will use the LayoutLMv3 model, a state-of-the-art model for this task, and PyTorch Lightning, a lightweight PyTorch wrapper for high-performance training.

Join the AI BootCamp!

Ready to dive deep into the world of AI and Machine Learning? Join our BootCamp to transform your career with the latest skills and real-world project experience. LLMs, ML best practices, and more!

We will start by preparing the dataset and data loaders, followed by building and training the model. We will then evaluate the performance of our model and analyze the results using a confusion matrix. Finally, we will explore ways to improve the performance of the model on specific classes. By the end of this tutorial, you will have a good understanding of how to use LayoutLMv3 for document classification and how to leverage PyTorch Lightning to train and evaluate deep learning models.

In this tutorial, we will be using Jupyter Notebook to run the code. If you prefer to follow along, you can access the notebook here: open the notebook (opens in a new tab)

Notebook Setup

We will begin by installing wkhtmltopdf1, a utility that can convert HTML files into images:

%%bash
wget -q https://github.com/wkhtmltopdf/packaging/releases/download/0.12.6-1/wkhtmltox_0.12.6-1.bionic_amd64.deb
cp wkhtmltox_0.12.6-1.bionic_amd64.deb /usr/bin
apt -qq install /usr/bin/wkhtmltox_0.12.6-1.bionic_amd64.deb

Next, we will proceed to install all the necessary libraries:

!pip install -qqq transformers==4.27.2 --progress-bar off
!pip install -qqq pytorch-lightning==1.9.4 --progress-bar off
!pip install -qqq torchmetrics==0.11.4 --progress-bar off
!pip install -qqq imgkit==1.2.3 --progress-bar off
!pip install -qqq easyocr==1.6.2 --progress-bar off
!pip install -qqq Pillow==9.4.0 --progress-bar off
!pip install -qqq tensorboardX==2.5.1 --progress-bar off
!pip install -qqq huggingface_hub==0.11.1 --progress-bar off
!pip install -qqq --upgrade --no-cache-dir gdown

The essential libraries for this tutorial are:

  • transformers: We'll be using the implementation of LayoutLMv3 from this library for our model.
  • pytorch-lightning: It will help us in fine-tuning our model.
  • torchmetrics: This library provides us with various metrics for classification and other tasks.
  • easyocr: We'll be using this library to run OCR on the document images.

Let's add all imports that we'll use:

from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from sklearn.model_selection import train_test_split
import imgkit
import easyocr
import torchvision.transforms as T
from pathlib import Path
import matplotlib.pyplot as plt
import os
import cv2
from typing import List
import json
from torchmetrics import Accuracy
from huggingface_hub import notebook_login
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
 
%matplotlib inline
pl.seed_everything(42)

The last line sets the seed for PyTorch Lightning to 42. Setting a seed ensures that the random number generator used by PyTorch Lightning (and the underlying PyTorch framework) produces the same sequence of random numbers each time the code is run.

Data

The data is from Kaggle - Financial Documents Clustering2. It contains HTML documents (tables) from the publically available Hexaware Technologies financial annual reports3. It has 5 categories:

  • Income Statements (317 files)
  • Balance Sheets (282 files)
  • Cash Flows (36 files)
  • Notes (702 files)
  • Others (1236 files)

Download and extract an exact copy of the Kaggle files from my Google Drive:

!gdown 1tMZXonmajLPK9zhZ2dt-CdzRTs5YfHy0
!unzip -q financial-documents.zip
!mv "TableClassifierQuaterlyWithNotes" "documents"

Convert HTML to Images

The documents are in HTML format, which is not usable for our model. We'll convert them to images.

First, let's change the folder names to "snake_case" (opens in a new tab):

for dir in Path("documents").glob("*"):
  dir.rename(str(dir).lower().replace(" ", "_"))
 
list(Path("documents").glob("*"))
[PosixPath('documents/notes'),
 PosixPath('documents/cash_flow'),
 PosixPath('documents/balance_sheets'),
 PosixPath('documents/income_statement'),
 PosixPath('documents/others')]

We need a directory for the converted images and each class of documents:

for dir in Path("documents").glob("*"):
    image_dir = Path(f"images/{dir.name}")
    image_dir.mkdir(exist_ok=True, parents=True)

To convert the HTML files to images, we'll be utilizing the imgkit package:

def convert_html_to_image(file_path: Path, images_dir: Path, scale: float = 1.0) -> Path:
    file_name = file_path.with_suffix(".jpg").name
    save_path = images_dir / file_path.parent.name / f"{file_name}"
    imgkit.from_file(str(file_path), save_path, options={'quiet': '', 'format': 'jpeg'})
 
    image = Image.open(save_path)
    width, height = image.size
    image = image.resize((int(width * scale), int(height * scale)))
    image.save(str(save_path))
 
    return save_path
document_paths = list(Path("documents").glob("*/*"))
 
for doc_path in tqdm(document_paths):
    convert_html_to_image(doc_path, Path("images"), scale=0.8)

Let's look at a sample document image:

image_paths = sorted(list(Path("images").glob("*/*.jpg")))
 
image = Image.open(image_paths[0]).convert("RGB")
width, height = image.size
image

Sample document

EasyOCR

EasyOCR is a Python library for optical character recognition (OCR), which is the process of extracting text from images. EasyOCR uses deep learning models to recognize text and can handle a wide range of font styles, sizes, and orientations.

reader = easyocr.Reader(['en'])

We'll feed our sample document into the EasyOCR reader and see what it detects:

image_path = image_paths[0]
ocr_result = reader.readtext(str(image_path))

The ocr_result has the following format:

text box coordinates [x,y], text, confidence

Here's the first row from the result:

([[279, 13], [327, 13], [327, 27], [279, 27]], 'In lacs)', 0.46634036192148154)

We'll examine the OCR output overlaid on top of the document image:

def create_bounding_box(bbox_data):
    xs = []
    ys = []
    for x, y in bbox_data:
        xs.append(x)
        ys.append(y)
 
    left = int(min(xs))
    top = int(min(ys))
    right = int(max(xs))
    bottom = int(max(ys))
 
    return [left, top, right, bottom]
 
font_path = Path(cv2.__path__[0]) / "qt/fonts/DejaVuSansCondensed.ttf"
font = ImageFont.truetype(str(font_path), size=12)
 
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(28, 28))
 
left_image = Image.open(image_path).convert("RGB")
right_image = Image.new("RGB", left_image.size, (255, 255, 255))
 
left_draw = ImageDraw.Draw(left_image)
right_draw = ImageDraw.Draw(right_image)
 
for i, (bbox, word, confidence) in enumerate(ocr_result):
    box = create_bounding_box(bbox)
 
    left_draw.rectangle(box, outline="blue", width=2)
    left, top, right, bottom = box
 
    left_draw.text((right + 5, top), text=str(i + 1), fill="red", font=font)
    right_draw.text((left, top), text=word, fill="black", font=font)
 
ax1.imshow(left_image)
ax2.imshow(right_image)
ax1.axis("off");
ax2.axis("off");

Document OCR

We define a helper function create_bounding_box() that takes text box coordinates. The function finds the minimum and maximum values of xs and ys, and returns the coordinates of the resulting bounding box as a list in the format left, top, right, bottom.

We can extract the OCR (Optical Character Recognition) result from each image and then save the results in JSON files:

for image_path in tqdm(image_paths):
    ocr_result = reader.readtext(str(image_path), batch_size=16)
 
    ocr_page = []
    for bbox, word, confidence in ocr_result:
        ocr_page.append({
            "word": word, "bounding_box": create_bounding_box(bbox)
        })
 
    with image_path.with_suffix(".json").open("w") as f:
        json.dump(ocr_page, f)

LayoutLMv3

LayoutLMv34 is a state-of-the-art pre-trained language model developed by Microsoft Research Asia. It is designed to handle document analysis tasks that require understanding of both text and layout information, such as document classification, information extraction, and question answering.

The model is built on top of the transformer architecture and trained on massive amounts of annotated document images and text. LayoutLMv3 is capable of recognizing and encoding both the textual content and the visual layout of a document, allowing it to provide superior performance on document analysis tasks.

You can use LayoutLMv3 for various tasks, such as document classification, named entity recognition, and question answering. To use LayoutLMv3, you can fine-tune the pre-trained model on your specific task with a small amount of task-specific data. Hugging Face Transformers and PyTorch provide easy-to-use APIs that will allow us to fine-tune LayoutLMv3 for document classification.

Preprocessing

LayoutLMv3 uses text, bounding boxes and images as input. To prepare all, we can use the LayoutLMv3Processor. The processor combines OCR, tokenization and image preprocessing:

feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
tokenizer = LayoutLMv3TokenizerFast.from_pretrained(
    "microsoft/layoutlmv3-base"
)
processor = LayoutLMv3Processor(feature_extractor, tokenizer)

The LayoutLMv3FeatureExtractor uses Tesseract OCR as the default option. However, Tesseract OCR was very slow during my experiments. Instead, we'll use a custom OCR engine (EasyOCR). Consider Google Cloud Vision or Amazon Textract, if you require a faster and more accurate OCR solution.

We'll apply the processor to the sample document. LayoutLMv3 requires that each bounding box be normalized to be on a 0-1000 scale. We'll need the image width and height scale for that:

image_path = image_paths[0]
image = Image.open(image_path).convert("RGB")
width, height = image.size
 
width_scale = 1000 / width
height_scale = 1000 / height

Next, we'll take the OCR and extract words and bounding boxes:

def scale_bounding_box(box: List[int], width_scale : float = 1.0, height_scale : float = 1.0) -> List[int]:
    return [
        int(box[0] * width_scale),
        int(box[1] * height_scale),
        int(box[2] * width_scale),
        int(box[3] * height_scale)
    ]
 
json_path = image_path.with_suffix(".json")
with json_path.open("r") as f:
    ocr_result = json.load(f)
 
words = []
boxes = []
for row in ocr_result:
    boxes.append(scale_bounding_box(row["bounding_box"], width_scale, height_scale))
    words.append(row["word"])
 
len(words), len(boxes)
(174, 174)

We define the function scale_bounding_box() to apply the image scale to each bounding box. Next, we iterate over each row of the OCR results stored in ocr_result, extract the bounding box coordinates and word text for each recognized text region, and scale the bounding box coordinates using the scale_bounding_box().

encoding = processor(
    image,
    words,
    boxes=boxes,
    max_length=512,
    padding="max_length",
    truncation=True,
    return_tensors="pt"
)
 
print(f"""
input_ids:  {list(encoding["input_ids"].squeeze().shape)}
word boxes: {list(encoding["bbox"].squeeze().shape)}
image data: {list(encoding["pixel_values"].squeeze().shape)}
image size: {image.size}
""")
input_ids:  [512]
word boxes: [512, 4]
image data: [3, 224, 224]
image size: (819, 1195)

We have three pieces of information: input_ids from the tokenizer, bbox for the bounding boxes, and pixel_values for the image. Let's have a look at the encoded image:

image_data = encoding["pixel_values"][0]
transform = T.ToPILImage()
transform(image_data)

Encoded Document Image

The image encoding is a 3-dimensional array of shape (channels, height, width). Next, we convert the tensor to a PIL image object using a transformation from torchvision.

Model

Let's create an instance of LayoutLMv3:

model = LayoutLMv3ForSequenceClassification.from_pretrained(
    "microsoft/layoutlmv3-base", num_labels=2
)

The sequence classification model is loaded from the microsoft/layoutlmv3-base checkpoint. We set num_labels to 2, which indicates we'll use it for binary classification.

We can run the encoded document through the model and look at the predictions:

outputs = model(**encoding)
outputs.logits
tensor([[0.0644, 0.2629]], grad_fn=<AddmmBackward0>)

Naturally, our model is untrained and lacks the ability to comprehend the documents in our dataset. Let's train it!

Training

To fine-tune LayoutLMv3, we will utilize PyTorch Lightning. This is what we'll do:

  • Split the data into training and testing subsets
  • Create a PyTorch Dataset
  • Generating dataloaders
  • Define a LightningModule
  • Use the Trainer from PyTorch Lightning to train our model

Let's start by preparing the data:

train_images, test_images = train_test_split(image_paths, test_size=.2)
DOCUMENT_CLASSES = sorted(list(map(
    lambda p: p.name,
    Path("images").glob("*")
)))
DOCUMENT_CLASSES
[
    'balance_sheets',
    'cash_flow',
    'income_statement',
    'notes',
    'others'
]

First, we split the document images into train and test subsets. Next, we extract the document classes from the document image directory names. This allows us to create a mapping from document image to its class.

We have everything needed to create a PyTorch Dataset:

class DocumentClassificationDataset(Dataset):
 
    def __init__(self, image_paths, processor):
        self.image_paths = image_paths
        self.processor = processor
 
    def __len__(self):
        return len(self.image_paths)
 
    def __getitem__(self, item):
 
        image_path = self.image_paths[item]
        json_path = image_path.with_suffix(".json")
        with json_path.open("r") as f:
            ocr_result = json.load(f)
 
            with Image.open(image_path).convert("RGB") as image:
 
                width, height = image.size
                width_scale = 1000 / width
                height_scale = 1000 / height
 
                words = []
                boxes = []
                for row in ocr_result:
                    boxes.append(scale_bounding_box(
                        row["bounding_box"],
                        width_scale,
                        height_scale
                    ))
                    words.append(row["word"])
 
                encoding = self.processor(
                    image,
                    words,
                    boxes=boxes,
                    max_length=512,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt"
                )
 
        label = DOCUMENT_CLASSES.index(image_path.parent.name)
 
        return dict(
            input_ids=encoding["input_ids"].flatten(),
            attention_mask=encoding["attention_mask"].flatten(),
            bbox=encoding["bbox"].flatten(end_dim=1),
            pixel_values=encoding["pixel_values"].flatten(end_dim=1),
            labels=torch.tensor(label, dtype=torch.long)
        )

The class takes two arguments:

  • image_paths: a list of paths to document images
  • processor: an instance of the LayoutLMv3Processor class

The __len__ method returns the number of images in the dataset, and the __getitem__ method loads and preprocesses the image and OCR results at a given index.

We can now create datasets and data loaders for the train and test documents:

train_dataset = DocumentClassificationDataset(train_images, processor)
test_dataset = DocumentClassificationDataset(test_images, processor)
 
train_data_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=2
)
 
test_data_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=2
)

Let's implement a LightningModule using PyTorch Lightning. This will wrap all the components and allow us to train our model:

class ModelModule(pl.LightningModule):
    def __init__(self, n_classes:int):
        super().__init__()
        self.model = LayoutLMv3ForSequenceClassification.from_pretrained(
            "microsoft/layoutlmv3-base",
            num_labels=n_classes
        )
        self.model.config.id2label = {k: v for k, v in enumerate(DOCUMENT_CLASSES)}
        self.model.config.label2id = {v: k for k, v in enumerate(DOCUMENT_CLASSES)}
        self.train_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
 
    def forward(self, input_ids, attention_mask, bbox, pixel_values, labels=None):
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            bbox=bbox,
            pixel_values=pixel_values,
            labels=labels
        )
 
    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        bbox = batch["bbox"]
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]
        output = self(input_ids, attention_mask, bbox, pixel_values, labels)
        self.log("train_loss", output.loss)
        self.log(
            "train_acc",
            self.train_accuracy(output.logits, labels),
            on_step=True,
            on_epoch=True
        )
        return output.loss
 
    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        bbox = batch["bbox"]
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]
        output = self(input_ids, attention_mask, bbox, pixel_values, labels)
        self.log("val_loss", output.loss)
        self.log(
            "val_acc",
            self.val_accuracy(output.logits, labels),
            on_step=False,
            on_epoch=True
        )
        return output.loss
 
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.00001) #1e-5
        return optimizer

The __init__ method initializes the LayoutLMv3 model for sequence classification with a specified number of classes, and sets up the accuracy metric for both training and validation.

The forward method takes input tensors (input_ids, attention_mask, bbox, and pixel_values) and an optional labels tensor (only used during training), and returns the model output.

The training_step and validation_step methods define the training and validation steps respectively. In each method, the input tensors are passed through the model, and the loss and accuracy are logged. The configure_optimizers method defines an Adam optimizer used for training.

Let's create an instance of our ModelModule:

model_module = ModelModule(len(DOCUMENT_CLASSES))

We'll use Tensorboard to track the training progress:

%load_ext tensorboard
%tensorboard --logdir lightning_logs

Finally, we need to set up the PyTorch Lightning Trainer:

model_checkpoint = ModelCheckpoint(
    filename="{epoch}-{step}-{val_loss:.4f}", save_last=True, save_top_k=3, monitor="val_loss", mode="min"
)
 
trainer = pl.Trainer(
    accelerator="gpu",
    precision=16,
    devices=1,
    max_epochs=5,
    callbacks=[
        model_checkpoint
    ],
)

The ModelCheckpoint callback is defined to save the model's weights after each epoch, with a specific naming format that includes the epoch number, training step, and validation loss.

The Trainer will use a single GPU, mixed precision (16 bit) training, and train for 5 epochs.

Let's train:

trainer.fit(model_module, train_data_loader, test_data_loader)
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name           | Type                                | Params
-----------------------------------------------------------------------
0 | model          | LayoutLMv3ForSequenceClassification | 125 M
1 | train_accuracy | MulticlassAccuracy                  | 0
2 | val_accuracy   | MulticlassAccuracy                  | 0
-----------------------------------------------------------------------
125 M     Trainable params
0         Non-trainable params
125 M     Total params
251.843   Total estimated model params size (MB)

We can have a look at the training metrics:

Training Metrics Accuracy

Training Metrics Loss

Our model was able to attain a validation accuracy of 94%, and the overall training process appears to be stable. Let's dive a bit deeper into our model performance.

Evaluation

To evaluate the model's performance, we'll begin by loading the best trained model and uploading it to the HuggingFace Hub:

trained_model = ModelModule.load_from_checkpoint(
    model_checkpoint.best_model_path,
    n_classes=len(DOCUMENT_CLASSES),
    local_files_only=True
)
 
notebook_login()
 
trained_model.model.push_to_hub(
    "layoutlmv3-financial-document-classification"
)

Once the model is uploaded, we can easily download it using its name or ID. We will load the model from the Hub and put it on the GPU for inference:

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
 
model = LayoutLMv3ForSequenceClassification.from_pretrained(
    "curiousily/layoutlmv3-financial-document-classification"
)
model = model.eval().to(DEVICE)

We'll write a function to do inference for a single document image:

def predict_document_image(
    image_path: Path,
    model: LayoutLMv3ForSequenceClassification,
    processor: LayoutLMv3Processor):
 
    json_path = image_path.with_suffix(".json")
    with json_path.open("r") as f:
        ocr_result = json.load(f)
 
        with Image.open(image_path).convert("RGB") as image:
 
            width, height = image.size
            width_scale = 1000 / width
            height_scale = 1000 / height
 
            words = []
            boxes = []
            for row in ocr_result:
                boxes.append(
                    scale_bounding_box(
                        row["bounding_box"],
                        width_scale,
                        height_scale
                    )
                )
                words.append(row["word"])
 
            encoding = processor(
                image,
                words,
                boxes=boxes,
                max_length=512,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            )
 
    with torch.inference_mode():
        output = model(
            input_ids=encoding["input_ids"].to(DEVICE),
            attention_mask=encoding["attention_mask"].to(DEVICE),
            bbox=encoding["bbox"].to(DEVICE),
            pixel_values=encoding["pixel_values"].to(DEVICE)
        )
 
    predicted_class = output.logits.argmax()
    return model.config.id2label[predicted_class.item()]

This function takes an image path as input, opens the image, extracts the OCR, scales the bounding boxes based on the image size, and preprocesses the image and text data using the previously defined processor. The preprocessed data is then sent to the model for inference on the GPU. Finally, the function returns the predicted class label for the input image.

We can now execute the function on all test documents:

labels = []
predictions = []
for image_path in tqdm(test_images):
    labels.append(image_path.parent.name)
    predictions.append(
        predict_document_image(image_path, model, processor)
    )

Given that the dataset is imbalanced, relying solely on accuracy as the evaluation metric may not provide a complete picture of the model's performance. Therefore, we will use a confusion matrix to gain deeper insights:

cm = confusion_matrix(labels, predictions, labels=DOCUMENT_CLASSES)
cm_display = ConfusionMatrixDisplay(
    confusion_matrix=cm,
    display_labels=DOCUMENT_CLASSES
)
 
cm_display.plot()
cm_display.ax_.set_xticklabels(DOCUMENT_CLASSES, rotation=45)
cm_display.figure_.set_size_inches(16, 8)
 
plt.show();

There is some confusion between the two most represented classes - others and notes. Could you create an improved model that makes more accurate predictions for those?

Confusion Matrix

Conclusion

You have successfully built a LayoutLMv3-based model using PyTorch Lightning to classify document images into multiple categories! We demonstrated how to handle an imbalanced dataset, train and evaluate the model, and perform inference on new images. PyTorch Lightning provided an efficient way to train the model and evaluate its performance, while LayoutLMv3 allowed us to fine-tune a state-of-the-art pre-trained model on our specific task.

Overall, this tutorial shows how PyTorch Lightning and LayoutLMv3 can be used together to build a powerful document classification model, even in the face of an imbalanced dataset.

3,000+ people already joined

Join the The State of AI Newsletter

Every week, receive a curated collection of cutting-edge AI developments, practical tutorials, and analysis, empowering you to stay ahead in the rapidly evolving field of AI.

I won't send you any spam, ever!

References

Footnotes

  1. wkhtmltopdf - tool to render HTML into PDF/images (opens in a new tab)

  2. Financial Documents Clustering (opens in a new tab)

  3. Hexaware Technologies financial annual reports (opens in a new tab)

  4. LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking (opens in a new tab)