In this paper, the authors present a deep learning model to detect disease from chest x-ray images. A convolutional neural network (CNN) is trained to detect the disease names. Recurrent neural networks (RNNs) are then trained to describe the contexts of a detected disease, based on the deep CNN features.
CNN Models used and Dataset
CNNs encode input images effectively. In this paper, the authors experiment with a Network in Network (NIN) model and GoogLeNet model. The dataset contains 3,955 radiology reports and 7,470 associated chest x-rays. 71% of the dataset accounts for normal cases (no disease). The data set was balanced by augmenting training images by randomly cropping 224x224 images from the original 256x256 size image.
Adaptability of Transfer learning
Since this boils down to a classification problem on a small dataset, transfer learning is a technique that comes to our mind. The authors experimented this with ImageNet trained models. ImageNet trained CNN weights were re-used except for the last layer for classification, and were trained on the chest x-ray images with a learning rate 1/10 times lower than the default learning rate.
Experiment showed that, the fine-tuned model could not exceed 15% accuracy while the same model with random initialization achieved close to 60% validation accuracy. Hence it was concluded that features trained specifically for chest x-rays are better suited than re-using features learned from ImageNet.
Regularization
As with any deep learning models, regularization should be properly configured to avoid overfitting and force the model to learn useful features. Batch normalization and dropout regularization techniques were employed to increase training and validation accuracy.GoogLeNet model showed better results than the NIN model when used with batch normalization and dropout layers. Also, a further increase in accuracy was observed when images were duplicated instead of randomly cropping to augment training data.
Annotation generation with RNN
Till this point, it is a normal application of a deep CNN model on a classification problem. However, further in this paper, the authors introduce the usage of RNNs to annotate images with human like diagnosis. We know that RNNs are heavily used in NLP systems and machine translations. In this paper RNNs are used to learn the annotations of the chest X-ray images.From the dataset, majority of annotations contain up to five words. Longer ones are ignored by constraining RNNs to roll up to 5 time steps and shorted ones are zero padded. The authors experiment with Long Short-Term Memory (LSTM) and Gated Recurrent Units (GRU).
Initial state of the RNN is set as the CNN image embedding and the first annotation word as the initial input. NIN and GoogLeNet use average-pooling layers after the convolutional layers unlike other models which use fully-connected layers. Hence the input to the RNN, is the output of the last average pooling layer of the model. Output of the RNNs are the following annotation sequences. Trained by minimizing the negative log likelihood of output sequences and true sequences.
where yt is the output of the RNN at step t and st is the correct word. CNN(I) is the embedding of input I, and N is the number of words in the annotation.
Recurrent Feedback Model with Image Labeling with Joint Image/Text Context
Using this RNN, we get annotations for the image. This is then used to create a more diverse set of image labels from which we can infer information beyond just the name of the disease.
For this, a joint image/text context vector is generated by applying mean-pooling on the state vectors of the RNN at each step. Note that this state vector is initialized with the CNN embedding and then fine tuned by unrolling over the annotation sequence. Below is an illustration of how the joint image/text context vector is calculated.
The obtained vector encodes both image context and text context describing the image. From this, we obtain new image labels taking disease context into account.
The CNN is trained once more with the new labels (fine tune previous CNN by replacing the last classification layer). Then train the RNN again with the new image embedding, and finally generate image annotations.
Results borrowed from the paper are shown below
Additional Reading
- Network in Network model : https://arxiv.org/abs/1312.4400
GoogLeNet: https://ai.google/research/pubs/pub43022- Course on Sequence models: https://www.coursera.org/learn/nlp-sequence-models
Comments
Post a Comment