In this tutorial, we’ll deal with a new type of task that uses a middle layer in the FastAI library. We’ll be using a Siamese network, which takes two images and tries to determine whether they are the same class or not, particularly we’ll see:
- how to quickly get
DataLoaders
from standard PyTorchDatasets
- how to adapt this in a
Transform
to get some of the show features of fastai - how to add some new behavior to
show_batch/show_results
for a custom task - how to write a custom
DataBlock
- how to create your own model from a pretrained model
- how to pass along a custom
splitter
toLearner
to take advantage of transfer learning
Preparing the data
To make our data ready for training a model, we need to create a DataLoaders
object (note the plural “loaders”).
This is just a wrapper around a training DataLoader
and a validation DataLoader
.
Purely in Pytorch
To start, we’ll use only PyTorch and PIL to create a Dataset
and see how to get it inside of FastAI. The only helper functions from FastAI that we’ll use are untar_data
and get_image_files
. We’ll use the Oxford-IIIT Pet Dataset
.
from fastai.data.external import untar_data,URLs
from fastai.data.transforms import get_image_files
untar_data
returns a pathlib.Path
object with the location of the decompressed dataset, and in this case, all the images are in a subfolder:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
files[0]
Let’s import PIL and open the first image:
import PIL
img = PIL.Image.open(files[0])
img
Now, let’s wrap all of our standard preprocessing into a single helper function:
import torch
import numpy as np
def open_image(fname, size=224):
img = PIL.Image.open(fname).convert('RGB')
img = img.resize((size, size))
t = torch.Tensor(np.array(img))
return t.permute(2,0,1).float()/255.0
open_image(files[0]).shape
We can see the label of our image is in the filename, before the last ”_” and some number. We can use Regex to create a label function
import re
def label_func(fname):
return re.match(r'^(.*)_\d+.jpg$', fname.name).groups()[0]
label_func(files[0])
Now, gathering all unique labels:
labels = list(set(files.map(label_func)))
len(labels)
This means that we have 37 different pet breeds. To create our Siamese dataset, we’ll create tuples of images for inputs and the target will be True if the images are the same class, and False if they are not.
lbl2files = {l: [f for f in files if label_func(f) == l] for l in labels}