Saturday, August 12, 2017

Visualizing Intermediate Outputs of a Similarity Network

In my last post, I described an experiment where the addition of a self attention layer helped a network do better at the task of document classification. However, attention didn't seem to help for another experiment where I was trying to predict sentence similarity. I figured it might be useful to visualize the outputs of the network at each stage, in order to see where exactly it was failing. This post describes that work. The visualizations did give me pointers to what was happening, and I tried some of these ideas out, but so far I haven't been able to get a network with attention to perform better than a network without it at the similarity task.

The diagram below illustrates the structure of the network whose outputs I was trying to visualize. The network is built to predict the similarity between two sentences on a 6 point scale. The training data comes from the Semantic Similarity Task Dataset for 2012, and consists of sentence pairs and associated similarity score (floating point numbers) between 0 and 5. For this experiment, I quantize the labels into 6 different similarity classes, and attempt to predict that value. Word vectors are looked up from pretrained GloVe embeddings for each word in the two sentence pair, then the sequence of word vectors sent through a Bidirectional LSTM to produce a encoded sentence matrix for each sentence in the pair. The sentence matrices are then sent through an attention layer to create a vector that first creates an alignment matrix between the two sentence matrices, then uses the alignment matrix to determine how much to weight each part of the two sentences when producing the output vector. The output vector is then fed into a Fully Connected network to do the final prediction.

I wanted to visualize the outputs at each stage of the network to see how they differed at each stage. So I first selected three sentence pairs with label similarity values approximately equidistant along the label range. For each sentence, I computed the (a) similarity matrices for the input (one-hot) vector sequence for each sentence, (b) their word vector sequence after embedding, (c) the sentence vector after encoding, (d) the alignment between the two sentence matrices, (e) and the similarity matrix between the aligned sentences. Each of these matrices are represented as a heat map for visualization. In addition, (f) I also used the alignment between the two embeddings to compute the weighted sentence matrix to see if that made any difference.

Each heatmap also has a crude measure of "similarity" that divides the sum of the diagonal elements by the sum of all the elements.

The sequence of heatmaps below show the outputs for a network trained for 10 epochs with a training accuracy of 0.8, validation accuracy of 0.7 and training accuracy of 0.4. The sentence pair that generated these outputs are as follows:

Left: A man is riding a bicycle.
Right: A man is riding a bike.
Score: 5.0

Next, we consider a slightly less similar (according to the score label) sentence pair as follows:

Left: A woman is playing the flute.
Right: A man is playing the flute.
Score: 2.4

Finally, we consider a pair of sentences which are even more dissimilar.

Left: A man is cutting a potato.
Right: A woman is cutting a tomato.
Score: 1.25

In all cases, the heatmap for the input is self-explanatory, since common words are down the diagonal. The output of the embedding step also kind of makes sense, since bicycle and bike in the first case, man and woman in the second and third cases, and potato and tomato in the third case show a non-zero resemblance. In all cases, the resulting sentence matrix (output of the encoding step) results in a blurry blob indicating the similarity between the two sentences in the pair. I did expect the alignments to be more meaningful - in all 3 cases above, there doesn't seem to be a meaningful pattern. Since the attention output is dependent on the alignment, there is no meaningful pattern there either.

Results from computing the alignment against the embedding output and weighting the encoding output to produce the attention output results in slightly more meaningful patterns. For example, in all 3 cases, the terminating period seems to be unimportant. Strangely, common words seem to hold less importance than I would have expected. Sadly, though, my crude measure of similarity does not match up with the labels, regardless of which pair of outputs I use for my alignment.

Here is the notebook that renders these visualizations, and here is the notebook to build the pre-trained model on which the visualization is based. I used a combination of model.predict() to generate outputs of sub-networks, as well as extracting the trained weights from the model, and applying numpy operations to get results.

That's all I have for today, hope you found it interesting.