Skip to main content

machine learning - How to visualize attention?


In the articles about sequence attention we can see images like this:


enter image description here


Here we see that while translating from French to English, the network attends sequentially to each input state, but sometimes it attends to two words at time while producing an output, as in translation “la Syrie” to “Syria” for example.


Here is the Mathematica code:



attend = SequenceAttentionLayer[
"Input" -> {"Varying", 2},
"Query" -> {"Varying", 2}
] // NetInitialize

attend[<|
"Input" -> {{1, 2}, {3, 4}, {5, 6}},
"Query" -> {{1, 0}, {0, 1}, {0, 0}, {1, 1}}
|>]



{{2.24154,3.24154},{1.55551,2.55551},{3.,4.},{1.29312,2.29312}}



My question is: "How to visualize it?"


Yes, I can use ArrayPlot. But I don't understand the output. Why it has such dimension? According to the documentation it's okay.


enter image description here


But I expected the output 3x4 or 4x3. Because dimension 2 in "Input" and "Query" is the size 2 embedding of words.


Can someone explain to me?


Note: SequenceAttentionLayer was introduced in V11.1. https://reference.wolfram.com/language/ref/SequenceAttentionLayer.html



Answer




The output of the SequenceAttentionLayer is a weighted sum of the vectors supplies into the "Input" port. If the input vectors have dimension m X d1, and query vectors have n X d2, then the output vectors will have n X d1.


In your example there are three vectors in the input and four vectors in the query:


input = {{1, 2}, {3, 4}, {5, 6}};
query = {{1, 0}, {0, 1}, {0, 0}, {1, 1}};

attend = SequenceAttentionLayer["Input" -> {"Varying", 2},
"Query" -> {"Varying", 2}] // NetInitialize
(* SequenceAttentionLayer[ <> ] *)

The output is a 4X2 matrix



output = attend[<|"Input" -> input, "Query" -> query|>]
(* {{4.28436, 5.28436}, {1.82552, 2.82552}, {3., 4.}, {3.18698,
4.18698}} *)

for each query vector (for example {1, 0}), three weights will be calculated. And the output for this query vector is


output_{1,0} = w11 * input[[1]] + w12 * input[[2]] + w13 * input[[3]] = {4.28436, 5.28436}

The same is done for all the remaining query vectors, and thus the output has dimension 4X2.


The weights are calculated inside the SequenceAttentionLayer using a "ScoringNet", which can be extracted as


snet = NetExtract[attend, "ScoringNet"]


enter image description here


Again, for the first query vector, the three weights are


{w11, w12, w13} = 
SoftmaxLayer[][
Table[snet[<|"Input" -> input[[n]], "Query" -> query[[1]]|>], {n, 1,
3}]]
(*{0.0685481, 0.220724, 0.710728}*)

We can see that the first output is indeed the weighted sum using these three weights



w11*input[[1]] + w12*input[[2]] + w13*input[[3]]
(* {4.28436, 5.28436} *)

In this example, the weight for the third input vector is high, which means that the third input vector has more influence on the outcome. Or in other words, the network is "paying more attention" to the third vector.


To visualize the attention ( the weights for all query vectors on the input vectors), we can calculate and plot all the weights


w = Table[SoftmaxLayer[][
Table[snet[<|"Input" -> input[[n]], "Query" -> query[[m]]|>], {n, 1, 3}]], {m, 1, 4}];
ArrayPlot[w, FrameTicks -> {Thread[{Range[4], query}], Thread[{Range[3], input}]}]

enter image description here



Here the axes are labeled by the input and query vectors. But with a language translation model, the labels of the input vectors will be replaced by the words correspond to those vectors (embeddings), while the label of the query vectors will be replaced by the translated words.


Comments

Popular posts from this blog

plotting - Plot 4D data with color as 4th dimension

I have a list of 4D data (x position, y position, amplitude, wavelength). I want to plot x, y, and amplitude on a 3D plot and have the color of the points correspond to the wavelength. I have seen many examples using functions to define color but my wavelength cannot be expressed by an analytic function. Is there a simple way to do this? Answer Here a another possible way to visualize 4D data: data = Flatten[Table[{x, y, x^2 + y^2, Sin[x - y]}, {x, -Pi, Pi,Pi/10}, {y,-Pi,Pi, Pi/10}], 1]; You can use the function Point along with VertexColors . Now the points are places using the first three elements and the color is determined by the fourth. In this case I used Hue, but you can use whatever you prefer. Graphics3D[ Point[data[[All, 1 ;; 3]], VertexColors -> Hue /@ data[[All, 4]]], Axes -> True, BoxRatios -> {1, 1, 1/GoldenRatio}]

plotting - Filling between two spheres in SphericalPlot3D

Manipulate[ SphericalPlot3D[{1, 2 - n}, {θ, 0, Pi}, {ϕ, 0, 1.5 Pi}, Mesh -> None, PlotPoints -> 15, PlotRange -> {-2.2, 2.2}], {n, 0, 1}] I cant' seem to be able to make a filling between two spheres. I've already tried the obvious Filling -> {1 -> {2}} but Mathematica doesn't seem to like that option. Is there any easy way around this or ... Answer There is no built-in filling in SphericalPlot3D . One option is to use ParametricPlot3D to draw the surfaces between the two shells: Manipulate[ Show[SphericalPlot3D[{1, 2 - n}, {θ, 0, Pi}, {ϕ, 0, 1.5 Pi}, PlotPoints -> 15, PlotRange -> {-2.2, 2.2}], ParametricPlot3D[{ r {Sin[t] Cos[1.5 Pi], Sin[t] Sin[1.5 Pi], Cos[t]}, r {Sin[t] Cos[0 Pi], Sin[t] Sin[0 Pi], Cos[t]}}, {r, 1, 2 - n}, {t, 0, Pi}, PlotStyle -> Yellow, Mesh -> {2, 15}]], {n, 0, 1}]

plotting - Mathematica: 3D plot based on combined 2D graphs

I have several sigmoidal fits to 3 different datasets, with mean fit predictions plus the 95% confidence limits (not symmetrical around the mean) and the actual data. I would now like to show these different 2D plots projected in 3D as in but then using proper perspective. In the link here they give some solutions to combine the plots using isometric perspective, but I would like to use proper 3 point perspective. Any thoughts? Also any way to show the mean points per time point for each series plus or minus the standard error on the mean would be cool too, either using points+vertical bars, or using spheres plus tubes. Below are some test data and the fit function I am using. Note that I am working on a logit(proportion) scale and that the final vertical scale is Log10(percentage). (* some test data *) data = Table[Null, {i, 4}]; data[[1]] = {{1, -5.8}, {2, -5.4}, {3, -0.8}, {4, -0.2}, {5, 4.6}, {1, -6.4}, {2, -5.6}, {3, -0.7}, {4, 0.04}, {5, 1.0}, {1, -6.8}, {2, -4.7}, {3, -1....