In this tutorial, we’ll deal with a new type of task that uses a middle layer in the FastAI library. We’ll be using a Siamese network, which takes two images and tries to determine whether they are the same class or not, particularly we’ll see:

  • how to quickly get DataLoaders from standard PyTorch Datasets
  • how to adapt this in a Transform to get some of the show features of fastai
  • how to add some new behavior to show_batch/show_results for a custom task
  • how to write a custom DataBlock
  • how to create your own model from a pretrained model
  • how to pass along a custom splitter to Learner to take advantage of transfer learning

Preparing the data

To make our data ready for training a model, we need to create a DataLoaders object (note the plural “loaders”).

This is just a wrapper around a training DataLoader and a validation DataLoader.

Purely in Pytorch

To start, we’ll use only PyTorch and PIL to create a Dataset and see how to get it inside of FastAI. The only helper functions from FastAI that we’ll use are untar_data and get_image_files. We’ll use the Oxford-IIIT Pet Dataset.

from fastai.data.external import untar_data,URLs
from fastai.data.transforms import get_image_files

untar_data returns a pathlib.Path object with the location of the decompressed dataset, and in this case, all the images are in a subfolder:

path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
files[0]

Let’s import PIL and open the first image:

import PIL

img = PIL.Image.open(files[0])
img

Now, let’s wrap all of our standard preprocessing into a single helper function:

import torch
import numpy as np

def open_image(fname, size=224):
    img = PIL.Image.open(fname).convert('RGB')
    img = img.resize((size, size))
    t = torch.Tensor(np.array(img))
    return t.permute(2,0,1).float()/255.0

open_image(files[0]).shape

We can see the label of our image is in the filename, before the last ”_” and some number. We can use Regex to create a label function

import re

def label_func(fname):
    return re.match(r'^(.*)_\d+.jpg$', fname.name).groups()[0]

label_func(files[0])

Now, gathering all unique labels:

labels = list(set(files.map(label_func)))
len(labels)

This means that we have 37 different pet breeds. To create our Siamese dataset, we’ll create tuples of images for inputs and the target will be True if the images are the same class, and False if they are not.

lbl2files = {l: [f for f in files if label_func(f) == l] for l in labels}