I am working on an image classification model in Pytorch. The setup is the following:
My training instances are bags of images (i.e. one training instance = one bag), where each bag contains a varying amount of images. Each bag has one label associated with it (0 or 1), that indicates whether at least one of the images contains a certain property (e.g. tumor as in my case). The objective is to learn a classifier that classifies new bags as accurately as possible. However the following problem occurs when I try to train my model: when I feed the bags one by one into the CNN architecture, the prediction (the probability that bag has label 1) is instantly either 0 or 1.
Now before I delve into the code itself, which is quite long, I have a similar setup where instead of tissue images I use numbers. So each bag contains a number MNIST like images (just pictures with a number), and the bag gets a positive label (i.e. 1) if one of the bag contains the number 9. Strangely enough, this task with the digits works very well (you can see that the learning works and in the end good classification performance is obtained), even though the setup compared to the tissue images is near identical.
Below I post code sections that differ between these 2 tasks, but they are essentially the following: the input shape of the bags differ, the digit images are much smaller (28x28) and have only one channel while the tissue images are 224x224 and have three channels. Therefor the convolutional layers also vary a little bit in specification.
First code section is of the tissue images, which has the problem that somehow this model won't learn
class Attention(nn.Module):
def __init__(self):
super(Attention, self).__init__()
self.L = 500
self.D = 128
self.K = 1
self.feature_extractor_part1 = nn.Sequential(
nn.Conv2d(3, 4, kernel_size=4),
nn.ReLU(),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(4, 8, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(2, stride=2)
)
self.feature_extractor_part2 = nn.Sequential(
nn.Linear(8 * 54 * 54, self.L),
nn.ReLU(),
)
self.attention = nn.Sequential(
nn.Linear(self.L, self.D),
nn.Tanh(),
nn.Linear(self.D, self.K)
)
self.classifier = nn.Sequential(
nn.Linear(self.L*self.K, 1),
nn.Sigmoid()
)
def forward(self, x):
x = x.squeeze(0)
H = self.feature_extractor_part1(x)
H = H.view(-1, 8 * 54 * 54)
H = self.feature_extractor_part2(H)
A = self.attention(H)
A = torch.transpose(A, 1, 0)
A = F.softmax(A, dim=1)
M = torch.mm(A, H)
print(M.shape)
Y_prob = self.classifier(M)
Y_hat = torch.ge(Y_prob, 0.5).float()
return Y_prob, Y_hat, A
Second code section is of the digit images, which works perfectly
class Attention(nn.Module):
def __init__(self):
super(Attention, self).__init__()
self.L = 500
self.D = 128
self.K = 1
self.feature_extractor_part1 = nn.Sequential(
nn.Conv2d(1, 10, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(10, 20, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(2, stride=2)
)
self.feature_extractor_part2 = nn.Sequential(
nn.Linear(20 * 4 * 4, self.L),
nn.ReLU(),
)
self.attention = nn.Sequential(
nn.Linear(self.L, self.D),
nn.Tanh(),
nn.Linear(self.D, self.K)
)
self.classifier = nn.Sequential(
nn.Linear(self.L*self.K, 1),
nn.Sigmoid()
)
def forward(self, x):
x = x.squeeze(0)
H = self.feature_extractor_part1(x)
H = H.view(-1, 20 * 4 * 4)
H = self.feature_extractor_part2(H)
A = self.attention(H)
A = torch.transpose(A, 1, 0)
A = F.softmax(A, dim=1)
M = torch.mm(A, H)
Y_prob = self.classifier(M)
Y_hat = torch.ge(Y_prob, 0.5).float()
return Y_prob, Y_hat, A
What I have tried:
I have tried changing the convolutional layers (dimension) and the learning rate.