I heard that RNN was implemented in Mathematica as of 11.1. Trying to search online, I find some general information about neural networks in Mathematica, or a list of related functions. My trouble is that this list of functions lumps purely statistical machine learning functions like Classify and Predict together with neural network functions, as well as (I presume) recurrent neural network functions, so that it is really hard to tell what it is I actually need to do RNN.
Perhaps there is a resource I missed that shows an explicit example of how to tackle a time series forecasting problem with several inputs, making use of Mathematicas RNN functions?
If none are known, perhaps someone knowledgeable could write a short example, e.g. using example data from here?
Thanks for any suggestion!
Answer
Here is a simple example that may help you get started. In this example, we are going to a predict a simple time series of a sinusoid wave.
data = Table[Sin[x], {x, 0, 100, 0.04}];
We will cut the data into windows of 51 data points. The first 50 points as a whole is our X, and the last data point is our Y.
training =
RandomSample[
List /@ Most[#] -> List@Last[#] & /@ (Partition[data, 51, 1])];
We use a single gated recurrent layer in our neural network
net = NetChain[{
GatedRecurrentLayer[10],
LinearLayer[1]}, "Input" -> {50, 1}, "Output" -> 1
]
and train with the training data
trained = NetTrain[net, training]
After training, we can use it to predict the time series. We first feed the neural network with 50 data points and then repeatedly use the data it generates to feed back into the neural network to generate the next data point. Here is a comparison between the ground truth and our predictions, which shows very good agreements.
ListPlot[{Flatten@
NestList[Append[Rest[#], trained[#]] &,
List /@ Sin[Range[-49*0.04, 0, 0.04]], 500][[All, -1]],
Table[Sin[x], {x, 0, 500*0.04, 0.04}]}, Joined -> True,
PlotLegends -> {"predicted", "ground truth"}]


Comments
Post a Comment