In this example,generator[<|"BatchSize" -> 4|>] get an association of lists, but NetTrain[net, generator] can't be used.
resource = ResourceObject["MNIST"];
trainingData = ResourceData[resource, "TestData"];
encoder1 = NetChain[{FlattenLayer[], 128, 8},
"Input" -> NetEncoder[{"Image", {28, 28}, ColorSpace -> "Grayscale"}]];
encoder2 = NetChain[{8}, "Input" -> NetEncoder[{"Class", Range[0, 9], "UnitVector"}]];
decoder = NetChain[{128, 28*28, ReshapeLayer[{1, 28, 28}]}, "Output" -> NetDecoder[{"Image", ColorSpace -> "Grayscale"}]];
net = NetGraph[{encoder1, encoder2, ThreadingLayer[#1*(1 - #2) &],
ThreadingLayer[#1*#2 &], ThreadingLayer[Plus],
decoder, ReplicateLayer[8], FlattenLayer[]},
{NetPort["image"] -> 1 -> 3, NetPort["digit"] -> 2 -> 4,
NetPort["switch"] -> 7 -> 8 -> {3, 4} -> 5 -> 6 -> NetPort["Output"]}, "switch" -> 1]
This net can predict the image although it haven't be trained
NetInitialize[net][<|"image" -> trainingData[[1, 1]],
"digit" -> trainingData[[1, 2]],
"switch" -> 0|>]
generator also works well
generator =
Function[Block[{data = RandomSample[trainingData, #BatchSize]}, <|
"image" -> data[[All, 1]], "digit" -> data[[All, 2]],
"switch" -> RandomInteger[1, #BatchSize],
"Output" -> data[[All, 1]]|>]];
generator[<|"BatchSize" -> 4|>]
But when type NetTrain[net, generator],it throws error.
weird!What's wrong?
Answer
You get much better error messages if you pass a list of samples instead of a generator to NetTrain:
NetTrain[net, generator[<|"BatchSize" -> 10|>]]
NetTrain::invindim: Data provided to port "switch" should be a list of length-1 vectors.
This is much more helpful, and if we change the generator so it creates a list of length-1 vectors for switch:
generator =
Function[Block[{data = RandomSample[trainingData, #BatchSize]}, <|
"image" -> data[[All, 1]], "digit" -> data[[All, 2]],
"switch" -> ({#} & /@ RandomInteger[1, #BatchSize]),
"Output" -> data[[All, 1]]|>]];
Training with the generator works, too:
NetTrain[net, generator]




Comments
Post a Comment