Main Content

Visualize Activations of LSTM Network

This example shows how to investigate and visualize the features learned by LSTM networks by extracting the activations.

Load pretrained network.JapaneseVowelsNetis a pretrained LSTM network trained on the Japanese Vowels dataset as described in [1] and [2]. It was trained on the sequences sorted by sequence length with a mini-batch size of 27.

loadJapaneseVowelsNet

View the network architecture.

net.Layers
ans = 5x1 Layer array with layers: 1 'sequenceinput' Sequence Input Sequence input with 12 dimensions 2 'lstm' LSTM LSTM with 100 hidden units 3 'fc' Fully Connected 9 fully connected layer 4 'softmax' Softmax softmax 5 'classoutput' Classification Output crossentropyex with '1' and 8 other classes

Load the test data.

[XTest,YTest] = japaneseVowelsTestData;

Visualize the first time series in a plot. Each line corresponds to a feature.

X = XTest{1}; figure plot(XTest{1}') xlabel("Time Step") title("Test Observation 1") numFeatures = size(XTest{1},1); legend("Feature "+ string(1:numFeatures),'Location','northeastoutside')

Figure contains an axes object. The axes object with title Test Observation 1 contains 12 objects of type line. These objects represent Feature 1, Feature 2, Feature 3, Feature 4, Feature 5, Feature 6, Feature 7, Feature 8, Feature 9, Feature 10, Feature 11, Feature 12.

For each time step of the sequences, get the activations output by the LSTM layer (layer 2) for that time step and update the network state.

sequenceLength = size(X,2); idxLayer = 2; outputSize = net.Layers(idxLayer).NumHiddenUnits;fori = 1:sequenceLength features(:,i) = activations(net,X(:,i),idxLayer); [net, YPred(i)] = classifyAndUpdateState(net,X(:,i));end

Visualize the first 10 hidden units using a heatmap.

figure heatmap(features(1:10,:)); xlabel("Time Step") ylabel("Hidden Unit") title("LSTM Activations")

Figure contains an object of type heatmap. The chart of type heatmap has title LSTM Activations.

The heatmap shows how strongly each hidden unit activates and highlights how the activations change over time.

References

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions."Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2]UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

See Also

|||||

Related Topics