/
...
/
/
👁️
Chapter 6: Practical aspects of Image classification
Search
Duplicate
Try Notion
👁️👁️

Chapter 6: Practical aspects of Image classification

Generating CAMs

Class activation maps (CAMs) allow us to understand model predictions
The strategy is to compute gradients at a given layer for a chosen class (like cat)
Implementation
Choose class and layer to investigate
Generate activations for the layer, with dimensions (C, H, W), let's say (64, 4, 4)
Back-propagate on this layer and keep the gradient, dimensions are (512, 64, 4, 4)
Average this gradient for each feature map, output dimension is (64)
Multiply each activations image by this average, dimensions are (64, 4, 4)
Create a heatmap by averaging each weighted activation, output dimensions are (4, 4)
Upscale this heatmap to the image size
Transform
Python
Copy
from torchvision import transform as T compose_train = T.Compose([ T.ToPILImage(), T.Resize(128), T.CenterCrop(128), T.ColorJitter( brightness=(.95, 1.05), contrast=(.95, 1.05), saturation=(.95, 1.05), hue=0.05, ), T.RandomAffine(5, translate=(.01, .1)), T.ToTensor(), T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]), ]) compose_val = T.Compose([ T.ToPILImage(), T.Resize(128), T.CenterCrop(128), T.ToTensor(), T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]), ])
Data augmentation for training images, allowing us to diversifie our training set (more on that next section)
Both training and validation transform apply image resizing, cropping and normalizing
Need to convert tensor to PIL first, before converting it back to tensor at the end
Normalize need to be applied on tensor only
Dataset
Python
Copy
str2int = {"Uninfected": 1, "Parasitized": 0} int2str = {1: "Uninfected", 0: "Parasitized"} class MalariaDataset(Dataset): def __init__(self, files, transform): super().__init__() self.files = files self.transform = transform def __len__(self): return len(self.files) def __getitem__(self, idx): file_path = self.files[idx] img = cv2.imread(file_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img.astype(np.uint8) label = file_path.split("/")[-2] return img, label def collate_fn(self, buffer): _imgs, _classes = list(zip(*buffer)) imgs = [self.transform(img)[None] for img in _imgs] classes = [torch.tensor([str2int[_class]]) for _class in _classes] imgs = torch.cat(imgs).to(device) classes = torch.cat(classes).to(device) return imgs, classes, _imgs def load_img(self, idx): file_path = self.files[idx] img = cv2.imread(file_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) label = str2int[file_path.split("/")[-2]] plt.imshow(img) plt.title(label)
__init__ needs transformation and a list of file as inputs
__getitem__ apply no transformation, just load the image and get the label. Image need to be uint8 type for later transform
collate_fn calls transform in a loop, before concatenating the image and label tensors
Dataloader
Python
Copy
from glob import glob from sklearn.model_selection import train_test_split files = glob("cell_images/Parasitized/*.png") + glob("cell_images/Uninfected/*.png") files_train, files_test = train_test_split(files) ds_train = MalariaDataset(files_train, compose_train) ds_test = MalariaDataset(files_test, compose_val) dl_train = DataLoader(ds_train, batch_size=32, shuffle=True, collate_fn=ds_train.collate_fn) dl_test = DataLoader(ds_test, batch_size=32, shuffle=False, collate_fn=ds_test.collate_fn)
Load image for visualisation
Python
Copy
ds_test.load_img(3)
here the cell is sane (1 is Uninfected)
Model builder
Python
Copy
def conv_layer(ni, no): return nn.Sequential( nn.Dropout(0.2), nn.Conv2d(ni, no, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(no), nn.MaxPool2d(2), ) class MalariaClassifier(nn.Module): def __init__(self): super().__init__() self.model = nn.Sequential( conv_layer(3, 64), # output dim: (N, 64, 64, 64) conv_layer(64, 64), # (N, 64, 32, 32) conv_layer(64, 128), # (N, 128, 16, 16) conv_layer(128, 256), # (N, 256, 8, 8) conv_layer(256, 512), # (N, 512, 4, 4) conv_layer(512, 64), # (N, 64, 2, 2) nn.Flatten(), nn.Linear(256, 256), nn.Dropout(.2), nn.ReLU(), nn.Linear(256, len(str2int)) ) self.loss_fn = nn.CrossEntropyLoss() def forward(self, x): return self.model(x) def compute_metrics(self, y_hat, y): loss = self.loss_fn(y_hat, y) acc = (torch.max(y_hat, 1)[1] == y).float().mean() return loss, acc
We define our CNN block as a function returning a nn.Sequential instance, so that we nest nn.Sequential objects. The arguments of this function are the number of input and output channels
We add a compute_metrics method to return loss and accuracy in a handy way
The rest of the training process is similar to before
Let’s now focus on the heatmap
Python
Copy
def pred2fmap(x, model): layers = nn.Sequential(*( list(model.model[:5].children()) + list(model.model[5][:2].children()) )) model.eval() logits = model(x) idx_pred = logits.max(-1)[-1] model.zero_grad() logits[0, idx_pred].backward(retain_graph=True) activations = layers(x) # shape: (1, 64, 4, 4) pooled_grads = model.model[-6][1].weight.grad.data.mean((1, 2, 3)) # shape: (64) for idx in range(activations.shape[1]): activations[:, idx, :, :] *= pooled_grads[idx] heatmap = torch.mean(activations, dim=1)[0].cpu().detach() # shape: (4, 4) value = int2str[idx_pred.item()] return heatmap, value
We choose to observe the last CNN layer, so we create another model up to this last Conv2d block
We get the output of the full model logits, and select index the class with the maximum probability idx_pred
We set gradient to zero so that we can then back-propagate on this chosen class. Settingretain_graph to True allows us to reuse the gradient weights
We now feed forward the truncated model with our input to get our activations, and fetch the average of our last layer gradient on our original model
We then iterate through the 64 output channels of our activation to multiply it element-wise with the average gradient of the channel (or feature map)
Finally, our heatmap is the average of the 64 feature maps, resulting in a (4, 4) image
We need to upsample this heatmap to the original image size
Python
Copy
IM_SIZE = 128 def upsample_img(map, img): m, M = map.min(), map.max() map = 255 * ((map - m) / (M - m)) map = np.uint8(map) map = cv2.resize(map, (IM_SIZE, IM_SIZE)) map = cv2.applyColorMap(255-map, cv2.COLORMAP_JET) map = np.uint8(map) map = np.uint8(0.7 * map + 0.3 * img) return map
Apply a min-max normalization, so that image values are all between 0 and 1, before multiplying by 255, bringing us back to the regular image range (0, 255)
Color map allows us to better visualise heatmap
uint8 is the most standard format to OpenCV so we need to convert the float image before using cv2
Run the heatmap
Python
Copy
N = 20 _dl_val = DataLoader(ds_test, batch_size=N, shuffle=True, collate_fn=ds_test.collate_fn) x, y, z = next(iter(_dl_val)) imgs, maps = [], [] for idx in range(N): img = cv2.resize(z[idx], (IM_SIZE, IM_SIZE)) map, y = pred2fmap(x[idx:idx+1], model) if y == "Uninfected": continue map = upsample_img(map, img) maps.append(map) imgs.append(img) for jdx, ax in enumerate(maps): fig, axes = plt.subplots(1, 2, figsize=(7, 7)) axes[0].imshow(maps[jdx]) axes[1].imshow(imgs[jdx]) plt.plot()
Create a new dataloader of N elements to be visualised
Fetch the first batch of the dataloader and compute the heatmap for each of the N elements, we only display Parasitized cells
We use x[idx:idx+1] to maintains the input shape, for a single element

Regularization: batch norm and data augmentation

The combination of batch norm and data augmentation acts as a regularizer, limiting overfitting
Batch norm also allow the model to converge faster during training, by removing the internal covariate shift (ICS) between network layer. Actually this initial claim has proven to be wrong, and batch norm role is to make the loss surface smoother (see additional resources below)
Batch norm empirically works better when placed after activation

Practical aspects of implementation

Class imbalance

When a binary class if represented only 1% of the time, a model returning always the other class will be accurate 99% of the time
Confusion Matrix helps to visualize the imbalance
When it comes to training, we can add a weight to the rare class during loss back-propagation, so that the model will have more incentive to classify it more accurately
Leveraging pre-trained networks is of great help, and also to use data augmentation on the rare classes to boost their representation

Size of an objects within an image

When it comes to classify small objects within an image, it is often best to consider detection instead
If classify is still needed, we can split the image in equal sized tiles and use them as inputs

Difference between training and validation data

When model performs poorly in testing environment with new images, possible reasons are:
inference images are not curated like the training data and maybe have not been preprocessed
difference of camera resolution
difference of light condition
Are the subjects of inference representative of the training dataset?
class imbalance or even whole new class that the model wasn’t trained on
Is the training and validation dataset split correctly?
always add shuffle to avoid the case when classes are sorted and your splits don’t have the same distribution
Always make sure that training and validation/inference datasets have the same distribution

Number of nodes in flatten layer

It is good practice to have a maximum of 500 to 5000 nodes in the flatten layer, so that the number of parameters in the following dense layer is limited
This is the total number of parameters of the preceding layer
e.g. image dimension of (512, 7, 7) result in 512 * 7 * 7 = 25,088 parameters to flatten, which result in a (25,088, 4,096) matrix in the first dense layer in the case VGG16

Image size

There are few options when working with large images of dimension say 1000x2000
Can we downscale it? While it will probably work for image objects it may not be appropriate for text information
Does all the image convey information and can we crop it?
Reduce batch size so that data will fit in GPU memory

OpenCV

When it comes to releasing models in production, a simple model with a slightly worse accuracy is often preferable to a more complex one
Consider checking whether OpenCV can be a good baseline for your use-case
Rule-based and pattern recognitions may be good approximations
Some basic deep learning models are also implemented in OpenCV, like Haar face detector

Additional Resources

Batch norm and Data augmentation
Batch norm paper
Batch norm doesn’t solve ICS